From f3d3b958c7f931ae599d1d3b79346d9d01ba62e9 Mon Sep 17 00:00:00 2001 From: Luoshang Pan Date: Fri, 12 Nov 2021 10:31:37 -0800 Subject: [PATCH] support different dtype for reorder_batched_ad_lengths_gpu (#755) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/755 as title this could be used for reordering inputs for per_sample_weights Reviewed By: jianyuh Differential Revision: D32276580 fbshipit-source-id: 39ae66b8776c6d061d960c0885cb6d5c95aa5d15 --- fbgemm_gpu/src/sparse_ops.cu | 53 +++++++++++++++++++++++------- fbgemm_gpu/test/sparse_ops_test.py | 5 +-- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops.cu index b264ef6ad5..cf63e7ea08 100644 --- a/fbgemm_gpu/src/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops.cu @@ -983,14 +983,17 @@ at::Tensor _fusednbitrowwise_to_float_gpu( return output; } + + +template __global__ void reorder_batched_ad_lengths_kernel( // reorder lengths from (ragged) [B x T x #num_ads_b)] to // [T][B][#num_ads_b], i.e. [T][sum(#num_ads_b)]. - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 cat_ad_lengths, const at::PackedTensorAccessor32 batch_offsets, - at::PackedTensorAccessor32 + at::PackedTensorAccessor32 reordered_cat_ad_lengths, int32_t T) { const int32_t B = batch_offsets.size(0) - 1; @@ -1033,16 +1036,42 @@ at::Tensor reorder_batched_ad_lengths_gpu( const dim3 threads(32, 32); const dim3 blocks((B * T + 32 - 1) / 32); - reorder_batched_ad_lengths_kernel<<< - blocks, - threads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - cat_ad_lengths.packed_accessor32(), - batch_offsets.packed_accessor32(), - reordered_cat_ad_lengths - .packed_accessor32(), - T); + if (cat_ad_lengths.dtype() == at::kInt) { + reorder_batched_ad_lengths_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cat_ad_lengths.packed_accessor32(), + batch_offsets.packed_accessor32(), + reordered_cat_ad_lengths + .packed_accessor32(), + T); + } else if (cat_ad_lengths.dtype() == at::kLong) { + reorder_batched_ad_lengths_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cat_ad_lengths.packed_accessor32(), + batch_offsets.packed_accessor32(), + reordered_cat_ad_lengths + .packed_accessor32(), + T); + } else if (cat_ad_lengths.dtype() == at::kFloat) { + reorder_batched_ad_lengths_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cat_ad_lengths.packed_accessor32(), + batch_offsets.packed_accessor32(), + reordered_cat_ad_lengths + .packed_accessor32(), + T); + } else { + TORCH_CHECK(false, "not implmented for ", cat_ad_lengths.dtype().name()); + } C10_CUDA_KERNEL_LAUNCH_CHECK(); return reordered_cat_ad_lengths; diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 4c1645b531..3bad25f876 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -653,13 +653,14 @@ def test_block_bucketize_sparse_features( T=st.integers(min_value=1, max_value=20), L=st.integers(min_value=2, max_value=20), A=st.integers(min_value=1, max_value=20), + Dtype=st.sampled_from([torch.int32, torch.float, torch.int64]), ) @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) - def test_reorder_batched_ad_lengths(self, B: int, T: int, L: int, A: int) -> None: + def test_reorder_batched_ad_lengths(self, B: int, T: int, L: int, A: int, Dtype: torch.dtype) -> None: cat_ad_lengths = ( torch.cat([torch.tensor([L for _ in range(T * A)]) for _ in range(B)], 0) - .int() .cuda() + .to(Dtype) ) batch_offsets = torch.tensor([A * b for b in range(B + 1)]).int().cuda() num_ads_in_batch = B * A