Skip to content

Commit

Permalink
Fix fbgemm TBE building issues (pytorch#2235)
Browse files Browse the repository at this point in the history
Summary:

Unblock of fbgemm TBE (inference, training) usages on AMD GPUs .

Reviewed By: zoranzhao, houseroad

Differential Revision: D52425243
  • Loading branch information
jianyuh authored and facebook-github-bot committed Dec 28, 2023
1 parent 7dd0c7f commit 06e4912
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 27 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,6 @@ set(embedding_codegen_dependencies
${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_nobag_small_template.cu
${CMAKE_CODEGEN_DIR}/embedding_forward_split_template.cu
${CMAKE_CODEGEN_DIR}/embedding_forward_template_helpers.cuh
${CMAKE_CODEGEN_DIR}/embedding_op_registration.h
${CMAKE_CODEGEN_DIR}/__init__.template
${CMAKE_CODEGEN_DIR}/lookup_args.py
${CMAKE_CODEGEN_DIR}/split_embedding_codegen_lookup_invoker.template
Expand All @@ -365,6 +364,7 @@ set(embedding_codegen_dependencies
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_backward_template_helpers.cuh
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_common.h
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_inplace_update.h
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_op_registration.h
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/quantize_ops_utils.h
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/split_embeddings_utils.cuh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

////////////////////////////////////////////////////////////////////////////////
// Required for op registrations
#include "codegen/embedding_op_registration.h"
#include "fbgemm_gpu/embedding_op_registration.h"
////////////////////////////////////////////////////////////////////////////////
#include "codegen/embedding_forward_template_helpers.cuh"

Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/embedding_backward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
{%- if not is_index_select %}
////////////////////////////////////////////////////////////////////////////////
// Required for op registrations
#include "codegen/embedding_op_registration.h"
#include "fbgemm_gpu/embedding_op_registration.h"
////////////////////////////////////////////////////////////////////////////////
{%- endif %}
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
Expand Down
24 changes: 13 additions & 11 deletions fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ __inline__ __device__ void process_all_indices_no_pooling(
// Assuming kWarpSize is a multiple of STEP
for (uint32_t l_start = 0; l_start < TOTAL_L; l_start += STEP) {
Vec4StepT<STEP, emb_t> vecs;
#pragma loop unroll
#pragma unroll
for (uint32_t j = 0; j < STEP; ++j) {
// Get weight pointer
const auto* ptr = reinterpret_cast<const emb_vec_t*>(
Expand All @@ -151,7 +151,7 @@ __inline__ __device__ void process_all_indices_no_pooling(

if (process_d) {
// Write to output (not pooling)
#pragma loop unroll
#pragma unroll
for (uint32_t j = 0; j < STEP; ++j) {
{%- if weighted %}
const auto index_weight = index_weights[l_start + j];
Expand Down Expand Up @@ -354,7 +354,7 @@ __noinline__ __device__ void process_all_indices_small_Ls(
const auto cache_look_up_bits_step = cache_look_up_bits & STEP_MASK;
if (USE_MIXED_TYPE_CACHE && cache_look_up_bits_step != 0) {
if (cache_look_up_bits_step == STEP_MASK) {
#pragma loop unroll
#pragma unroll
for (uint32_t j = 0; j < STEP; ++j) {
const auto smem_offset = (l_start % kWarpSize) + j;
if (process_d) {
Expand Down Expand Up @@ -383,7 +383,7 @@ __noinline__ __device__ void process_all_indices_small_Ls(
cache_look_up_bits >>= STEP;
}
else {
#pragma loop unroll
#pragma unroll
for (uint32_t j = 0; j < STEP; ++j) {
const auto smem_offset = (l_start % kWarpSize) + j;
if (process_d) {
Expand Down Expand Up @@ -425,14 +425,14 @@ __noinline__ __device__ void process_all_indices_small_Ls(
else {
if (process_d) {
// Load STEP rows
#pragma loop unroll
#pragma unroll
for (uint32_t j = 0; j < STEP; ++j) {
const auto smem_offset = (l_start % kWarpSize) + j;
accumulator.load(&SMEM_EMB_WEIGHT_DATA(smem_offset, threadIdx.x), j);
}
}

#pragma loop unroll
#pragma unroll
for (uint32_t j = 0; j < STEP; ++j) {
// Accumulate rows
if (process_d) {
Expand Down Expand Up @@ -588,20 +588,21 @@ __noinline__ __device__ void process_all_indices_large_Ls(
{%- endif %}
if (cache_look_up_bits_step == STEP_MASK) {
// Load STEP rows from lxu_cache_weights
#pragma loop unroll
#pragma unroll
for (uint32_t j = 0; j < STEP; ++j) {
const auto* weight =
&SMEM_CACHE_WEIGHT_DATA((l_start % kWarpSize) + SMEM_OFFSET, WEIGHT_OFFSET);
ACC_ADD_OR_FMA(weight, index_weights[SMEM_OFFSET])
}
cache_look_up_bits >>= STEP * NUM_LOAD_GROUPS;
// Bypass the hip clang error of "shift count >= width of type"
cache_look_up_bits >>= std::min(STEP * NUM_LOAD_GROUPS, 31u);
}
else {
// Load and accumulate STEP rows for UVM caching that emb_t and cache_t
// are not the same and rows within STEPS are read from different
// locations. It is unlikely that the compiler will be able to unroll
// the loop below because of the runtime conditionals
#pragma loop unroll
#pragma unroll
for (uint32_t j = 0; j < STEP; ++j) {
if (cache_look_up_bits & 1u) {
// Look up from lxu_cache_weights
Expand All @@ -621,7 +622,7 @@ __noinline__ __device__ void process_all_indices_large_Ls(
}
else {
// Load STEP rows from dev_weights
#pragma loop unroll
#pragma unroll
for (uint32_t j = 0; j < STEP; ++j) {
accumulator.load(
&SMEM_EMB_WEIGHT_DATA(
Expand All @@ -641,7 +642,8 @@ __noinline__ __device__ void process_all_indices_large_Ls(
{%- endif %}

if (USE_MIXED_TYPE_CACHE) {
cache_look_up_bits >>= STEP * NUM_LOAD_GROUPS;
// Bypass the hip clang error of "shift count >= width of type"
cache_look_up_bits >>= std::min(STEP * NUM_LOAD_GROUPS, 31u);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

////////////////////////////////////////////////////////////////////////////////
// Required for op registrations
#include "codegen/embedding_op_registration.h"
#include "fbgemm_gpu/embedding_op_registration.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
#include "fbgemm_gpu/embedding_common.h"
////////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
{%- if not is_index_select %}
////////////////////////////////////////////////////////////////////////////////
// Required for op registrations
#include "codegen/embedding_op_registration.h"
#include "fbgemm_gpu/embedding_op_registration.h"
////////////////////////////////////////////////////////////////////////////////
{%- endif %}
#include "codegen/embedding_forward_template_helpers.cuh"
Expand Down
24 changes: 17 additions & 7 deletions fbgemm_gpu/codegen/split_embedding_optimizer_codegen.template
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,26 @@ from torch.optim.optimizer import Optimizer
import logging

{%- if is_fbcode %}
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_training"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings"
)
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops")

if torch.version.hip:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings_hip"
)
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops_hip")
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_training"
)
else:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings"
)
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops")
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training"
)
{%- endif %}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
#include <curand_kernel.h>
#include <mutex>

#include "dispatch_macros.h"
#include "embedding_common.h"
#include "fbgemm_cuda_utils.cuh"
#include "sparse_ops_utils.h"
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/sparse_ops_utils.h"

#define SHFL_SYNC(val, srcLane) \
shfl_sync(val, srcLane, kThreadGroupSize, shfl_sync_mask)
Expand Down
6 changes: 6 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_permute102.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ DLL_PUBLIC Tensor permute102_baddbmm_permute102_cuda(
auto ldc = n * batch_size;
auto strideC = n;

// computeType is hipblasComputeType_t (e.g., HIPBLAS_COMPUTE_32F) instead of
// hipDataType (e.g., HIPBLAS_R_32F) after RoCM 6.0
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && defined(HIPBLAS_V2)
auto computeType = HIPBLAS_COMPUTE_32F;
#else
auto computeType = HIPBLAS_R_32F;
#endif

auto result = hipblasGemmStridedBatchedEx(
handle,
Expand Down

0 comments on commit 06e4912

Please sign in to comment.