Skip to content

Commit

Permalink
Reorganize code around cache eviction (#1886)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1886

- Reorganize code around cache eviction to reduce duplicate code

Reviewed By: sryap

Differential Revision: D47622836

fbshipit-source-id: 4318e1d2f84a9e233a009e3089d0807622ba9f35
  • Loading branch information
q10 authored and facebook-github-bot committed Jul 21, 2023
1 parent 8372bc7 commit 5a0be46
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,13 @@ DEVICE_INLINE void split_{{ optimizer }}_table_update_kernel(
auto weight_row_template =
WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
weights, cache_weights, D, nullptr);
if (!std::is_same<emb_t, float>::value && stochastic_rounding) {
StochasticRoundingRNGState state;
// Different for every *run* and every *thread*.
auto stochastic_rounding_seeds =
at::cuda::philox::unpack(stochastic_rounding_philox_args);
stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
threadIdx.x + run_id * blockDim.x,
&state);
weight_row_template.set_stoc_state(&state);
}

weight_row_template.set_stochastic_rounding(
stochastic_rounding,
stochastic_rounding_philox_args,
threadIdx.x + run_id * blockDim.x
);

float2 qparams_template;
if (is_int8 && !cache_weights) {
qparams_template = weight_row_template.load_qparams();
Expand Down
155 changes: 114 additions & 41 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>

// clang-format off
#ifdef __HIP_PLATFORM_HCC__
Expand Down Expand Up @@ -1275,6 +1276,40 @@ DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) {
}
}

// Min a register value across all warp threads
template <typename T, int ReduceWidth = kWarpSize>
DEVICE_INLINE T warp_reduce_min(T val) {
#pragma unroll
for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
val = std::min(val, shfl_xor(val, mask));
}
return val;
}

// Max a register value across all warp threads
template <typename T, int ReduceWidth = kWarpSize>
DEVICE_INLINE T warp_reduce_max(T val) {
#pragma unroll
for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
val = std::max(val, shfl_xor(val, mask));
}
return val;
}

template <typename scalar_t>
DEVICE_INLINE float2 warp_find_qparams(scalar_t local_min, scalar_t local_max) {
float2 qparams;
local_min = warp_reduce_min<scalar_t>(local_min);
local_max = warp_reduce_max<scalar_t>(local_max);
if (threadIdx.x == 0) {
qparams.x = (local_max - local_min) / 255.0f;
qparams.y = local_min;
}
qparams.x = shfl_sync(qparams.x, 0);
qparams.y = shfl_sync(qparams.y, 0);
return qparams;
}

template <typename emb_t, typename cache_t, typename dst_t>
// TODO: pass in dimension info and calculate qparams for rowwise integer
// quantization
Expand All @@ -1293,13 +1328,8 @@ struct WeightRow {
int dim_;
StochasticRoundingRNGState* stoc_rounding_state_;

DEVICE_INLINE void set_stoc_state(
StochasticRoundingRNGState* stoc_rounding_state) {
stoc_rounding_state_ = stoc_rounding_state;
}

// load from cache if resident; else load from embedding
DEVICE_INLINE Vec4T<dst_t> load(int32_t d, float2 qparams) {
DEVICE_INLINE Vec4T<dst_t> load(int32_t d, float2 qparams) const {
if (cache_row_) {
return dequantize_load<dst_t, cache_t>(cache_row_ + d, qparams);
} else {
Expand All @@ -1326,9 +1356,86 @@ struct WeightRow {
store_qparams_to_row(row_ + dim_, qparams);
}

DEVICE_INLINE float2 load_qparams() {
DEVICE_INLINE float2 load_qparams() const {
return load_qparams_from_row<emb_t>(row_ + dim_);
}

DEVICE_INLINE void warp_copy_to(
WeightRow<emb_t, cache_t, cache_t>& target,
const int32_t dim_length,
const int32_t num_lanes,
const int32_t lane_id) const {
float2 qparams;
if constexpr (std::is_same_v<emb_t, uint8_t>) {
// Load quantization params from embedding row
qparams = load_qparams();
}

// Copy over for each warp-sized slice of Vec4's
for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) {
const auto slice = load(d * 4, qparams);
target.store(slice, d * 4, qparams);
}
}

DEVICE_INLINE void warp_evict(
const int32_t dim_length,
const int32_t num_lanes,
const int32_t lane_id) {
float2 qparams;

if constexpr (std::is_same_v<emb_t, uint8_t>) {
auto local_min = std::numeric_limits<at::acc_type<cache_t, true>>::max();
auto local_max =
std::numeric_limits<at::acc_type<cache_t, true>>::lowest();

// Compute the qparams from the cache row (not embedding row) weights
for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) {
Vec4T<cache_t> cache_slice = load(d * 4, qparams); // qparams not used
local_max = max(local_max, vec4_max(cache_slice));
local_min = min(local_min, vec4_min(cache_slice));
}

// Compute the max and min across the warps
qparams = warp_find_qparams(local_min, local_max);

if (lane_id == 0) {
// Store the qparams into the embedding row
store_qparams(qparams);
}
}

for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) {
// Dequantize-load a slice of the cache row
Vec4T<cache_t> cache_slice = load(d * 4, qparams);
// and evict the slice into the embedding row
evict(cache_slice, d * 4, qparams); // FP32 -> FP16/FP32
}
}

DEVICE_INLINE void set_stochastic_rounding(
const bool stochastic_rounding,
const at::PhiloxCudaState stochastic_rounding_philox_args,
const uint64_t salt_value) {
if constexpr (!std::is_same_v<emb_t, float>) {
if (stochastic_rounding) {
StochasticRoundingRNGState state;
auto stochastic_rounding_seeds =
at::cuda::philox::unpack(stochastic_rounding_philox_args);

stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
// The salt value should be different for every *run* and every
// *thread*.
salt_value,
&state);

// Set the internal stoc_rounding_state_
stoc_rounding_state_ = &state;
}
}
}
};

__host__ DEVICE_INLINE int32_t div_round_up(int32_t a, int32_t b) {
Expand Down Expand Up @@ -1454,40 +1561,6 @@ DEVICE_INLINE scalar_t vec4_max(fbgemm_gpu::Vec4T<scalar_t> vec4) {
return max_val;
}

// Min a register value across all warp threads
template <typename T, int ReduceWidth = kWarpSize>
DEVICE_INLINE T warp_reduce_min(T val) {
#pragma unroll
for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
val = std::min(val, shfl_xor(val, mask));
}
return val;
}

// Max a register value across all warp threads
template <typename T, int ReduceWidth = kWarpSize>
DEVICE_INLINE T warp_reduce_max(T val) {
#pragma unroll
for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
val = std::max(val, shfl_xor(val, mask));
}
return val;
}

template <typename scalar_t>
__device__ float2 warp_find_qparams(scalar_t local_min, scalar_t local_max) {
float2 qparams;
local_min = warp_reduce_min<scalar_t>(local_min);
local_max = warp_reduce_max<scalar_t>(local_max);
if (threadIdx.x == 0) {
qparams.x = (local_max - local_min) / 255.0f;
qparams.y = local_min;
}
qparams.x = shfl_sync(qparams.x, 0);
qparams.y = shfl_sync(qparams.y, 0);
return qparams;
}

struct __align__(32) float8 {
__host__ __device__ float8() {}
float4 vals[2];
Expand Down
Loading

0 comments on commit 5a0be46

Please sign in to comment.