Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support variable batch_size for block_bucketize_sparse_features #2012

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ std::tuple<

///@ingroup sparse-data-cuda
block_bucketize_sparse_features_cuda(
at::Tensor lengths,
at::Tensor indices,
bool bucketize_pos,
bool sequence,
at::Tensor block_sizes,
int64_t my_size,
c10::optional<at::Tensor> weights);
const at::Tensor& lengths,
const at::Tensor& indices,
const bool bucketize_pos,
const bool sequence,
const at::Tensor& block_sizes,
const int64_t my_size,
const c10::optional<at::Tensor>& weights,
const c10::optional<at::Tensor>& batch_size_per_feature,
const int64_t max_batch_size);

std::tuple<
at::Tensor,
Expand All @@ -155,13 +157,15 @@ std::tuple<

///@ingroup sparse-data-cpu
block_bucketize_sparse_features_cpu(
at::Tensor lengths,
at::Tensor indices,
bool bucketize_pos,
bool sequence,
at::Tensor block_sizes,
int64_t my_size,
c10::optional<at::Tensor> weights);
const at::Tensor& lengths,
const at::Tensor& indices,
const bool bucketize_pos,
const bool sequence,
const at::Tensor& block_sizes,
const int64_t my_size,
const c10::optional<at::Tensor>& weights,
const c10::optional<at::Tensor>& batch_size_per_feature,
const int64_t max_batch_size);

std::tuple<
at::Tensor,
Expand Down
152 changes: 120 additions & 32 deletions fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,46 @@ using Tensor = at::Tensor;

namespace fbgemm_gpu {

// Kernel for calulating lengthh idx to feature id mapping. Used for block
// bucketize sparse features with variable batch size for row-wise partition
template <typename offset_t>
__global__
__launch_bounds__(kMaxThreads) void _populate_length_to_feature_id_inplace_kernel(
const uint64_t max_B,
const int T,
const offset_t* const __restrict__ batch_sizes,
const offset_t* const __restrict__ batch_size_offsets,
offset_t* const __restrict__ length_to_feature_idx) {
const auto b_t = blockIdx.x * blockDim.x + threadIdx.x;

const auto t = b_t / max_B;
const auto b = b_t % max_B;

if (t >= T || b >= batch_sizes[t]) {
return;
}

length_to_feature_idx[batch_size_offsets[t] + b] = t;
}

// Kernel for bucketize lengths, with the Block distribution (vs. cyclic,
// block-cyclic distribution). Used for bucketize sparse feature, especially for
// checkpointing with row-wise partition (sparse_feature is partitioned
// continuously along the sparse dimension into my_size blocks)
template <typename offset_t, typename index_t>
__global__
__launch_bounds__(kMaxThreads) void _block_bucketize_sparse_features_cuda_kernel1(
int32_t lengths_size,
int32_t B,
const index_t* __restrict__ block_sizes_data,
int my_size,
const offset_t* __restrict__ offsets_data,
const index_t* __restrict__ indices_data,
offset_t* __restrict__ new_lengths_data) {
const int32_t lengths_size,
const int32_t B,
const index_t* const __restrict__ block_sizes_data,
const int my_size,
const offset_t* const __restrict__ offsets_data,
const index_t* const __restrict__ indices_data,
offset_t* const __restrict__ new_lengths_data,
offset_t* __restrict__ length_to_feature_idx) {
using uindex_t = std::make_unsigned_t<index_t>;
CUDA_KERNEL_LOOP(b_t, lengths_size) {
int32_t t = b_t / B;
const auto t = length_to_feature_idx ? length_to_feature_idx[b_t] : b_t / B;
index_t blk_size = block_sizes_data[t];
offset_t rowstart = (b_t == 0 ? 0 : offsets_data[b_t - 1]);
offset_t rowend = offsets_data[b_t];
Expand Down Expand Up @@ -71,11 +94,12 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sparse_features_cuda_kernel
index_t* __restrict__ new_indices_data,
scalar_t* __restrict__ new_weights_data,
index_t* __restrict__ new_pos_data,
index_t* __restrict__ unbucketize_permute_data) {
index_t* const __restrict__ unbucketize_permute_data,
const offset_t* const __restrict__ length_to_feature_idx) {
using uindex_t = std::make_unsigned_t<index_t>;
using uoffset_t = std::make_unsigned_t<offset_t>;
CUDA_KERNEL_LOOP(b_t, lengths_size) {
int32_t t = b_t / B;
const auto t = length_to_feature_idx ? length_to_feature_idx[b_t] : b_t / B;
index_t blk_size = block_sizes_data[t];
offset_t rowstart = (b_t == 0 ? 0 : offsets_data[b_t - 1]);
offset_t rowend = offsets_data[b_t];
Expand Down Expand Up @@ -115,36 +139,73 @@ DLL_PUBLIC std::tuple<
c10::optional<Tensor>,
c10::optional<Tensor>>
block_bucketize_sparse_features_cuda(
Tensor lengths,
Tensor indices,
bool bucketize_pos,
bool sequence,
Tensor block_sizes,
int64_t my_size,
c10::optional<Tensor> weights) {
const Tensor& lengths,
const Tensor& indices,
const bool bucketize_pos,
const bool sequence,
const Tensor& block_sizes,
const int64_t my_size,
const c10::optional<Tensor>& weights,
const c10::optional<Tensor>& batch_sizes,
const int64_t max_B) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(lengths.get_device());
// allocate tensors and buffers
const int lengths_size = lengths.numel();
const int T = block_sizes.numel();
const int B = lengths_size / T;
const int new_lengths_size = lengths_size * my_size;
const auto lengths_size = lengths.numel();
const auto T = block_sizes.numel();
const auto B = lengths_size / T;
const auto new_lengths_size = lengths_size * my_size;
auto offsets = at::empty({lengths_size}, lengths.options());
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
auto new_offsets = at::empty({new_lengths_size}, lengths.options());
auto new_indices = at::empty_like(indices);
auto lengths_contig = lengths.contiguous();
auto indices_contig = indices.contiguous();
auto offsets_contig = offsets.contiguous();
auto batch_sizes_contig =
batch_sizes.value_or(at::empty({T}, lengths.options())).contiguous();
auto batch_sizes_offsets_contig =
at::empty({T}, batch_sizes_contig.options());
Tensor new_weights;
Tensor new_pos;
Tensor unbucketize_permute;
// count nonzeros
offsets_contig = asynchronous_inclusive_cumsum_gpu(lengths);
int threads_per_block = 256;
int num_blocks = (lengths_size + threads_per_block - 1) / threads_per_block;
if (batch_sizes.has_value()) {
TORCH_CHECK(max_B > 0);
batch_sizes_offsets_contig =
asynchronous_exclusive_cumsum_gpu(batch_sizes.value());
}
auto length_to_feature_idx =
at::empty({lengths_size}, lengths_contig.options());
if (batch_sizes.has_value()) {
constexpr auto threads_per_block = 256;
const auto num_blocks =
cuda_calc_xblock_count(max_B * T, threads_per_block);
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_populate_length_to_feature_id_inplace_kernel",
[&] {
using offset_t = index_t;
_populate_length_to_feature_id_inplace_kernel<<<
num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
max_B,
T,
batch_sizes_contig.data_ptr<offset_t>(),
batch_sizes_offsets_contig.data_ptr<offset_t>(),
length_to_feature_idx.data_ptr<offset_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}

constexpr auto threads_per_block = 256;
const auto num_blocks =
cuda_calc_xblock_count(lengths_size, threads_per_block);
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(),
"_block_bucketize_sparse_features_cuda_kernel1",
Expand All @@ -165,7 +226,10 @@ block_bucketize_sparse_features_cuda(
my_size,
offsets_contig.data_ptr<offset_t>(),
indices_contig.data_ptr<index_t>(),
new_lengths.data_ptr<offset_t>());
new_lengths.data_ptr<offset_t>(),
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -215,7 +279,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
new_pos.data_ptr<index_t>(),
unbucketize_permute.data_ptr<index_t>());
unbucketize_permute.data_ptr<index_t>(),
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -259,7 +326,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
nullptr,
unbucketize_permute.data_ptr<index_t>());
unbucketize_permute.data_ptr<index_t>(),
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -297,7 +367,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
nullptr,
new_pos.data_ptr<index_t>(),
unbucketize_permute.data_ptr<index_t>());
unbucketize_permute.data_ptr<index_t>(),
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -333,7 +406,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
nullptr,
nullptr,
unbucketize_permute.data_ptr<index_t>());
unbucketize_permute.data_ptr<index_t>(),
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -379,7 +455,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
new_pos.data_ptr<index_t>(),
nullptr);
nullptr,
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -423,7 +502,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
new_weights.data_ptr<scalar_t>(),
nullptr,
nullptr);
nullptr,
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -461,7 +543,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
nullptr,
new_pos.data_ptr<index_t>(),
nullptr);
nullptr,
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down Expand Up @@ -497,7 +582,10 @@ block_bucketize_sparse_features_cuda(
new_indices.data_ptr<index_t>(),
nullptr,
nullptr,
nullptr);
nullptr,
batch_sizes.has_value()
? length_to_feature_idx.data_ptr<offset_t>()
: static_cast<offset_t*>(nullptr));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down
Loading
Loading