From ec55e960573685c76e4ad2b1893991709a61cbbd Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Mon, 27 Dec 2021 17:18:05 -0800 Subject: [PATCH] Support int32_t indices/offsets for caching handling logics (#811) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/811 In training, we assume the indices / offsets are int64_t for embedding (TBE), but in inference, we assume the indices / offsets are int32_t. This Diff enables both int32_t and int64_t supports for the caching logics so that we can reuse the same functions for both training and inference, while reducing the extra overhead to convert the indices/offsets from int to long or vice versa. Differential Revision: D33045589 fbshipit-source-id: 42ebcd899bb5dc6735eaf67cad48ac3b168d60ca --- .../split_table_batched_embeddings_ops.py | 1 - fbgemm_gpu/src/split_embeddings_cache_cuda.cu | 568 ++++++++++-------- 2 files changed, 304 insertions(+), 265 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index 221a5ea897..cdc11ccf81 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -1824,7 +1824,6 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None: if not self.lxu_cache_weights.numel(): return - (indices, offsets) = indices.long(), offsets.long() linear_cache_indices = torch.ops.fb.linearize_cache_indices( self.cache_hash_size_cumsum, indices, diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index f671bcba10..823fed07c1 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -74,18 +74,18 @@ __host__ __device__ inline int32_t padded_row_size_in_bytes( } } // namespace -// TODO: do we care about 64-bit indices? Currently we just ignore. -__host__ DEVICE_INLINE uint32_t cache_slot(int32_t h_in, int32_t C) { - // MurmorHash3 32-bit mixing function. - uint32_t h = (uint32_t)h_in; - h ^= h >> 16; - h *= 0x85ebca6b; - h ^= h >> 13; - h *= 0xc2b2ae35; - h ^= h >> 16; - // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - return ((uint64_t)h * (uint64_t)C) >> 32; -} +// // TODO: do we care about 64-bit indices? Currently we just ignore. +// __host__ DEVICE_INLINE uint32_t cache_slot(int32_t h_in, int32_t C) { +// // MurmorHash3 32-bit mixing function. +// uint32_t h = (uint32_t)h_in; +// h ^= h >> 16; +// h *= 0x85ebca6b; +// h ^= h >> 13; +// h *= 0xc2b2ae35; +// h ^= h >> 16; +// // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ +// return ((uint64_t)h * (uint64_t)C) >> 32; +// } __host__ DEVICE_INLINE uint32_t cache_slot(int64_t h_in, int32_t C) { // MurmurHash3 64-bit mixing function. @@ -117,7 +117,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel( weights_offsets, const at::PackedTensorAccessor32 D_offsets, - at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 lxu_cache_state, at::PackedTensorAccessor64 lxu_cache_weights, @@ -241,12 +241,13 @@ void lxu_cache_flush_cuda( return; } +template __global__ __launch_bounds__(kMaxThreads) void linearize_cache_indices_kernel( const at::PackedTensorAccessor32 cache_hash_size_cumsum, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 indices, + const at::PackedTensorAccessor32 offsets, + at::PackedTensorAccessor32 linear_cache_indices) { int32_t T = cache_hash_size_cumsum.size(0) - 1; int64_t total_cache_hash_size = cache_hash_size_cumsum[T]; @@ -257,13 +258,13 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_cache_indices_kernel( bool valid = t < T; int64_t hash_offset = valid ? cache_hash_size_cumsum[t] : -1; - int64_t indices_start = valid ? offsets[t * B + b] : -1; + auto indices_start = valid ? offsets[t * B + b] : -1; int32_t L = valid ? offsets[t * B + b + 1] - indices_start : 0; int32_t lane_id = threadIdx.x % kWarpSize; // hash_offset < 0 for non-caching tables for (int32_t j = 0; j < kWarpSize; ++j) { - int64_t indices_start_warp = __shfl_sync(0xFFFFFFFF, indices_start, j); + auto indices_start_warp = __shfl_sync(0xFFFFFFFF, indices_start, j); int32_t L_warp = __shfl_sync(0xFFFFFFFF, L, j); int64_t hash_offset_warp = __shfl_sync(0xFFFFFFFF, hash_offset, j); if (hash_offset_warp >= 0) { @@ -300,18 +301,21 @@ Tensor linearize_cache_indices_cuda( if (B == 0) { return linear_cache_indices; } - linearize_cache_indices_kernel<<< - div_round_up(B * T, kMaxThreads), - kMaxThreads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - cache_hash_size_cumsum - .packed_accessor32(), - indices.packed_accessor32(), - offsets.packed_accessor32(), - linear_cache_indices - .packed_accessor32()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "linearize_cache_indices_kernel", [&]() { + linearize_cache_indices_kernel<<< + div_round_up(B * T, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cache_hash_size_cumsum + .packed_accessor32(), + indices.packed_accessor32(), + offsets.packed_accessor32(), + linear_cache_indices + .packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); return linear_cache_indices; } @@ -335,88 +339,93 @@ std::tuple> get_unique_indices_cuda( unique_indices_count = at::empty( {linear_indices.numel()}, linear_indices.options().dtype(at::kInt)); } - - // sort indices - size_t temp_storage_bytes_0 = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( - nullptr, - temp_storage_bytes_0, - linear_indices.data_ptr(), - sorted_indices.data_ptr(), - N, - 0, - int(log2(float(max_indices + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage_0 = at::empty( - {static_cast(temp_storage_bytes_0)}, - linear_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( - temp_storage_0.data_ptr(), - temp_storage_bytes_0, - linear_indices.data_ptr(), - sorted_indices.data_ptr(), - N, - 0, - int(log2(float(max_indices + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); - // get unique indices - if (compute_count) { - size_t temp_storage_bytes_1 = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( - nullptr, - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_count->data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage_1 = at::empty( - {static_cast(temp_storage_bytes_1)}, - linear_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( - temp_storage_1.data_ptr(), - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_count->data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); - } else { - size_t temp_storage_bytes_1 = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( - nullptr, - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage_1 = at::empty( - {static_cast(temp_storage_bytes_1)}, - linear_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( - temp_storage_1.data_ptr(), - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); - } + AT_DISPATCH_INDEX_TYPES( + linear_indices.scalar_type(), "get_unique_indices_cuda", [&]() { + // sort indices + size_t temp_storage_bytes_0 = 0; + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( + nullptr, + temp_storage_bytes_0, + linear_indices.data_ptr(), + sorted_indices.data_ptr(), + N, + 0, + int(log2(float(max_indices + 1)) + 1), + at::cuda::getCurrentCUDAStream(), + false)); + auto temp_storage_0 = at::empty( + {static_cast(temp_storage_bytes_0)}, + linear_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( + temp_storage_0.data_ptr(), + temp_storage_bytes_0, + linear_indices.data_ptr(), + sorted_indices.data_ptr(), + N, + 0, + int(log2(float(max_indices + 1)) + 1), + at::cuda::getCurrentCUDAStream(), + false)); + // get unique indices + if (compute_count) { + size_t temp_storage_bytes_1 = 0; + AT_CUDA_CHECK( + FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( + nullptr, + temp_storage_bytes_1, + sorted_indices.data_ptr(), + unique_indices.data_ptr(), + unique_indices_count->data_ptr(), + unique_indices_length.data_ptr(), + N, + at::cuda::getCurrentCUDAStream(), + false)); + auto temp_storage_1 = at::empty( + {static_cast(temp_storage_bytes_1)}, + linear_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK( + FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( + temp_storage_1.data_ptr(), + temp_storage_bytes_1, + sorted_indices.data_ptr(), + unique_indices.data_ptr(), + unique_indices_count->data_ptr(), + unique_indices_length.data_ptr(), + N, + at::cuda::getCurrentCUDAStream(), + false)); + } else { + size_t temp_storage_bytes_1 = 0; + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( + nullptr, + temp_storage_bytes_1, + sorted_indices.data_ptr(), + unique_indices.data_ptr(), + unique_indices_length.data_ptr(), + N, + at::cuda::getCurrentCUDAStream(), + false)); + auto temp_storage_1 = at::empty( + {static_cast(temp_storage_bytes_1)}, + linear_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( + temp_storage_1.data_ptr(), + temp_storage_bytes_1, + sorted_indices.data_ptr(), + unique_indices.data_ptr(), + unique_indices_length.data_ptr(), + N, + at::cuda::getCurrentCUDAStream(), + false)); + } + }); return std::make_tuple( unique_indices, unique_indices_length, unique_indices_count); } +template __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 unique_indices, const int32_t* __restrict__ N_unique, int64_t max_indices, @@ -439,7 +448,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( return; } int64_t idx = unique_indices[n]; - if (idx == max_indices) { + if (static_cast(idx) == max_indices) { if (threadIdx.x == 0) { cache_sets[n] = C; // invalid index, used as sentinel } @@ -484,49 +493,52 @@ std::pair lru_cache_find_uncached_cuda( auto sorted_cache_sets = empty_like(cache_sets); auto cache_set_sorted_unique_indices = empty_like(unique_indices); - // Find uncached indices - lru_cache_find_uncached_kernel<<< - div_round_up(N, kMaxThreads / kWarpSize), - dim3(kWarpSize, kMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - unique_indices.packed_accessor32(), - unique_indices_length.data_ptr(), - max_indices, - lxu_cache_state.packed_accessor32(), - cache_sets.packed_accessor32(), - time_stamp, - lru_state.packed_accessor32()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - // Sort the cache sets and ids - size_t temp_storage_bytes = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( - nullptr, - temp_storage_bytes, - cache_sets.data_ptr(), - sorted_cache_sets.data_ptr(), - unique_indices.data_ptr(), - cache_set_sorted_unique_indices.data_ptr(), - N, - 0, - int(log2(float(lxu_cache_state.size(0) + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage = at::empty( - {static_cast(temp_storage_bytes)}, - unique_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( - temp_storage.data_ptr(), - temp_storage_bytes, - cache_sets.data_ptr(), - sorted_cache_sets.data_ptr(), - unique_indices.data_ptr(), - cache_set_sorted_unique_indices.data_ptr(), - N, - 0, - int(log2(float(lxu_cache_state.size(0) + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); + AT_DISPATCH_INDEX_TYPES( + unique_indices.scalar_type(), "lru_cache_find_uncached_cuda", [&]() { + // Find uncached indices + lru_cache_find_uncached_kernel<<< + div_round_up(N, kMaxThreads / kWarpSize), + dim3(kWarpSize, kMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + unique_indices.packed_accessor32(), + unique_indices_length.data_ptr(), + max_indices, + lxu_cache_state.packed_accessor32(), + cache_sets.packed_accessor32(), + time_stamp, + lru_state.packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Sort the cache sets and ids + size_t temp_storage_bytes = 0; + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + nullptr, + temp_storage_bytes, + cache_sets.data_ptr(), + sorted_cache_sets.data_ptr(), + unique_indices.data_ptr(), + cache_set_sorted_unique_indices.data_ptr(), + N, + 0, + int(log2(float(lxu_cache_state.size(0) + 1)) + 1), + at::cuda::getCurrentCUDAStream(), + false)); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + unique_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + temp_storage.data_ptr(), + temp_storage_bytes, + cache_sets.data_ptr(), + sorted_cache_sets.data_ptr(), + unique_indices.data_ptr(), + cache_set_sorted_unique_indices.data_ptr(), + N, + 0, + int(log2(float(lxu_cache_state.size(0) + 1)) + 1), + at::cuda::getCurrentCUDAStream(), + false)); + }); return {sorted_cache_sets, cache_set_sorted_unique_indices}; } @@ -844,6 +856,7 @@ void lru_cache_populate_cuda( stochastic_rounding); } +template __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( at::PackedTensorAccessor64 weights, const at::PackedTensorAccessor32 @@ -858,7 +871,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( D_offsets, const at::PackedTensorAccessor32 sorted_cache_sets, - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 cache_set_sorted_indices, const int32_t* __restrict__ N_unique, at::PackedTensorAccessor32 @@ -910,7 +923,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( if (insert_current_lru_cost == time_stamp) { return; } - int64_t insert_idx = cache_set_sorted_indices[n + l]; + index_t insert_idx = cache_set_sorted_indices[n + l]; int32_t t_insert = cache_index_table_map[insert_idx]; SparseType weight_ty_insert = static_cast(weights_tys[t_insert]); @@ -999,28 +1012,35 @@ void lru_cache_insert_byte_cuda( int32_t N = cache_set_sorted_unique_indices.numel(); - lru_cache_insert_byte_kernel<<< - div_round_up(N, kMaxThreads / kWarpSize), - dim3(kWarpSize, kMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - weights.packed_accessor64(), - cache_hash_size_cumsum - .packed_accessor32(), - cache_index_table_map - .packed_accessor32(), - weights_offsets.packed_accessor32(), - weights_tys.packed_accessor32(), - D_offsets.packed_accessor32(), - sorted_cache_sets.packed_accessor32(), - cache_set_sorted_unique_indices - .packed_accessor32(), - unique_indices_length.data_ptr(), - lxu_cache_state.packed_accessor32(), - lxu_cache_weights.packed_accessor64(), - time_stamp, - lru_state.packed_accessor32()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + AT_DISPATCH_INDEX_TYPES( + cache_set_sorted_unique_indices.scalar_type(), + "lru_cache_insert_byte_cuda", + [&]() { + lru_cache_insert_byte_kernel<<< + div_round_up(N, kMaxThreads / kWarpSize), + dim3(kWarpSize, kMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + weights.packed_accessor64(), + cache_hash_size_cumsum + .packed_accessor32(), + cache_index_table_map + .packed_accessor32(), + weights_offsets.packed_accessor32(), + weights_tys.packed_accessor32(), + D_offsets.packed_accessor32(), + sorted_cache_sets + .packed_accessor32(), + cache_set_sorted_unique_indices + .packed_accessor32(), + unique_indices_length.data_ptr(), + lxu_cache_state.packed_accessor32(), + lxu_cache_weights + .packed_accessor64(), + time_stamp, + lru_state.packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); } void lru_cache_populate_byte_cuda( @@ -1093,8 +1113,9 @@ void lru_cache_populate_byte_cuda( lru_state); } +template __global__ __launch_bounds__(kMaxThreads) void lfu_update_counts_kernel( - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 unique_indices, const int32_t* __restrict__ N_unique, const at::PackedTensorAccessor32 @@ -1104,7 +1125,7 @@ __global__ __launch_bounds__(kMaxThreads) void lfu_update_counts_kernel( if (n >= *N_unique) { return; } - int64_t idx = unique_indices[n]; + auto idx = unique_indices[n]; lfu_state[idx] += unique_indices_count[n]; } @@ -1122,16 +1143,20 @@ void lfu_update_counts_cuda( device_guard.set_index(unique_indices.get_device()); int32_t N = unique_indices.size(0); - lfu_update_counts_kernel<<< - div_round_up(N, kMaxThreads), - kMaxThreads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - unique_indices.packed_accessor32(), - unique_indices_length.data_ptr(), - unique_indices_count - .packed_accessor32(), - lfu_state.packed_accessor64()); + AT_DISPATCH_INDEX_TYPES( + unique_indices.scalar_type(), "lfu_update_counts_cuda", [&]() { + lfu_update_counts_kernel<<< + div_round_up(N, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + unique_indices + .packed_accessor32(), + unique_indices_length.data_ptr(), + unique_indices_count + .packed_accessor32(), + lfu_state.packed_accessor64()); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1139,8 +1164,9 @@ constexpr int32_t kCacheSetBits = 24; constexpr int32_t kLFUCounterBits = 40; static_assert(kCacheSetBits + kLFUCounterBits == 8 * sizeof(int64_t), ""); +template __global__ __launch_bounds__(kMaxThreads) void lfu_cache_find_uncached_kernel( - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 unique_indices, const int32_t* __restrict__ N_unique, int64_t max_indices, @@ -1213,48 +1239,51 @@ std::pair lfu_cache_find_uncached_cuda( auto sorted_cache_sets = empty_like(cache_sets); auto cache_set_sorted_unique_indices = empty_like(unique_indices); - // Find uncached indices - lfu_cache_find_uncached_kernel<<< - div_round_up(N, kMaxThreads / kWarpSize), - dim3(kWarpSize, kMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - unique_indices.packed_accessor32(), - unique_indices_length.data_ptr(), - max_indices, - lxu_cache_state.packed_accessor32(), - (uint64_t*)cache_sets.data_ptr(), - lfu_state.packed_accessor64()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - // Sort the cache sets and ids - size_t temp_storage_bytes = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( - nullptr, - temp_storage_bytes, - (uint64_t*)cache_sets.data_ptr(), - (uint64_t*)sorted_cache_sets.data_ptr(), - unique_indices.data_ptr(), - cache_set_sorted_unique_indices.data_ptr(), - N, - 0, - int(log2(float(lxu_cache_state.size(0) + 1)) + 1) + kLFUCounterBits, - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage = at::empty( - {static_cast(temp_storage_bytes)}, - unique_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( - temp_storage.data_ptr(), - temp_storage_bytes, - (uint64_t*)cache_sets.data_ptr(), - (uint64_t*)sorted_cache_sets.data_ptr(), - unique_indices.data_ptr(), - cache_set_sorted_unique_indices.data_ptr(), - N, - 0, - int(log2(float(lxu_cache_state.size(0) + 1)) + 1) + kLFUCounterBits, - at::cuda::getCurrentCUDAStream(), - false)); + AT_DISPATCH_INDEX_TYPES( + unique_indices.scalar_type(), "lfu_cache_find_uncached_cuda", [&]() { + // Find uncached indices + lfu_cache_find_uncached_kernel<<< + div_round_up(N, kMaxThreads / kWarpSize), + dim3(kWarpSize, kMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + unique_indices.packed_accessor32(), + unique_indices_length.data_ptr(), + max_indices, + lxu_cache_state.packed_accessor32(), + (uint64_t*)cache_sets.data_ptr(), + lfu_state.packed_accessor64()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Sort the cache sets and ids + size_t temp_storage_bytes = 0; + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + nullptr, + temp_storage_bytes, + (uint64_t*)cache_sets.data_ptr(), + (uint64_t*)sorted_cache_sets.data_ptr(), + unique_indices.data_ptr(), + cache_set_sorted_unique_indices.data_ptr(), + N, + 0, + int(log2(float(lxu_cache_state.size(0) + 1)) + 1) + kLFUCounterBits, + at::cuda::getCurrentCUDAStream(), + false)); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + unique_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + temp_storage.data_ptr(), + temp_storage_bytes, + (uint64_t*)cache_sets.data_ptr(), + (uint64_t*)sorted_cache_sets.data_ptr(), + unique_indices.data_ptr(), + cache_set_sorted_unique_indices.data_ptr(), + N, + 0, + int(log2(float(lxu_cache_state.size(0) + 1)) + 1) + kLFUCounterBits, + at::cuda::getCurrentCUDAStream(), + false)); + }); return {sorted_cache_sets, cache_set_sorted_unique_indices}; } @@ -1595,6 +1624,7 @@ void lfu_cache_populate_cuda( // uint8_t only). Basically no "high-precision cache" support for now. // - The insert/evict of embedding row from the cache are done in a byte-by-byte // manner. +template __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( at::PackedTensorAccessor64 weights, @@ -1609,7 +1639,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( const at::PackedTensorAccessor32 D_offsets, const uint64_t* __restrict__ sorted_cache_sets, - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 cache_set_sorted_indices, const int32_t* __restrict__ N_unique, at::PackedTensorAccessor32 @@ -1665,7 +1695,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); int64_t insert_current_lfu_cost = __shfl_sync(0xFFFFFFFF, sorted_lfu_cost, l); - int64_t insert_idx = cache_set_sorted_indices[n + l]; + index_t insert_idx = cache_set_sorted_indices[n + l]; int64_t insert_lfu_cost = lfu_state[insert_idx]; if (insert_current_lfu_cost > insert_lfu_cost) { @@ -1759,26 +1789,32 @@ void lfu_cache_insert_byte_cuda( int32_t N = cache_set_sorted_unique_indices.numel(); - lfu_cache_insert_byte_kernel<<< - div_round_up(N, kCacheMaxThreads / kWarpSize), - dim3(kWarpSize, kCacheMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - weights.packed_accessor64(), - cache_hash_size_cumsum - .packed_accessor32(), - cache_index_table_map - .packed_accessor32(), - weights_offsets.packed_accessor32(), - weights_tys.packed_accessor32(), - D_offsets.packed_accessor32(), - (uint64_t*)sorted_cache_sets.data_ptr(), - cache_set_sorted_unique_indices - .packed_accessor32(), - unique_indices_length.data_ptr(), - lxu_cache_state.packed_accessor32(), - lxu_cache_weights.packed_accessor64(), - lfu_state.packed_accessor64()); + AT_DISPATCH_INDEX_TYPES( + cache_set_sorted_unique_indices.scalar_type(), + "lfu_cache_insert_byte_cuda", + [&]() { + lfu_cache_insert_byte_kernel<<< + div_round_up(N, kCacheMaxThreads / kWarpSize), + dim3(kWarpSize, kCacheMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + weights.packed_accessor64(), + cache_hash_size_cumsum + .packed_accessor32(), + cache_index_table_map + .packed_accessor32(), + weights_offsets.packed_accessor32(), + weights_tys.packed_accessor32(), + D_offsets.packed_accessor32(), + (uint64_t*)sorted_cache_sets.data_ptr(), + cache_set_sorted_unique_indices + .packed_accessor32(), + unique_indices_length.data_ptr(), + lxu_cache_state.packed_accessor32(), + lxu_cache_weights + .packed_accessor64(), + lfu_state.packed_accessor64()); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1854,8 +1890,9 @@ void lfu_cache_populate_byte_cuda( lfu_state); } +template __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel( - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 linear_cache_indices, const at::PackedTensorAccessor32 lxu_cache_state, @@ -1901,17 +1938,20 @@ Tensor lxu_cache_lookup_cuda( const dim3 threads(kWarpSize, kMaxThreads / kWarpSize); const dim3 blocks(div_round_up(N, kMaxThreads / kWarpSize)); - lxu_cache_lookup_kernel<<< - blocks, - threads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - linear_cache_indices - .packed_accessor32(), - lxu_cache_state.packed_accessor32(), - lxu_cache_locations - .packed_accessor32()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + AT_DISPATCH_INDEX_TYPES( + linear_cache_indices.scalar_type(), "lxu_cache_lookup_cuda", [&]() { + lxu_cache_lookup_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + linear_cache_indices + .packed_accessor32(), + lxu_cache_state.packed_accessor32(), + lxu_cache_locations + .packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); return lxu_cache_locations; }