Skip to content

Commit

Permalink
Use a smaller grid size for TBE bwd cta_per_row
Browse files Browse the repository at this point in the history
Summary: Same idea as D39720886 (e2bfc2e)

Differential Revision: D39760002

fbshipit-source-id: daf815b4e019f38f2fdd2d626645be4cf9a53752
  • Loading branch information
sryap committed Sep 23, 2022
1 parent 8e9cc8b commit 27bef0d
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions fbgemm_gpu/codegen/embedding_backward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,10 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
{use_deterministic_algorithms ? 0 : grad_accum_counter.numel(), max_D},
grad_output.options().dtype(std::is_same<cache_t, double>::value ? at::kDouble : at::kFloat));
int32_t grid_size = std::min(
div_round_up(long_run_ids.numel(), kMaxThreads),
64 * at::cuda::getCurrentDeviceProperties()->multiProcessorCount);
// Check https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-7-x
// "Compute capability 7.x devices allow a single thread block to
// address the full capacity of shared memory: 96 KB on Volta,
Expand Down Expand Up @@ -1127,7 +1131,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
{% endif %}
kMaxVecsPerThread,
kThreadGroupSize>
<<<div_round_up(long_run_ids.numel(), kMaxThreads),
<<<grid_size,
dim3(kThreadGroupSize, BT_block_size),
BT_block_size * sizeof(at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>) * 4 * kWarpSize *
kMaxVecsPerThread,
Expand Down Expand Up @@ -1182,7 +1186,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
use_deterministic_algorithms,
{{ args.split_kernel_arg_constructors | join(", ") }});
C10_CUDA_KERNEL_LAUNCH_CHECK();
int32_t grid_size = std::min(
grid_size = std::min(
div_round_up(sorted_linear_indices_run.numel(), kBackwardMaxThreads / kThreadGroupSize),
64 * at::cuda::getCurrentDeviceProperties()->multiProcessorCount);
#ifndef __HIP_PLATFORM_HCC__
Expand Down

0 comments on commit 27bef0d

Please sign in to comment.