Skip to content

Commit

Permalink
Enable merge_pooled_embeddings op. in ROCm (#15)
Browse files Browse the repository at this point in the history
* Enable merge_pooled_embeddings op. in ROCm

* Enabling the merge pool ops.

Co-authored-by: liligwu <[email protected]>
  • Loading branch information
reza-amd and liligwu authored Apr 1, 2022
1 parent 40928ba commit 69abf78
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 11 deletions.
5 changes: 3 additions & 2 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ if(USE_ROCM)
find_package(rocRAND REQUIRED)
find_package(hipSPARSE REQUIRED)
find_package(OpenMP REQUIRED)
find_package(rocPRIM REQUIRED)

message("${message_line}")
message(STATUS "hip found ${ROCM_FOUND}")
Expand Down Expand Up @@ -281,7 +282,7 @@ set(fbgemm_gpu_sources_cpu
src/input_combine_cpu.cpp
src/layout_transform_ops_cpu.cpp
src/layout_transform_ops_gpu.cpp
# src/merge_pooled_embeddings_cpu.cpp src/merge_pooled_embeddings_gpu.cpp
src/merge_pooled_embeddings_cpu.cpp src/merge_pooled_embeddings_gpu.cpp
src/permute_pooled_embedding_ops_gpu.cpp
src/quantize_ops_cpu.cpp
src/quantize_ops_gpu.cpp
Expand Down Expand Up @@ -341,7 +342,7 @@ elseif(USE_ROCM)

hip_add_library(fbgemm_gpu_py SHARED ${cpp_asmjit_files} ${FBGEMM_ALL_HIP_FILES} ${FBGEMM_HIP_HCC_LIBRARIES}
HIPCC_OPTIONS ${HIP_CXX_FLAGS})
target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE})
target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE})
set_property(TARGET fbgemm_gpu_py PROPERTY HIP_ARCHITECTURES ${FBGEMM_ROCM_ARCH})

# For ROCm5.1
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/cmake/Hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ IF(HIP_FOUND)
find_package(hip REQUIRED)

set(ROCRAND_INCLUDE ${ROCRAND_PATH}/include)
set(ROCM_SMI_INCLUDE ${ROCM_PATH}/rocm_smi/include)

set(FBGEMM_HIP_INCLUDE ${ROCM_PATH}/include ${FBGEMM_HIP_INCLUDE})
set(FBGEMM_HIP_INCLUDE ${hip_INCLUDE_DIRS} $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}> $<INSTALL_INTERFACE:include> ${FBGEMM_HIP_INCLUDE})

hip_include_directories(${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE})
hip_include_directories(${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE})

ENDIF()
96 changes: 91 additions & 5 deletions fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,93 @@
#include <c10/util/irange.h>
#include <torch/library.h>

// FIXME: Enable merge_pooled_embeddings for HIP.
// AMD GPUs don't seem to have nvml equivalent library support.
#ifndef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
#include "rocm_smi/rocm_smi.h"
#include "hip/hip_runtime.h"

#include <algorithm>

#include "fbgemm_gpu/merge_pooled_embeddings.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

using Tensor = at::Tensor;

#define RSMI_CHECK(fn) \
do { \
rsmi_status_t ret = (fn); \
TORCH_CHECK((ret) == RSMI_STATUS_SUCCESS); \
} while (0)

#define RSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE 16

using Node = int64_t;
using Links = int64_t;
template <typename T>
using AdjacencyMatrix = std::function<T(Node, Node)>;

namespace {

AdjacencyMatrix<Links> get_nvlink_matrix() {
auto world_size = at::cuda::getNumGPUs();
RSMI_CHECK(rsmi_init(0));

// Note that ROCm_SMI uses a different numbering method to ROCm runtime,
// so we need to learn the mapping by using the bus ID.
uint32_t device_count;
RSMI_CHECK(rsmi_num_monitor_devices(&device_count));

std::unordered_map<Node, uint32_t> rocm_device_to_rsmi_device;

for (const auto i : c10::irange(device_count)) {
uint64_t pci_info;
RSMI_CHECK(rsmi_dev_pci_id_get(i, &pci_info));
uint64_t domain, bus, device, function;
domain = (pci_info >> 32) & 0xffffffff;
bus = (pci_info >> 8) & 0xff;
device = (pci_info >> 3) & 0x1f;
function = pci_info & 0x7;
// Different form CUDA, we do not get the PCI BUS ID as a char* and we need to reconstruct it.
char pci_bus_id_str[RSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE];
sprintf(pci_bus_id_str, "%04X:%02X:%02X.%0X", domain, bus, device, function);

std::array<char, RSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE> pci_bus_id;
std::copy(
&pci_bus_id_str[0],
&pci_bus_id_str[RSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE],
pci_bus_id.data());
int32_t node = 0;
auto err = hipDeviceGetByPCIBusId(&node, pci_bus_id.data());
if (err == hipSuccess) {
rocm_device_to_rsmi_device.insert({node, i});
} else {
// flush the last error - this can occur when e.g. we set
// HIP_VISIBLE_DEVICES to a subset of the available GPUs in the system.
hipGetLastError();
}
}

std::vector<Links> links(world_size * world_size);
for (const auto i : c10::irange(world_size)) {
auto src_rsmi_device = rocm_device_to_rsmi_device.find(i);
if (src_rsmi_device != rocm_device_to_rsmi_device.end()){
for (const auto j : c10::irange(world_size)) {
auto dst_rsmi_device = rocm_device_to_rsmi_device.find(j);
if (dst_rsmi_device != rocm_device_to_rsmi_device.end()){
bool is_active;
RSMI_CHECK(rsmi_is_P2P_accessible(src_rsmi_device->second, dst_rsmi_device->second, &is_active));
if (is_active) {
links[i * world_size + j] += 1;
}
}
}
}
}
RSMI_CHECK(rsmi_shut_down());
return [=](Node i, Node j) { return links[i * world_size + j]; };
}
} // namespace

#else // CUDA
#include <nvml.h>

#include <algorithm>
Expand Down Expand Up @@ -106,7 +190,9 @@ AdjacencyMatrix<Links> get_nvlink_matrix() {

return [=](Node i, Node j) { return links[i * world_size + j]; };
}

} // namespace
#endif
namespace {
// Hilariously unoptimized, but algorithmic correctness matters more here, and
// we only do it once.
AdjacencyMatrix<Node> get_intermediate_node(AdjacencyMatrix<Links> links) {
Expand Down Expand Up @@ -409,4 +495,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"all_to_one_device(Tensor[] input_tensors, Device target_device) -> Tensor[]");
DISPATCH_TO_CUDA("all_to_one_device", fbgemm_gpu::all_to_one_device);
}
#endif

6 changes: 3 additions & 3 deletions fbgemm_gpu/test/merge_pooled_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@


@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(open_source, "Not supported in open source yet")
#@unittest.skipIf(open_source, "Not supported in open source yet")
class MergePooledEmbeddingsTest(unittest.TestCase):
@given(
num_ads=st.integers(min_value=1, max_value=10),
embedding_dimension=st.integers(min_value=1, max_value=32),
ads_tables=st.integers(min_value=1, max_value=32),
num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
non_default_stream=st.booleans(),
r=st.randoms(use_true_random=False),
r=st.randoms(),
)
# Can instantiate 8 contexts which takes a long time.
@settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None)
Expand Down Expand Up @@ -93,7 +93,7 @@ def ref(pooled_ad_embeddings, batch_indices):
num_inputs=st.integers(min_value=1, max_value=10),
num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
non_default_stream=st.booleans(),
r=st.randoms(use_true_random=False),
r=st.randoms(),
)
# Can instantiate 8 contexts which takes a long time.
@settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
Expand Down

0 comments on commit 69abf78

Please sign in to comment.