Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cache miss emulation in UVM_CACHING #1637

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/library.h>
#include <algorithm>
#include "c10/core/ScalarType.h"
#ifdef FBCODE_CAFFE2
#include "common/stats/Stats.h"
Expand All @@ -18,6 +18,8 @@
#include "fbgemm_gpu/sparse_ops_utils.h"
#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh"

#include <algorithm>

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

Expand All @@ -37,14 +39,29 @@ DEFINE_quantile_stat(
facebook::fb303::ExportTypeConsts::kNone,
std::array<double, 4>{{.25, .50, .75, .99}});

// Miss rate due to conflict in cache associativity.
// (Unique) Miss rate due to conflict in cache associativity.
// # unique misses due to conflict / # requested indices.
DEFINE_quantile_stat(
tbe_uvm_cache_conflict_unique_miss_rate,
"tbe_uvm_cache_conflict_unique_miss_rate_per_mille",
facebook::fb303::ExportTypeConsts::kNone,
std::array<double, 4>{{.25, .50, .75, .99}});

// Miss rate due to conflict in cache associativity.
// # misses due to conflict / # requested indices.
DEFINE_quantile_stat(
tbe_uvm_cache_conflict_miss_rate,
"tbe_uvm_cache_conflict_miss_rate_per_mille",
facebook::fb303::ExportTypeConsts::kNone,
std::array<double, 4>{{.25, .50, .75, .99}});

// Total miss rate.
DEFINE_quantile_stat(
tbe_uvm_cache_total_miss_rate,
"tbe_uvm_cache_total_miss_rate_per_mille",
facebook::fb303::ExportTypeConsts::kNone,
std::array<double, 4>{{.25, .50, .75, .99}});

// FLAGs to control UVMCacheStats.
DEFINE_int32(
tbe_uvm_cache_stat_report,
Expand All @@ -58,6 +75,12 @@ DEFINE_int32(
"If tbe_uvm_cache_stat_report is enabled, more detailed raw stats will be printed with this "
"period. This should be an integer multiple of tbe_uvm_cache_stat_report.");

DEFINE_int32(
tbe_uvm_cache_enforced_misses,
0,
"If set to non-zero, some cache lookups (tbe_uvm_cache_enforced_misses / 256) are enforced to be misses; "
"this is performance evaluation purposes only; and should be zero otherwise.");

// TODO: align this with uvm_cache_stats_index in
// split_embeddings_cache_cuda.cu.
const int kUvmCacheStatsSize = 6;
Expand All @@ -84,10 +107,11 @@ void process_uvm_cache_stats(
// uvm_cache_stats_counters[0]: num_req_indices
// uvm_cache_stats_counters[1]: num_unique_indices
// uvm_cache_stats_counters[2]: num_unique_misses
// uvm_cache_stats_counters[3]: num_unique_conflict_misses
// uvm_cache_stats_counters[3]: num_conflict_unique_misses
// uvm_cache_stats_counters[4]: num_conflict_misses
// They should be zero-out after the calculated rates are populated into
// cache counters.
static std::vector<int64_t> uvm_cache_stats_counters(4);
static std::vector<int64_t> uvm_cache_stats_counters(5);

// Export cache stats.
auto uvm_cache_stats_cpu = uvm_cache_stats.cpu();
Expand All @@ -107,19 +131,32 @@ void process_uvm_cache_stats(
// Calculate cache related ratios based on the cumulated numbers and
// push them into the counter pools.
if (populate_uvm_stats && uvm_cache_stats_counters[0] > 0) {
double unique_rate =
const double unique_rate =
static_cast<double>(uvm_cache_stats_counters[1]) /
uvm_cache_stats_counters[0] * 1000;
double unique_miss_rate =
const double unique_miss_rate =
static_cast<double>(uvm_cache_stats_counters[2]) /
uvm_cache_stats_counters[0] * 1000;
double unique_conflict_miss_rate =
const double conflict_unique_miss_rate =
static_cast<double>(uvm_cache_stats_counters[3]) /
uvm_cache_stats_counters[0] * 1000;
const double conflict_miss_rate =
static_cast<double>(uvm_cache_stats_counters[4]) /
uvm_cache_stats_counters[0] * 1000;
// total # misses = unique misses - conflict_unique_misses + conflict
// misses.
const double total_miss_rate =
static_cast<double>(
uvm_cache_stats_counters[2] - uvm_cache_stats_counters[3] +
uvm_cache_stats_counters[4]) /
uvm_cache_stats_counters[0] * 1000;

STATS_tbe_uvm_cache_unique_rate.addValue(unique_rate);
STATS_tbe_uvm_cache_unique_miss_rate.addValue(unique_miss_rate);
STATS_tbe_uvm_cache_conflict_unique_miss_rate.addValue(
unique_conflict_miss_rate);
conflict_unique_miss_rate);
STATS_tbe_uvm_cache_conflict_miss_rate.addValue(conflict_miss_rate);
STATS_tbe_uvm_cache_total_miss_rate.addValue(total_miss_rate);

// Fill all the elements of the vector uvm_cache_stats_counters as 0
// to zero out the cumulated counters.
Expand Down Expand Up @@ -365,7 +402,7 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
// cache_index_table_map: (linearized) index to table number map.
// 1D tensor, dtype=int32.
c10::optional<Tensor> cache_index_table_map,
// lxu_cache_state: Cache state (cached idnex, or invalid).
// lxu_cache_state: Cache state (cached index, or invalid).
// 2D tensor: # sets x assoc. dtype=int64.
c10::optional<Tensor> lxu_cache_state,
// lxu_state: meta info for replacement (time stamp for LRU).
Expand Down Expand Up @@ -461,6 +498,16 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
uvm_cache_stats);

#ifdef FBCODE_CAFFE2
if (FLAGS_tbe_uvm_cache_enforced_misses > 0) {
// Override some lxu_cache_locations (N for every 256 indices) with cache
// miss to enforce access to UVM.
lxu_cache_locations = emulate_cache_miss(
lxu_cache_locations.value(),
FLAGS_tbe_uvm_cache_enforced_misses,
gather_uvm_stats,
uvm_cache_stats);
}

process_uvm_cache_stats(
signature,
total_cache_hash_size.value(),
Expand Down
6 changes: 6 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ at::Tensor lxu_cache_lookup_cuda(
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);

at::Tensor emulate_cache_miss(
at::Tensor lxu_cache_locations,
const int64_t enforced_misses_per_256,
const bool gather_cache_stats,
at::Tensor uvm_cache_stats);

///@ingroup table-batched-embed-cuda
/// Lookup the LRU/LFU cache: find the cache weights location for all indices.
/// Look up the slots in the cache corresponding to `linear_cache_indices`, with
Expand Down
92 changes: 77 additions & 15 deletions fbgemm_gpu/src/split_embeddings_cache_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ enum uvm_cache_stats_index {
num_conflict_misses = 5,
};

// Experiments showed that performance of lru/lxu_cache_find_uncached_kernel is
// not sensitive to grid size as long as the number thread blocks per SM is not
// too small nor too big.
constexpr int MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS = 16;

int get_max_thread_blocks_for_cache_kernels_() {
cudaDeviceProp* deviceProp =
at::cuda::getDeviceProperties(c10::cuda::current_device());
return deviceProp->multiProcessorCount *
MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS;
}

} // namespace

int64_t host_lxu_cache_slot(int64_t h_in, int64_t C) {
Expand Down Expand Up @@ -495,6 +507,69 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> get_unique_indices_cuda(

namespace {

template <typename index_t>
__global__ __launch_bounds__(kMaxThreads) void emulate_cache_miss_kernel(
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
lxu_cache_locations,
const int64_t enforced_misses_per_256,
const bool gather_cache_stats,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats) {
const int32_t N = lxu_cache_locations.size(0);
int64_t n_enforced_misses = 0;
CUDA_KERNEL_LOOP(n, N) {
if ((n & 0x00FF) < enforced_misses_per_256) {
if (lxu_cache_locations[n] >= 0) {
n_enforced_misses++;
}
lxu_cache_locations[n] = kCacheLocationMissing;
}
}
if (gather_cache_stats && n_enforced_misses > 0) {
atomicAdd(
&uvm_cache_stats[uvm_cache_stats_index::num_conflict_misses],
n_enforced_misses);
}
}
} // namespace

Tensor emulate_cache_miss(
Tensor lxu_cache_locations,
const int64_t enforced_misses_per_256,
const bool gather_cache_stats,
Tensor uvm_cache_stats) {
TENSOR_ON_CUDA_GPU(lxu_cache_locations);
TENSOR_ON_CUDA_GPU(uvm_cache_stats);

const auto N = lxu_cache_locations.numel();
if (lxu_cache_locations.numel() == 0) {
// nothing to do
return lxu_cache_locations;
}

const dim3 blocks(std::min(
div_round_up(N, kMaxThreads),
get_max_thread_blocks_for_cache_kernels_()));

AT_DISPATCH_INDEX_TYPES(
lxu_cache_locations.scalar_type(), "emulate_cache_miss", [&] {
emulate_cache_miss_kernel<<<
blocks,
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
lxu_cache_locations
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
enforced_misses_per_256,
gather_cache_stats,
uvm_cache_stats
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return lxu_cache_locations;
}

namespace {
template <typename index_t>
__global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel(
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
Expand Down Expand Up @@ -622,19 +697,6 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_find_uncached_kernel
}
}
}

// Experiments showed that performance of lru/lxu_cache_find_uncached_kernel is
// not sensitive to grid size as long as the number thread blocks per SM is not
// too small nor too big.
constexpr int MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS = 16;

int get_max_thread_blocks_for_cache_kernels_() {
cudaDeviceProp* deviceProp =
at::cuda::getDeviceProperties(c10::cuda::current_device());
return deviceProp->multiProcessorCount *
MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS;
}

} // namespace

std::pair<Tensor, Tensor> lru_cache_find_uncached_cuda(
Expand Down Expand Up @@ -798,8 +860,8 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel(
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats) {
const int32_t C = lxu_cache_state.size(0);
int64_t n_conflict_misses = 0;
int64_t n_inserted = 0;
int32_t n_conflict_misses = 0;
int32_t n_inserted = 0;
for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
n += gridDim.x * blockDim.y) {
// check if this warp is responsible for this whole segment.
Expand Down
119 changes: 119 additions & 0 deletions fbgemm_gpu/test/uvm_cache_miss_emulate_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <gtest/gtest.h>

#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh"

using namespace ::testing;

// Helper function that generates input tensor for emulate_cache_miss testing.
at::Tensor generate_lxu_cache_locations(
const int64_t num_requests,
const int64_t num_sets,
const int64_t associativity = 32) {
const auto lxu_cache_locations = at::randint(
0,
num_sets * associativity,
{num_requests},
at::device(at::kCPU).dtype(at::kInt));
return lxu_cache_locations;
}

// Wrapper function that takes lxu_cache_locations on CPU, copies it to GPU,
// runs emulate_cache_miss(), and then returns the result, placed on CPU.
std::pair<at::Tensor, at::Tensor> run_emulate_cache_miss(
at::Tensor lxu_cache_locations,
const int64_t enforced_misses_per_256,
const bool gather_uvm_stats = false) {
at::Tensor lxu_cache_locations_copy = at::_to_copy(lxu_cache_locations);
const auto options =
lxu_cache_locations.options().device(at::kCUDA).dtype(at::kInt);
const auto uvm_cache_stats =
gather_uvm_stats ? at::zeros({6}, options) : at::empty({0}, options);

const auto lxu_cache_location_with_cache_misses = emulate_cache_miss(
lxu_cache_locations_copy.to(at::kCUDA),
enforced_misses_per_256,
gather_uvm_stats,
uvm_cache_stats);
return {lxu_cache_location_with_cache_misses.cpu(), uvm_cache_stats.cpu()};
}

TEST(uvm_cache_miss_emulate_test, no_cache_miss) {
constexpr int64_t num_requests = 10000;
constexpr int64_t num_sets = 32768;
constexpr int64_t associativity = 32;

auto lxu_cache_locations_cpu =
generate_lxu_cache_locations(num_requests, num_sets, associativity);
auto lxu_cache_location_with_cache_misses_and_uvm_cache_stats =
run_emulate_cache_miss(lxu_cache_locations_cpu, 0);
auto lxu_cache_location_with_cache_misses =
lxu_cache_location_with_cache_misses_and_uvm_cache_stats.first;
EXPECT_TRUE(
at::equal(lxu_cache_locations_cpu, lxu_cache_location_with_cache_misses));
}

TEST(uvm_cache_miss_emulate_test, enforced_cache_miss) {
constexpr int64_t num_requests = 10000;
constexpr int64_t num_sets = 32768;
constexpr int64_t associativity = 32;
constexpr std::array<int64_t, 6> enforced_misses_per_256_for_testing = {
1, 5, 7, 33, 100, 256};

for (const bool miss_in_lxu_cache_locations : {false, true}) {
for (const bool gather_cache_stats : {false, true}) {
for (const auto enforced_misses_per_256 :
enforced_misses_per_256_for_testing) {
auto lxu_cache_locations_cpu =
generate_lxu_cache_locations(num_requests, num_sets, associativity);
if (miss_in_lxu_cache_locations) {
// one miss in the original lxu_cache_locations; shouldn't be counted
// as enforced misses from emulate_cache_miss().
auto z = lxu_cache_locations_cpu.data_ptr<int32_t>();
z[0] = -1;
}
auto lxu_cache_location_with_cache_misses_and_uvm_cache_stats =
run_emulate_cache_miss(
lxu_cache_locations_cpu,
enforced_misses_per_256,
gather_cache_stats);
auto lxu_cache_location_with_cache_misses =
lxu_cache_location_with_cache_misses_and_uvm_cache_stats.first;
EXPECT_FALSE(at::equal(
lxu_cache_locations_cpu, lxu_cache_location_with_cache_misses));

auto x = lxu_cache_locations_cpu.data_ptr<int32_t>();
auto y = lxu_cache_location_with_cache_misses.data_ptr<int32_t>();
int64_t enforced_misses = 0;
for (int32_t i = 0; i < lxu_cache_locations_cpu.numel(); ++i) {
if (x[i] != y[i]) {
EXPECT_EQ(y[i], -1);
enforced_misses++;
}
}
int64_t num_requests_over_256 =
static_cast<int64_t>(num_requests / 256);
int64_t expected_misses = num_requests_over_256 *
enforced_misses_per_256 +
std::min((num_requests - num_requests_over_256 * 256),
enforced_misses_per_256);
if (miss_in_lxu_cache_locations) {
expected_misses--;
}
EXPECT_EQ(expected_misses, enforced_misses);
if (gather_cache_stats) {
auto uvm_cache_stats =
lxu_cache_location_with_cache_misses_and_uvm_cache_stats.second;
auto cache_stats_ptr = uvm_cache_stats.data_ptr<int32_t>();
// enforced misses are recorded as conflict misses.
EXPECT_EQ(expected_misses, cache_stats_ptr[5]);
}
}
}
}
}