diff --git a/fbgemm_gpu/src/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops.cu index b264ef6ad5..cf63e7ea08 100644 --- a/fbgemm_gpu/src/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops.cu @@ -983,14 +983,17 @@ at::Tensor _fusednbitrowwise_to_float_gpu( return output; } + + +template __global__ void reorder_batched_ad_lengths_kernel( // reorder lengths from (ragged) [B x T x #num_ads_b)] to // [T][B][#num_ads_b], i.e. [T][sum(#num_ads_b)]. - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 cat_ad_lengths, const at::PackedTensorAccessor32 batch_offsets, - at::PackedTensorAccessor32 + at::PackedTensorAccessor32 reordered_cat_ad_lengths, int32_t T) { const int32_t B = batch_offsets.size(0) - 1; @@ -1033,16 +1036,42 @@ at::Tensor reorder_batched_ad_lengths_gpu( const dim3 threads(32, 32); const dim3 blocks((B * T + 32 - 1) / 32); - reorder_batched_ad_lengths_kernel<<< - blocks, - threads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - cat_ad_lengths.packed_accessor32(), - batch_offsets.packed_accessor32(), - reordered_cat_ad_lengths - .packed_accessor32(), - T); + if (cat_ad_lengths.dtype() == at::kInt) { + reorder_batched_ad_lengths_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cat_ad_lengths.packed_accessor32(), + batch_offsets.packed_accessor32(), + reordered_cat_ad_lengths + .packed_accessor32(), + T); + } else if (cat_ad_lengths.dtype() == at::kLong) { + reorder_batched_ad_lengths_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cat_ad_lengths.packed_accessor32(), + batch_offsets.packed_accessor32(), + reordered_cat_ad_lengths + .packed_accessor32(), + T); + } else if (cat_ad_lengths.dtype() == at::kFloat) { + reorder_batched_ad_lengths_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cat_ad_lengths.packed_accessor32(), + batch_offsets.packed_accessor32(), + reordered_cat_ad_lengths + .packed_accessor32(), + T); + } else { + TORCH_CHECK(false, "not implmented for ", cat_ad_lengths.dtype().name()); + } C10_CUDA_KERNEL_LAUNCH_CHECK(); return reordered_cat_ad_lengths; diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 4c1645b531..3bad25f876 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -653,13 +653,14 @@ def test_block_bucketize_sparse_features( T=st.integers(min_value=1, max_value=20), L=st.integers(min_value=2, max_value=20), A=st.integers(min_value=1, max_value=20), + Dtype=st.sampled_from([torch.int32, torch.float, torch.int64]), ) @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) - def test_reorder_batched_ad_lengths(self, B: int, T: int, L: int, A: int) -> None: + def test_reorder_batched_ad_lengths(self, B: int, T: int, L: int, A: int, Dtype: torch.dtype) -> None: cat_ad_lengths = ( torch.cat([torch.tensor([L for _ in range(T * A)]) for _ in range(B)], 0) - .int() .cuda() + .to(Dtype) ) batch_offsets = torch.tensor([A * b for b in range(B + 1)]).int().cuda() num_ads_in_batch = B * A