diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh index 0761f78514..91949ef938 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh @@ -43,6 +43,15 @@ at::Tensor linearize_cache_indices_cuda( at::Tensor indices, at::Tensor offsets); +///@ingroup table-batched-embed-cuda +/// Linearize the indices of all tables to make it be unique. +/// Note the update_table_indices and update_row_indices are +/// from the row indices format for inplace update. +at::Tensor linearize_cache_indices_from_row_idx_cuda( + at::Tensor cache_hash_size_cumsum, + at::Tensor update_table_indices, + at::Tensor update_row_indices); + ///@ingroup table-batched-embed-cuda /// LRU cache: fetch the rows corresponding to `linear_cache_indices` from ///`weights`, and insert them into the cache at timestep `time_stamp`. diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index 5bb06aa381..2ae208bfbe 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -314,6 +314,81 @@ Tensor linearize_cache_indices_cuda( return linear_cache_indices; } +namespace { + +template +__global__ +__launch_bounds__(kMaxThreads) void linearize_cache_indices_from_row_idx_kernel( + const at::PackedTensorAccessor32 + cache_hash_size_cumsum, + const at::PackedTensorAccessor32 + update_table_indices, + const at::PackedTensorAccessor32 + update_row_indices, + at::PackedTensorAccessor32 + linear_cache_indices) { + const index_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= update_row_indices.size(0)) { + return; + } + const int table_index = update_table_indices[index]; + + const auto max_offset = + ::__ldg(&cache_hash_size_cumsum[cache_hash_size_cumsum.size(0) - 1]); + const auto curr_offset = ::__ldg(&cache_hash_size_cumsum[table_index]); + if (curr_offset >= 0 && update_row_indices[index] >= 0) { + linear_cache_indices[index] = update_row_indices[index] + curr_offset; + } else { + // Either table index is wrong, or index value is negative (due to pruning): + // set it to invalid value. + linear_cache_indices[index] = max_offset; + } +} + +} // namespace + +Tensor linearize_cache_indices_from_row_idx_cuda( + Tensor cache_hash_size_cumsum, + Tensor update_table_indices, + Tensor update_row_indices) { + TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum); + TENSOR_ON_CUDA_GPU(update_table_indices); + TENSOR_ON_CUDA_GPU(update_row_indices); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(cache_hash_size_cumsum.get_device()); + + const auto T = cache_hash_size_cumsum.size(0) - 1; + TORCH_CHECK(T > 0); + + auto linear_cache_indices = at::empty_like(update_row_indices); + const auto num_indices = update_row_indices.numel(); + if (num_indices == 0) { + return linear_cache_indices; + } + + AT_DISPATCH_INDEX_TYPES( + update_row_indices.scalar_type(), + "linearize_cache_indices_from_row_idx_kernel", + [&] { + linearize_cache_indices_from_row_idx_kernel<<< + div_round_up(num_indices, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cache_hash_size_cumsum + .packed_accessor32(), + update_table_indices + .packed_accessor32(), + update_row_indices + .packed_accessor32(), + linear_cache_indices + .packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + return linear_cache_indices; +} + std::tuple> get_unique_indices_cuda( Tensor linear_indices, int64_t max_indices, diff --git a/fbgemm_gpu/src/split_table_batched_embeddings.cpp b/fbgemm_gpu/src/split_table_batched_embeddings.cpp index c391c846c9..f51013e610 100644 --- a/fbgemm_gpu/src/split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/split_table_batched_embeddings.cpp @@ -17,6 +17,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "linearize_cache_indices(Tensor cache_hash_size_cumsum, Tensor indices, Tensor offsets) -> Tensor"); DISPATCH_TO_CUDA("linearize_cache_indices", linearize_cache_indices_cuda); + m.def( + "linearize_cache_indices_from_row_idx(Tensor cache_hash_size_cumsum, Tensor update_table_indices, Tensor update_row_indices) -> Tensor"); + DISPATCH_TO_CUDA( + "linearize_cache_indices_from_row_idx", + linearize_cache_indices_from_row_idx_cuda); m.def( "lru_cache_populate(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, bool stochastic_rounding) -> ()"); DISPATCH_TO_CUDA("lru_cache_populate", lru_cache_populate_cuda); diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 30cffa1e59..6926b4ebfd 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -4717,6 +4717,75 @@ def test_linearize_cache_indices(self) -> None: ) ) + @unittest.skipIf(*gpu_unavailable) + def test_linearize_cache_indices_from_row_idx(self) -> None: + update_row_indices = torch.tensor( + [10, 2, 3, 7, 1, 4, 5, 9, 2, 7, 6, 8, 5, 1, 0, 4], + dtype=torch.int, + device="cuda", + ) + update_table_indices = torch.tensor( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], + dtype=torch.int, + device="cuda", + ) + varying_update_table_indices = torch.tensor( + [0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3], + dtype=torch.int, + device="cuda", + ) + + # Testing equal sized tables. + cache_hash_size_cumsum_0 = torch.tensor([0, 12, 24, 36, 48]).cuda() + linear_cache_indices_0 = torch.ops.fbgemm.linearize_cache_indices_from_row_idx( + cache_hash_size_cumsum_0, + update_table_indices, + update_row_indices, + ) + self.assertTrue( + torch.equal( + linear_cache_indices_0.cpu(), + torch.tensor( + [10, 2, 3, 7, 13, 16, 17, 21, 26, 31, 30, 32, 41, 37, 36, 40], + dtype=torch.int, + ), + ) + ) + + # Testing partially cached tables. + cache_hash_size_cumsum_1 = torch.tensor([0, 12, -1, 24, 36]).cuda() + linear_cache_indices_1 = torch.ops.fbgemm.linearize_cache_indices_from_row_idx( + cache_hash_size_cumsum_1, + update_table_indices, + update_row_indices, + ) + self.assertTrue( + torch.equal( + linear_cache_indices_1.cpu(), + torch.tensor( + [10, 2, 3, 7, 13, 16, 17, 21, 36, 36, 36, 36, 29, 25, 24, 28], + dtype=torch.int, + ), + ) + ) + + # Testing batched with varying pooling factor. + cache_hash_size_cumsum_2 = torch.tensor([0, 12, -1, 24, 36]).cuda() + linear_cache_indices_2 = torch.ops.fbgemm.linearize_cache_indices_from_row_idx( + cache_hash_size_cumsum_2, + varying_update_table_indices, + update_row_indices, + ) + self.assertTrue( + torch.equal( + linear_cache_indices_2.cpu(), + torch.tensor( + [10, 2, 3, 19, 13, 16, 17, 21, 36, 36, 36, 36, 36, 36, 24, 28], + dtype=torch.int, + ), + ) + ) + @unittest.skipIf(*gpu_unavailable) @given( associativity=st.sampled_from(