Skip to content

Commit

Permalink
Implement unique indices for KJT to dedup indices per feature (#1815)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1815

Implement unique indices for KJT to dedup indices per feature. The op design and implementation details can be found in this doc: https://docs.google.com/document/d/1og-jlVX4lI8xDcK6wf-HL4a_EK6vL-rBmVHvhIkCp9Q/edit?usp=sharing

Reviewed By: sryap

Differential Revision: D46526033

fbshipit-source-id: 89a5430eae6930c8d56efe5f52978c78213e36f8
  • Loading branch information
GD06 authored and facebook-github-bot committed Jun 12, 2023
1 parent a1268a1 commit b9bcddd
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 0 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ if(NOT FBGEMM_CPU_ONLY)
src/jagged_tensor_ops/jagged_tensor_ops.cu
src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu
src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
src/jagged_tensor_ops/jagged_unique_indices.cu
src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
src/jagged_tensor_ops/stacked_jagged_1d_to_dense.cu
src/jagged_tensor_ops/stacked_jagged_2d_to_dense.cu
Expand Down
206 changes: 206 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "common.cuh"

using Tensor = at::Tensor;

namespace fbgemm_gpu {

// Linearzie the index with the cumsum of hash size so that linearized indices
// can be sorted together.
template <typename index_t>
__global__ __launch_bounds__(kMaxThreads) void linearize_index_wo_infos_kernel(
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
hash_size_cumsum,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
linear_indices,
FixedDivisor fd) {
const int32_t T = hash_size_cumsum.size(0) - 1;
const auto b_t = blockIdx.x * blockDim.x + threadIdx.x;
int32_t b;
int32_t t;
const auto total_B = offsets.size(0) - 1;
const auto valid = b_t < total_B;

fd.DivMod(b_t, &t, &b);

const auto hash_offset = valid ? hash_size_cumsum[t] : -1;
const auto indices_start = valid ? offsets[b_t] : -1;
const int32_t L = valid ? offsets[b_t + 1] - indices_start : 0;
const int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize;

for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) {
const auto indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j);
const auto t_warp = fbgemm_gpu::shfl_sync(t, j);
const auto L_warp = fbgemm_gpu::shfl_sync(L, j);
const auto hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j);
for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) {
const auto idx = __ldg(&indices[indices_start_warp + i]);
linear_indices[indices_start_warp + i] = hash_offset_warp + idx;
}
}
}

// Delinearize the unique indices from the reverse index info and the original
// indices. For each element in the input indices, the value should equal to
// the element from the unique indices according to the reverse index info.
template <typename index_t>
__global__ __launch_bounds__(kMaxThreads) void delinearize_unique_index_kernel(
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
reverse_index,
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
unique_indices) {
const auto total_indices = indices.size(0);
const auto b_t = blockIdx.x * blockDim.x + threadIdx.x;
if (b_t < total_indices) {
const auto original_index = indices[b_t];
const auto pos = reverse_index[b_t];
unique_indices[pos] = original_index;
}
}

// Compute the lengths for each feature in the unique indices. The range of
// indices for each feature equals to the difference between the max and min
// values in the reverse index array.
template <typename index_t, auto max_value, auto min_value>
__global__ __launch_bounds__(kMaxThreads) void unique_indices_length_kernel(
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
reverse_index,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> lengths) {
typedef cub::BlockReduce<index_t, kMaxThreads> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage_max;
__shared__ typename BlockReduce::TempStorage temp_storage_min;
__shared__ index_t block_results[2];

const int32_t tid = threadIdx.x;
const int32_t bid = blockIdx.x;
const int32_t num_blocks = gridDim.x;
const int32_t batch_size = (offsets.size(0) - 1) / num_blocks;

const auto offset_begin = bid * batch_size;
const auto offset_end = (bid + 1) * batch_size;

const auto reverse_index_begin = offsets[offset_begin];
const auto reverse_index_end = offsets[offset_end];

if (reverse_index_begin == reverse_index_end) {
return;
}

index_t t_max = min_value;
index_t t_min = max_value;
for (index_t i = (reverse_index_begin + tid); i < reverse_index_end;
i += kMaxThreads) {
const index_t value = reverse_index[i];
t_max = (value > t_max) ? value : t_max;
t_min = (value < t_min) ? value : t_min;
}

index_t block_max = BlockReduce(temp_storage_max).Reduce(t_max, cub::Max());
index_t block_min = BlockReduce(temp_storage_min).Reduce(t_min, cub::Min());
if (tid == 0) {
block_results[0] = block_max;
block_results[1] = block_min;
}
__syncthreads();

t_max = block_results[0];
t_min = block_results[1];
const index_t total_length = (t_max - t_min) + 1;
const index_t div_length = total_length / batch_size;
const index_t r_length = total_length % batch_size;
for (int32_t i = tid; i < batch_size; i += kMaxThreads) {
index_t seg_length = (i < r_length) ? (div_length + 1) : div_length;
lengths[offset_begin + i] = seg_length;
}
}

std::tuple<Tensor, Tensor, Tensor> jagged_unique_indices_cuda(
const Tensor& hash_size_cumsum,
const Tensor& offsets,
const Tensor& indices) {
const auto total_B = offsets.size(0) - 1;
const auto T = hash_size_cumsum.size(0) - 1;

Tensor linear_indices = at::empty_like(indices);

using at::RestrictPtrTraits;

AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "linearize_index", ([&] {
const auto linearize_index_kernel_ =
linearize_index_wo_infos_kernel<index_t>;
linearize_index_kernel_<<<
div_round_up(total_B, kMaxThreads),
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
hash_size_cumsum.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
indices.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
offsets.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
linear_indices.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
FixedDivisor(total_B / T));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
Tensor linear_unique_indices;
Tensor reverse_index;
std::tie(linear_unique_indices, reverse_index) =
at::_unique(linear_indices, true, true);
const auto total_indices = indices.size(0);
Tensor unique_indices = at::empty_like(linear_unique_indices);
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "delinearize_unique_index", ([&] {
const auto delinearize_unique_index_kernel_ =
delinearize_unique_index_kernel<index_t>;
delinearize_unique_index_kernel_<<<
div_round_up(total_indices, kMaxThreads),
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
indices.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
reverse_index.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
unique_indices.packed_accessor32<index_t, 1, RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
Tensor lengths = at::zeros({total_B}, offsets.options());
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "unique_indices_length", ([&] {
const auto unique_indices_length_kernel_ = unique_indices_length_kernel<
index_t,
std::numeric_limits<index_t>::max(),
std::numeric_limits<index_t>::min()>;
unique_indices_length_kernel_<<<
T,
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
reverse_index.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
offsets.packed_accessor32<index_t, 1, RestrictPtrTraits>(),
lengths.packed_accessor32<index_t, 1, RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
Tensor output_offsets;
output_offsets = asynchronous_complete_cumsum_gpu(lengths);
return {output_offsets, unique_indices, reverse_index};
}
} // namespace fbgemm_gpu
JAGGED_TENSOR_OPS_CUDA_DISPATCH(
"jagged_unique_indices",
fbgemm_gpu::jagged_unique_indices_cuda);
2 changes: 2 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1678,6 +1678,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"jagged_slice(Tensor x_values, Tensor x_lengths, Tensor start, int slice_length) -> (Tensor, Tensor)");
m.def(
"jagged_slice_forward(Tensor x_values, Tensor x_lengths, Tensor src_start, Tensor output_lengths, Tensor tgt_start, int num_output_rows, int slice_length, bool fill_zeros) -> Tensor");
m.def(
"jagged_unique_indices(Tensor hash_size_cumsum, Tensor offsets, Tensor indices) -> (Tensor, Tensor, Tensor)");
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
Expand Down
71 changes: 71 additions & 0 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2360,6 +2360,77 @@ def test_jagged_slice_errors(
values, lengths, torch.tensor([-2, 1, 1, 0, 1, 2]), 7
)

@unittest.skipIf(*gpu_unavailable)
@given(
B=st.integers(min_value=100, max_value=200),
F=st.integers(min_value=50, max_value=100),
max_length=st.integers(min_value=5, max_value=10),
)
@settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
def test_jagged_unique_indices(
self,
B: int, # Batch size
F: int, # The number of features
max_length: int, # The maximum value of pooling factor
) -> None:
hash_size_list = []
lengths_list = []
indices_list = []
linearized_indices_list = []
for _ in range(F):
# We generate a small hash size to increase index duplication
hash_size = random.randint(3, 5)
hash_size_list.append(hash_size)
for _ in range(B):
length = random.randint(0, max_length)
lengths_list.append(length)
if length > 0:
indices = np.random.randint(0, hash_size, size=length)
linearized_indices = indices + sum(hash_size_list[:-1])
indices_list.extend(indices)
linearized_indices_list.extend(linearized_indices)

device = torch.device("cuda")
dtype = torch.int64
hash_size = torch.as_tensor(hash_size_list, dtype=dtype, device=device)
lengths = torch.as_tensor(lengths_list, dtype=dtype, device=device)
indices = torch.as_tensor(indices_list, dtype=dtype, device=device)
linearized_indices = torch.as_tensor(
linearized_indices_list, dtype=dtype, device=device
)

hash_size_cum_sum = torch.zeros(F + 1, dtype=dtype, device=device)
hash_size_cum_sum[1:] = torch.cumsum(hash_size, dim=0)
offsets = torch.zeros(F * B + 1, dtype=dtype, device=device)
offsets[1:] = torch.cumsum(lengths, dim=0)

(
output_offsets,
unique_indices,
reverse_index,
) = torch.ops.fbgemm.jagged_unique_indices(hash_size_cum_sum, offsets, indices)

unique_linearized_indices = torch.unique(linearized_indices, sorted=True)
self.assertTrue(unique_linearized_indices.numel() == unique_indices.numel())

unique_indices_list = unique_indices.tolist()
reverse_index_list = reverse_index.tolist()
for i in range(len(reverse_index_list)):
pos = reverse_index_list[i]
self.assertTrue(unique_indices_list[pos] == indices_list[i])

input_offsets_list = offsets.tolist()
output_offsets_list = output_offsets.tolist()
for i in range(F):
input_start = input_offsets_list[i * B]
input_end = input_offsets_list[(i + 1) * B]
output_start = output_offsets_list[i * B]
output_end = output_offsets_list[(i + 1) * B]
for each_offset in range(input_start, input_end):
pos = reverse_index_list[each_offset]
self.assertTrue((output_start <= pos) and (pos < output_end))
return


if __name__ == "__main__":
unittest.main()

0 comments on commit b9bcddd

Please sign in to comment.