Skip to content

Commit

Permalink
Support variable batch_size for block_bucketize_sparse_features (#2012)
Browse files Browse the repository at this point in the history
Summary:

This diff add support variable batch size for block bucketize_sparse features for RW sharding.

Differential Revision: D48683632
  • Loading branch information
Qiang Zhang authored and facebook-github-bot committed Sep 21, 2023
1 parent a511f9d commit 9461d55
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 81 deletions.
32 changes: 18 additions & 14 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ std::tuple<

///@ingroup sparse-data-cuda
block_bucketize_sparse_features_cuda(
at::Tensor lengths,
at::Tensor indices,
bool bucketize_pos,
bool sequence,
at::Tensor block_sizes,
int64_t my_size,
c10::optional<at::Tensor> weights);
const at::Tensor& lengths,
const at::Tensor& indices,
const bool bucketize_pos,
const bool sequence,
const at::Tensor& block_sizes,
const int64_t my_size,
const c10::optional<at::Tensor>& weights,
const c10::optional<at::Tensor>& batch_size_per_feature,
const int64_t max_batch_size);

std::tuple<
at::Tensor,
Expand All @@ -155,13 +157,15 @@ std::tuple<

///@ingroup sparse-data-cpu
block_bucketize_sparse_features_cpu(
at::Tensor lengths,
at::Tensor indices,
bool bucketize_pos,
bool sequence,
at::Tensor block_sizes,
int64_t my_size,
c10::optional<at::Tensor> weights);
const at::Tensor& lengths,
const at::Tensor& indices,
const bool bucketize_pos,
const bool sequence,
const at::Tensor& block_sizes,
const int64_t my_size,
const c10::optional<at::Tensor>& weights,
const c10::optional<at::Tensor>& batch_size_per_feature,
const int64_t max_batch_size);

std::tuple<
at::Tensor,
Expand Down
152 changes: 120 additions & 32 deletions fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,46 @@ using Tensor = at::Tensor;

namespace fbgemm_gpu {

// Kernel for calulating lengthh idx to feature id mapping. Used for block
// bucketize sparse features with variable batch size for row-wise partition
template <typename offset_t>
__global__
__launch_bounds__(kMaxThreads) void _populate_length_to_feature_id_inplace_kernel(
const uint64_t max_B,
const int T,
const offset_t* const __restrict__ batch_sizes,
const offset_t* const __restrict__ batch_size_offsets,
offset_t* const __restrict__ length_to_feature_idx) {
const auto b_t = blockIdx.x * blockDim.x + threadIdx.x;

const auto t = b_t / max_B;
const auto b = b_t % max_B;

if (t >= T || b >= batch_sizes[t]) {
return;
}

length_to_feature_idx[batch_size_offsets[t] + b] = t;
}

// Kernel for bucketize lengths, with the Block distribution (vs. cyclic,
// block-cyclic distribution). Used for bucketize sparse feature, especially for
// checkpointing with row-wise partition (sparse_feature is partitioned
// continuously along the sparse dimension into my_size blocks)
template <typename offset_t, typename index_t>
__global__
__launch_bounds__(kMaxThreads) void _block_bucketize_sparse_features_cuda_kernel1(
int32_t lengths_size,
int32_t B,
const index_t* __restrict__ block_sizes_data,
int my_size,
const offset_t* __restrict__ offsets_data,
const index_t* __restrict__ indices_data,
offset_t* __restrict__ new_lengths_data) {
const int32_t lengths_size,
const int32_t B,
const index_t* const __restrict__ block_sizes_data,
const int my_size,
const offset_t* const __restrict__ offsets_data,
const index_t* const __restrict__ indices_data,
offset_t* const __restrict__ new_lengths_data,
offset_t* __restrict__ length_to_feature_idx) {
using uindex_t = std::make_unsigned_t<index_t>;
CUDA_KERNEL_LOOP(b_t, lengths_size) {
int32_t t = b_t / B;
const auto t = length_to_feature_idx ? length_to_feature_idx[b_t] : b_t / B;
index_t blk_size = block_sizes_data[t];
offset_t rowstart = (b_t == 0 ? 0 : offsets_data[b_t - 1]);
offset_t rowend = offsets_data[b_t];
Expand Down Expand Up @@ -71,11 +94,12 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sparse_features_cuda_kernel
index_t* __restrict__ new_indices_data,
scalar_t* __restrict__ new_weights_data,
index_t* __restrict__ new_pos_data,
index_t* __restrict__ unbucketize_permute_data) {
index_t* const __restrict__ unbucketize_permute_data,
const offset_t* const __restrict__ length_to_feature_idx) {
using uindex_t = std::make_unsigned_t<index_t>;
using uoffset_t = std::make_unsigned_t<offset_t>;
CUDA_KERNEL_LOOP(b_t, lengths_size) {
int32_t t = b_t / B;
const auto t = length_to_feature_idx ? length_to_feature_idx[b_t] : b_t / B;
index_t blk_size = block_sizes_data[t];
offset_t rowstart = (b_t == 0 ? 0 : offsets_data[b_t - 1]);
offset_t rowend = offsets_data[b_t];
Expand Down Expand Up @@ -115,36 +139,73 @@ DLL_PUBLIC std::tuple<
c10::optional<Tensor>,
c10::optional<Tensor>>
block_bucketize_sparse_features_cuda(
Tensor lengths,
Tensor indices,
bool bucketize_pos,
bool sequence,
Tensor block_sizes,
int64_t my_size,
c10::optional<Tensor> weights) {
const Tensor& lengths,
const Tensor& indices,
const bool bucketize_pos,
const bool sequence,
const Tensor& block_sizes,
const int64_t my_size,
const c10::optional<Tensor>& weights,
const c10::optional<Tensor>& batch_sizes,
const int64_t max_B) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(lengths.get_device());
// allocate tensors and buffers
const int lengths_size = lengths.numel();
const int T = block_sizes.numel();
const int B = lengths_size / T;
const int new_lengths_size = lengths_size * my_size;
const auto lengths_size = lengths.numel();
const auto T = block_sizes.numel();
const auto B = lengths_size / T;
const auto new_lengths_size = lengths_size * my_size;
auto offsets = at::empty({lengths_size}, lengths.options());
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
auto new_offsets = at::empty({new_lengths_size}, lengths.options());
auto new_indices = at::empty_like(indices);
auto lengths_contig = lengths.contiguous();
auto indices_contig = indices.contiguous();
auto offsets_contig = offsets.contiguous();
auto batch_sizes_contig =
batch_sizes.value_or(at::empty({T}, lengths.options())).contiguous();
auto batch_sizes_offsets_contig =
at::empty({T}, batch_sizes_contig.options());
Tensor new_weights;
Tensor new_pos;
Tensor unbucketize_permute;
// count nonzeros
offsets_contig = asynchronous_inclusive_cumsum_gpu(lengths);
int threads_per_block = 256;
int num_blocks = (lengths_size + threads_per_block - 1) / threads_per_block;
if (batch_sizes.has_value()) {
TORCH_CHECK(max_B > 0);
batch_sizes_offsets_contig =
asynchronous_exclusive_cumsum_gpu(batch_sizes.value());
}
auto length_to_feature_idx =
at::empty({lengths_size}, lengths_contig.options());
if (batch_sizes.has_value()) {
constexpr auto threads_per_block = 256;
const auto num_blocks =
cuda_calc_xblock_count(max_B * T, threads_per_block);
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_populate_length_to_feature_id_inplace_kernel",
[&] {
using offset_t = index_t;
_populate_length_to_feature_id_inplace_kernel<<<
num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
max_B,
T,
batch_sizes_contig.data_ptr<offset_t>(),
batch_sizes_offsets_contig.data_ptr<offset_t>(),
length_to_feature_idx.data_ptr<offset_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
constexpr auto threads_per_block = 256;
const auto num_blocks =
cuda_calc_xblock_count(lengths_size, threads_per_block);
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_block_bucketize_sparse_features_cuda_kernel1",
Expand All @@ -165,7 +226,10 @@ block_bucketize_sparse_features_cuda(
my_size,
offsets_contig.data_ptr<offset_t>(),
indices_contig.data_ptr<index_t>(),
new_lengths.data_ptr<offset_t>());
new_lengths.data_ptr<offset_t>(),
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -215,7 +279,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
new_pos.data_ptr<index_t>(),
unbucketize_permute.data_ptr<index_t>());
unbucketize_permute.data_ptr<index_t>(),
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -259,7 +326,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
nullptr,
unbucketize_permute.data_ptr<index_t>());
unbucketize_permute.data_ptr<index_t>(),
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -297,7 +367,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
nullptr,
new_pos.data_ptr<index_t>(),
unbucketize_permute.data_ptr<index_t>());
unbucketize_permute.data_ptr<index_t>(),
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -333,7 +406,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
nullptr,
nullptr,
unbucketize_permute.data_ptr<index_t>());
unbucketize_permute.data_ptr<index_t>(),
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -379,7 +455,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
new_pos.data_ptr<index_t>(),
nullptr);
nullptr,
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -423,7 +502,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
nullptr,
nullptr);
nullptr,
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -461,7 +543,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
nullptr,
new_pos.data_ptr<index_t>(),
nullptr);
nullptr,
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -497,7 +582,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
nullptr,
nullptr,
nullptr);
nullptr,
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down
Loading

0 comments on commit 9461d55

Please sign in to comment.