Skip to content

Commit

Permalink
Add group_index_select_dim0 (pytorch#1421)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1421

`group_index_select_dim0` does `index_select_dim0` for each group of
inputs in a single kernel.  The operator takes a list of inputs and a
list of indices and returns a list of outputs.

There are some limitations to group_index_select_dim0:
- All inputs must have the same shape.
- All indices must have the same shape.
- Because we use variadic template for the autograd function, it
supports up to 55 groups.

Differential Revision: D40683435

fbshipit-source-id: bcc52c0b1ae4f9270901bbd45d85d45c89051eb6
  • Loading branch information
sryap authored and facebook-github-bot committed Nov 3, 2022
1 parent 3542b50 commit 035e812
Show file tree
Hide file tree
Showing 6 changed files with 642 additions and 1 deletion.
88 changes: 88 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,5 +224,93 @@ def jagged_index_select_2d_ref(
logging.info(f"backward: fbgemm {time * 1e3:.3f} ms, ref {time_ref * 1e3:.3f} ms")


@cli.command()
@click.option("--row-size", default=512)
@click.option("--batch-size", default=4096)
@click.option("--unique-batch-size", default=1024)
@click.option("--input-precision", type=str, default="fp32")
@click.option("--sort-indices", type=bool, default=True)
@click.option("--num-groups", default=32)
def group_index_select_2d_bench(
row_size: int,
batch_size: int,
unique_batch_size: int,
input_precision: str,
sort_indices: bool,
num_groups: int,
) -> None:
def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
inverse_index = list(range(curr_size))
np_arr = np.array(inverse_index)
for _ in range(final_size - curr_size):
inverse_index.append(np.random.randint(0, curr_size))
np_arr = np.array(inverse_index)
np.random.shuffle(np_arr)
return np_arr

dtype = torch.float
if input_precision == "fp32":
dtype = torch.float
elif input_precision == "fp16":
dtype = torch.half
else:
raise RuntimeError(f"Does not support data type {input_precision}")

offset_indices_group = []
indices_group = []
for i in range(num_groups):
# pyre-fixme[16]: Module `cuda` has no attribute `IntTensor`.
indices = torch.cuda.IntTensor(gen_inverse_index(unique_batch_size, batch_size))
if sort_indices:
indices, _ = indices.sort()
indices_group.append(indices)
indices = torch.add(indices, batch_size * i)
offset_indices_group.append(indices)

offset_indices = torch.concat(offset_indices_group)

input = torch.rand(num_groups * batch_size, row_size, dtype=dtype, device="cuda")
input.requires_grad = True

num_bytes = 2 * batch_size * row_size * input.element_size() * num_groups

bench_kwargs = {"num_warmups": 10, "iters": 100}

# Benchmark forward
time_ref, output_ref = benchmark_torch_function(
torch.index_select, (input, 0, offset_indices), **bench_kwargs
)

input_group = input.split(batch_size, 0)
time, output_group = benchmark_torch_function(
torch.ops.fbgemm.group_index_select_dim0,
(input_group, indices_group),
**bench_kwargs,
)
logging.info(
f"forward: PyTorch batch {time_ref:.5f} sec ({num_bytes / time_ref / 1e9:.5f} GB/s), "
f"fbgemm group {time:5f} sec ({num_bytes / time / 1e9:.5f} GB/s)"
)

# Benchmark backward
grad = torch.rand_like(output_ref)
time_ref, _ = benchmark_torch_function(
functools.partial(output_ref.backward, retain_graph=True),
(grad,),
**bench_kwargs,
)

cat_output = torch.cat(output_group)
time, _ = benchmark_torch_function(
functools.partial(cat_output.backward, retain_graph=True),
(grad,),
**bench_kwargs,
)
logging.info(
f"backward: PyTorch batch {time_ref:.5f} sec ({num_bytes / time_ref / 1e9:.5f} GB/s), "
f"fbgemm group {time:.5f} sec ({num_bytes / time / 1e9:.5f} GB/s)"
)


if __name__ == "__main__":
cli()
27 changes: 27 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,5 +697,32 @@ at::Tensor index_add_with_unique_indices_cuda(
std::vector<int64_t>& input_shape,
const int consecutive_range_start,
const int consecutive_range_length);

///@ingroup sparse-data-cuda
std::vector<at::Tensor> group_index_select_cuda(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
const c10::TensorOptions& input_tensor_options,
const c10::ScalarType& input_scalar_type,
const c10::ScalarType& indices_scalar_type,
const c10::DeviceIndex& device,
const std::vector<int64_t>& output_shape,
const int num_input_rows,
const int num_output_rows,
const int num_cols,
const int num_groups);

std::vector<at::Tensor> group_index_add_cuda(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
const c10::TensorOptions& input_tensor_options,
const c10::ScalarType& input_scalar_type,
const c10::ScalarType& indices_scalar_type,
const c10::DeviceIndex& device,
const std::vector<int64_t>& output_shape,
const int num_input_rows,
const int num_output_rows,
const int num_cols,
const int num_groups);
#endif
} // namespace fbgemm_gpu
192 changes: 192 additions & 0 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2850,6 +2850,198 @@ Tensor index_add_with_unique_indices_cuda(
return input_grad.reshape(input_shape);
}
template <typename index_t, typename scalar_t, int UNROLL_FACTOR>
__global__ __launch_bounds__(kMaxThreads) void group_index_select_2d_kernel(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
scalar_t* output,
const int64_t num_input_rows,
const int64_t num_output_rows,
const int64_t num_cols,
const int64_t num_groups) {
for (int64_t bid = threadIdx.y * gridDim.x + blockIdx.x;
bid < num_groups * num_output_rows;
bid += gridDim.x * blockDim.y) {
const int64_t group_id = bid / num_output_rows;
const int64_t row = bid % num_output_rows;
scalar_t* input = (scalar_t*)input_ptrs[group_id];
index_t* indices = (index_t*)indices_ptrs[group_id];
const index_t idx = indices[row];
CUDA_KERNEL_ASSERT(idx < num_input_rows)
int col;
scalar_t* output_ = output + (num_output_rows * num_cols * group_id);
for (col = threadIdx.x * UNROLL_FACTOR;
col < num_cols / UNROLL_FACTOR * UNROLL_FACTOR;
col += blockDim.x * UNROLL_FACTOR) {
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) {
output_[row * num_cols + col + i] =
LDG(&input[idx * num_cols + col + i]);
}
}
for (; col < num_cols; ++col) {
output_[row * num_cols + col] = LDG(&input[idx * num_cols + col]);
}
}
}
template <typename index_t, typename scalar_t, int UNROLL_FACTOR>
__global__ __launch_bounds__(kMaxThreads) void group_index_add_2d_kernel(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
scalar_t* output,
const int64_t num_input_rows,
const int64_t num_output_rows,
const int64_t num_cols,
const int64_t num_groups) {
for (int64_t bid = threadIdx.y * gridDim.x + blockIdx.x;
bid < num_groups * num_input_rows;
bid += gridDim.x * blockDim.y) {
const int64_t group_id = bid / num_input_rows;
const int64_t row = bid % num_input_rows;
scalar_t* input = (scalar_t*)input_ptrs[group_id];
index_t* indices = (index_t*)indices_ptrs[group_id];
const index_t idx = indices[row];
CUDA_KERNEL_ASSERT(idx < num_output_rows)
int col;
scalar_t* output_ = output + (num_output_rows * num_cols * group_id);
for (col = threadIdx.x * UNROLL_FACTOR;
col < num_cols / UNROLL_FACTOR * UNROLL_FACTOR;
col += blockDim.x * UNROLL_FACTOR) {
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) {
// PyTorch also uses atomicAdd. It does not require sorting and
// provides better parallelism. But this can lead to numerical
// indeterminisim.
gpuAtomicAddNoReturn(
&output_[idx * num_cols + col + i],
input[row * num_cols + col + i]);
}
}
for (; col < num_cols; ++col) {
gpuAtomicAddNoReturn(
&output[idx * num_cols + col], input[row * num_cols + col]);
}
}
}
std::vector<Tensor> group_index_select_cuda(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
const c10::TensorOptions& input_tensor_options,
const c10::ScalarType& input_scalar_type,
const c10::ScalarType& indices_scalar_type,
const c10::DeviceIndex& device,
const std::vector<int64_t>& output_shape,
const int num_input_rows,
const int num_output_rows,
const int num_cols,
const int num_groups) {
if (num_groups == 0) {
return std::vector<Tensor>();
}
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(device);
Tensor output = at::empty(output_shape, input_tensor_options);
// Partition work based on num_output_rows
const int UNROLL_FACTOR = 1;
uint32_t max_grid_size =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
uint32_t grid_size = std::min(
cuda_calc_xblock_count(num_groups * num_output_rows, 1), max_grid_size);
uint32_t block_size_x =
std::min(div_round_up(num_cols, UNROLL_FACTOR), kMaxThreads);
uint32_t block_size_y =
std::max((num_groups * num_output_rows) / grid_size, (uint32_t)1);
dim3 block_size(
block_size_x,
std::min(block_size_y, (uint32_t)(kMaxThreads / block_size_x)),
1);
AT_DISPATCH_INDEX_TYPES(
indices_scalar_type, "group_index_select_2d_wrapper_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input_scalar_type, "group_index_select_2d_wrapper_2", [&] {
group_index_select_2d_kernel<index_t, scalar_t, UNROLL_FACTOR>
<<<grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input_ptrs,
indices_ptrs,
output.data_ptr<scalar_t>(),
num_input_rows,
num_output_rows,
num_cols,
num_groups);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return output.split(num_output_rows, 0);
}
std::vector<Tensor> group_index_add_cuda(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
const c10::TensorOptions& input_tensor_options,
const c10::ScalarType& input_scalar_type,
const c10::ScalarType& indices_scalar_type,
const c10::DeviceIndex& device,
const std::vector<int64_t>& output_shape,
const int num_input_rows,
const int num_output_rows,
const int num_cols,
const int num_groups) {
if (num_groups == 0) {
return std::vector<Tensor>();
}
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(device);
Tensor output = at::zeros(output_shape, input_tensor_options);
// Partition work based on num_input_rows
const int UNROLL_FACTOR = 1;
uint32_t max_grid_size =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
uint32_t grid_size = std::min(
cuda_calc_xblock_count(num_groups * num_input_rows, 1), max_grid_size);
uint32_t block_size_x =
std::min(div_round_up(num_cols, UNROLL_FACTOR), kMaxThreads);
uint32_t block_size_y =
std::max((num_groups * num_input_rows) / grid_size, (uint32_t)1);
dim3 block_size(
block_size_x,
std::min(block_size_y, (uint32_t)(kMaxThreads / block_size_x)),
1);
AT_DISPATCH_INDEX_TYPES(
indices_scalar_type, "group_index_add_2d_wrapper_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input_scalar_type, "group_index_add_2d_wrapper_2", [&] {
group_index_add_2d_kernel<index_t, scalar_t, UNROLL_FACTOR>
<<<grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input_ptrs,
indices_ptrs,
output.data_ptr<scalar_t>(),
num_input_rows,
num_output_rows,
num_cols,
num_groups);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return output.split(num_output_rows, 0);
}
// Copied from cupy/random/_kernels.py v11
// (commit id 420e41fd41157d4cf526b0e94eb86a3f8eb5a231)
Expand Down
18 changes: 18 additions & 0 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2357,6 +2357,19 @@ Tensor index_select_dim0(
return at::index_select(input, 0, indices);
}

std::vector<Tensor> group_index_select_dim0(
const std::vector<Tensor>& input_group,
const std::vector<Tensor>& indices_group) {
int num_groups = input_group.size();
TORCH_CHECK(num_groups == (int)indices_group.size())
std::vector<Tensor> output_group;
for (int i = 0; i < num_groups; i++) {
output_group.push_back(
at::index_select(input_group[i], 0, indices_group[i]));
}
return output_group;
}

Tensor bottom_unique_k_per_row(const Tensor& input, const int64_t k) {
auto num_cols = input.size(-1);
Tensor input_reshaped = input.reshape({-1, num_cols});
Expand Down Expand Up @@ -2389,6 +2402,7 @@ Tensor bottom_unique_k_per_row(const Tensor& input, const int64_t k) {
output_shape[output_shape.size() - 1] = k;
return output.reshape(output_shape);
}

} // namespace

} // namespace fbgemm_gpu
Expand Down Expand Up @@ -2458,6 +2472,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// skip_indices_sorting_fwd is for skipping indices sorting in forward
m.def(
"index_select_dim0(Tensor input, Tensor indices, int? consecutive_range_start=0, int? consecutive_range_length=0, bool? skip_indices_sorting_fwd=None) -> Tensor");
m.def(
"group_index_select_dim0(Tensor[] input_group, Tensor[] indices_group) -> Tensor[]");
m.def(
"jagged_index_select(Tensor values, Tensor lengths, Tensor indices) -> Tensor[]");
// This is an one-off op to be used in bench_utils.py for zipf generation w/o
Expand Down Expand Up @@ -2528,6 +2544,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
fbgemm_gpu::permute_sequence_embeddings_cpu);
DISPATCH_TO_CPU("pack_segments", fbgemm_gpu::pack_segments_cpu);
DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0);
DISPATCH_TO_CPU(
"group_index_select_dim0", fbgemm_gpu::group_index_select_dim0);
DISPATCH_TO_CPU(
"bottom_unique_k_per_row", fbgemm_gpu::bottom_unique_k_per_row);
}
Loading

0 comments on commit 035e812

Please sign in to comment.