Skip to content

Commit

Permalink
Add BF16 in FP8 quantize ops (pytorch#1961)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1961

- Added output_dtype for half, bfloat16 and float as output in
  dequantization functions; currently it's an integer value defined by
  Sparse_dtype (float:0, half:1, bfloat16:5)
- Added type conversion in quant and dequant kernels by using native
  CUDA/HIP functions for half to float conversion and writing
  everything explicitly.

Differential Revision: D47904459

fbshipit-source-id: 41d3f0c50365d0482aab912c202f458a787419d8
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 21, 2023
1 parent d43717a commit facb7ed
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 26 deletions.
6 changes: 4 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ at::Tensor _float_or_half_to_fused8bitrowwise_gpu(const at::Tensor& input);
at::Tensor _fused8bitrowwise_to_float_gpu(const at::Tensor& input);
at::Tensor _FP8rowwise_to_float_gpu(
const at::Tensor& input,
const bool forward = true);
const bool forward = true,
const int64_t output_dtype = 0);
at::Tensor _paddedFP8rowwise_to_float_gpu(
const at::Tensor& input,
const bool forward = true,
Expand All @@ -239,7 +240,8 @@ at::Tensor float_or_half_to_fused8bitrowwise_cpu(const at::Tensor& input);
at::Tensor fused8bitrowwise_to_float_cpu(const at::Tensor& input);
at::Tensor FP8rowwise_to_float_cpu(
const at::Tensor& input,
const bool forward = true);
const bool forward = true,
const int64_t output_dtype = 0);
at::Tensor fused8bitrowwise_to_half_cpu(const at::Tensor& input);
at::Tensor fused8bitrowwise_to_float_or_half_cpu(
const at::Tensor& input,
Expand Down
137 changes: 119 additions & 18 deletions fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,27 @@ namespace fbgemm_gpu {

namespace {

// TODO: Move these helper functions to header
__device__ inline float bf16_to_fp32(const void* input) {
#ifdef __HIP_PLATFORM_HCC__
return float(*reinterpret_cast<const hip_bfloat16*>(input));
#else
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(input));
#endif
}

#ifdef __HIP_PLATFORM_HCC__
__device__ inline at::BFloat16 fp32_to_bf16(const float* input) {
hip_bfloat16 result(*input);
return *reinterpret_cast<at::BFloat16*>(&(result.data));
}
#else
__device__ inline at::BFloat16 fp32_to_bf16(const float* input) {
const __nv_bfloat16 result = __float2bfloat16(*input);
return *reinterpret_cast<const at::BFloat16*>(&result);
}
#endif

// FP32/FP16 -> FP8 rowwise kernel
template <typename input_t>
__global__ inline void _float_to_FP8rowwise_cuda_kernel(
Expand Down Expand Up @@ -48,8 +69,16 @@ __global__ inline void _float_to_FP8rowwise_cuda_kernel(
max_pos / (kEpsilon + fmaxf(maximum_element, -minimum_element));
output_row_scale_bias[0] = scale;
for (int64_t col = 0; col < ncols; ++col) {
output_row[col] =
float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos);
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
output_row[col] = float_to_hfp8(
bf16_to_fp32(&input_row[col]) * scale, ebit, bias, max_pos);
} else if constexpr (std::is_same<input_t, at::Half>::value) {
output_row[col] = float_to_hfp8(
__half2float(input_row[col]) * scale, ebit, bias, max_pos);
} else {
output_row[col] =
float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos);
}
}
}
}
Expand Down Expand Up @@ -87,7 +116,15 @@ __global__ inline void _get_FP8_qparam_cuda_kernel(
for (int64_t col = threadIdx.x; col < ncols; col += lane_width) {
// Get thread-local minmax. These are the smallest min and max ever seen
// by this thread.
maximum_element = fmaxf(maximum_element, fabs(input_row[col]));
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
maximum_element =
fmaxf(maximum_element, fabs(bf16_to_fp32(&input_row[col])));
} else if constexpr (std::is_same<input_t, at::Half>::value) {
maximum_element =
fmaxf(maximum_element, fabs(__half2float(input_row[col])));
} else {
maximum_element = fmaxf(maximum_element, fabs(input_row[col]));
}
}
}

Expand Down Expand Up @@ -149,8 +186,16 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel(
// TODO: lift range_list into shared memory. However, when nrows is large,
// it might exceed the size of shared memory.
// output_addr[0] = lrintf((input[input_idx] - bias) * inverse_scale);
output_addr[0] =
float_to_hfp8(input[input_idx] * scale, ebit, bias, max_pos);
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
output_addr[0] = float_to_hfp8(
bf16_to_fp32(&input[input_idx]) * scale, ebit, bias, max_pos);
} else if constexpr (std::is_same<input_t, at::Half>::value) {
output_addr[0] = float_to_hfp8(
__half2float(input[input_idx]) * scale, ebit, bias, max_pos);
} else {
output_addr[0] =
float_to_hfp8(input[input_idx] * scale, ebit, bias, max_pos);
}
}
}
}
Expand All @@ -176,8 +221,17 @@ __global__ inline void _FP8rowwise_to_float_cuda_kernel(
reinterpret_cast<const float*>(input_row + output_columns);
output_t* output_row = output + row * output_columns;

output_row[col] =
const float output_ =
hfp8_to_float(input_row[col], ebit, bias) / input_row_scale_bias[0];

if constexpr (std::is_same<output_t, at::BFloat16>::value) {
*reinterpret_cast<at::BFloat16*>(&output_row[col]) =
fp32_to_bf16(&output_);
} else if constexpr (std::is_same<output_t, at::Half>::value) {
output_row[col] = __half2float(output_);
} else {
output_row[col] = output_;
}
}
}
}
Expand Down Expand Up @@ -221,8 +275,12 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
// think unsigned as we use 0, 255

if (nrows <= 20) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "_float_to_FP8rowwise_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_float_to_FP8rowwise_cuda_kernel",
[&] {
_float_to_FP8rowwise_cuda_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
Expand Down Expand Up @@ -261,8 +319,12 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
const auto num_blocks_warp =
cuda_calc_xblock_count(nrows, rows_per_block);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "_get_FP8_qparam_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_get_FP8_qparam_cuda_kernel",
[&] {
_get_FP8_qparam_cuda_kernel<scalar_t>
<<<num_blocks_warp,
dim3(blockDim_x, rows_per_block),
Expand All @@ -285,8 +347,12 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
dim3 gridDim(gridDim_x, gridDim_y);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_compute_FP8_quantize_cuda_kernel",
[&] {
_compute_FP8_quantize_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(),
Expand All @@ -306,7 +372,14 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
///@ingroup quantize-data-cuda
DLL_PUBLIC Tensor
_float_to_FP8rowwise_gpu(const Tensor& input, const bool forward) {
return _float_to_FP8rowwise_gpu_t<float>(input, forward);
auto input_type = input.dtype();
if (input_type == at::kHalf) {
return _float_to_FP8rowwise_gpu_t<half>(input, forward);
} else if (input_type == at::kBFloat16) {
return _float_to_FP8rowwise_gpu_t<__nv_bfloat16>(input, forward);
} else {
return _float_to_FP8rowwise_gpu_t<float>(input, forward);
}
}
template <typename output_t>
Expand Down Expand Up @@ -337,10 +410,18 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kFloat));
} else { // T = at::Half
} else if constexpr (std::is_same_v<output_t, half>) { // T = at::Half
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kHalf));
} else if constexpr (std::is_same_v<
output_t,
__nv_bfloat16>) { // T = at::BFloat16
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kBFloat16));
} else {
TORCH_CHECK(false);
}
if (nrows == 0 || output_columns == 0) {
Expand All @@ -356,8 +437,12 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
output.scalar_type(), "FP8rowwise_to_float_cuda_kernel", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
output.scalar_type(),
"FP8rowwise_to_float_cuda_kernel",
[&] {
_FP8rowwise_to_float_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<std::uint8_t>(),
Expand All @@ -373,8 +458,24 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
DLL_PUBLIC at::Tensor _FP8rowwise_to_float_gpu(
const at::Tensor& input,
bool forward) {
return _FP8rowwise_to_float_gpu_t<float>(input, forward);
bool forward,
const int64_t output_dtype) {
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
Tensor output;
switch (output_sparse_dtype) {
case SparseType::FP32:
output = _FP8rowwise_to_float_gpu_t<float>(input, forward);
break;
case SparseType::FP16:
output = _FP8rowwise_to_float_gpu_t<half>(input, forward);
break;
case SparseType::BF16:
output = _FP8rowwise_to_float_gpu_t<__nv_bfloat16>(input, forward);
break;
default:
TORCH_CHECK(false);
}
return output;
}
} // namespace fbgemm_gpu
9 changes: 6 additions & 3 deletions fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ Tensor fused8bitrowwise_to_float_or_half_cpu(
const Tensor& input,
const int64_t output_dtype) {
Tensor output;

SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
switch (output_sparse_dtype) {
case SparseType::FP32:
Expand All @@ -241,7 +240,10 @@ Tensor float_to_FP8rowwise_cpu(const Tensor& input, bool forward) {
}

///@ingroup quantize-data-cpu
Tensor FP8rowwise_to_float_cpu(const Tensor& input, bool forward) {
Tensor FP8rowwise_to_float_cpu(
const Tensor& input,
bool forward,
const int64_t output_dtype) {
TORCH_CHECK(false, "fp8 is not supported by CPU");
return input;
}
Expand Down Expand Up @@ -413,7 +415,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("HalfToFused8BitRowwiseQuantized(Tensor t) -> Tensor");
m.def("FloatOrHalfToFused8BitRowwiseQuantized(Tensor t) -> Tensor");
m.def("Fused8BitRowwiseQuantizedToFloat(Tensor input) -> Tensor");
m.def("FP8RowwiseQuantizedToFloat(Tensor input, bool forward) -> Tensor");
m.def(
"FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor");
m.def("Fused8BitRowwiseQuantizedToHalf(Tensor input) -> Tensor");
m.def(
"Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0) -> Tensor");
Expand Down
17 changes: 15 additions & 2 deletions fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>

#include "c10/core/ScalarType.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

Expand All @@ -20,7 +22,8 @@ namespace fbgemm_gpu {
///@ingroup quantize-data-meta
Tensor FP8rowwise_to_float_meta(
const Tensor& input,
[[maybe_unused]] bool forward) {
[[maybe_unused]] bool forward,
const int64_t output_dtype) {
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");

const auto input_sizes = input.sizes();
Expand All @@ -31,7 +34,17 @@ Tensor FP8rowwise_to_float_meta(

auto output_dims = input_sizes.vec();
output_dims[last_dim] = output_columns;
return at::empty(output_dims, input.options().dtype(at::kFloat));
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
switch (output_sparse_dtype) {
case SparseType::FP32:
return at::empty(output_dims, input.options().dtype(at::kFloat));
case SparseType::FP16:
return at::empty(output_dims, input.options().dtype(at::kHalf));
case SparseType::BF16:
return at::empty(output_dims, input.options().dtype(at::kBFloat16));
default:
TORCH_CHECK(false);
}
}

} // namespace fbgemm_gpu
Expand Down
Loading

0 comments on commit facb7ed

Please sign in to comment.