Skip to content

Commit

Permalink
HIP extension support for FBGEMM_GPU (#846)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #846

Reviewed By: jspark1105

Differential Revision: D33231489

fbshipit-source-id: 6bd46ddee45c767ad25c2d52b6c05030bba94082
  • Loading branch information
jianyuh authored and facebook-github-bot committed Jan 3, 2022
1 parent 6620471 commit c6df576
Show file tree
Hide file tree
Showing 18 changed files with 254 additions and 106 deletions.
10 changes: 5 additions & 5 deletions fbgemm_gpu/codegen/embedding_backward_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def rowwise_adagrad() -> None:
momentum1[idx] = new_sum_square_grads;
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
}
multiplier = __shfl_sync(0xFFFFFFFF, multiplier, 0);
multiplier = shfl_sync(multiplier, 0);
"""
split_weight_update_cpu = """
at::acc_type<scalar_t, true> g_local_sum_square = 0.0;
Expand Down Expand Up @@ -474,8 +474,8 @@ def rowwise_weighted_adagrad() -> None:
multiplier = learning_rate * lambda / (cbrtf(new_sum_square_grads) + eps);
correction = 1.0 - multiplier * weight_decay;
}
multiplier = __shfl_sync(0xFFFFFFFF, multiplier, 0);
correction = __shfl_sync(0xFFFFFFFF, correction, 0);
multiplier = shfl_sync(multiplier, 0);
correction = shfl_sync(correction, 0);
"""
split_weight_update_cpu = """
// weight_decay not supported for cpu version
Expand Down Expand Up @@ -636,7 +636,7 @@ def partial_rowwise_lamb() -> None:
m2 = beta2 * momentum2[idx] + (1.0 - beta2) * g_avg_square;
momentum2[idx] = m2;
}
m2 = __shfl_sync(0xFFFFFFFF, m2, 0);
m2 = shfl_sync(m2, 0);
at::acc_type<cache_t, true> m2_hat = 1.0 / (sqrtf((m2 / (1.0 - powf(beta2, iter)))) + eps);
at::acc_type<cache_t, true> weight_sum_sq = 0.0;
Expand Down Expand Up @@ -772,7 +772,7 @@ def partial_rowwise_adam() -> None:
momentum2[idx] = v_t;
v_hat_t = v_t / (1.0 - powf(beta2, iter));
}
v_hat_t = __shfl_sync(0xFFFFFFFF, v_hat_t, 0);
v_hat_t = shfl_sync(v_hat_t, 0);
"""

split_weight_update = """
Expand Down
9 changes: 9 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_dense_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ class SplitLookupFunction_Dense_Op
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;
ctx->saved_data["pooling_mode"] = pooling_mode;

#ifdef __HIP_PLATFORM_HCC__
constexpr int32_t BT_block_size = 64;
#else
constexpr int32_t BT_block_size = 32;
#endif
if (!indice_weights.has_value()) {
return {dense_embedding_codegen_forward_unweighted_cuda(
dev_weights,
Expand Down Expand Up @@ -159,8 +163,13 @@ class SplitLookupFunction_Dense_Op

TORCH_CHECK(grad_outputs.size() == 1);

#ifdef __HIP_PLATFORM_HCC__
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
#else
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
#endif
using torch::autograd::Variable;

auto grad_output = grad_outputs[0];
Expand Down
9 changes: 9 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op :
{% endfor %}

{% if not nobag %}
#ifdef __HIP_PLATFORM_HCC__
constexpr int32_t BT_block_size = 64;
#else
constexpr int32_t BT_block_size = 32;
#endif
if (!indice_weights) {
return {split_embedding_codegen_forward_unweighted_cuda(
dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets,
Expand Down Expand Up @@ -257,8 +261,13 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op :

TORCH_CHECK(grad_outputs.size() == 1);

#ifdef __HIP_PLATFORM_HCC__
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
#else
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
#endif
using torch::autograd::Variable;

auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ __launch_bounds__(kForwardMaxThreads) void {{ "dense" if dense else "split" }}_e
int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0;
{% endif %}
for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
int64_t idx_j = __shfl_sync(0xFFFFFFFF, idx, j);
int64_t idx_j = shfl_sync(idx, j);
{% if not dense %}
int32_t cache_idx_j = __shfl_sync(0xFFFFFFFF, cache_idx, j);
int32_t cache_idx_j = shfl_sync(cache_idx, j);
{% endif %}
at::acc_type<cache_t, true> grad_indice_weight = 0.0;

Expand Down
29 changes: 20 additions & 9 deletions fbgemm_gpu/codegen/embedding_backward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,14 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
{% endif %}
for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) {
{% if not nobag %}
int32_t b_j = __shfl_sync(0xFFFFFFFF, b, j);
int32_t D_start_j = __shfl_sync(0xFFFFFFFF, D_start, j);
int32_t b_j = shfl_sync(b, j);
int32_t D_start_j = shfl_sync(D_start, j);
{% else %}
int32_t l_j = __shfl_sync(0xFFFFFFFF, l, j);
int32_t l_j = shfl_sync(l, j);
{% endif %}

{% if weighted %}
at::acc_type<cache_t, true> idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j);
at::acc_type<cache_t, true> idx_weight_j = shfl_sync(idx_weight, j);
{% endif %}

#pragma unroll kMaxVecsPerThread
Expand Down Expand Up @@ -546,13 +547,13 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_

for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) {
{% if not nobag %}
int32_t b_j = __shfl_sync(0xFFFFFFFF, b, j);
int32_t D_start_j = __shfl_sync(0xFFFFFFFF, D_start, j);
int32_t b_j = shfl_sync(b, j);
int32_t D_start_j = shfl_sync(D_start, j);
{% else %}
int32_t l_j = __shfl_sync(0xFFFFFFFF, l, j);
int32_t l_j = shfl_sync(l, j);
{% endif %}
{% if weighted %}
at::acc_type<cache_t, true> idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j);
at::acc_type<cache_t, true> idx_weight_j = shfl_sync(idx_weight, j);
{% endif %}

#pragma unroll kMaxVecsPerThread
Expand Down Expand Up @@ -757,7 +758,12 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_

// V100: 96 KB; A100: 160 KB.
int max_shared_bytes = 0;
#ifndef __HIP_PLATFORM_HCC__
cudaDeviceGetAttribute(&max_shared_bytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_weights.get_device());
#else
// MI100 has 64 KB local memory (shared memory) per workgroup
max_shared_bytes = 64 << 10;
#endif
C10_CUDA_KERNEL_LAUNCH_CHECK();
int shared_kb = max_shared_bytes >> 10;
// V100: 64 KB; A100: 96 KB.
Expand Down Expand Up @@ -958,6 +964,8 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
// over 48 KB per block are architecture-specific, as such they
// must use dynamic shared memory (rather than statically sized
// arrays) and require an explicit opt-in using cudaFuncSetAttribute()".
#ifndef __HIP_PLATFORM_HCC__
cudaFuncSetAttribute(
split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1<
{% if not dense %}
Expand All @@ -970,7 +978,9 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
{{ kMaxVecsPerThread }}>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
used_shared_bytes); // V100: 64 KB; A100: 96 KB.
#endif
C10_CUDA_KERNEL_LAUNCH_CHECK();
// dividing by kMaxThreads is a heuristic to avoid num of blocks far exceeding num_long_run_ids[0]
split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1<
{% if not dense %}
emb_t,
Expand All @@ -980,7 +990,6 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
scalar_t,
{% endif %}
{{ kMaxVecsPerThread }}>
// dividing by kMaxThreads is a heuristic to avoid num of blocks far exceeding num_long_run_ids[0]
<<<div_round_up(long_run_ids.numel(), kMaxThreads),
dim3(kWarpSize, BT_block_size),
BT_block_size * sizeof(at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>) * 4 * kWarpSize *
Expand Down Expand Up @@ -1031,6 +1040,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
{% endif %}
{{ args.split_kernel_arg_constructors | join(", ") }});
C10_CUDA_KERNEL_LAUNCH_CHECK();
#ifndef __HIP_PLATFORM_HCC__
cudaFuncSetAttribute(
split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1<
{% if not dense %}
Expand All @@ -1043,6 +1053,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
{{ kMaxVecsPerThread }}>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
used_shared_bytes); // V100: 64 KB; A100: 96 KB.
#endif
C10_CUDA_KERNEL_LAUNCH_CHECK();
split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1<
{% if not dense %}
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/codegen/embedding_bounds_check.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ __global__ void bounds_check_indices_kernel(
}

auto L = indices_end - indices_start;
for (auto i = threadIdx.x; i < L; i += fbgemm_gpu::kWarpSize) {
for (index_t i = (index_t)threadIdx.x; i < L;
i += (index_t)fbgemm_gpu::kWarpSize) {
auto idx = indices[indices_start + i];
if (idx == -1) {
// -1 indicates pruned rows.
Expand Down
16 changes: 16 additions & 0 deletions fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,11 @@ __global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" i
}
// equivalent to fence + wait.
cp_async_wait<0>();
#ifdef __HIP_PLATFORM_HCC__
__syncthreads();
#else
__syncwarp();
#endif
for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) {
#pragma unroll OutputRowsPerThread
for (uint32_t i = 0; i < OutputRowsPerThread; ++i) {
Expand Down Expand Up @@ -507,9 +511,21 @@ __global__ void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_{
found = true;
dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx;
}
#ifdef __HIP_PLATFORM_HCC__
// FIXME: __any_sync with mask isn't supported by HIP yet.
// See https://fburl.com/fvy7j0lq for the similar context.
// assert false here with https://fburl.com/pfm7enw2
assert(false);
if (__any(found)) {
#else
if (__any_sync(subwarp_mask, found)) {
#endif
break;
#ifdef __HIP_PLATFORM_HCC__
} else if (__any(empty)) {
#else
} else if (__any_sync(subwarp_mask, empty)) {
#endif
dense_indices[indices_start + l_start + subwarp_id] = -1;
break;
}
Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,16 @@ __global__ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if noba
at::acc_type<cache_t, true> idx_weight = l < L ? indice_weights[indices_start + l] : 0;
{% endif %}
for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
int64_t idx_j = __shfl_sync(0xFFFFFFFF, idx, j);
int64_t idx_j = shfl_sync(idx, j);
{% if nobag %}
int64_t output_j = indices_start + l_start + j;
{% endif %}
{% if not dense %}
int32_t cache_idx_j = __shfl_sync(0xFFFFFFFF, cache_idx, j);
int32_t cache_idx_j = shfl_sync(cache_idx, j);
{% endif %}

{% if weighted %}
at::acc_type<cache_t, true> idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j);
at::acc_type<cache_t, true> idx_weight_j = shfl_sync(idx_weight, j);
{% endif %}

{% if not dense %}
Expand Down
Loading

0 comments on commit c6df576

Please sign in to comment.