diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp b/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp index 6d4426cb27..43a182b6b1 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp @@ -4,12 +4,12 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ + #include #include #include #include #include -#include #include "c10/core/ScalarType.h" #ifdef FBCODE_CAFFE2 #include "common/stats/Stats.h" @@ -18,6 +18,8 @@ #include "fbgemm_gpu/sparse_ops_utils.h" #include "fbgemm_gpu/split_embeddings_cache_cuda.cuh" +#include + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -37,7 +39,7 @@ DEFINE_quantile_stat( facebook::fb303::ExportTypeConsts::kNone, std::array{{.25, .50, .75, .99}}); -// Miss rate due to conflict in cache associativity. +// (Unique) Miss rate due to conflict in cache associativity. // # unique misses due to conflict / # requested indices. DEFINE_quantile_stat( tbe_uvm_cache_conflict_unique_miss_rate, @@ -45,6 +47,21 @@ DEFINE_quantile_stat( facebook::fb303::ExportTypeConsts::kNone, std::array{{.25, .50, .75, .99}}); +// Miss rate due to conflict in cache associativity. +// # misses due to conflict / # requested indices. +DEFINE_quantile_stat( + tbe_uvm_cache_conflict_miss_rate, + "tbe_uvm_cache_conflict_miss_rate_per_mille", + facebook::fb303::ExportTypeConsts::kNone, + std::array{{.25, .50, .75, .99}}); + +// Total miss rate. +DEFINE_quantile_stat( + tbe_uvm_cache_total_miss_rate, + "tbe_uvm_cache_total_miss_rate_per_mille", + facebook::fb303::ExportTypeConsts::kNone, + std::array{{.25, .50, .75, .99}}); + // FLAGs to control UVMCacheStats. DEFINE_int32( tbe_uvm_cache_stat_report, @@ -58,6 +75,12 @@ DEFINE_int32( "If tbe_uvm_cache_stat_report is enabled, more detailed raw stats will be printed with this " "period. This should be an integer multiple of tbe_uvm_cache_stat_report."); +DEFINE_int32( + tbe_uvm_cache_enforced_misses, + 0, + "If set to non-zero, some cache lookups (tbe_uvm_cache_enforced_misses / 256) are enforced to be misses; " + "this is performance evaluation purposes only; and should be zero otherwise."); + // TODO: align this with uvm_cache_stats_index in // split_embeddings_cache_cuda.cu. const int kUvmCacheStatsSize = 6; @@ -84,10 +107,11 @@ void process_uvm_cache_stats( // uvm_cache_stats_counters[0]: num_req_indices // uvm_cache_stats_counters[1]: num_unique_indices // uvm_cache_stats_counters[2]: num_unique_misses - // uvm_cache_stats_counters[3]: num_unique_conflict_misses + // uvm_cache_stats_counters[3]: num_conflict_unique_misses + // uvm_cache_stats_counters[4]: num_conflict_misses // They should be zero-out after the calculated rates are populated into // cache counters. - static std::vector uvm_cache_stats_counters(4); + static std::vector uvm_cache_stats_counters(5); // Export cache stats. auto uvm_cache_stats_cpu = uvm_cache_stats.cpu(); @@ -107,19 +131,32 @@ void process_uvm_cache_stats( // Calculate cache related ratios based on the cumulated numbers and // push them into the counter pools. if (populate_uvm_stats && uvm_cache_stats_counters[0] > 0) { - double unique_rate = + const double unique_rate = static_cast(uvm_cache_stats_counters[1]) / uvm_cache_stats_counters[0] * 1000; - double unique_miss_rate = + const double unique_miss_rate = static_cast(uvm_cache_stats_counters[2]) / uvm_cache_stats_counters[0] * 1000; - double unique_conflict_miss_rate = + const double conflict_unique_miss_rate = static_cast(uvm_cache_stats_counters[3]) / uvm_cache_stats_counters[0] * 1000; + const double conflict_miss_rate = + static_cast(uvm_cache_stats_counters[4]) / + uvm_cache_stats_counters[0] * 1000; + // total # misses = unique misses - conflict_unique_misses + conflict + // misses. + const double total_miss_rate = + static_cast( + uvm_cache_stats_counters[2] - uvm_cache_stats_counters[3] + + uvm_cache_stats_counters[4]) / + uvm_cache_stats_counters[0] * 1000; + STATS_tbe_uvm_cache_unique_rate.addValue(unique_rate); STATS_tbe_uvm_cache_unique_miss_rate.addValue(unique_miss_rate); STATS_tbe_uvm_cache_conflict_unique_miss_rate.addValue( - unique_conflict_miss_rate); + conflict_unique_miss_rate); + STATS_tbe_uvm_cache_conflict_miss_rate.addValue(conflict_miss_rate); + STATS_tbe_uvm_cache_total_miss_rate.addValue(total_miss_rate); // Fill all the elements of the vector uvm_cache_stats_counters as 0 // to zero out the cumulated counters. @@ -365,7 +402,7 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( // cache_index_table_map: (linearized) index to table number map. // 1D tensor, dtype=int32. c10::optional cache_index_table_map, - // lxu_cache_state: Cache state (cached idnex, or invalid). + // lxu_cache_state: Cache state (cached index, or invalid). // 2D tensor: # sets x assoc. dtype=int64. c10::optional lxu_cache_state, // lxu_state: meta info for replacement (time stamp for LRU). @@ -461,6 +498,16 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( uvm_cache_stats); #ifdef FBCODE_CAFFE2 + if (FLAGS_tbe_uvm_cache_enforced_misses > 0) { + // Override some lxu_cache_locations (N for every 256 indices) with cache + // miss to enforce access to UVM. + lxu_cache_locations = emulate_cache_miss( + lxu_cache_locations.value(), + FLAGS_tbe_uvm_cache_enforced_misses, + gather_uvm_stats, + uvm_cache_stats); + } + process_uvm_cache_stats( signature, total_cache_hash_size.value(), diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh index 52854a4f2e..3532928963 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh @@ -155,6 +155,12 @@ at::Tensor lxu_cache_lookup_cuda( bool gather_cache_stats, c10::optional uvm_cache_stats); +at::Tensor emulate_cache_miss( + at::Tensor lxu_cache_locations, + const int64_t enforced_misses_per_256, + const bool gather_cache_stats, + at::Tensor uvm_cache_stats); + ///@ingroup table-batched-embed-cuda /// Lookup the LRU/LFU cache: find the cache weights location for all indices. /// Look up the slots in the cache corresponding to `linear_cache_indices`, with diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index 9d23ee9fff..e5930ab745 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -79,6 +79,18 @@ enum uvm_cache_stats_index { num_conflict_misses = 5, }; +// Experiments showed that performance of lru/lxu_cache_find_uncached_kernel is +// not sensitive to grid size as long as the number thread blocks per SM is not +// too small nor too big. +constexpr int MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS = 16; + +int get_max_thread_blocks_for_cache_kernels_() { + cudaDeviceProp* deviceProp = + at::cuda::getDeviceProperties(c10::cuda::current_device()); + return deviceProp->multiProcessorCount * + MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS; +} + } // namespace int64_t host_lxu_cache_slot(int64_t h_in, int64_t C) { @@ -495,6 +507,69 @@ std::tuple> get_unique_indices_cuda( namespace { +template +__global__ __launch_bounds__(kMaxThreads) void emulate_cache_miss_kernel( + at::PackedTensorAccessor32 + lxu_cache_locations, + const int64_t enforced_misses_per_256, + const bool gather_cache_stats, + at::PackedTensorAccessor32 + uvm_cache_stats) { + const int32_t N = lxu_cache_locations.size(0); + int64_t n_enforced_misses = 0; + CUDA_KERNEL_LOOP(n, N) { + if ((n & 0x00FF) < enforced_misses_per_256) { + if (lxu_cache_locations[n] >= 0) { + n_enforced_misses++; + } + lxu_cache_locations[n] = kCacheLocationMissing; + } + } + if (gather_cache_stats && n_enforced_misses > 0) { + atomicAdd( + &uvm_cache_stats[uvm_cache_stats_index::num_conflict_misses], + n_enforced_misses); + } +} +} // namespace + +Tensor emulate_cache_miss( + Tensor lxu_cache_locations, + const int64_t enforced_misses_per_256, + const bool gather_cache_stats, + Tensor uvm_cache_stats) { + TENSOR_ON_CUDA_GPU(lxu_cache_locations); + TENSOR_ON_CUDA_GPU(uvm_cache_stats); + + const auto N = lxu_cache_locations.numel(); + if (lxu_cache_locations.numel() == 0) { + // nothing to do + return lxu_cache_locations; + } + + const dim3 blocks(std::min( + div_round_up(N, kMaxThreads), + get_max_thread_blocks_for_cache_kernels_())); + + AT_DISPATCH_INDEX_TYPES( + lxu_cache_locations.scalar_type(), "emulate_cache_miss", [&] { + emulate_cache_miss_kernel<<< + blocks, + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + lxu_cache_locations + .packed_accessor32(), + enforced_misses_per_256, + gather_cache_stats, + uvm_cache_stats + .packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + return lxu_cache_locations; +} + +namespace { template __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( const at::PackedTensorAccessor32 @@ -622,19 +697,6 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_find_uncached_kernel } } } - -// Experiments showed that performance of lru/lxu_cache_find_uncached_kernel is -// not sensitive to grid size as long as the number thread blocks per SM is not -// too small nor too big. -constexpr int MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS = 16; - -int get_max_thread_blocks_for_cache_kernels_() { - cudaDeviceProp* deviceProp = - at::cuda::getDeviceProperties(c10::cuda::current_device()); - return deviceProp->multiProcessorCount * - MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS; -} - } // namespace std::pair lru_cache_find_uncached_cuda( @@ -798,8 +860,8 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( at::PackedTensorAccessor32 uvm_cache_stats) { const int32_t C = lxu_cache_state.size(0); - int64_t n_conflict_misses = 0; - int64_t n_inserted = 0; + int32_t n_conflict_misses = 0; + int32_t n_inserted = 0; for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique; n += gridDim.x * blockDim.y) { // check if this warp is responsible for this whole segment. diff --git a/fbgemm_gpu/test/uvm_cache_miss_emulate_test.cpp b/fbgemm_gpu/test/uvm_cache_miss_emulate_test.cpp new file mode 100644 index 0000000000..808ed33624 --- /dev/null +++ b/fbgemm_gpu/test/uvm_cache_miss_emulate_test.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh" + +using namespace ::testing; + +// Helper function that generates input tensor for emulate_cache_miss testing. +at::Tensor generate_lxu_cache_locations( + const int64_t num_requests, + const int64_t num_sets, + const int64_t associativity = 32) { + const auto lxu_cache_locations = at::randint( + 0, + num_sets * associativity, + {num_requests}, + at::device(at::kCPU).dtype(at::kInt)); + return lxu_cache_locations; +} + +// Wrapper function that takes lxu_cache_locations on CPU, copies it to GPU, +// runs emulate_cache_miss(), and then returns the result, placed on CPU. +std::pair run_emulate_cache_miss( + at::Tensor lxu_cache_locations, + const int64_t enforced_misses_per_256, + const bool gather_uvm_stats = false) { + at::Tensor lxu_cache_locations_copy = at::_to_copy(lxu_cache_locations); + const auto options = + lxu_cache_locations.options().device(at::kCUDA).dtype(at::kInt); + const auto uvm_cache_stats = + gather_uvm_stats ? at::zeros({6}, options) : at::empty({0}, options); + + const auto lxu_cache_location_with_cache_misses = emulate_cache_miss( + lxu_cache_locations_copy.to(at::kCUDA), + enforced_misses_per_256, + gather_uvm_stats, + uvm_cache_stats); + return {lxu_cache_location_with_cache_misses.cpu(), uvm_cache_stats.cpu()}; +} + +TEST(uvm_cache_miss_emulate_test, no_cache_miss) { + constexpr int64_t num_requests = 10000; + constexpr int64_t num_sets = 32768; + constexpr int64_t associativity = 32; + + auto lxu_cache_locations_cpu = + generate_lxu_cache_locations(num_requests, num_sets, associativity); + auto lxu_cache_location_with_cache_misses_and_uvm_cache_stats = + run_emulate_cache_miss(lxu_cache_locations_cpu, 0); + auto lxu_cache_location_with_cache_misses = + lxu_cache_location_with_cache_misses_and_uvm_cache_stats.first; + EXPECT_TRUE( + at::equal(lxu_cache_locations_cpu, lxu_cache_location_with_cache_misses)); +} + +TEST(uvm_cache_miss_emulate_test, enforced_cache_miss) { + constexpr int64_t num_requests = 10000; + constexpr int64_t num_sets = 32768; + constexpr int64_t associativity = 32; + constexpr std::array enforced_misses_per_256_for_testing = { + 1, 5, 7, 33, 100, 256}; + + for (const bool miss_in_lxu_cache_locations : {false, true}) { + for (const bool gather_cache_stats : {false, true}) { + for (const auto enforced_misses_per_256 : + enforced_misses_per_256_for_testing) { + auto lxu_cache_locations_cpu = + generate_lxu_cache_locations(num_requests, num_sets, associativity); + if (miss_in_lxu_cache_locations) { + // one miss in the original lxu_cache_locations; shouldn't be counted + // as enforced misses from emulate_cache_miss(). + auto z = lxu_cache_locations_cpu.data_ptr(); + z[0] = -1; + } + auto lxu_cache_location_with_cache_misses_and_uvm_cache_stats = + run_emulate_cache_miss( + lxu_cache_locations_cpu, + enforced_misses_per_256, + gather_cache_stats); + auto lxu_cache_location_with_cache_misses = + lxu_cache_location_with_cache_misses_and_uvm_cache_stats.first; + EXPECT_FALSE(at::equal( + lxu_cache_locations_cpu, lxu_cache_location_with_cache_misses)); + + auto x = lxu_cache_locations_cpu.data_ptr(); + auto y = lxu_cache_location_with_cache_misses.data_ptr(); + int64_t enforced_misses = 0; + for (int32_t i = 0; i < lxu_cache_locations_cpu.numel(); ++i) { + if (x[i] != y[i]) { + EXPECT_EQ(y[i], -1); + enforced_misses++; + } + } + int64_t num_requests_over_256 = + static_cast(num_requests / 256); + int64_t expected_misses = num_requests_over_256 * + enforced_misses_per_256 + + std::min((num_requests - num_requests_over_256 * 256), + enforced_misses_per_256); + if (miss_in_lxu_cache_locations) { + expected_misses--; + } + EXPECT_EQ(expected_misses, enforced_misses); + if (gather_cache_stats) { + auto uvm_cache_stats = + lxu_cache_location_with_cache_misses_and_uvm_cache_stats.second; + auto cache_stats_ptr = uvm_cache_stats.data_ptr(); + // enforced misses are recorded as conflict misses. + EXPECT_EQ(expected_misses, cache_stats_ptr[5]); + } + } + } + } +}