diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 8924069841..55db89cce3 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -353,7 +353,6 @@ set(embedding_codegen_dependencies ${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_nobag_small_template.cu ${CMAKE_CODEGEN_DIR}/embedding_forward_split_template.cu ${CMAKE_CODEGEN_DIR}/embedding_forward_template_helpers.cuh - ${CMAKE_CODEGEN_DIR}/embedding_op_registration.h ${CMAKE_CODEGEN_DIR}/__init__.template ${CMAKE_CODEGEN_DIR}/lookup_args.py ${CMAKE_CODEGEN_DIR}/split_embedding_codegen_lookup_invoker.template @@ -365,6 +364,7 @@ set(embedding_codegen_dependencies ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_backward_template_helpers.cuh ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_common.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_inplace_update.h + ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_op_registration.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/fbgemm_cuda_utils.cuh ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/quantize_ops_utils.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/split_embeddings_utils.cuh diff --git a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu index d6e72ccaa4..33fe178f20 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu @@ -12,7 +12,7 @@ //////////////////////////////////////////////////////////////////////////////// // Required for op registrations -#include "codegen/embedding_op_registration.h" +#include "fbgemm_gpu/embedding_op_registration.h" //////////////////////////////////////////////////////////////////////////////// #include "codegen/embedding_forward_template_helpers.cuh" diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index 1d928cd17a..0cd3371a09 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -14,7 +14,7 @@ {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations -#include "codegen/embedding_op_registration.h" +#include "fbgemm_gpu/embedding_op_registration.h" //////////////////////////////////////////////////////////////////////////////// {%- endif %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" diff --git a/fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu index 923b52baf4..509994e8a7 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu @@ -135,7 +135,7 @@ __inline__ __device__ void process_all_indices_no_pooling( // Assuming kWarpSize is a multiple of STEP for (uint32_t l_start = 0; l_start < TOTAL_L; l_start += STEP) { Vec4StepT vecs; - #pragma loop unroll + #pragma unroll for (uint32_t j = 0; j < STEP; ++j) { // Get weight pointer const auto* ptr = reinterpret_cast( @@ -151,7 +151,7 @@ __inline__ __device__ void process_all_indices_no_pooling( if (process_d) { // Write to output (not pooling) - #pragma loop unroll + #pragma unroll for (uint32_t j = 0; j < STEP; ++j) { {%- if weighted %} const auto index_weight = index_weights[l_start + j]; @@ -354,7 +354,7 @@ __noinline__ __device__ void process_all_indices_small_Ls( const auto cache_look_up_bits_step = cache_look_up_bits & STEP_MASK; if (USE_MIXED_TYPE_CACHE && cache_look_up_bits_step != 0) { if (cache_look_up_bits_step == STEP_MASK) { - #pragma loop unroll + #pragma unroll for (uint32_t j = 0; j < STEP; ++j) { const auto smem_offset = (l_start % kWarpSize) + j; if (process_d) { @@ -383,7 +383,7 @@ __noinline__ __device__ void process_all_indices_small_Ls( cache_look_up_bits >>= STEP; } else { - #pragma loop unroll + #pragma unroll for (uint32_t j = 0; j < STEP; ++j) { const auto smem_offset = (l_start % kWarpSize) + j; if (process_d) { @@ -425,14 +425,14 @@ __noinline__ __device__ void process_all_indices_small_Ls( else { if (process_d) { // Load STEP rows - #pragma loop unroll + #pragma unroll for (uint32_t j = 0; j < STEP; ++j) { const auto smem_offset = (l_start % kWarpSize) + j; accumulator.load(&SMEM_EMB_WEIGHT_DATA(smem_offset, threadIdx.x), j); } } - #pragma loop unroll + #pragma unroll for (uint32_t j = 0; j < STEP; ++j) { // Accumulate rows if (process_d) { @@ -588,20 +588,21 @@ __noinline__ __device__ void process_all_indices_large_Ls( {%- endif %} if (cache_look_up_bits_step == STEP_MASK) { // Load STEP rows from lxu_cache_weights - #pragma loop unroll + #pragma unroll for (uint32_t j = 0; j < STEP; ++j) { const auto* weight = &SMEM_CACHE_WEIGHT_DATA((l_start % kWarpSize) + SMEM_OFFSET, WEIGHT_OFFSET); ACC_ADD_OR_FMA(weight, index_weights[SMEM_OFFSET]) } - cache_look_up_bits >>= STEP * NUM_LOAD_GROUPS; + // Bypass the hip clang error of "shift count >= width of type" + cache_look_up_bits >>= std::min(STEP * NUM_LOAD_GROUPS, 31u); } else { // Load and accumulate STEP rows for UVM caching that emb_t and cache_t // are not the same and rows within STEPS are read from different // locations. It is unlikely that the compiler will be able to unroll // the loop below because of the runtime conditionals - #pragma loop unroll + #pragma unroll for (uint32_t j = 0; j < STEP; ++j) { if (cache_look_up_bits & 1u) { // Look up from lxu_cache_weights @@ -621,7 +622,7 @@ __noinline__ __device__ void process_all_indices_large_Ls( } else { // Load STEP rows from dev_weights - #pragma loop unroll + #pragma unroll for (uint32_t j = 0; j < STEP; ++j) { accumulator.load( &SMEM_EMB_WEIGHT_DATA( @@ -641,7 +642,8 @@ __noinline__ __device__ void process_all_indices_large_Ls( {%- endif %} if (USE_MIXED_TYPE_CACHE) { - cache_look_up_bits >>= STEP * NUM_LOAD_GROUPS; + // Bypass the hip clang error of "shift count >= width of type" + cache_look_up_bits >>= std::min(STEP * NUM_LOAD_GROUPS, 31u); } } } diff --git a/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp index d78471c476..686ac230ce 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp @@ -23,7 +23,7 @@ //////////////////////////////////////////////////////////////////////////////// // Required for op registrations -#include "codegen/embedding_op_registration.h" +#include "fbgemm_gpu/embedding_op_registration.h" #include "fbgemm_gpu/sparse_ops_utils.h" #include "fbgemm_gpu/embedding_common.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/fbgemm_gpu/codegen/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_template.cu index 40fa31db99..1aabf9d0ec 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_template.cu @@ -26,7 +26,7 @@ {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations -#include "codegen/embedding_op_registration.h" +#include "fbgemm_gpu/embedding_op_registration.h" //////////////////////////////////////////////////////////////////////////////// {%- endif %} #include "codegen/embedding_forward_template_helpers.cuh" diff --git a/fbgemm_gpu/codegen/split_embedding_optimizer_codegen.template b/fbgemm_gpu/codegen/split_embedding_optimizer_codegen.template index 2acf1b1b6d..4405c3abcc 100644 --- a/fbgemm_gpu/codegen/split_embedding_optimizer_codegen.template +++ b/fbgemm_gpu/codegen/split_embedding_optimizer_codegen.template @@ -19,16 +19,26 @@ from torch.optim.optimizer import Optimizer import logging {%- if is_fbcode %} -torch.ops.load_library( - "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training" -) torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_training" ) -torch.ops.load_library( - "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings" -) -torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops") + +if torch.version.hip: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings_hip" + ) + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops_hip") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_training" + ) +else: + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings" + ) + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training" + ) {%- endif %} diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh index 94ba1bd4e5..12e841cfe7 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh @@ -23,10 +23,10 @@ #include #include -#include "dispatch_macros.h" -#include "embedding_common.h" -#include "fbgemm_cuda_utils.cuh" -#include "sparse_ops_utils.h" +#include "fbgemm_gpu/dispatch_macros.h" +#include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/sparse_ops_utils.h" #define SHFL_SYNC(val, srcLane) \ shfl_sync(val, srcLane, kThreadGroupSize, shfl_sync_mask) diff --git a/fbgemm_gpu/codegen/embedding_op_registration.h b/fbgemm_gpu/include/fbgemm_gpu/embedding_op_registration.h similarity index 100% rename from fbgemm_gpu/codegen/embedding_op_registration.h rename to fbgemm_gpu/include/fbgemm_gpu/embedding_op_registration.h diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute102.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute102.cu index 52dd6674fd..0f50c8e581 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute102.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute102.cu @@ -79,7 +79,13 @@ DLL_PUBLIC Tensor permute102_baddbmm_permute102_cuda( auto ldc = n * batch_size; auto strideC = n; + // computeType is hipblasComputeType_t (e.g., HIPBLAS_COMPUTE_32F) instead of + // hipDataType (e.g., HIPBLAS_R_32F) after RoCM 6.0 +#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2) + auto computeType = HIPBLAS_COMPUTE_32F; +#else auto computeType = HIPBLAS_R_32F; +#endif auto result = hipblasGemmStridedBatchedEx( handle,