diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 5b993373e1..842059af2f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -137,13 +137,15 @@ std::tuple< ///@ingroup sparse-data-cuda block_bucketize_sparse_features_cuda( - at::Tensor lengths, - at::Tensor indices, - bool bucketize_pos, - bool sequence, - at::Tensor block_sizes, - int64_t my_size, - c10::optional weights); + const at::Tensor& lengths, + const at::Tensor& indices, + const bool bucketize_pos, + const bool sequence, + const at::Tensor& block_sizes, + const int64_t my_size, + const c10::optional& weights, + const c10::optional& batch_size_per_feature, + const int64_t max_batch_size); std::tuple< at::Tensor, @@ -154,13 +156,15 @@ std::tuple< ///@ingroup sparse-data-cpu block_bucketize_sparse_features_cpu( - at::Tensor lengths, - at::Tensor indices, - bool bucketize_pos, - bool sequence, - at::Tensor block_sizes, - int64_t my_size, - c10::optional weights); + const at::Tensor& lengths, + const at::Tensor& indices, + const bool bucketize_pos, + const bool sequence, + const at::Tensor& block_sizes, + const int64_t my_size, + const c10::optional& weights, + const c10::optional& batch_size_per_feature, + const int64_t max_batch_size); std::tuple< at::Tensor, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu index 8b0b0c35f4..26977e6f0f 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu @@ -12,6 +12,28 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { +// Kernel for calulating lengthh idx to feature id mapping. Used for block +// bucketize sparse features with variable batch size for row-wise partition +template +__global__ +__launch_bounds__(kMaxThreads) void _populate_length_to_feature_id_inplace_kernel( + const uint64_t max_B, + const int T, + const offset_t* const __restrict__ batch_sizes, + const offset_t* const __restrict__ batch_size_offsets, + offset_t* const __restrict__ length_to_feature_idx) { + const auto b_t = blockIdx.x * blockDim.x + threadIdx.x; + + const auto t = b_t / max_B; + const auto b = b_t % max_B; + + if (t >= T || b >= batch_sizes[t]) { + return; + } + + length_to_feature_idx[batch_size_offsets[t] + b] = t; +} + // Kernel for bucketize lengths, with the Block distribution (vs. cyclic, // block-cyclic distribution). Used for bucketize sparse feature, especially for // checkpointing with row-wise partition (sparse_feature is partitioned @@ -19,16 +41,17 @@ namespace fbgemm_gpu { template __global__ __launch_bounds__(kMaxThreads) void _block_bucketize_sparse_features_cuda_kernel1( - int32_t lengths_size, - int32_t B, - const index_t* __restrict__ block_sizes_data, - int my_size, - const offset_t* __restrict__ offsets_data, - const index_t* __restrict__ indices_data, - offset_t* __restrict__ new_lengths_data) { + const int32_t lengths_size, + const int32_t B, + const index_t* const __restrict__ block_sizes_data, + const int my_size, + const offset_t* const __restrict__ offsets_data, + const index_t* const __restrict__ indices_data, + offset_t* const __restrict__ new_lengths_data, + offset_t* __restrict__ length_to_feature_idx) { using uindex_t = std::make_unsigned_t; CUDA_KERNEL_LOOP(b_t, lengths_size) { - int32_t t = b_t / B; + const auto t = length_to_feature_idx ? length_to_feature_idx[b_t] : b_t / B; index_t blk_size = block_sizes_data[t]; offset_t rowstart = (b_t == 0 ? 0 : offsets_data[b_t - 1]); offset_t rowend = offsets_data[b_t]; @@ -71,11 +94,12 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sparse_features_cuda_kernel index_t* __restrict__ new_indices_data, scalar_t* __restrict__ new_weights_data, index_t* __restrict__ new_pos_data, - index_t* __restrict__ unbucketize_permute_data) { + index_t* const __restrict__ unbucketize_permute_data, + const offset_t* const __restrict__ length_to_feature_idx) { using uindex_t = std::make_unsigned_t; using uoffset_t = std::make_unsigned_t; CUDA_KERNEL_LOOP(b_t, lengths_size) { - int32_t t = b_t / B; + const auto t = length_to_feature_idx ? length_to_feature_idx[b_t] : b_t / B; index_t blk_size = block_sizes_data[t]; offset_t rowstart = (b_t == 0 ? 0 : offsets_data[b_t - 1]); offset_t rowend = offsets_data[b_t]; @@ -115,13 +139,15 @@ DLL_PUBLIC std::tuple< c10::optional, c10::optional> block_bucketize_sparse_features_cuda( - Tensor lengths, - Tensor indices, - bool bucketize_pos, - bool sequence, - Tensor block_sizes, - int64_t my_size, - c10::optional weights) { + const Tensor& lengths, + const Tensor& indices, + const bool bucketize_pos, + const bool sequence, + const Tensor& block_sizes, + const int64_t my_size, + const c10::optional& weights, + const c10::optional& batch_sizes, + const int64_t max_B) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices); at::cuda::OptionalCUDAGuard device_guard; @@ -138,11 +164,44 @@ block_bucketize_sparse_features_cuda( auto lengths_contig = lengths.contiguous(); auto indices_contig = indices.contiguous(); auto offsets_contig = offsets.contiguous(); + auto batch_sizes_contig = + batch_sizes.value_or(at::empty({T}, lengths.options())).contiguous(); + auto batch_sizes_offsets_contig = + at::empty({T}, batch_sizes_contig.options()); Tensor new_weights; Tensor new_pos; Tensor unbucketize_permute; // count nonzeros offsets_contig = asynchronous_inclusive_cumsum_gpu(lengths); + if (batch_sizes.has_value()) { + assert(max_B > 0); + batch_sizes_offsets_contig = + asynchronous_exclusive_cumsum_gpu(batch_sizes.value()); + } + auto length_to_feature_idx = + at::empty({lengths_size}, lengths_contig.options()); + if (batch_sizes.has_value()) { + constexpr auto threads_per_block = 256; + const int num_blocks = cuda_calc_xblock_count(max_B * T, threads_per_block); + AT_DISPATCH_INDEX_TYPES( + offsets_contig.scalar_type(), + "_populate_length_to_feature_id_inplace_kernel", + [&] { + using offset_t = index_t; + _populate_length_to_feature_id_inplace_kernel<<< + num_blocks, + threads_per_block, + 0, + at::cuda::getCurrentCUDAStream()>>>( + max_B, + T, + batch_sizes_contig.data_ptr(), + batch_sizes_offsets_contig.data_ptr(), + length_to_feature_idx.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } + int threads_per_block = 256; int num_blocks = (lengths_size + threads_per_block - 1) / threads_per_block; AT_DISPATCH_INDEX_TYPES( @@ -165,7 +224,10 @@ block_bucketize_sparse_features_cuda( my_size, offsets_contig.data_ptr(), indices_contig.data_ptr(), - new_lengths.data_ptr()); + new_lengths.data_ptr(), + batch_sizes.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); @@ -215,7 +277,10 @@ block_bucketize_sparse_features_cuda( new_indices.data_ptr(), new_weights.data_ptr(), new_pos.data_ptr(), - unbucketize_permute.data_ptr()); + unbucketize_permute.data_ptr(), + batch_sizes.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); @@ -259,7 +324,10 @@ block_bucketize_sparse_features_cuda( new_indices.data_ptr(), new_weights.data_ptr(), nullptr, - unbucketize_permute.data_ptr()); + unbucketize_permute.data_ptr(), + batch_sizes.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); @@ -297,7 +365,10 @@ block_bucketize_sparse_features_cuda( new_indices.data_ptr(), nullptr, new_pos.data_ptr(), - unbucketize_permute.data_ptr()); + unbucketize_permute.data_ptr(), + batch_sizes.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); @@ -333,7 +404,10 @@ block_bucketize_sparse_features_cuda( new_indices.data_ptr(), nullptr, nullptr, - unbucketize_permute.data_ptr()); + unbucketize_permute.data_ptr(), + batch_sizes.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); @@ -379,7 +453,10 @@ block_bucketize_sparse_features_cuda( new_indices.data_ptr(), new_weights.data_ptr(), new_pos.data_ptr(), - nullptr); + nullptr, + batch_sizes.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); @@ -423,7 +500,10 @@ block_bucketize_sparse_features_cuda( new_indices.data_ptr(), new_weights.data_ptr(), nullptr, - nullptr); + nullptr, + batch_sizes.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); @@ -461,7 +541,10 @@ block_bucketize_sparse_features_cuda( new_indices.data_ptr(), nullptr, new_pos.data_ptr(), - nullptr); + nullptr, + batch_sizes.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); @@ -497,7 +580,10 @@ block_bucketize_sparse_features_cuda( new_indices.data_ptr(), nullptr, nullptr, - nullptr); + nullptr, + batch_sizes.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 4ea54e18ee..bcc3e4aaac 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -203,17 +203,18 @@ template < typename index_t, typename scalar_t> void _block_bucketize_sparse_features_cpu( - Tensor lengths, - Tensor indices, - c10::optional weights, - bool bucketize_pos, - Tensor block_sizes, - int64_t my_size, + const Tensor& lengths, + const Tensor& indices, + const c10::optional& weights, + const bool bucketize_pos, + const Tensor& block_sizes, + const int64_t my_size, Tensor new_lengths, Tensor new_indices, c10::optional new_weights, c10::optional new_pos, - c10::optional unbucketize_permute) { + const c10::optional& unbucketize_permute, + const c10::optional& batch_sizes) { // allocate tensors and buffers const auto lengths_size = lengths.numel(); const auto new_lengths_size = lengths_size * my_size; @@ -224,14 +225,17 @@ void _block_bucketize_sparse_features_cpu( const offset_t* lengths_data = lengths.data_ptr(); offset_t* offsets_data = offsets.data_ptr(); const index_t* indices_data = indices.data_ptr(); - scalar_t* weights_data; - scalar_t* new_weights_data; - index_t* new_pos_data; - index_t* unbucketize_permute_data; + scalar_t* weights_data = nullptr; + scalar_t* new_weights_data = nullptr; + index_t* new_pos_data = nullptr; + index_t* unbucketize_permute_data = nullptr; offset_t* new_lengths_data = new_lengths.data_ptr(); offset_t* new_offsets_data = new_offsets.data_ptr(); index_t* new_indices_data = new_indices.data_ptr(); index_t* block_sizes_data = block_sizes.data_ptr(); + offset_t* batch_sizes_data = nullptr; + const auto variable_batch_size = batch_sizes.has_value(); + using uindex_t = std::make_unsigned_t; using uoffset_t = std::make_unsigned_t; @@ -246,13 +250,19 @@ void _block_bucketize_sparse_features_cpu( new_pos_data = new_pos.value().data_ptr(); } + if (variable_batch_size) { + batch_sizes_data = batch_sizes.value().data_ptr(); + } + // count nonzeros prefix_sum(lengths_size, lengths_data, offsets_data); assert(offsets_data[lengths_size] == indices.numel()); + int64_t cur_offset = 0; for (const auto t : c10::irange(T)) { auto blk_size = block_sizes_data[t]; - for (const auto b : c10::irange(B)) { - const auto b_t = t * B + b; + const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B; + for (const auto b : c10::irange(cur_batch_size)) { + const auto b_t = (variable_batch_size ? cur_offset : t * B) + b; const offset_t rowstart = offsets_data[b_t]; const offset_t rowend = offsets_data[b_t + 1]; for (const auto i : c10::irange(rowstart, rowend)) { @@ -269,15 +279,18 @@ void _block_bucketize_sparse_features_cpu( new_lengths_data[p * lengths_size + b_t]++; } } + cur_offset += cur_batch_size; } // bucketize nonzeros prefix_sum(new_lengths_size, new_lengths_data, new_offsets_data); assert(new_offsets_data[new_lengths_size] == new_indices.numel()); + cur_offset = 0; for (const auto t : c10::irange(T)) { - auto blk_size = block_sizes_data[t]; - for (const auto b : c10::irange(B)) { - const auto b_t = t * B + b; + const auto blk_size = block_sizes_data[t]; + const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B; + for (const auto b : c10::irange(cur_batch_size)) { + const auto b_t = (variable_batch_size ? cur_offset : t * B) + b; const offset_t rowstart = offsets_data[b_t]; const offset_t rowend = offsets_data[b_t + 1]; for (const auto i : c10::irange(rowstart, rowend)) { @@ -308,6 +321,7 @@ void _block_bucketize_sparse_features_cpu( } } } + cur_offset += cur_batch_size; } } @@ -819,13 +833,17 @@ std::tuple< c10::optional, c10::optional> block_bucketize_sparse_features_cpu( - Tensor lengths, - Tensor indices, - bool bucketize_pos, - bool sequence, - Tensor block_sizes, - int64_t my_size, - c10::optional weights) { + const Tensor& lengths, + const Tensor& indices, + const bool bucketize_pos, + const bool sequence, + const Tensor& block_sizes, + const int64_t my_size, + const c10::optional& weights, + const c10::optional& batch_sizes, + const int64_t + /* max_batch_size */ // dummy variable only used in GPU implementation +) { const auto lengths_size = lengths.numel(); const auto new_lengths_size = lengths_size * my_size; auto new_lengths = at::zeros({new_lengths_size}, lengths.options()); @@ -871,7 +889,8 @@ block_bucketize_sparse_features_cpu( new_indices, new_weights, new_pos, - unbucketize_permute); + unbucketize_permute, + batch_sizes); }); }); }); @@ -905,7 +924,8 @@ block_bucketize_sparse_features_cpu( new_indices, new_weights, new_pos, - unbucketize_permute); + unbucketize_permute, + batch_sizes); }); }); }); @@ -937,7 +957,8 @@ block_bucketize_sparse_features_cpu( new_indices, new_weights, new_pos, - unbucketize_permute); + unbucketize_permute, + batch_sizes); }); }); } else { @@ -964,7 +985,8 @@ block_bucketize_sparse_features_cpu( new_indices, new_weights, new_pos, - unbucketize_permute); + unbucketize_permute, + batch_sizes); }); }); } @@ -2656,7 +2678,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, int output_size) -> Tensor"); m.def( - "block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, int my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)"); + "block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, int my_size, Tensor? weights=None, Tensor? batch_sizes=None, int max_B= -1) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)"); m.def( "bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, int my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)"); m.def("asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor"); diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 32c4c2916e..105b2ce3b6 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -839,6 +839,93 @@ def test_block_bucketize_sparse_features( unbucketized_indices, indices, rtol=0, atol=0 ) + @given( + index_type=st.sampled_from([torch.int, torch.long]), + has_weight=st.booleans(), + bucketize_pos=st.booleans(), + sequence=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) + def test_block_bucketize_sparse_features_with_variable_batch_sizes( + self, + index_type: Optional[torch.dtype], + has_weight: bool, + bucketize_pos: bool, + sequence: bool, + ) -> None: + lengths = torch.tensor([2, 1, 1, 2, 0, 2], dtype=index_type) + indices = torch.tensor( + [1, 8, 5, 6, 7, 8, 8, 4], + dtype=index_type, + ) + batch_sizes = torch.tensor([3, 1, 2], dtype=index_type) + weights = ( + torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + dtype=torch.float, + ) + if has_weight + else None + ) + + block_sizes = torch.tensor([5, 10, 8], dtype=index_type) + my_size = 2 + max_B = batch_sizes.max().item() + + new_lengths_ref = torch.tensor( + [1, 0, 0, 2, 0, 1, 1, 1, 1, 0, 0, 1], + dtype=index_type, + ) + new_indices_ref = torch.tensor( + [1, 7, 8, 4, 3, 0, 1, 0], + dtype=index_type, + ) + + ( + new_lengths_cpu, + new_indices_cpu, + new_weights_cpu, + new_pos_cpu, + unbucketize_permute, + ) = torch.ops.fbgemm.block_bucketize_sparse_features( + lengths, + indices, + bucketize_pos, + sequence, + block_sizes, + my_size, + weights, + batch_sizes, + ) + torch.testing.assert_close(new_lengths_cpu, new_lengths_ref, rtol=0, atol=0) + torch.testing.assert_close(new_indices_cpu, new_indices_ref, rtol=0, atol=0) + + if gpu_available: + ( + new_lengths_gpu, + new_indices_gpu, + new_weights_gpu, + new_pos_gpu, + unbucketize_permute_gpu, + ) = torch.ops.fbgemm.block_bucketize_sparse_features( + lengths.cuda(), + indices.cuda(), + bucketize_pos, + sequence, + block_sizes.cuda(), + my_size, + weights.cuda() if weights is not None else None, + batch_sizes.cuda(), + max_B, + ) + + torch.testing.assert_close( + new_lengths_gpu.cpu(), new_lengths_ref, rtol=0, atol=0 + ) + torch.testing.assert_close( + new_indices_gpu.cpu(), new_indices_ref, rtol=0, atol=0 + ) + @given( index_type=st.sampled_from([torch.int, torch.long]), has_weight=st.booleans(),