Skip to content

Commit

Permalink
support different dtype for reorder_batched_ad_lengths_gpu (#755)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #755

as title

this could be used for reordering inputs for per_sample_weights

Reviewed By: jianyuh

Differential Revision: D32276580

fbshipit-source-id: 39ae66b8776c6d061d960c0885cb6d5c95aa5d15
  • Loading branch information
Luoshang Pan authored and facebook-github-bot committed Nov 12, 2021
1 parent 4399bed commit f3d3b95
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
53 changes: 41 additions & 12 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -983,14 +983,17 @@ at::Tensor _fusednbitrowwise_to_float_gpu(
return output;
}
template <typename Dtype>
__global__ void reorder_batched_ad_lengths_kernel(
// reorder lengths from (ragged) [B x T x #num_ads_b)] to
// [T][B][#num_ads_b], i.e. [T][sum(#num_ads_b)].
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
cat_ad_lengths,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
batch_offsets,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
reordered_cat_ad_lengths,
int32_t T) {
const int32_t B = batch_offsets.size(0) - 1;
Expand Down Expand Up @@ -1033,16 +1036,42 @@ at::Tensor reorder_batched_ad_lengths_gpu(
const dim3 threads(32, 32);
const dim3 blocks((B * T + 32 - 1) / 32);
reorder_batched_ad_lengths_kernel<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cat_ad_lengths.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
batch_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
reordered_cat_ad_lengths
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
T);
if (cat_ad_lengths.dtype() == at::kInt) {
reorder_batched_ad_lengths_kernel<int32_t><<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cat_ad_lengths.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
batch_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
reordered_cat_ad_lengths
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
T);
} else if (cat_ad_lengths.dtype() == at::kLong) {
reorder_batched_ad_lengths_kernel<int64_t><<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cat_ad_lengths.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
batch_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
reordered_cat_ad_lengths
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
T);
} else if (cat_ad_lengths.dtype() == at::kFloat) {
reorder_batched_ad_lengths_kernel<float><<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cat_ad_lengths.packed_accessor32<float, 1, at::RestrictPtrTraits>(),
batch_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
reordered_cat_ad_lengths
.packed_accessor32<float, 1, at::RestrictPtrTraits>(),
T);
} else {
TORCH_CHECK(false, "not implmented for ", cat_ad_lengths.dtype().name());
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
return reordered_cat_ad_lengths;
Expand Down
5 changes: 3 additions & 2 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,13 +653,14 @@ def test_block_bucketize_sparse_features(
T=st.integers(min_value=1, max_value=20),
L=st.integers(min_value=2, max_value=20),
A=st.integers(min_value=1, max_value=20),
Dtype=st.sampled_from([torch.int32, torch.float, torch.int64]),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
def test_reorder_batched_ad_lengths(self, B: int, T: int, L: int, A: int) -> None:
def test_reorder_batched_ad_lengths(self, B: int, T: int, L: int, A: int, Dtype: torch.dtype) -> None:
cat_ad_lengths = (
torch.cat([torch.tensor([L for _ in range(T * A)]) for _ in range(B)], 0)
.int()
.cuda()
.to(Dtype)
)
batch_offsets = torch.tensor([A * b for b in range(B + 1)]).int().cuda()
num_ads_in_batch = B * A
Expand Down

0 comments on commit f3d3b95

Please sign in to comment.