diff --git a/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh index d4edba9504..f5d4b68c9d 100644 --- a/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh @@ -76,18 +76,13 @@ DEVICE_INLINE void split_{{ optimizer }}_table_update_kernel( auto weight_row_template = WeightRow>( weights, cache_weights, D, nullptr); - if (!std::is_same::value && stochastic_rounding) { - StochasticRoundingRNGState state; - // Different for every *run* and every *thread*. - auto stochastic_rounding_seeds = - at::cuda::philox::unpack(stochastic_rounding_philox_args); - stochastic_rounding_init( - std::get<0>(stochastic_rounding_seeds) ^ - std::get<1>(stochastic_rounding_seeds), - threadIdx.x + run_id * blockDim.x, - &state); - weight_row_template.set_stoc_state(&state); - } + + weight_row_template.set_stochastic_rounding( + stochastic_rounding, + stochastic_rounding_philox_args, + threadIdx.x + run_id * blockDim.x + ); + float2 qparams_template; if (is_int8 && !cache_weights) { qparams_template = weight_row_template.load_qparams(); diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index 301dba66f9..566d14c780 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -10,6 +10,7 @@ #include #include +#include // clang-format off #ifdef __HIP_PLATFORM_HCC__ @@ -1275,6 +1276,40 @@ DEVICE_INLINE void store_qparams_to_row(uint8_t* ptr, float2 qparams) { } } +// Min a register value across all warp threads +template +DEVICE_INLINE T warp_reduce_min(T val) { +#pragma unroll + for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { + val = std::min(val, shfl_xor(val, mask)); + } + return val; +} + +// Max a register value across all warp threads +template +DEVICE_INLINE T warp_reduce_max(T val) { +#pragma unroll + for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { + val = std::max(val, shfl_xor(val, mask)); + } + return val; +} + +template +DEVICE_INLINE float2 warp_find_qparams(scalar_t local_min, scalar_t local_max) { + float2 qparams; + local_min = warp_reduce_min(local_min); + local_max = warp_reduce_max(local_max); + if (threadIdx.x == 0) { + qparams.x = (local_max - local_min) / 255.0f; + qparams.y = local_min; + } + qparams.x = shfl_sync(qparams.x, 0); + qparams.y = shfl_sync(qparams.y, 0); + return qparams; +} + template // TODO: pass in dimension info and calculate qparams for rowwise integer // quantization @@ -1293,13 +1328,8 @@ struct WeightRow { int dim_; StochasticRoundingRNGState* stoc_rounding_state_; - DEVICE_INLINE void set_stoc_state( - StochasticRoundingRNGState* stoc_rounding_state) { - stoc_rounding_state_ = stoc_rounding_state; - } - // load from cache if resident; else load from embedding - DEVICE_INLINE Vec4T load(int32_t d, float2 qparams) { + DEVICE_INLINE Vec4T load(int32_t d, float2 qparams) const { if (cache_row_) { return dequantize_load(cache_row_ + d, qparams); } else { @@ -1326,9 +1356,86 @@ struct WeightRow { store_qparams_to_row(row_ + dim_, qparams); } - DEVICE_INLINE float2 load_qparams() { + DEVICE_INLINE float2 load_qparams() const { return load_qparams_from_row(row_ + dim_); } + + DEVICE_INLINE void warp_copy_to( + WeightRow& target, + const int32_t dim_length, + const int32_t num_lanes, + const int32_t lane_id) const { + float2 qparams; + if constexpr (std::is_same_v) { + // Load quantization params from embedding row + qparams = load_qparams(); + } + + // Copy over for each warp-sized slice of Vec4's + for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) { + const auto slice = load(d * 4, qparams); + target.store(slice, d * 4, qparams); + } + } + + DEVICE_INLINE void warp_evict( + const int32_t dim_length, + const int32_t num_lanes, + const int32_t lane_id) { + float2 qparams; + + if constexpr (std::is_same_v) { + auto local_min = std::numeric_limits>::max(); + auto local_max = + std::numeric_limits>::lowest(); + + // Compute the qparams from the cache row (not embedding row) weights + for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) { + Vec4T cache_slice = load(d * 4, qparams); // qparams not used + local_max = max(local_max, vec4_max(cache_slice)); + local_min = min(local_min, vec4_min(cache_slice)); + } + + // Compute the max and min across the warps + qparams = warp_find_qparams(local_min, local_max); + + if (lane_id == 0) { + // Store the qparams into the embedding row + store_qparams(qparams); + } + } + + for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) { + // Dequantize-load a slice of the cache row + Vec4T cache_slice = load(d * 4, qparams); + // and evict the slice into the embedding row + evict(cache_slice, d * 4, qparams); // FP32 -> FP16/FP32 + } + } + + DEVICE_INLINE void set_stochastic_rounding( + const bool stochastic_rounding, + const at::PhiloxCudaState stochastic_rounding_philox_args, + const uint64_t salt_value) { + if constexpr (!std::is_same_v) { + if (stochastic_rounding) { + StochasticRoundingRNGState state; + auto stochastic_rounding_seeds = + at::cuda::philox::unpack(stochastic_rounding_philox_args); + + stochastic_rounding_init( + std::get<0>(stochastic_rounding_seeds) ^ + std::get<1>(stochastic_rounding_seeds), + // The salt value should be different for every *run* and every + // *thread*. + salt_value, + &state); + + // Set the internal stoc_rounding_state_ + stoc_rounding_state_ = &state; + } + } + } }; __host__ DEVICE_INLINE int32_t div_round_up(int32_t a, int32_t b) { @@ -1454,40 +1561,6 @@ DEVICE_INLINE scalar_t vec4_max(fbgemm_gpu::Vec4T vec4) { return max_val; } -// Min a register value across all warp threads -template -DEVICE_INLINE T warp_reduce_min(T val) { -#pragma unroll - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { - val = std::min(val, shfl_xor(val, mask)); - } - return val; -} - -// Max a register value across all warp threads -template -DEVICE_INLINE T warp_reduce_max(T val) { -#pragma unroll - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { - val = std::max(val, shfl_xor(val, mask)); - } - return val; -} - -template -__device__ float2 warp_find_qparams(scalar_t local_min, scalar_t local_max) { - float2 qparams; - local_min = warp_reduce_min(local_min); - local_max = warp_reduce_max(local_max); - if (threadIdx.x == 0) { - qparams.x = (local_max - local_min) / 255.0f; - qparams.y = local_min; - } - qparams.x = shfl_sync(qparams.x, 0); - qparams.y = shfl_sync(qparams.y, 0); - return qparams; -} - struct __align__(32) float8 { __host__ __device__ float8() {} float4 vals[2]; diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index 9ea83d33e0..4cdbce96d6 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -139,7 +139,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel( const int32_t D_current = D_end_current - D_start_current; int32_t D_emb = D_current; - if (std::is_same::value) { + if constexpr (std::is_same_v) { D_emb += kINT8QparamsBytes; } auto weight_row = WeightRow>( @@ -147,19 +147,12 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel( &lxu_cache_weights[b][0], D_current, nullptr); - if (!std::is_same::value && stochastic_rounding) { - StochasticRoundingRNGState state; - // different for every *run* and every *thread*. - auto stochastic_rounding_seeds = - at::cuda::philox::unpack(stochastic_rounding_philox_args); - stochastic_rounding_init( - std::get<0>(stochastic_rounding_seeds) ^ - std::get<1>(stochastic_rounding_seeds), - blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + - threadIdx.x, - &state); - weight_row.set_stoc_state(&state); - } + + weight_row.set_stochastic_rounding( + stochastic_rounding, + stochastic_rounding_philox_args, + blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + + threadIdx.x); float2 qparams; if (std::is_same::value) { @@ -1035,57 +1028,32 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( const int32_t D_end_current = D_offsets[t_current + 1]; const int32_t D_current = D_end_current - D_start_current; int32_t D_emb = D_current; - if (std::is_same::value) { + if constexpr (std::is_same_v) { D_emb += kINT8QparamsBytes; } + auto weight_row = WeightRow( &weights[weights_offset_current + idx_current * D_emb + 0], &lxu_cache_weights[cache_set * kWarpSize + insert_slot][0], D_current, nullptr); - if (!std::is_same::value && stochastic_rounding) { - StochasticRoundingRNGState state; - // different for every *run* and every *thread*. - auto stochastic_rounding_seeds = - at::cuda::philox::unpack(stochastic_rounding_philox_args); - stochastic_rounding_init( - std::get<0>(stochastic_rounding_seeds) ^ - std::get<1>(stochastic_rounding_seeds), - (blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + - threadIdx.x) * - kWarpSize + - l, - &state); - weight_row.set_stoc_state(&state); - } - float2 qparams; - at::acc_type local_min = - std::numeric_limits>::max(); - at::acc_type local_max = - std::numeric_limits>::lowest(); - if (std::is_same::value) { - for (int32_t d = threadIdx.x; d * 4 < D_current; d += blockDim.x) { - Vec4T cache_weights_vec = - weight_row.load(d * 4, qparams); // qparams not used - local_max = max(local_max, vec4_max(cache_weights_vec)); - local_min = min(local_min, vec4_min(cache_weights_vec)); - } - qparams = warp_find_qparams(local_min, local_max); - if (threadIdx.x == 0) { - weight_row.store_qparams(qparams); - } - } - for (int32_t d = threadIdx.x; d * 4 < D_current; d += blockDim.x) { - Vec4T cache_weights_vec = weight_row.load(d * 4, qparams); - weight_row.evict( - cache_weights_vec, d * 4, qparams); // FP32 -> FP16/FP32 - } + + weight_row.set_stochastic_rounding( + stochastic_rounding, + stochastic_rounding_philox_args, + (blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + + threadIdx.x) * + kWarpSize + + l); + + weight_row.warp_evict(D_current, blockDim.x, threadIdx.x); } + int32_t D_emb = D_insert; - if (std::is_same::value) { + if constexpr (std::is_same_v) { D_emb += kINT8QparamsBytes; } - // insert into cache + auto weight_row_cache = WeightRow( &weights[weights_offset_insert + idx_insert * D_emb + 0], &lxu_cache_weights[cache_set * kWarpSize + insert_slot][0], @@ -1098,14 +1066,9 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( D_insert, nullptr); - float2 qparams; - if (std::is_same::value) { - qparams = weight_row_emb.load_qparams(); - } - for (int32_t d = threadIdx.x; d * 4 < D_insert; d += blockDim.x) { - auto row = weight_row_emb.load(d * 4, qparams); - weight_row_cache.store(row, d * 4, qparams); - } + weight_row_emb.warp_copy_to( + weight_row_cache, D_insert, blockDim.x, threadIdx.x); + if (threadIdx.x == 0) { lxu_cache_state[cache_set][insert_slot] = insert_idx; lru_state[cache_set][insert_slot] = time_stamp; @@ -1113,6 +1076,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( lxu_cache_locking_counter[cache_set][insert_slot] += 1; } } + n_inserted++; } n_conflict_misses += (SL - n_inserted); @@ -2102,7 +2066,7 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( const int32_t D_current = D_end_current - D_start_current; int32_t D_emb = D_current; - if (std::is_same::value) { + if constexpr (std::is_same_v) { D_emb += kINT8QparamsBytes; } auto weight_row = WeightRow( @@ -2110,49 +2074,24 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( &lxu_cache_weights[cache_set * kWarpSize + insert_slot][0], D_current, nullptr); - if (!std::is_same::value && stochastic_rounding) { - StochasticRoundingRNGState state; - // different for every *run* and every *thread*. - auto stochastic_rounding_seeds = - at::cuda::philox::unpack(stochastic_rounding_philox_args); - stochastic_rounding_init( - std::get<0>(stochastic_rounding_seeds) ^ - std::get<1>(stochastic_rounding_seeds), - (blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + - threadIdx.x) * - kWarpSize + - l, - &state); - weight_row.set_stoc_state(&state); - } - float2 qparams; - at::acc_type local_min = - std::numeric_limits>::max(); - at::acc_type local_max = - std::numeric_limits>::lowest(); - if (std::is_same::value) { - for (int32_t d = threadIdx.x; d * 4 < D_current; d += blockDim.x) { - Vec4T cache_weights_vec = - weight_row.load(d * 4, qparams); // qparams not used - local_max = max(local_max, vec4_max(cache_weights_vec)); - local_min = min(local_min, vec4_min(cache_weights_vec)); - } - qparams = warp_find_qparams(local_min, local_max); - if (threadIdx.x == 0) { - weight_row.store_qparams(qparams); - } - } - for (int32_t d = threadIdx.x; d * 4 < D_current; d += blockDim.x) { - Vec4T cache_weights_vec = weight_row.load(d * 4, qparams); - weight_row.evict(cache_weights_vec, d * 4, qparams); - } + weight_row.set_stochastic_rounding( + stochastic_rounding, + stochastic_rounding_philox_args, + (blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + + threadIdx.x) * + kWarpSize + + l); + + weight_row.warp_evict(D_current, blockDim.x, threadIdx.x); } + // insert into cache int32_t D_emb = D_insert; - if (std::is_same::value) { + if constexpr (std::is_same_v) { D_emb += kINT8QparamsBytes; } + auto weight_row_cache = WeightRow( &weights[weights_offset_insert + idx_insert * D_emb + 0], &lxu_cache_weights[cache_set * kWarpSize + insert_slot][0], @@ -2165,14 +2104,9 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( D_insert, nullptr); - float2 qparams; - if (std::is_same::value) { - qparams = weight_row_emb.load_qparams(); - } - for (int32_t d = threadIdx.x; d * 4 < D_insert; d += blockDim.x) { - auto row = weight_row_emb.load(d * 4, qparams); - weight_row_cache.store(row, d * 4, qparams); - } + weight_row_emb.warp_copy_to( + weight_row_cache, D_insert, blockDim.x, threadIdx.x); + if (threadIdx.x == 0) { lxu_cache_state[cache_set][insert_slot] = insert_idx; } @@ -2972,7 +2906,7 @@ __global__ __launch_bounds__(kMaxThreads) void reset_weight_momentum_kernel( } int32_t D_emb = D; - if (std::is_same::value) { + if constexpr (std::is_same_v) { D_emb += kINT8QparamsBytes; }