Skip to content

Commit

Permalink
Update asynchronous_complete_cumsum to support 2D inputs (#1573)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1573

Before this diff, `asynchronous_complete_cumsum` only supports 1D
inputs.  This diff adds the 2D input support.

Reviewed By: yinghai, jianyuh

Differential Revision: D42956351

fbshipit-source-id: 2e1775534f346f5e7d5c4900a4b2547ea708d16a
  • Loading branch information
sryap authored and facebook-github-bot committed Feb 3, 2023
1 parent 84fe62b commit 64c6a5b
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 38 deletions.
41 changes: 41 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,5 +312,46 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
)


@cli.command()
@click.option("--num-vecs", default=2048)
@click.option("--num-entries-per-vec", default=1024)
@click.option("--dtype", type=str, default="long")
def asynchronous_complete_cumsum_2d_bench(
num_vecs: int,
num_entries_per_vec: int,
dtype: str,
) -> None:
# Reference code from TorchRec https://github.com/pytorch/torchrec/pull/332
@torch.jit.script
def asynchronous_complete_cumsum_2d_ref(lengths: torch.Tensor) -> torch.Tensor:
(f, b) = lengths.shape
offsets_0 = lengths.new_zeros((f, 1))
offsets_1 = torch.cumsum(lengths, dim=-1).to(lengths.dtype)
offsets = torch.cat([offsets_0, offsets_1], dim=-1)
return offsets

assert dtype == "int" or dtype == "long", "Only int and long are supported"
index_dtype = torch.int64 if dtype == "long" else torch.int32

x = torch.randint(low=0, high=100, size=(num_vecs, num_entries_per_vec)).type(
index_dtype
)
x = x.cuda()

time_ref, _ = benchmark_torch_function(
asynchronous_complete_cumsum_2d_ref, (x,), num_warmups=100, iters=1000
)

time, _ = benchmark_torch_function(
torch.ops.fbgemm.asynchronous_complete_cumsum, (x,), num_warmups=100, iters=1000
)

logging.info(
f"asynchronous_complete_cumsum_2d_bench: input shape {x.shape}, dtype {dtype}"
)
logging.info(f"ref time: {time_ref:.5f} sec")
logging.info(f"fbgemm_gpu time: {time:.5f} sec")


if __name__ == "__main__":
cli()
36 changes: 33 additions & 3 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3054,14 +3054,44 @@ class FixedDivisor {
int shift_;
};

/**
* inclusive_sum_scan_kernel performs intra- and inter-thread block sum scan
* (i.e., prefix sum scan). We use cub::BlockScan to do inclusive sum within
* thread block and use a waterfall sync method to perform prefix sum across
* thread block.
*
* @param arr an array of input values. Its length must be fixed to
* ITEMS_PER_THREAD
* @param temp_storage a shared memory struct for cub::BlockScan
* @param block_flags a global flag buffer for inter-block sync (must be
* initialized with zeros)
* @param block_sums a global sum buffer for inter-block sync
* @param block_prev a shared memory pointer for sharing sum from the previous
* block within a block
* @param num_entries_per_block a number of input entries for this block
* @param block_id a relative thread block ID (the first block that contains
* the first set of input entries has block_id = 0)
* @param is_multi_block a boolean to indicate if inter-block sum scan has to
* be performed
* @param signal If the value of block_flags of the previous block is equal to
* signal, it means that the previous block has written its sum
* to block_sums. We have thread blocks increment the value of
* block_flags by one after they write their sums to block_sums.
* We increment the flag instead of setting the flag to a single
* value to support multiple sequential inclusive_sum_scan_kernel
* calls (e.g., in the AUC kernel). signal is the order that
* inclusive_sum_scan_kernel is called. Since we intialize
* block_flags with zeros, the signal of the first call should be
* one.
*/
template <typename scalar_t, int ITEMS_PER_THREAD, int NUM_THREADS_PER_BLOCK>
__inline__ __device__ void inclusive_sum_scan_kernel(
scalar_t (&arr)[ITEMS_PER_THREAD],
typename cub::BlockScan<scalar_t, NUM_THREADS_PER_BLOCK>::TempStorage&
temp_storage,
int* block_flags, // global flags for inter-block sync
scalar_t* block_sums, // global sums for inter-block sync
scalar_t* block_prev, // shared memory for previous sum sync within a block
int* block_flags,
scalar_t* block_sums,
scalar_t* block_prev,
const int num_entries_per_block,
const int block_id,
const bool is_multi_block,
Expand Down
202 changes: 173 additions & 29 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,71 @@ Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) {
return t_out;
}
template <
typename scalar_t,
int ITEMS_PER_THREAD,
int NUM_THREADS_PER_BLOCK,
int MAX_ENTRIES_PER_BLOCK>
__global__
__launch_bounds__(NUM_THREADS_PER_BLOCK) void batched_complete_cumsum_kernel(
const scalar_t* __restrict__ input,
const int32_t num_entries,
const int32_t last_block_num_entries,
const int32_t padded_num_entries_per_block,
const int32_t num_blocks,
int32_t* __restrict__ block_flags,
scalar_t* __restrict__ block_sums,
scalar_t* __restrict__ output) {
typedef cub::BlockScan<scalar_t, NUM_THREADS_PER_BLOCK> BlockScan;
__shared__ typename BlockScan::TempStorage bs_temp_storage;
__shared__ scalar_t block_prev;
scalar_t arr[ITEMS_PER_THREAD];
const int32_t block_id = blockIdx.x % num_blocks;
const int32_t vec_id = blockIdx.x / num_blocks;
const int num_entries_per_block = block_id == num_blocks - 1
? last_block_num_entries
: MAX_ENTRIES_PER_BLOCK;
const int input_offset = vec_id * num_entries;
const int output_offset = vec_id * (num_entries + 1);
const int flag_offset = vec_id * num_blocks;
const int block_offset = block_id * padded_num_entries_per_block;
const bool is_multi_block = num_blocks > 1;
const int section_offset = ITEMS_PER_THREAD * threadIdx.x;
// Load input entries into array
for (int i = 0;
i < ITEMS_PER_THREAD && section_offset + i < num_entries_per_block;
i++) {
arr[i] = input[input_offset + block_offset + section_offset + i];
}
inclusive_sum_scan_kernel<scalar_t, ITEMS_PER_THREAD, NUM_THREADS_PER_BLOCK>(
arr,
bs_temp_storage,
is_multi_block ? block_flags + flag_offset : nullptr,
is_multi_block ? block_sums + flag_offset : nullptr,
is_multi_block ? &block_prev : nullptr,
num_entries_per_block,
block_id,
is_multi_block,
/*signal=*/1);
// Write zero to the first entry of each vector
if (block_id == 0 && threadIdx.x == 0) {
output[output_offset] = 0;
}
// Load results to output
for (int i = 0;
i < ITEMS_PER_THREAD && section_offset + i < num_entries_per_block;
i++) {
output[output_offset + block_offset + section_offset + i + 1] = arr[i];
}
}
Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) {
TENSOR_ON_CUDA_GPU(t_in);
Expand All @@ -278,35 +343,114 @@ Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) {
size_t temp_storage_bytes = 0;
TORCH_CHECK(t_in.is_contiguous());
TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong);
// CUB only handles up to INT_MAX elements.
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
TORCH_CHECK(t_in.dim() == 1);
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
t_out[0].zero_();
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
t_in.options().dtype(at::kByte));
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
return t_out;
TORCH_CHECK(t_in.dim() == 1 || t_in.dim() == 2);
if (t_in.dim() == 1) {
// CUB only handles up to INT_MAX elements.
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
t_out[0].zero_();
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
t_in.options().dtype(at::kByte));
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
return t_out;
} else {
// Fix NUM_THREADS_PER_BLOCK because of CUB
constexpr int32_t MAX_ENTRIES_PER_BLOCK = 512;
constexpr int32_t NUM_THREADS_PER_BLOCK = 256;
const int32_t LOG_NUM_THREADS = std::log2(NUM_THREADS_PER_BLOCK);
// Enforce the same constraint as CUB
const auto num_vecs = t_in.size(0);
const auto num_entries = t_in.size(1);
TORCH_CHECK(num_entries < std::numeric_limits<int32_t>::max());
auto t_out = at::empty({num_vecs, num_entries + 1}, t_in.options());
const auto num_blocks = div_round_up(num_entries, MAX_ENTRIES_PER_BLOCK);
const int num_entries_per_block =
num_blocks > 1 ? MAX_ENTRIES_PER_BLOCK : num_entries;
// rounded_num_entries_per_block is either 0 or 256
const int rounded_num_entries_per_block =
(num_entries_per_block >> LOG_NUM_THREADS) << LOG_NUM_THREADS;
// padded_num_entries_per_block is either 256 or 512
const int padded_num_entries_per_block = rounded_num_entries_per_block +
(rounded_num_entries_per_block != num_entries_per_block
? NUM_THREADS_PER_BLOCK
: 0);
const int items_per_thread =
padded_num_entries_per_block / NUM_THREADS_PER_BLOCK;
const int last_block_num_entries =
num_entries - ((num_blocks - 1) * MAX_ENTRIES_PER_BLOCK);
const auto grid_size = num_blocks * num_vecs;
at::Tensor block_flags;
at::Tensor block_sums;
if (num_blocks > 1) {
block_flags = at::zeros({grid_size}, t_in.options().dtype(at::kInt));
block_sums = at::empty({grid_size}, t_out.options());
}
auto max_smem_size =
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
#define INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL(ITEMS_PER_THREAD) \
batched_complete_cumsum_kernel< \
index_t, \
ITEMS_PER_THREAD, \
NUM_THREADS_PER_BLOCK, \
MAX_ENTRIES_PER_BLOCK> \
<<<grid_size, \
NUM_THREADS_PER_BLOCK, \
0, \
at::cuda::getCurrentCUDAStream()>>>( \
t_in.data_ptr<index_t>(), \
num_entries, \
last_block_num_entries, \
padded_num_entries_per_block, \
num_blocks, \
num_blocks > 1 ? block_flags.data_ptr<int32_t>() : nullptr, \
num_blocks > 1 ? block_sums.data_ptr<index_t>() : nullptr, \
t_out.data_ptr<index_t>())
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "batched_complete_cumsum_kernel_warpper", [&] {
typedef cub::BlockScan<index_t, NUM_THREADS_PER_BLOCK> BlockScan;
TORCH_CHECK(
sizeof(BlockScan::TempStorage) + sizeof(index_t) <=
max_smem_size);
TORCH_CHECK(items_per_thread == 1 || items_per_thread == 2)
if (items_per_thread == 1) {
INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL(1);
} else {
INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL(2);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
#undef INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL
return t_out;
}
}
// Kernel for permuting the indices and weights. Used for permutation of sparse
Expand Down
29 changes: 23 additions & 6 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,18 +1090,35 @@ Tensor asynchronous_inclusive_cumsum_cpu(const Tensor& t_in) {

Tensor asynchronous_complete_cumsum_cpu(const Tensor& t_in) {
TENSOR_ON_CPU(t_in);
TORCH_CHECK(t_in.dim() == 1);
const auto num_dims = t_in.dim();
TORCH_CHECK(num_dims == 1 || num_dims == 2);

const auto t_in_contig = t_in.expect_contiguous();
auto output = at::zeros({t_in.numel() + 1}, t_in.options());
auto output = num_dims == 1
? at::zeros({t_in.numel() + 1}, t_in.options())
: at::zeros({t_in.size(0), t_in.size(1) + 1}, t_in.options());

AT_DISPATCH_ALL_TYPES(
t_in_contig->scalar_type(),
"asynchronous_complete_cumsum_cpu_kernel",
[&] {
const auto N = t_in_contig->numel();
const auto last_sum = exclusive_scan_ptrs_cpu(
N, t_in_contig->data_ptr<scalar_t>(), output.data_ptr<scalar_t>());
output.data_ptr<scalar_t>()[N] = last_sum;
if (num_dims == 1) {
const auto N = t_in_contig->numel();
output.data_ptr<scalar_t>()[N] = exclusive_scan_ptrs_cpu(
N,
t_in_contig->data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
} else {
const auto num_vecs = t_in_contig->size(0);
const auto N = t_in_contig->size(1);
at::parallel_for(0, num_vecs, 1, [&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
scalar_t* out_ptr = output.data_ptr<scalar_t>() + i * (N + 1);
out_ptr[N] = exclusive_scan_ptrs_cpu(
N, t_in_contig->data_ptr<scalar_t>() + i * N, out_ptr);
}
});
}
});
return output;
}
Expand Down
32 changes: 32 additions & 0 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,38 @@ def test_cumsum(self, n: int, long_index: bool) -> None:
zc.cpu(),
)

# pyre-ignore [56]
@given(
n=st.integers(min_value=1, max_value=600),
b=st.integers(min_value=1, max_value=10),
long_index=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
def test_asynchronous_complete_cumsum_2d(
self, n: int, b: int, long_index: bool
) -> None:
index_dtype = torch.int64 if long_index else torch.int32

def test_asynchronous_complete_cumsum_2d_helper(x: torch.Tensor) -> None:
np_index_dtype = np.int64 if long_index else np.int32
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
zeros = torch.zeros(b, 1)
torch.testing.assert_close(
torch.from_numpy(
np.cumsum(
torch.concat([zeros, x.cpu()], dim=1).numpy(), axis=1
).astype(np_index_dtype)
),
zc.cpu(),
)

x = torch.randint(low=0, high=100, size=(b, n)).type(index_dtype)
# cpu test
test_asynchronous_complete_cumsum_2d_helper(x)
if gpu_available:
# gpu test
test_asynchronous_complete_cumsum_2d_helper(x.cuda())

# pyre-ignore [56]
@given(
N=st.integers(min_value=1, max_value=20),
Expand Down

0 comments on commit 64c6a5b

Please sign in to comment.