Skip to content

Commit

Permalink
Add BF16 in FP8 quantize ops (pytorch#1961)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1961

- Added output_dtype for half, bfloat16 and float as output in
  dequantization functions; currently it's an integer value defined by
  Sparse_dtype (float:0, half:1, bfloat16:5)
- Added type conversion in quant and dequant kernels by using native
  CUDA/HIP functions for half to float conversion and writing
  everything explicitly.

Reviewed By: jianyuh

Differential Revision: D47904459

fbshipit-source-id: f608d7da5dcf05ff78a6e0eb13d985ed99207d1a
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 23, 2023
1 parent 3265211 commit 56e870d
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 53 deletions.
9 changes: 9 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,15 @@ DEVICE_INLINE half16 to_half16(float_16 v) {
return t;
}

// Override __bfloat162float to accept at::BFloat16
static DEVICE_INLINE float __bfloat162float(const at::BFloat16 input) {
#ifdef __HIP_PLATFORM_HCC__
return float(*reinterpret_cast<const __nv_bfloat16*>(&input));
#else
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&input));
#endif
}

#ifdef __HIP_PLATFORM_HCC__
// the descriptions of __float2bfloat16 and __float2bfloat16_rn are identical
// https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__MISC.html#group__CUDA__MATH____BFLOAT16__MISC
Expand Down
6 changes: 4 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ at::Tensor _float_or_half_to_fused8bitrowwise_gpu(const at::Tensor& input);
at::Tensor _fused8bitrowwise_to_float_gpu(const at::Tensor& input);
at::Tensor _FP8rowwise_to_float_gpu(
const at::Tensor& input,
const bool forward = true);
const bool forward = true,
const int64_t output_dtype = 0);
at::Tensor _paddedFP8rowwise_to_float_gpu(
const at::Tensor& input,
const bool forward = true,
Expand All @@ -239,7 +240,8 @@ at::Tensor float_or_half_to_fused8bitrowwise_cpu(const at::Tensor& input);
at::Tensor fused8bitrowwise_to_float_cpu(const at::Tensor& input);
at::Tensor FP8rowwise_to_float_cpu(
const at::Tensor& input,
const bool forward = true);
const bool forward = true,
const int64_t output_dtype = 0);
at::Tensor fused8bitrowwise_to_half_cpu(const at::Tensor& input);
at::Tensor fused8bitrowwise_to_float_or_half_cpu(
const at::Tensor& input,
Expand Down
18 changes: 0 additions & 18 deletions fbgemm_gpu/src/quantize_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,3 @@
#define QUANTIZE_OPS_MIN(a, b) ((a) < (b) ? (a) : (b))

using Tensor = at::Tensor;

namespace fbgemm_gpu {

namespace {

template <typename T>
__device__ inline __attribute__((always_inline)) T
quantize_ops_shfl_xor(const T val, int laneMask, int width) {
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
return __shfl_xor(val, laneMask, width);
#else
return __shfl_xor_sync(0xffffffff, val, laneMask, width);
#endif
}

} // namespace

} // namespace fbgemm_gpu
121 changes: 100 additions & 21 deletions fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,16 @@ __global__ inline void _float_to_FP8rowwise_cuda_kernel(
max_pos / (kEpsilon + fmaxf(maximum_element, -minimum_element));
output_row_scale_bias[0] = scale;
for (int64_t col = 0; col < ncols; ++col) {
output_row[col] =
float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos);
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
output_row[col] = float_to_hfp8(
__bfloat162float(input_row[col]) * scale, ebit, bias, max_pos);
} else if constexpr (std::is_same<input_t, at::Half>::value) {
output_row[col] = float_to_hfp8(
__half2float(input_row[col]) * scale, ebit, bias, max_pos);
} else {
output_row[col] =
float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos);
}
}
}
}
Expand Down Expand Up @@ -87,17 +95,24 @@ __global__ inline void _get_FP8_qparam_cuda_kernel(
for (int64_t col = threadIdx.x; col < ncols; col += lane_width) {
// Get thread-local minmax. These are the smallest min and max ever seen
// by this thread.
maximum_element = fmaxf(maximum_element, fabs(input_row[col]));
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
maximum_element =
fmaxf(maximum_element, fabs(__bfloat162float(input_row[col])));
} else if constexpr (std::is_same<input_t, at::Half>::value) {
maximum_element =
fmaxf(maximum_element, fabs(__half2float(input_row[col])));
} else {
maximum_element = fmaxf(maximum_element, fabs(input_row[col]));
}
}
}

// Perform warp-wide min and max reductions. All threads in the warp
// participate, even if they aren't assigned to a row, since we can't assume
// the existence of the `*_sync` warp primitives with support for masking.
for (int offset = lane_width >> 1; offset > 0; offset >>= 1) {
maximum_element = fmaxf(
maximum_element,
quantize_ops_shfl_xor(maximum_element, offset, lane_width));
maximum_element =
fmaxf(maximum_element, shfl_xor(maximum_element, offset, lane_width));
}

// only the leading thread in the warp is needed to return the final result in
Expand Down Expand Up @@ -149,8 +164,16 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel(
// TODO: lift range_list into shared memory. However, when nrows is large,
// it might exceed the size of shared memory.
// output_addr[0] = lrintf((input[input_idx] - bias) * inverse_scale);
output_addr[0] =
float_to_hfp8(input[input_idx] * scale, ebit, bias, max_pos);
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
output_addr[0] = float_to_hfp8(
__bfloat162float(input[input_idx]) * scale, ebit, bias, max_pos);
} else if constexpr (std::is_same<input_t, at::Half>::value) {
output_addr[0] = float_to_hfp8(
__half2float(input[input_idx]) * scale, ebit, bias, max_pos);
} else {
output_addr[0] =
float_to_hfp8(input[input_idx] * scale, ebit, bias, max_pos);
}
}
}
}
Expand All @@ -176,8 +199,17 @@ __global__ inline void _FP8rowwise_to_float_cuda_kernel(
reinterpret_cast<const float*>(input_row + output_columns);
output_t* output_row = output + row * output_columns;

output_row[col] =
const float output_ =
hfp8_to_float(input_row[col], ebit, bias) / input_row_scale_bias[0];

if constexpr (std::is_same<output_t, at::BFloat16>::value) {
*reinterpret_cast<__nv_bfloat16*>(&output_row[col]) =
__float2bfloat16(output_);
} else if constexpr (std::is_same<output_t, at::Half>::value) {
output_row[col] = __half2float(output_);
} else {
output_row[col] = output_;
}
}
}
}
Expand Down Expand Up @@ -221,8 +253,12 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
// think unsigned as we use 0, 255

if (nrows <= 20) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "_float_to_FP8rowwise_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_float_to_FP8rowwise_cuda_kernel",
[&] {
_float_to_FP8rowwise_cuda_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
Expand Down Expand Up @@ -261,8 +297,12 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
const auto num_blocks_warp =
cuda_calc_xblock_count(nrows, rows_per_block);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "_get_FP8_qparam_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_get_FP8_qparam_cuda_kernel",
[&] {
_get_FP8_qparam_cuda_kernel<scalar_t>
<<<num_blocks_warp,
dim3(blockDim_x, rows_per_block),
Expand All @@ -285,8 +325,12 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
dim3 gridDim(gridDim_x, gridDim_y);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_compute_FP8_quantize_cuda_kernel",
[&] {
_compute_FP8_quantize_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(),
Expand All @@ -306,7 +350,14 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
///@ingroup quantize-data-cuda
DLL_PUBLIC Tensor
_float_to_FP8rowwise_gpu(const Tensor& input, const bool forward) {
return _float_to_FP8rowwise_gpu_t<float>(input, forward);
auto input_type = input.dtype();
if (input_type == at::kHalf) {
return _float_to_FP8rowwise_gpu_t<half>(input, forward);
} else if (input_type == at::kBFloat16) {
return _float_to_FP8rowwise_gpu_t<__nv_bfloat16>(input, forward);
} else {
return _float_to_FP8rowwise_gpu_t<float>(input, forward);
}
}
template <typename output_t>
Expand Down Expand Up @@ -337,10 +388,18 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kFloat));
} else { // T = at::Half
} else if constexpr (std::is_same_v<output_t, half>) { // T = at::Half
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kHalf));
} else if constexpr (std::is_same_v<
output_t,
__nv_bfloat16>) { // T = at::BFloat16
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kBFloat16));
} else {
TORCH_CHECK(false);
}
if (nrows == 0 || output_columns == 0) {
Expand All @@ -356,8 +415,12 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
output.scalar_type(), "FP8rowwise_to_float_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
output.scalar_type(),
"FP8rowwise_to_float_cuda_kernel",
[&] {
_FP8rowwise_to_float_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<std::uint8_t>(),
Expand All @@ -373,8 +436,24 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
DLL_PUBLIC at::Tensor _FP8rowwise_to_float_gpu(
const at::Tensor& input,
bool forward) {
return _FP8rowwise_to_float_gpu_t<float>(input, forward);
bool forward,
const int64_t output_dtype) {
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
Tensor output;
switch (output_sparse_dtype) {
case SparseType::FP32:
output = _FP8rowwise_to_float_gpu_t<float>(input, forward);
break;
case SparseType::FP16:
output = _FP8rowwise_to_float_gpu_t<half>(input, forward);
break;
case SparseType::BF16:
output = _FP8rowwise_to_float_gpu_t<__nv_bfloat16>(input, forward);
break;
default:
TORCH_CHECK(false);
}
return output;
}
} // namespace fbgemm_gpu
10 changes: 4 additions & 6 deletions fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,10 @@ __global__ inline void _get_8bit_qparam_cuda_kernel(
// participate, even if they aren't assigned to a row, since we can't assume
// the existence of the `*_sync` warp primitives with support for masking.
for (int offset = lane_width >> 1; offset > 0; offset >>= 1) {
minimum_element = fminf(
minimum_element,
quantize_ops_shfl_xor(minimum_element, offset, lane_width));
maximum_element = fmaxf(
maximum_element,
quantize_ops_shfl_xor(maximum_element, offset, lane_width));
minimum_element =
fminf(minimum_element, shfl_xor(minimum_element, offset, lane_width));
maximum_element =
fmaxf(maximum_element, shfl_xor(maximum_element, offset, lane_width));
}

// only the leading thread in the warp is needed to return the final result in
Expand Down
9 changes: 6 additions & 3 deletions fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ Tensor fused8bitrowwise_to_float_or_half_cpu(
const Tensor& input,
const int64_t output_dtype) {
Tensor output;

SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
switch (output_sparse_dtype) {
case SparseType::FP32:
Expand All @@ -241,7 +240,10 @@ Tensor float_to_FP8rowwise_cpu(const Tensor& input, bool forward) {
}

///@ingroup quantize-data-cpu
Tensor FP8rowwise_to_float_cpu(const Tensor& input, bool forward) {
Tensor FP8rowwise_to_float_cpu(
const Tensor& input,
bool forward,
const int64_t output_dtype) {
TORCH_CHECK(false, "fp8 is not supported by CPU");
return input;
}
Expand Down Expand Up @@ -413,7 +415,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("HalfToFused8BitRowwiseQuantized(Tensor t) -> Tensor");
m.def("FloatOrHalfToFused8BitRowwiseQuantized(Tensor t) -> Tensor");
m.def("Fused8BitRowwiseQuantizedToFloat(Tensor input) -> Tensor");
m.def("FP8RowwiseQuantizedToFloat(Tensor input, bool forward) -> Tensor");
m.def(
"FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor");
m.def("Fused8BitRowwiseQuantizedToHalf(Tensor input) -> Tensor");
m.def(
"Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0) -> Tensor");
Expand Down
17 changes: 15 additions & 2 deletions fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>

#include "c10/core/ScalarType.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

Expand All @@ -20,7 +22,8 @@ namespace fbgemm_gpu {
///@ingroup quantize-data-meta
Tensor FP8rowwise_to_float_meta(
const Tensor& input,
[[maybe_unused]] bool forward) {
[[maybe_unused]] bool forward,
const int64_t output_dtype) {
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");

const auto input_sizes = input.sizes();
Expand All @@ -31,7 +34,17 @@ Tensor FP8rowwise_to_float_meta(

auto output_dims = input_sizes.vec();
output_dims[last_dim] = output_columns;
return at::empty(output_dims, input.options().dtype(at::kFloat));
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
switch (output_sparse_dtype) {
case SparseType::FP32:
return at::empty(output_dims, input.options().dtype(at::kFloat));
case SparseType::FP16:
return at::empty(output_dims, input.options().dtype(at::kHalf));
case SparseType::BF16:
return at::empty(output_dims, input.options().dtype(at::kBFloat16));
default:
TORCH_CHECK(false);
}
}

} // namespace fbgemm_gpu
Expand Down
Loading

0 comments on commit 56e870d

Please sign in to comment.