Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BF16 in FP8 quantize ops #1961

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,15 @@ DEVICE_INLINE half16 to_half16(float_16 v) {
return t;
}

// Override __bfloat162float to accept at::BFloat16
static DEVICE_INLINE float __bfloat162float(const at::BFloat16 input) {
#ifdef __HIP_PLATFORM_HCC__
return float(*reinterpret_cast<const __nv_bfloat16*>(&input));
#else
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&input));
#endif
}

#ifdef __HIP_PLATFORM_HCC__
// the descriptions of __float2bfloat16 and __float2bfloat16_rn are identical
// https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__MISC.html#group__CUDA__MATH____BFLOAT16__MISC
Expand Down
6 changes: 4 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ at::Tensor _float_or_half_to_fused8bitrowwise_gpu(const at::Tensor& input);
at::Tensor _fused8bitrowwise_to_float_gpu(const at::Tensor& input);
at::Tensor _FP8rowwise_to_float_gpu(
const at::Tensor& input,
const bool forward = true);
const bool forward = true,
const int64_t output_dtype = 0);
at::Tensor _paddedFP8rowwise_to_float_gpu(
const at::Tensor& input,
const bool forward = true,
Expand All @@ -239,7 +240,8 @@ at::Tensor float_or_half_to_fused8bitrowwise_cpu(const at::Tensor& input);
at::Tensor fused8bitrowwise_to_float_cpu(const at::Tensor& input);
at::Tensor FP8rowwise_to_float_cpu(
const at::Tensor& input,
const bool forward = true);
const bool forward = true,
const int64_t output_dtype = 0);
at::Tensor fused8bitrowwise_to_half_cpu(const at::Tensor& input);
at::Tensor fused8bitrowwise_to_float_or_half_cpu(
const at::Tensor& input,
Expand Down
18 changes: 0 additions & 18 deletions fbgemm_gpu/src/quantize_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,3 @@
#define QUANTIZE_OPS_MIN(a, b) ((a) < (b) ? (a) : (b))

using Tensor = at::Tensor;

namespace fbgemm_gpu {

namespace {

template <typename T>
__device__ inline __attribute__((always_inline)) T
quantize_ops_shfl_xor(const T val, int laneMask, int width) {
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
return __shfl_xor(val, laneMask, width);
#else
return __shfl_xor_sync(0xffffffff, val, laneMask, width);
#endif
}

} // namespace

} // namespace fbgemm_gpu
121 changes: 100 additions & 21 deletions fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,16 @@ __global__ inline void _float_to_FP8rowwise_cuda_kernel(
max_pos / (kEpsilon + fmaxf(maximum_element, -minimum_element));
output_row_scale_bias[0] = scale;
for (int64_t col = 0; col < ncols; ++col) {
output_row[col] =
float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos);
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
output_row[col] = float_to_hfp8(
__bfloat162float(input_row[col]) * scale, ebit, bias, max_pos);
} else if constexpr (std::is_same<input_t, at::Half>::value) {
output_row[col] = float_to_hfp8(
__half2float(input_row[col]) * scale, ebit, bias, max_pos);
} else {
output_row[col] =
float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos);
}
}
}
}
Expand Down Expand Up @@ -87,17 +95,24 @@ __global__ inline void _get_FP8_qparam_cuda_kernel(
for (int64_t col = threadIdx.x; col < ncols; col += lane_width) {
// Get thread-local minmax. These are the smallest min and max ever seen
// by this thread.
maximum_element = fmaxf(maximum_element, fabs(input_row[col]));
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
maximum_element =
fmaxf(maximum_element, fabs(__bfloat162float(input_row[col])));
} else if constexpr (std::is_same<input_t, at::Half>::value) {
maximum_element =
fmaxf(maximum_element, fabs(__half2float(input_row[col])));
} else {
maximum_element = fmaxf(maximum_element, fabs(input_row[col]));
}
}
}

// Perform warp-wide min and max reductions. All threads in the warp
// participate, even if they aren't assigned to a row, since we can't assume
// the existence of the `*_sync` warp primitives with support for masking.
for (int offset = lane_width >> 1; offset > 0; offset >>= 1) {
maximum_element = fmaxf(
maximum_element,
quantize_ops_shfl_xor(maximum_element, offset, lane_width));
maximum_element =
fmaxf(maximum_element, shfl_xor(maximum_element, offset, lane_width));
}

// only the leading thread in the warp is needed to return the final result in
Expand Down Expand Up @@ -149,8 +164,16 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel(
// TODO: lift range_list into shared memory. However, when nrows is large,
// it might exceed the size of shared memory.
// output_addr[0] = lrintf((input[input_idx] - bias) * inverse_scale);
output_addr[0] =
float_to_hfp8(input[input_idx] * scale, ebit, bias, max_pos);
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
output_addr[0] = float_to_hfp8(
__bfloat162float(input[input_idx]) * scale, ebit, bias, max_pos);
} else if constexpr (std::is_same<input_t, at::Half>::value) {
output_addr[0] = float_to_hfp8(
__half2float(input[input_idx]) * scale, ebit, bias, max_pos);
} else {
output_addr[0] =
float_to_hfp8(input[input_idx] * scale, ebit, bias, max_pos);
}
}
}
}
Expand All @@ -176,8 +199,17 @@ __global__ inline void _FP8rowwise_to_float_cuda_kernel(
reinterpret_cast<const float*>(input_row + output_columns);
output_t* output_row = output + row * output_columns;

output_row[col] =
const float output_ =
hfp8_to_float(input_row[col], ebit, bias) / input_row_scale_bias[0];

if constexpr (std::is_same<output_t, at::BFloat16>::value) {
*reinterpret_cast<__nv_bfloat16*>(&output_row[col]) =
__float2bfloat16(output_);
} else if constexpr (std::is_same<output_t, at::Half>::value) {
output_row[col] = __half2float(output_);
} else {
output_row[col] = output_;
}
}
}
}
Expand Down Expand Up @@ -221,8 +253,12 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
// think unsigned as we use 0, 255

if (nrows <= 20) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "_float_to_FP8rowwise_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_float_to_FP8rowwise_cuda_kernel",
[&] {
_float_to_FP8rowwise_cuda_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
Expand Down Expand Up @@ -261,8 +297,12 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
const auto num_blocks_warp =
cuda_calc_xblock_count(nrows, rows_per_block);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "_get_FP8_qparam_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_get_FP8_qparam_cuda_kernel",
[&] {
_get_FP8_qparam_cuda_kernel<scalar_t>
<<<num_blocks_warp,
dim3(blockDim_x, rows_per_block),
Expand All @@ -285,8 +325,12 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
dim3 gridDim(gridDim_x, gridDim_y);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_compute_FP8_quantize_cuda_kernel",
[&] {
_compute_FP8_quantize_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(),
Expand All @@ -306,7 +350,14 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
///@ingroup quantize-data-cuda
DLL_PUBLIC Tensor
_float_to_FP8rowwise_gpu(const Tensor& input, const bool forward) {
return _float_to_FP8rowwise_gpu_t<float>(input, forward);
auto input_type = input.dtype();
if (input_type == at::kHalf) {
return _float_to_FP8rowwise_gpu_t<half>(input, forward);
} else if (input_type == at::kBFloat16) {
return _float_to_FP8rowwise_gpu_t<__nv_bfloat16>(input, forward);
} else {
return _float_to_FP8rowwise_gpu_t<float>(input, forward);
}
}

template <typename output_t>
Expand Down Expand Up @@ -337,10 +388,18 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kFloat));
} else { // T = at::Half
} else if constexpr (std::is_same_v<output_t, half>) { // T = at::Half
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kHalf));
} else if constexpr (std::is_same_v<
output_t,
__nv_bfloat16>) { // T = at::BFloat16
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kBFloat16));
} else {
TORCH_CHECK(false);
}

if (nrows == 0 || output_columns == 0) {
Expand All @@ -356,8 +415,12 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
output.scalar_type(), "FP8rowwise_to_float_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
output.scalar_type(),
"FP8rowwise_to_float_cuda_kernel",
[&] {
_FP8rowwise_to_float_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<std::uint8_t>(),
Expand All @@ -373,8 +436,24 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {

DLL_PUBLIC at::Tensor _FP8rowwise_to_float_gpu(
const at::Tensor& input,
bool forward) {
return _FP8rowwise_to_float_gpu_t<float>(input, forward);
bool forward,
const int64_t output_dtype) {
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
Tensor output;
switch (output_sparse_dtype) {
case SparseType::FP32:
output = _FP8rowwise_to_float_gpu_t<float>(input, forward);
break;
case SparseType::FP16:
output = _FP8rowwise_to_float_gpu_t<half>(input, forward);
break;
case SparseType::BF16:
output = _FP8rowwise_to_float_gpu_t<__nv_bfloat16>(input, forward);
break;
default:
TORCH_CHECK(false);
}
return output;
}

} // namespace fbgemm_gpu
10 changes: 4 additions & 6 deletions fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,10 @@ __global__ inline void _get_8bit_qparam_cuda_kernel(
// participate, even if they aren't assigned to a row, since we can't assume
// the existence of the `*_sync` warp primitives with support for masking.
for (int offset = lane_width >> 1; offset > 0; offset >>= 1) {
minimum_element = fminf(
minimum_element,
quantize_ops_shfl_xor(minimum_element, offset, lane_width));
maximum_element = fmaxf(
maximum_element,
quantize_ops_shfl_xor(maximum_element, offset, lane_width));
minimum_element =
fminf(minimum_element, shfl_xor(minimum_element, offset, lane_width));
maximum_element =
fmaxf(maximum_element, shfl_xor(maximum_element, offset, lane_width));
}

// only the leading thread in the warp is needed to return the final result in
Expand Down
9 changes: 6 additions & 3 deletions fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ Tensor fused8bitrowwise_to_float_or_half_cpu(
const Tensor& input,
const int64_t output_dtype) {
Tensor output;

SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
switch (output_sparse_dtype) {
case SparseType::FP32:
Expand All @@ -241,7 +240,10 @@ Tensor float_to_FP8rowwise_cpu(const Tensor& input, bool forward) {
}

///@ingroup quantize-data-cpu
Tensor FP8rowwise_to_float_cpu(const Tensor& input, bool forward) {
Tensor FP8rowwise_to_float_cpu(
const Tensor& input,
bool forward,
const int64_t output_dtype) {
TORCH_CHECK(false, "fp8 is not supported by CPU");
return input;
}
Expand Down Expand Up @@ -413,7 +415,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("HalfToFused8BitRowwiseQuantized(Tensor t) -> Tensor");
m.def("FloatOrHalfToFused8BitRowwiseQuantized(Tensor t) -> Tensor");
m.def("Fused8BitRowwiseQuantizedToFloat(Tensor input) -> Tensor");
m.def("FP8RowwiseQuantizedToFloat(Tensor input, bool forward) -> Tensor");
m.def(
"FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor");
m.def("Fused8BitRowwiseQuantizedToHalf(Tensor input) -> Tensor");
m.def(
"Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0) -> Tensor");
Expand Down
17 changes: 15 additions & 2 deletions fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>

#include "c10/core/ScalarType.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

Expand All @@ -20,7 +22,8 @@ namespace fbgemm_gpu {
///@ingroup quantize-data-meta
Tensor FP8rowwise_to_float_meta(
const Tensor& input,
[[maybe_unused]] bool forward) {
[[maybe_unused]] bool forward,
const int64_t output_dtype) {
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");

const auto input_sizes = input.sizes();
Expand All @@ -31,7 +34,17 @@ Tensor FP8rowwise_to_float_meta(

auto output_dims = input_sizes.vec();
output_dims[last_dim] = output_columns;
return at::empty(output_dims, input.options().dtype(at::kFloat));
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
switch (output_sparse_dtype) {
case SparseType::FP32:
return at::empty(output_dims, input.options().dtype(at::kFloat));
case SparseType::FP16:
return at::empty(output_dims, input.options().dtype(at::kHalf));
case SparseType::BF16:
return at::empty(output_dims, input.options().dtype(at::kBFloat16));
default:
TORCH_CHECK(false);
}
}

} // namespace fbgemm_gpu
Expand Down
Loading