Skip to content

Commit

Permalink
Implement cache miss emulation in UVM_CACHING (pytorch#1637)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1637

Enforce cache misses (even if trace-driven testing doesn't experience cache miss due to limited trace size) so that we can evaluate performance under cache misses.

Note that it's not exactly cache misses; enforce access to UVM by overriding lxu_cache_locations -- N / 256 requests.

Differential Revision: D42194019

fbshipit-source-id: 5857bf342e5613a9a2a25e46525041a668a1d3c0
  • Loading branch information
doehyun authored and facebook-github-bot committed Mar 10, 2023
1 parent 8616ed7 commit 63d10c3
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 23 deletions.
62 changes: 54 additions & 8 deletions fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
#include <ATen/core/op_registration/op_registration.h>
Expand Down Expand Up @@ -37,14 +38,29 @@ DEFINE_quantile_stat(
facebook::fb303::ExportTypeConsts::kNone,
std::array<double, 4>{{.25, .50, .75, .99}});

// Miss rate due to conflict in cache associativity.
// (Unique) Miss rate due to conflict in cache associativity.
// # unique misses due to conflict / # requested indices.
DEFINE_quantile_stat(
tbe_uvm_cache_conflict_unique_miss_rate,
"tbe_uvm_cache_conflict_unique_miss_rate_per_mille",
facebook::fb303::ExportTypeConsts::kNone,
std::array<double, 4>{{.25, .50, .75, .99}});

// Miss rate due to conflict in cache associativity.
// # misses due to conflict / # requested indices.
DEFINE_quantile_stat(
tbe_uvm_cache_conflict_miss_rate,
"tbe_uvm_cache_conflict_miss_rate_per_mille",
facebook::fb303::ExportTypeConsts::kNone,
std::array<double, 4>{{.25, .50, .75, .99}});

// Total miss rate.
DEFINE_quantile_stat(
tbe_uvm_cache_total_miss_rate,
"tbe_uvm_cache_total_miss_rate_per_mille",
facebook::fb303::ExportTypeConsts::kNone,
std::array<double, 4>{{.25, .50, .75, .99}});

// FLAGs to control UVMCacheStats.
DEFINE_int32(
tbe_uvm_cache_stat_report,
Expand All @@ -58,6 +74,12 @@ DEFINE_int32(
"If tbe_uvm_cache_stat_report is enabled, more detailed raw stats will be printed with this "
"period. This should be an integer multiple of tbe_uvm_cache_stat_report.");

DEFINE_int32(
tbe_uvm_cache_enforced_misses,
0,
"If set to non-zero, some cache lookups (tbe_uvm_cache_enforced_misses / 256) are enforced to be misses; "
"this is performance evaluation purposes only; and should be zero otherwise.");

// TODO: align this with uvm_cache_stats_index in
// split_embeddings_cache_cuda.cu.
const int kUvmCacheStatsSize = 6;
Expand All @@ -84,10 +106,11 @@ void process_uvm_cache_stats(
// uvm_cache_stats_counters[0]: num_req_indices
// uvm_cache_stats_counters[1]: num_unique_indices
// uvm_cache_stats_counters[2]: num_unique_misses
// uvm_cache_stats_counters[3]: num_unique_conflict_misses
// uvm_cache_stats_counters[3]: num_conflict_unique_misses
// uvm_cache_stats_counters[4]: num_conflict_misses
// They should be zero-out after the calculated rates are populated into
// cache counters.
static std::vector<int64_t> uvm_cache_stats_counters(4);
static std::vector<int64_t> uvm_cache_stats_counters(5);

// Export cache stats.
auto uvm_cache_stats_cpu = uvm_cache_stats.cpu();
Expand All @@ -107,19 +130,32 @@ void process_uvm_cache_stats(
// Calculate cache related ratios based on the cumulated numbers and
// push them into the counter pools.
if (populate_uvm_stats && uvm_cache_stats_counters[0] > 0) {
double unique_rate =
const double unique_rate =
static_cast<double>(uvm_cache_stats_counters[1]) /
uvm_cache_stats_counters[0] * 1000;
double unique_miss_rate =
const double unique_miss_rate =
static_cast<double>(uvm_cache_stats_counters[2]) /
uvm_cache_stats_counters[0] * 1000;
double unique_conflict_miss_rate =
const double conflict_unique_miss_rate =
static_cast<double>(uvm_cache_stats_counters[3]) /
uvm_cache_stats_counters[0] * 1000;
const double conflict_miss_rate =
static_cast<double>(uvm_cache_stats_counters[4]) /
uvm_cache_stats_counters[0] * 1000;
// total # misses = unique misses - conflict_unique_misses + conflict
// misses.
const double total_miss_rate =
static_cast<double>(
uvm_cache_stats_counters[2] - uvm_cache_stats_counters[3] +
uvm_cache_stats_counters[4]) /
uvm_cache_stats_counters[0] * 1000;

STATS_tbe_uvm_cache_unique_rate.addValue(unique_rate);
STATS_tbe_uvm_cache_unique_miss_rate.addValue(unique_miss_rate);
STATS_tbe_uvm_cache_conflict_unique_miss_rate.addValue(
unique_conflict_miss_rate);
conflict_unique_miss_rate);
STATS_tbe_uvm_cache_conflict_miss_rate.addValue(conflict_miss_rate);
STATS_tbe_uvm_cache_total_miss_rate.addValue(total_miss_rate);

// Fill all the elements of the vector uvm_cache_stats_counters as 0
// to zero out the cumulated counters.
Expand Down Expand Up @@ -365,7 +401,7 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
// cache_index_table_map: (linearized) index to table number map.
// 1D tensor, dtype=int32.
c10::optional<Tensor> cache_index_table_map,
// lxu_cache_state: Cache state (cached idnex, or invalid).
// lxu_cache_state: Cache state (cached index, or invalid).
// 2D tensor: # sets x assoc. dtype=int64.
c10::optional<Tensor> lxu_cache_state,
// lxu_state: meta info for replacement (time stamp for LRU).
Expand Down Expand Up @@ -461,6 +497,16 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
uvm_cache_stats);

#ifdef FBCODE_CAFFE2
if (FLAGS_tbe_uvm_cache_enforced_misses > 0) {
// Override some lxu_cache_locations (N for every 256 indices) with cache
// miss to enforce access to UVM.
lxu_cache_locations = emulate_cache_miss(
lxu_cache_locations.value(),
FLAGS_tbe_uvm_cache_enforced_misses,
gather_uvm_stats,
uvm_cache_stats);
}

process_uvm_cache_stats(
signature,
total_cache_hash_size.value(),
Expand Down
6 changes: 6 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ at::Tensor lxu_cache_lookup_cuda(
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);

at::Tensor emulate_cache_miss(
at::Tensor lxu_cache_locations,
const int64_t enforced_misses_per_256,
const bool gather_cache_stats,
at::Tensor uvm_cache_stats);

///@ingroup table-batched-embed-cuda
/// Lookup the LRU/LFU cache: find the cache weights location for all indices.
/// Look up the slots in the cache corresponding to `linear_cache_indices`, with
Expand Down
92 changes: 77 additions & 15 deletions fbgemm_gpu/src/split_embeddings_cache_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ enum uvm_cache_stats_index {
num_conflict_misses = 5,
};

// Experiments showed that performance of lru/lxu_cache_find_uncached_kernel is
// not sensitive to grid size as long as the number thread blocks per SM is not
// too small nor too big.
constexpr int MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS = 16;

int get_max_thread_blocks_for_cache_kernels_() {
cudaDeviceProp* deviceProp =
at::cuda::getDeviceProperties(c10::cuda::current_device());
return deviceProp->multiProcessorCount *
MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS;
}

} // namespace

int64_t host_lxu_cache_slot(int64_t h_in, int64_t C) {
Expand Down Expand Up @@ -495,6 +507,69 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> get_unique_indices_cuda(
namespace {
template <typename index_t>
__global__ __launch_bounds__(kMaxThreads) void emulate_cache_miss_kernel(
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
lxu_cache_locations,
const int64_t enforced_misses_per_256,
const bool gather_cache_stats,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats) {
const int32_t N = lxu_cache_locations.size(0);
int64_t n_enforced_misses = 0;
CUDA_KERNEL_LOOP(n, N) {
if ((n & 0x00FF) < enforced_misses_per_256) {
if (lxu_cache_locations[n] >= 0) {
n_enforced_misses++;
}
lxu_cache_locations[n] = kCacheLocationMissing;
}
}
if (gather_cache_stats && n_enforced_misses > 0) {
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_conflict_misses],
n_enforced_misses);
}
}
} // namespace
Tensor emulate_cache_miss(
Tensor lxu_cache_locations,
const int64_t enforced_misses_per_256,
const bool gather_cache_stats,
Tensor uvm_cache_stats) {
TENSOR_ON_CUDA_GPU(lxu_cache_locations);
TENSOR_ON_CUDA_GPU(uvm_cache_stats);
const auto N = lxu_cache_locations.numel();
if (lxu_cache_locations.numel() == 0) {
// nothing to do
return lxu_cache_locations;
}
const dim3 blocks(std::min(
div_round_up(N, kMaxThreads),
get_max_thread_blocks_for_cache_kernels_()));
AT_DISPATCH_INDEX_TYPES(
lxu_cache_locations.scalar_type(), "emulate_cache_miss", [&] {
emulate_cache_miss_kernel<<<
blocks,
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
lxu_cache_locations
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
enforced_misses_per_256,
gather_cache_stats,
uvm_cache_stats
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return lxu_cache_locations;
}
namespace {
template <typename index_t>
__global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel(
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
Expand Down Expand Up @@ -622,19 +697,6 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_find_uncached_kernel
}
}
}
// Experiments showed that performance of lru/lxu_cache_find_uncached_kernel is
// not sensitive to grid size as long as the number thread blocks per SM is not
// too small nor too big.
constexpr int MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS = 16;
int get_max_thread_blocks_for_cache_kernels_() {
cudaDeviceProp* deviceProp =
at::cuda::getDeviceProperties(c10::cuda::current_device());
return deviceProp->multiProcessorCount *
MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS;
}
} // namespace
std::pair<Tensor, Tensor> lru_cache_find_uncached_cuda(
Expand Down Expand Up @@ -798,8 +860,8 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel(
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats) {
const int32_t C = lxu_cache_state.size(0);
int64_t n_conflict_misses = 0;
int64_t n_inserted = 0;
int32_t n_conflict_misses = 0;
int32_t n_inserted = 0;
for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
n += gridDim.x * blockDim.y) {
// check if this warp is responsible for this whole segment.
Expand Down
119 changes: 119 additions & 0 deletions fbgemm_gpu/test/uvm_cache_miss_emulate_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <gtest/gtest.h>

#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh"

using namespace ::testing;

// Helper function that generates input tensor for emulate_cache_miss testing.
at::Tensor generate_lxu_cache_locations(
int64_t num_requests,
int64_t num_sets,
int64_t associativity = 32) {
const auto lxu_cache_locations = at::randint(
0,
num_sets * associativity,
{num_requests},
at::device(at::kCPU).dtype(at::kInt));
return lxu_cache_locations;
}

// Wrapper function that takes lxu_cache_locations on CPU, copies it to GPU,
// runs emulate_cache_miss(), and then returns the result, placed on CPU.
std::pair<at::Tensor, at::Tensor> run_emulate_cache_miss(
at::Tensor lxu_cache_locations,
int64_t enforced_misses_per_256,
bool gather_uvm_stats = false) {
at::Tensor lxu_cache_locations_copy = at::_to_copy(lxu_cache_locations);
const auto options =
lxu_cache_locations.options().device(at::kCUDA).dtype(at::kInt);
const auto uvm_cache_stats =
gather_uvm_stats ? at::zeros({6}, options) : at::empty({0}, options);

const auto lxu_cache_location_with_cache_misses = emulate_cache_miss(
lxu_cache_locations_copy.to(at::kCUDA),
enforced_misses_per_256,
gather_uvm_stats,
uvm_cache_stats);
return {lxu_cache_location_with_cache_misses.cpu(), uvm_cache_stats.cpu()};
}

TEST(uvm_cache_miss_emulate_test, no_cache_miss) {
const int64_t num_requests = 10000;
const int64_t num_sets = 32768;
const int64_t associativity = 32;

auto lxu_cache_locations_cpu =
generate_lxu_cache_locations(num_requests, num_sets, associativity);
auto lxu_cache_location_with_cache_misses_and_uvm_cache_stats =
run_emulate_cache_miss(lxu_cache_locations_cpu, 0);
auto lxu_cache_location_with_cache_misses =
lxu_cache_location_with_cache_misses_and_uvm_cache_stats.first;
EXPECT_TRUE(
at::equal(lxu_cache_locations_cpu, lxu_cache_location_with_cache_misses));
}

TEST(uvm_cache_miss_emulate_test, enforced_cache_miss) {
const int64_t num_requests = 10000;
const int64_t num_sets = 32768;
const int64_t associativity = 32;
std::vector<int64_t> enforced_misses_per_256_for_testing = {
1, 5, 7, 33, 100, 256};

for (const bool miss_in_lxu_cache_locations : {false, true}) {
for (const bool gather_cache_stats : {false, true}) {
for (const auto enforced_misses_per_256 :
enforced_misses_per_256_for_testing) {
auto lxu_cache_locations_cpu =
generate_lxu_cache_locations(num_requests, num_sets, associativity);
if (miss_in_lxu_cache_locations) {
// one miss in the original lxu_cache_locations; shouldn't be counted
// as enforced misses from emulate_cache_miss().
auto z = lxu_cache_locations_cpu.data_ptr<int32_t>();
z[0] = -1;
}
auto lxu_cache_location_with_cache_misses_and_uvm_cache_stats =
run_emulate_cache_miss(
lxu_cache_locations_cpu,
enforced_misses_per_256,
gather_cache_stats);
auto lxu_cache_location_with_cache_misses =
lxu_cache_location_with_cache_misses_and_uvm_cache_stats.first;
EXPECT_FALSE(at::equal(
lxu_cache_locations_cpu, lxu_cache_location_with_cache_misses));

auto x = lxu_cache_locations_cpu.data_ptr<int32_t>();
auto y = lxu_cache_location_with_cache_misses.data_ptr<int32_t>();
int64_t enforced_misses = 0;
for (int32_t i = 0; i < lxu_cache_locations_cpu.numel(); ++i) {
if (x[i] != y[i]) {
EXPECT_EQ(y[i], -1);
enforced_misses++;
}
}
int64_t num_requests_over_256 =
static_cast<int64_t>(num_requests / 256);
int64_t expected_misses = num_requests_over_256 *
enforced_misses_per_256 +
std::min((num_requests - num_requests_over_256 * 256),
enforced_misses_per_256);
if (miss_in_lxu_cache_locations) {
expected_misses--;
}
EXPECT_EQ(expected_misses, enforced_misses);
if (gather_cache_stats) {
auto uvm_cache_stats =
lxu_cache_location_with_cache_misses_and_uvm_cache_stats.second;
auto cache_stats_ptr = uvm_cache_stats.data_ptr<int32_t>();
// enforced misses are recorded as conflict misses.
EXPECT_EQ(expected_misses, cache_stats_ptr[5]);
}
}
}
}
}

0 comments on commit 63d10c3

Please sign in to comment.