Skip to content

Commit

Permalink
Add memchecks to sparse ops, pt 2 (#2612)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2612

- Add memchecks to sparse ops, pt 2

Reviewed By: spcyppt

Differential Revision: D57602156

fbshipit-source-id: 997e4713d174ed0a0cf7b37ce0229c333904564a
  • Loading branch information
q10 authored and facebook-github-bot committed May 21, 2024
1 parent 5d35f5c commit f739e0a
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 92 deletions.
195 changes: 108 additions & 87 deletions fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ __global__
__launch_bounds__(kMaxThreads) void reorder_batched_ad_lengths_kernel(
// reorder lengths from (ragged) [B x T x #num_ads_b)] to
// [T][B][#num_ads_b], i.e. [T][sum(#num_ads_b)].
const at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
cat_ad_lengths,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
batch_offsets,
at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
reordered_cat_ad_lengths,
const int32_t T,
const bool broadcast_lengths) {
Expand Down Expand Up @@ -95,14 +95,15 @@ DLL_PUBLIC Tensor reorder_batched_ad_lengths_gpu(
cat_ad_lengths.scalar_type(),
"reorder_batched_ad_lengths_gpu_kernel",
[&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "reorder_batched_ad_lengths_kernel";
#endif
reorder_batched_ad_lengths_kernel<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
cat_ad_lengths
.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>(),
batch_offsets
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
reordered_cat_ad_lengths
.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>(),
MAKE_PTA_WITH_NAME(func_name, cat_ad_lengths, scalar_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, batch_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, reordered_cat_ad_lengths, scalar_t, 1, 32),
T,
broadcast_lengths);
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand All @@ -112,11 +113,11 @@ DLL_PUBLIC Tensor reorder_batched_ad_lengths_gpu(

template <typename Dtype, typename index_t = int32_t>
__global__ __launch_bounds__(kMaxThreads) void narrow_broadcast_indices_kernel(
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
cat_ad_offsets,
const at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
cat_ad_indices,
at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
reordered_cat_ad_indices,
const int num_ads_in_batch,
const int reordered_cat_ad_batches) {
Expand All @@ -139,15 +140,15 @@ __global__ __launch_bounds__(kMaxThreads) void narrow_broadcast_indices_kernel(
template <typename Dtype, typename index_t = int32_t>
__global__
__launch_bounds__(kMaxThreads) void narrow_batched_broadcast_indices_kernel(
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
cat_ad_offsets,
const at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
cat_ad_indices,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
reordered_cat_ad_offsets,
at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
reordered_cat_ad_indices,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
batch_offsets,
const int32_t T) {
const auto B = batch_offsets.size(0) - 1;
Expand Down Expand Up @@ -196,15 +197,15 @@ __launch_bounds__(kMaxThreads) void reorder_batched_ad_indices_kernel(
// if broadcast_indices is enabled, all the indices will be copies of the
// first batch of the cat_ad_indices, this is useful for request-only
// broadcast
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
cat_ad_offsets,
const at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
cat_ad_indices,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
reordered_cat_ad_offsets,
at::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
reordered_cat_ad_indices,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
batch_offsets,
const int32_t T,
const bool broadcast_indices) {
Expand Down Expand Up @@ -291,23 +292,24 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
cat_ad_offsets.scalar_type(),
"narrow_broadcast_indices_kernel_2",
[&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "narrow_broadcast_indices_kernel";
#endif
narrow_broadcast_indices_kernel<scalar_t, index_t>
<<<blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cat_ad_offsets.packed_accessor32<
index_t,
1,
at::RestrictPtrTraits>(),
cat_ad_indices.packed_accessor32<
MAKE_PTA_WITH_NAME(
func_name, cat_ad_offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, cat_ad_indices, scalar_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name,
reordered_cat_ad_indices,
scalar_t,
1,
at::RestrictPtrTraits>(),
reordered_cat_ad_indices.packed_accessor32<
scalar_t,
1,
at::RestrictPtrTraits>(),
32),
num_ads_in_batch,
reordered_cat_ad_offsets.numel() - 1);
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand All @@ -329,31 +331,33 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
cat_ad_offsets.scalar_type(),
"narrow_batched_broadcast_indices_kernel_2",
[&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"narrow_batched_broadcast_indices_kernel";
#endif
narrow_batched_broadcast_indices_kernel<scalar_t, index_t>
<<<blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cat_ad_offsets.packed_accessor32<
MAKE_PTA_WITH_NAME(
func_name, cat_ad_offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, cat_ad_indices, scalar_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name,
reordered_cat_ad_offsets,
index_t,
1,
at::RestrictPtrTraits>(),
cat_ad_indices.packed_accessor32<
32),
MAKE_PTA_WITH_NAME(
func_name,
reordered_cat_ad_indices,
scalar_t,
1,
at::RestrictPtrTraits>(),
reordered_cat_ad_offsets.packed_accessor32<
index_t,
1,
at::RestrictPtrTraits>(),
reordered_cat_ad_indices.packed_accessor32<
scalar_t,
1,
at::RestrictPtrTraits>(),
batch_offsets.packed_accessor32<
int32_t,
1,
at::RestrictPtrTraits>(),
32),
MAKE_PTA_WITH_NAME(
func_name, batch_offsets, int32_t, 1, 32),
T);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand All @@ -374,23 +378,23 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
cat_ad_offsets.scalar_type(),
"reorder_batched_ad_indices_gpu_kernel_2",
[&] {
reorder_batched_ad_indices_kernel<scalar_t, index_t><<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cat_ad_offsets
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
cat_ad_indices
.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>(),
reordered_cat_ad_offsets
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
reordered_cat_ad_indices
.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>(),
batch_offsets
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
T,
broadcast_indices);
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "reorder_batched_ad_indices_kernel";
#endif
reorder_batched_ad_indices_kernel<scalar_t, index_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name, cat_ad_offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, cat_ad_indices, scalar_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, reordered_cat_ad_offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, reordered_cat_ad_indices, scalar_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, batch_offsets, int32_t, 1, 32),
T,
broadcast_indices);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand All @@ -403,15 +407,15 @@ __launch_bounds__(kMaxThreads) void reorder_batched_sequence_embeddings_kernel(
// reorder embeddings from (ragged) [B x T x #num_ads_B_{i} x length_{B_{i},
// t, a})x D] to [T][B][#num_ads_b][length_{b, t, a}][D], i.e.
// [sum(length_{B_{i}, t, a}), D]
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
cat_sequence_embeddings_offsets,
const at::PackedTensorAccessor32<Dtype, 2, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<Dtype, 2, at::RestrictPtrTraits>
cat_sequence_embeddings,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
reordered_cat_sequence_embeddings_offsets,
at::PackedTensorAccessor32<Dtype, 2, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<Dtype, 2, at::RestrictPtrTraits>
reordered_cat_sequence_embeddings,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
batch_offsets,
const int32_t T,
const int32_t D) {
Expand Down Expand Up @@ -485,23 +489,40 @@ DLL_PUBLIC Tensor reorder_batched_sequence_embeddings_gpu(
cat_sequence_embeddings_offsets.scalar_type(),
"reorder_batched_sequence_embeddings_gpu_kernel_2",
[&] {
reorder_batched_sequence_embeddings_kernel<scalar_t, index_t><<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
cat_sequence_embeddings_offsets
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
cat_sequence_embeddings_contig
->packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(),
reordered_cat_sequence_embeddings_offsets
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
reordered_cat_sequence_embeddings
.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(),
batch_offsets
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
T,
D);
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"reorder_batched_sequence_embeddings_kernel";
#endif
reorder_batched_sequence_embeddings_kernel<scalar_t, index_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name,
cat_sequence_embeddings_offsets,
index_t,
1,
32),
MAKE_PTA_WITH_NAME(
func_name,
(*cat_sequence_embeddings_contig),
scalar_t,
2,
32),
MAKE_PTA_WITH_NAME(
func_name,
reordered_cat_sequence_embeddings_offsets,
index_t,
1,
32),
MAKE_PTA_WITH_NAME(
func_name,
reordered_cat_sequence_embeddings,
scalar_t,
2,
32),
MAKE_PTA_WITH_NAME(
func_name, batch_offsets, int32_t, 1, 32),
T,
D);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down
13 changes: 8 additions & 5 deletions fbgemm_gpu/test/sparse/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,14 @@ def extend_test_class(
"", os.path.dirname(__file__), "failures_dict.json"
)

additional_decorators = (additional_decorators or {}) | {
"test_pt2_compliant_tag_fbgemm_permute_2D_sparse_data": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
],
additional_decorators = {
**(additional_decorators or {}),
**{
"test_pt2_compliant_tag_fbgemm_permute_2D_sparse_data": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
]
},
}

# Only generate tests for PyTorch 2.2+
Expand Down

0 comments on commit f739e0a

Please sign in to comment.