Skip to content

Commit

Permalink
uvm_cache_stats for direct mapped (#1951)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1951

D40518654 introduced `uvm_cache_stats` to provide cache metrics for FBGEMM 32way cache.
This diff expands its usage to also provide cache metrics for direct mapped cache.

Differential Revision: https://internalfb.com/D48023956

fbshipit-source-id: ae7e41ec13f1f2eee555a50d0f0e3636025361cf
  • Loading branch information
Sungmin Cho authored and facebook-github-bot committed Sep 4, 2023
1 parent 196ad13 commit 2131e9d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 10 deletions.
8 changes: 6 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ void direct_mapped_lru_cache_populate_byte_cuda(
int64_t time_stamp,
at::Tensor lru_state,
at::Tensor lxu_cache_miss_timestamp,
int64_t row_alignment);
int64_t row_alignment,
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);

///@ingroup table-batched-embed-cuda
/// LFU cache: fetch the rows corresponding to `linear_cache_indices` from
Expand Down Expand Up @@ -174,7 +176,9 @@ at::Tensor emulate_cache_miss(
at::Tensor direct_mapped_lxu_cache_lookup_cuda(
at::Tensor linear_cache_indices,
at::Tensor lxu_cache_state,
int64_t invalid_index);
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);

//////@ingroup table-batched-embed-cuda
/// Flush the cache: store the weights from the cache to the backing storage.
Expand Down
103 changes: 97 additions & 6 deletions fbgemm_gpu/src/split_embeddings_cache_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cub/device/device_radix_sort.cuh>
#include <cub/device/device_run_length_encode.cuh>
#include <cub/device/device_select.cuh>
#include <cub/block/block_reduce.cuh>
#include "fbgemm_gpu/cub_namespace_postfix.cuh"
// clang-format on

Expand Down Expand Up @@ -742,11 +743,24 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_find_uncached_kernel
lxu_cache_state,
const int64_t time_stamp,
pta::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits> lru_state,
const bool gather_cache_stats,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats,
pta::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits>
lxu_cache_miss_timestamp) {
const int32_t N = linear_cache_indices.size(0);
const int32_t C = lxu_cache_state.size(0);
if (gather_cache_stats) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_calls], 1); // N_called.
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_requested_indices],
N); // N_requested_indices.
}
}
CUDA_KERNEL_LOOP(n, N) {
int64_t idx = linear_cache_indices[n];
if (idx == max_indices) {
Expand Down Expand Up @@ -893,7 +907,9 @@ Tensor direct_mapped_lru_cache_find_uncached_cuda(
Tensor lxu_cache_state,
int64_t time_stamp,
Tensor lru_state,
Tensor lxu_cache_miss_timestamp) {
Tensor lxu_cache_miss_timestamp,
bool gather_cache_stats,
Tensor uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
linear_cache_indices,
lxu_cache_state,
Expand Down Expand Up @@ -929,6 +945,8 @@ Tensor direct_mapped_lru_cache_find_uncached_cuda(
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
time_stamp,
MAKE_PTA_WITH_NAME(func_name, lru_state, int64_t, 2, 32),
gather_cache_stats,
MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, lxu_cache_miss_timestamp, int64_t, 2, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -1431,6 +1449,9 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_insert_byte_kernel(
pta::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits>
lxu_cache_miss_timestamp,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> cache_sets,
const bool gather_cache_stats,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats,
const int64_t row_alignment) {
const int32_t N = cache_sets.size(0);
Expand Down Expand Up @@ -1458,6 +1479,24 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_insert_byte_kernel(
// continue;
// }
if (gather_cache_stats && threadIdx.x == 0) {
// We are using this slot for a slightly different purpose.
// In 32 way:
// UVM traffic for insert
// = # of inserted rows
// = # of unique misses - # of unique misses that were not inserted
// = uvm_cache_stats_index::num_unique_misses
// - uvm_cache_stats_index::num_conflict_unique_misses
// In Direct Mapped (here):
// UVM traffic for insert
// = # of inserted rows
// = uvm_cache_stats_index::num_conflict_unique_misses
// (just store here directly)
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_conflict_unique_misses],
1);
}
// insert the index in the buffer into our only slot
const int32_t insert_slot = 0;
Expand Down Expand Up @@ -1579,6 +1618,8 @@ void direct_mapped_lru_cache_insert_byte_cuda(
Tensor linear_cache_indices,
Tensor lxu_cache_miss_timestamp,
Tensor cache_sets,
bool gather_cache_stats,
Tensor uvm_cache_stats,
int64_t row_alignment) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
weights,
Expand Down Expand Up @@ -1628,6 +1669,8 @@ void direct_mapped_lru_cache_insert_byte_cuda(
MAKE_PTA_WITH_NAME(
func_name, lxu_cache_miss_timestamp, int64_t, 2, 32),
MAKE_PTA_WITH_NAME(func_name, cache_sets, int32_t, 1, 32),
gather_cache_stats,
MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats, int32_t, 1, 32),
row_alignment);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand Down Expand Up @@ -1739,7 +1782,9 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
int64_t time_stamp,
Tensor lru_state,
Tensor lxu_cache_miss_timestamp,
int64_t row_alignment) {
int64_t row_alignment,
bool gather_cache_stats,
c10::optional<Tensor> uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
weights,
cache_hash_size_cumsum,
Expand All @@ -1753,6 +1798,14 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
lru_state,
lxu_cache_miss_timestamp);
if (gather_cache_stats) {
TORCH_CHECK(uvm_cache_stats.has_value());
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
uvm_cache_stats, lxu_cache_weights);
}
auto uvm_cache_stats_ = uvm_cache_stats.value_or(
at::empty({0}, weights.options().dtype(at::kInt)));
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(weights.get_device());
Expand Down Expand Up @@ -1795,7 +1848,9 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
lxu_cache_state,
time_stamp,
lru_state,
lxu_cache_miss_timestamp);
lxu_cache_miss_timestamp,
gather_cache_stats,
uvm_cache_stats_);
// insert caching weights
direct_mapped_lru_cache_insert_byte_cuda(
Expand All @@ -1812,6 +1867,8 @@ DLL_PUBLIC void direct_mapped_lru_cache_populate_byte_cuda(
linear_cache_indices,
lxu_cache_miss_timestamp,
cache_sets,
gather_cache_stats,
uvm_cache_stats_,
row_alignment);
}
Expand Down Expand Up @@ -2632,10 +2689,16 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lxu_cache_lookup_kernel(
lxu_cache_state,
int64_t invalid_index,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
lxu_cache_locations) {
lxu_cache_locations,
const bool gather_cache_stats,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats) {
const int32_t C = lxu_cache_state.size(0);
const int32_t N = linear_cache_indices.size(0);
int32_t n_indices = 0;
int32_t n_hits = 0;
CUDA_KERNEL_LOOP(n, N) {
int32_t cache_location = kCacheLocationMissing;
const auto slot = 0;
Expand All @@ -2646,13 +2709,29 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lxu_cache_lookup_kernel(
}
const int32_t cache_set = cache_slot(idx, C);
n_indices++;
const bool found =
(::__ldg((&lxu_cache_state[cache_set][0]) + slot) == idx);
if (found) {
cache_location = cache_set;
n_hits++;
}
lxu_cache_locations[n] = cache_location;
}
if (gather_cache_stats) {
typedef cub::BlockReduce<int32_t, kMaxThreads> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp;
int32_t conflict_miss = n_indices - n_hits;
int32_t conflict_miss_sum = BlockReduce(temp).Sum(conflict_miss);
if (threadIdx.x == 0) {
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_conflict_misses],
conflict_miss_sum);
}
}
}
} // namespace
Expand Down Expand Up @@ -2764,10 +2843,20 @@ DLL_PUBLIC void lxu_cache_locations_update_cuda(
DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cuda(
Tensor linear_cache_indices,
Tensor lxu_cache_state,
int64_t invalid_index) {
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<Tensor> uvm_cache_stats) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
linear_cache_indices, lxu_cache_state);
if (gather_cache_stats) {
TORCH_CHECK(uvm_cache_stats.has_value());
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
uvm_cache_stats, linear_cache_indices);
}
auto uvm_cache_stats_ = uvm_cache_stats.value_or(
at::empty({0}, linear_cache_indices.options().dtype(at::kInt)));
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(linear_cache_indices.get_device());
Expand Down Expand Up @@ -2796,7 +2885,9 @@ DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cuda(
MAKE_PTA_WITH_NAME(func_name, linear_cache_indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
invalid_index,
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32));
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32),
gather_cache_stats,
MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats_, int_32t, 1, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/src/split_table_batched_embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, int row_alignment=16, bool gather_cache_stats=False, Tensor(d!)? uvm_cache_stats=None) -> ()");
DISPATCH_TO_CUDA("lru_cache_populate_byte", lru_cache_populate_byte_cuda);
m.def(
"direct_mapped_lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, Tensor(d!) lxu_cache_miss_timestamp, int row_alignment=16) -> ()");
"direct_mapped_lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, Tensor(d!) lxu_cache_miss_timestamp, int row_alignment=16, bool gather_cache_stats=False, Tensor(e!)? uvm_cache_stats=None) -> ()");
DISPATCH_TO_CUDA(
"direct_mapped_lru_cache_populate_byte",
direct_mapped_lru_cache_populate_byte_cuda);
Expand All @@ -45,7 +45,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor");
DISPATCH_TO_CUDA("lxu_cache_lookup", lxu_cache_lookup_cuda);
m.def(
"direct_mapped_lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1) -> Tensor");
"direct_mapped_lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor");
DISPATCH_TO_CUDA(
"direct_mapped_lxu_cache_lookup", direct_mapped_lxu_cache_lookup_cuda);
m.def(
Expand Down

0 comments on commit 2131e9d

Please sign in to comment.