Skip to content

Commit

Permalink
Update asynchronous_complete_cumsum to support 2D inputs (pytorch#1573)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1573

Before this diff, `asynchronous_complete_cumsum` only supports 1D
inputs.  This diff adds the 2D input support.

Differential Revision: D42956351

fbshipit-source-id: 2a6e7107b919411a8cffe17f13ac12de7fc74c24
  • Loading branch information
sryap authored and facebook-github-bot committed Feb 2, 2023
1 parent 84fe62b commit 8f647da
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 29 deletions.
41 changes: 41 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,5 +312,46 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
)


@cli.command()
@click.option("--num-vecs", default=2048)
@click.option("--num-entries-per-vec", default=1024)
@click.option("--dtype", type=str, default="long")
def asynchronous_complete_cumsum_2d_bench(
num_vecs: int,
num_entries_per_vec: int,
dtype: str,
) -> None:
# Reference code from TorchRec https://github.com/pytorch/torchrec/pull/332
@torch.jit.script
def asynchronous_complete_cumsum_2d_ref(lengths: torch.Tensor) -> torch.Tensor:
(f, b) = lengths.shape
offsets_0 = lengths.new_zeros((f, 1))
offsets_1 = torch.cumsum(lengths, dim=-1).to(lengths.dtype)
offsets = torch.cat([offsets_0, offsets_1], dim=-1)
return offsets

assert dtype == "int" or dtype == "long", "Only int and long are supported"
index_dtype = torch.int64 if dtype == "long" else torch.int32

x = torch.randint(low=0, high=100, size=(num_vecs, num_entries_per_vec)).type(
index_dtype
)
x = x.cuda()

time_ref, _ = benchmark_torch_function(
asynchronous_complete_cumsum_2d_ref, (x,), num_warmups=100, iters=1000
)

time, _ = benchmark_torch_function(
torch.ops.fbgemm.asynchronous_complete_cumsum, (x,), num_warmups=100, iters=1000
)

logging.info(
f"asynchronous_complete_cumsum_2d_bench: input shape {x.shape}, dtype {dtype}"
)
logging.info(f"ref time: {time_ref:.5f} sec")
logging.info(f"fbgemm_gpu time: {time:.5f} sec")


if __name__ == "__main__":
cli()
202 changes: 173 additions & 29 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,71 @@ Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) {
return t_out;
}
template <
typename scalar_t,
int ITEMS_PER_THREAD,
int NUM_THREADS_PER_BLOCK,
int MAX_ENTRIES_PER_BLOCK>
__global__
__launch_bounds__(NUM_THREADS_PER_BLOCK) void batched_complete_cumsum_kernel(
const scalar_t* __restrict__ input,
const int32_t num_entries,
const int32_t last_block_num_entries,
const int32_t padded_num_entries_per_block,
const int32_t num_blocks,
int32_t* __restrict__ block_flags,
scalar_t* __restrict__ block_sums,
scalar_t* __restrict__ output) {
typedef cub::BlockScan<scalar_t, NUM_THREADS_PER_BLOCK> BlockScan;
__shared__ typename BlockScan::TempStorage bs_temp_storage;
__shared__ scalar_t block_prev;
scalar_t arr[ITEMS_PER_THREAD];
const int32_t block_id = blockIdx.x % num_blocks;
const int32_t vec_id = blockIdx.x / num_blocks;
const int num_entries_per_block = block_id == num_blocks - 1
? last_block_num_entries
: MAX_ENTRIES_PER_BLOCK;
const int input_offset = vec_id * num_entries;
const int output_offset = vec_id * (num_entries + 1);
const int flag_offset = vec_id * num_blocks;
const int block_offset = block_id * padded_num_entries_per_block;
const bool is_multi_block = num_blocks > 1;
const int section_offset = ITEMS_PER_THREAD * threadIdx.x;
// Load input entries into array
for (int i = 0;
i < ITEMS_PER_THREAD && section_offset + i < num_entries_per_block;
i++) {
arr[i] = input[input_offset + block_offset + section_offset + i];
}
inclusive_sum_scan_kernel<scalar_t, ITEMS_PER_THREAD, NUM_THREADS_PER_BLOCK>(
arr,
bs_temp_storage,
is_multi_block ? block_flags + flag_offset : nullptr,
is_multi_block ? block_sums + flag_offset : nullptr,
is_multi_block ? &block_prev : nullptr,
num_entries_per_block,
block_id,
is_multi_block,
/*signal=*/1);
// Write zero to the first entry of each vector
if (block_id == 0 && threadIdx.x == 0) {
output[output_offset] = 0;
}
// Load results to output
for (int i = 0;
i < ITEMS_PER_THREAD && section_offset + i < num_entries_per_block;
i++) {
output[output_offset + block_offset + section_offset + i + 1] = arr[i];
}
}
Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) {
TENSOR_ON_CUDA_GPU(t_in);
Expand All @@ -278,35 +343,114 @@ Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) {
size_t temp_storage_bytes = 0;
TORCH_CHECK(t_in.is_contiguous());
TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong);
// CUB only handles up to INT_MAX elements.
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
TORCH_CHECK(t_in.dim() == 1);
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
t_out[0].zero_();
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
t_in.options().dtype(at::kByte));
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
return t_out;
TORCH_CHECK(t_in.dim() == 1 || t_in.dim() == 2);
if (t_in.dim() == 1) {
// CUB only handles up to INT_MAX elements.
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
t_out[0].zero_();
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
t_in.options().dtype(at::kByte));
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
return t_out;
} else {
// Fix NUM_THREADS_PER_BLOCK because of CUB
constexpr int32_t MAX_ENTRIES_PER_BLOCK = 512;
constexpr int32_t NUM_THREADS_PER_BLOCK = 256;
const int32_t LOG_NUM_THREADS = std::log2(NUM_THREADS_PER_BLOCK);
// Enforce the same constraint as CUB
const auto num_vecs = t_in.size(0);
const auto num_entries = t_in.size(1);
TORCH_CHECK(num_entries < std::numeric_limits<int32_t>::max());
auto t_out = at::empty({num_vecs, num_entries + 1}, t_in.options());
const auto num_blocks = div_round_up(num_entries, MAX_ENTRIES_PER_BLOCK);
const int num_entries_per_block =
num_blocks > 1 ? MAX_ENTRIES_PER_BLOCK : num_entries;
// rounded_num_entries_per_block is either 0 or 256
const int rounded_num_entries_per_block =
(num_entries_per_block >> LOG_NUM_THREADS) << LOG_NUM_THREADS;
// padded_num_entries_per_block is either 256 or 512
const int padded_num_entries_per_block = rounded_num_entries_per_block +
(rounded_num_entries_per_block != num_entries_per_block
? NUM_THREADS_PER_BLOCK
: 0);
const int items_per_thread =
padded_num_entries_per_block / NUM_THREADS_PER_BLOCK;
const int last_block_num_entries =
num_entries - ((num_blocks - 1) * MAX_ENTRIES_PER_BLOCK);
const auto grid_size = num_blocks * num_vecs;
at::Tensor block_flags;
at::Tensor block_sums;
if (num_blocks > 1) {
block_flags = at::zeros({grid_size}, t_in.options().dtype(at::kInt));
block_sums = at::empty({grid_size}, t_out.options());
}
auto max_smem_size =
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
#define INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL(ITEMS_PER_THREAD) \
batched_complete_cumsum_kernel< \
index_t, \
ITEMS_PER_THREAD, \
NUM_THREADS_PER_BLOCK, \
MAX_ENTRIES_PER_BLOCK> \
<<<grid_size, \
NUM_THREADS_PER_BLOCK, \
0, \
at::cuda::getCurrentCUDAStream()>>>( \
t_in.data_ptr<index_t>(), \
num_entries, \
last_block_num_entries, \
padded_num_entries_per_block, \
num_blocks, \
num_blocks > 1 ? block_flags.data_ptr<int32_t>() : nullptr, \
num_blocks > 1 ? block_sums.data_ptr<index_t>() : nullptr, \
t_out.data_ptr<index_t>())
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "batched_complete_cumsum_kernel_warpper", [&] {
typedef cub::BlockScan<index_t, NUM_THREADS_PER_BLOCK> BlockScan;
TORCH_CHECK(
sizeof(BlockScan::TempStorage) + sizeof(index_t) <=
max_smem_size);
TORCH_CHECK(items_per_thread == 1 || items_per_thread == 2)
if (items_per_thread == 1) {
INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL(1);
} else {
INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL(2);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
#undef INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL
return t_out;
}
}
// Kernel for permuting the indices and weights. Used for permutation of sparse
Expand Down
27 changes: 27 additions & 0 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,33 @@ def test_cumsum(self, n: int, long_index: bool) -> None:
zc.cpu(),
)

@unittest.skipIf(*gpu_unavailable)
# pyre-ignore [56]
@given(
n=st.integers(min_value=1, max_value=600),
b=st.integers(min_value=1, max_value=10),
long_index=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
def test_asynchronous_complete_cumsum_2d(
self, n: int, b: int, long_index: bool
) -> None:
index_dtype = torch.int64 if long_index else torch.int32
np_index_dtype = np.int64 if long_index else np.int32

x = torch.randint(low=0, high=100, size=(b, n)).type(index_dtype)
x = x.cuda()
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
zeros = torch.zeros(b, 1)
torch.testing.assert_close(
torch.from_numpy(
np.cumsum(torch.concat([zeros, x.cpu()], dim=1).numpy(), axis=1).astype(
np_index_dtype
)
),
zc.cpu(),
)

# pyre-ignore [56]
@given(
N=st.integers(min_value=1, max_value=20),
Expand Down

0 comments on commit 8f647da

Please sign in to comment.