Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add group_index_select_dim0 #1421

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,5 +224,93 @@ def jagged_index_select_2d_ref(
logging.info(f"backward: fbgemm {time * 1e3:.3f} ms, ref {time_ref * 1e3:.3f} ms")


@cli.command()
@click.option("--row-size", default=512)
@click.option("--batch-size", default=4096)
@click.option("--unique-batch-size", default=1024)
@click.option("--input-precision", type=str, default="fp32")
@click.option("--sort-indices", type=bool, default=True)
@click.option("--num-groups", default=32)
def group_index_select_2d_bench(
row_size: int,
batch_size: int,
unique_batch_size: int,
input_precision: str,
sort_indices: bool,
num_groups: int,
) -> None:
def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
inverse_index = list(range(curr_size))
np_arr = np.array(inverse_index)
for _ in range(final_size - curr_size):
inverse_index.append(np.random.randint(0, curr_size))
np_arr = np.array(inverse_index)
np.random.shuffle(np_arr)
return np_arr

dtype = torch.float
if input_precision == "fp32":
dtype = torch.float
elif input_precision == "fp16":
dtype = torch.half
else:
raise RuntimeError(f"Does not support data type {input_precision}")

offset_indices_group = []
indices_group = []
for i in range(num_groups):
# pyre-fixme[16]: Module `cuda` has no attribute `IntTensor`.
indices = torch.cuda.IntTensor(gen_inverse_index(unique_batch_size, batch_size))
if sort_indices:
indices, _ = indices.sort()
indices_group.append(indices)
indices = torch.add(indices, batch_size * i)
offset_indices_group.append(indices)

offset_indices = torch.concat(offset_indices_group)

input = torch.rand(num_groups * batch_size, row_size, dtype=dtype, device="cuda")
input.requires_grad = True

num_bytes = 2 * batch_size * row_size * input.element_size() * num_groups

bench_kwargs = {"num_warmups": 10, "iters": 100}

# Benchmark forward
time_ref, output_ref = benchmark_torch_function(
torch.index_select, (input, 0, offset_indices), **bench_kwargs
)

input_group = input.split(batch_size, 0)
time, output_group = benchmark_torch_function(
torch.ops.fbgemm.group_index_select_dim0,
(input_group, indices_group),
**bench_kwargs,
)
logging.info(
f"forward: PyTorch batch {time_ref:.5f} sec ({num_bytes / time_ref / 1e9:.5f} GB/s), "
f"fbgemm group {time:5f} sec ({num_bytes / time / 1e9:.5f} GB/s)"
)

# Benchmark backward
grad = torch.rand_like(output_ref)
time_ref, _ = benchmark_torch_function(
functools.partial(output_ref.backward, retain_graph=True),
(grad,),
**bench_kwargs,
)

cat_output = torch.cat(output_group)
time, _ = benchmark_torch_function(
functools.partial(cat_output.backward, retain_graph=True),
(grad,),
**bench_kwargs,
)
logging.info(
f"backward: PyTorch batch {time_ref:.5f} sec ({num_bytes / time_ref / 1e9:.5f} GB/s), "
f"fbgemm group {time:.5f} sec ({num_bytes / time / 1e9:.5f} GB/s)"
)


if __name__ == "__main__":
cli()
27 changes: 27 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,5 +685,32 @@ at::Tensor index_add_with_unique_indices_cuda(
std::vector<int64_t>& input_shape,
const int consecutive_range_start,
const int consecutive_range_length);

///@ingroup sparse-data-cuda
std::vector<at::Tensor> group_index_select_cuda(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
const c10::TensorOptions& input_tensor_options,
const c10::ScalarType& input_scalar_type,
const c10::ScalarType& indices_scalar_type,
const c10::DeviceIndex& device,
const std::vector<int64_t>& output_shape,
const int num_input_rows,
const int num_output_rows,
const int num_cols,
const int num_groups);

std::vector<at::Tensor> group_index_add_cuda(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
const c10::TensorOptions& input_tensor_options,
const c10::ScalarType& input_scalar_type,
const c10::ScalarType& indices_scalar_type,
const c10::DeviceIndex& device,
const std::vector<int64_t>& output_shape,
const int num_input_rows,
const int num_output_rows,
const int num_cols,
const int num_groups);
#endif
} // namespace fbgemm_gpu
192 changes: 192 additions & 0 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2850,6 +2850,198 @@ Tensor index_add_with_unique_indices_cuda(
return input_grad.reshape(input_shape);
}

template <typename index_t, typename scalar_t, int UNROLL_FACTOR>
__global__ __launch_bounds__(kMaxThreads) void group_index_select_2d_kernel(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
scalar_t* output,
const int64_t num_input_rows,
const int64_t num_output_rows,
const int64_t num_cols,
const int64_t num_groups) {
for (int64_t bid = threadIdx.y * gridDim.x + blockIdx.x;
bid < num_groups * num_output_rows;
bid += gridDim.x * blockDim.y) {
const int64_t group_id = bid / num_output_rows;
const int64_t row = bid % num_output_rows;
scalar_t* input = (scalar_t*)input_ptrs[group_id];
index_t* indices = (index_t*)indices_ptrs[group_id];
const index_t idx = indices[row];
CUDA_KERNEL_ASSERT(idx < num_input_rows)
int col;
scalar_t* output_ = output + (num_output_rows * num_cols * group_id);
for (col = threadIdx.x * UNROLL_FACTOR;
col < num_cols / UNROLL_FACTOR * UNROLL_FACTOR;
col += blockDim.x * UNROLL_FACTOR) {
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) {
output_[row * num_cols + col + i] =
LDG(&input[idx * num_cols + col + i]);
}
}
for (; col < num_cols; ++col) {
output_[row * num_cols + col] = LDG(&input[idx * num_cols + col]);
}
}
}

template <typename index_t, typename scalar_t, int UNROLL_FACTOR>
__global__ __launch_bounds__(kMaxThreads) void group_index_add_2d_kernel(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
scalar_t* output,
const int64_t num_input_rows,
const int64_t num_output_rows,
const int64_t num_cols,
const int64_t num_groups) {
for (int64_t bid = threadIdx.y * gridDim.x + blockIdx.x;
bid < num_groups * num_input_rows;
bid += gridDim.x * blockDim.y) {
const int64_t group_id = bid / num_input_rows;
const int64_t row = bid % num_input_rows;
scalar_t* input = (scalar_t*)input_ptrs[group_id];
index_t* indices = (index_t*)indices_ptrs[group_id];
const index_t idx = indices[row];
CUDA_KERNEL_ASSERT(idx < num_output_rows)
int col;
scalar_t* output_ = output + (num_output_rows * num_cols * group_id);
for (col = threadIdx.x * UNROLL_FACTOR;
col < num_cols / UNROLL_FACTOR * UNROLL_FACTOR;
col += blockDim.x * UNROLL_FACTOR) {
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) {
// PyTorch also uses atomicAdd. It does not require sorting and
// provides better parallelism. But this can lead to numerical
// indeterminisim.
gpuAtomicAddNoReturn(
&output_[idx * num_cols + col + i],
input[row * num_cols + col + i]);
}
}
for (; col < num_cols; ++col) {
gpuAtomicAddNoReturn(
&output[idx * num_cols + col], input[row * num_cols + col]);
}
}
}

std::vector<Tensor> group_index_select_cuda(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
const c10::TensorOptions& input_tensor_options,
const c10::ScalarType& input_scalar_type,
const c10::ScalarType& indices_scalar_type,
const c10::DeviceIndex& device,
const std::vector<int64_t>& output_shape,
const int num_input_rows,
const int num_output_rows,
const int num_cols,
const int num_groups) {
if (num_groups == 0) {
return std::vector<Tensor>();
}

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(device);

Tensor output = at::empty(output_shape, input_tensor_options);

// Partition work based on num_output_rows
const int UNROLL_FACTOR = 1;
uint32_t max_grid_size =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
uint32_t grid_size = std::min(
cuda_calc_xblock_count(num_groups * num_output_rows, 1), max_grid_size);
uint32_t block_size_x =
std::min(div_round_up(num_cols, UNROLL_FACTOR), kMaxThreads);
uint32_t block_size_y =
std::max((num_groups * num_output_rows) / grid_size, (uint32_t)1);
dim3 block_size(
block_size_x,
std::min(block_size_y, (uint32_t)(kMaxThreads / block_size_x)),
1);

AT_DISPATCH_INDEX_TYPES(
indices_scalar_type, "group_index_select_2d_wrapper_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input_scalar_type, "group_index_select_2d_wrapper_2", [&] {
group_index_select_2d_kernel<index_t, scalar_t, UNROLL_FACTOR>
<<<grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input_ptrs,
indices_ptrs,
output.data_ptr<scalar_t>(),
num_input_rows,
num_output_rows,
num_cols,
num_groups);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});

return output.split(num_output_rows, 0);
}

std::vector<Tensor> group_index_add_cuda(
const int64_t* input_ptrs,
const int64_t* indices_ptrs,
const c10::TensorOptions& input_tensor_options,
const c10::ScalarType& input_scalar_type,
const c10::ScalarType& indices_scalar_type,
const c10::DeviceIndex& device,
const std::vector<int64_t>& output_shape,
const int num_input_rows,
const int num_output_rows,
const int num_cols,
const int num_groups) {
if (num_groups == 0) {
return std::vector<Tensor>();
}

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(device);

Tensor output = at::zeros(output_shape, input_tensor_options);

// Partition work based on num_input_rows
const int UNROLL_FACTOR = 1;
uint32_t max_grid_size =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
uint32_t grid_size = std::min(
cuda_calc_xblock_count(num_groups * num_input_rows, 1), max_grid_size);
uint32_t block_size_x =
std::min(div_round_up(num_cols, UNROLL_FACTOR), kMaxThreads);
uint32_t block_size_y =
std::max((num_groups * num_input_rows) / grid_size, (uint32_t)1);
dim3 block_size(
block_size_x,
std::min(block_size_y, (uint32_t)(kMaxThreads / block_size_x)),
1);

AT_DISPATCH_INDEX_TYPES(
indices_scalar_type, "group_index_add_2d_wrapper_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input_scalar_type, "group_index_add_2d_wrapper_2", [&] {
group_index_add_2d_kernel<index_t, scalar_t, UNROLL_FACTOR>
<<<grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input_ptrs,
indices_ptrs,
output.data_ptr<scalar_t>(),
num_input_rows,
num_output_rows,
num_cols,
num_groups);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});

return output.split(num_output_rows, 0);
}
// Copied from cupy/random/_kernels.py v11
// (commit id 420e41fd41157d4cf526b0e94eb86a3f8eb5a231)

Expand Down
18 changes: 18 additions & 0 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2357,6 +2357,19 @@ Tensor index_select_dim0(
return at::index_select(input, 0, indices);
}

std::vector<Tensor> group_index_select_dim0(
const std::vector<Tensor>& input_group,
const std::vector<Tensor>& indices_group) {
int num_groups = input_group.size();
TORCH_CHECK(num_groups == (int)indices_group.size())
std::vector<Tensor> output_group;
for (int i = 0; i < num_groups; i++) {
output_group.push_back(
at::index_select(input_group[i], 0, indices_group[i]));
}
return output_group;
}

Tensor bottom_unique_k_per_row(const Tensor& input, const int64_t k) {
auto num_cols = input.size(-1);
Tensor input_reshaped = input.reshape({-1, num_cols});
Expand Down Expand Up @@ -2389,6 +2402,7 @@ Tensor bottom_unique_k_per_row(const Tensor& input, const int64_t k) {
output_shape[output_shape.size() - 1] = k;
return output.reshape(output_shape);
}

} // namespace

} // namespace fbgemm_gpu
Expand Down Expand Up @@ -2458,6 +2472,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// skip_indices_sorting_fwd is for skipping indices sorting in forward
m.def(
"index_select_dim0(Tensor input, Tensor indices, int? consecutive_range_start=0, int? consecutive_range_length=0, bool? skip_indices_sorting_fwd=None) -> Tensor");
m.def(
"group_index_select_dim0(Tensor[] input_group, Tensor[] indices_group) -> Tensor[]");
m.def(
"jagged_index_select(Tensor values, Tensor lengths, Tensor indices) -> Tensor[]");
// This is an one-off op to be used in bench_utils.py for zipf generation w/o
Expand Down Expand Up @@ -2526,6 +2542,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
fbgemm_gpu::permute_sequence_embeddings_cpu);
DISPATCH_TO_CPU("pack_segments", fbgemm_gpu::pack_segments_cpu);
DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0);
DISPATCH_TO_CPU(
"group_index_select_dim0", fbgemm_gpu::group_index_select_dim0);
DISPATCH_TO_CPU(
"bottom_unique_k_per_row", fbgemm_gpu::bottom_unique_k_per_row);
}
Loading