Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable int4 to int4 CPU STBE in fbgemm_gpu TBE API #2994

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{

Tensor output;
SparseType o_dtype = static_cast<SparseType>(output_dtype);
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16);
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT4);
bool output_is_bf16 = o_dtype == SparseType::BF16;
bool output_is_int8 = o_dtype == SparseType::INT8;
bool output_is_int4 = o_dtype == SparseType::INT4;
{% if not nobag %}
const int kINT8QparamsBytes = 8;
int64_t total_adjusted_D = total_D;
Expand All @@ -178,10 +179,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
}
output = at::empty({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
{% else %}
const int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout
constexpr int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout
constexpr int kINT4QparamsElems = 8; // scale + bias takes 4 bytes which are 8 int4 elements
int64_t adjusted_D = D;
if (o_dtype == SparseType::INT8) {
adjusted_D += kINT8QparamsBytes;
} else if (o_dtype == SparseType::INT4) {
adjusted_D += kINT4QparamsElems;
}
output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));

Expand Down Expand Up @@ -212,7 +216,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
using other_fbgemm_out_t = typename std::conditional<
std::is_same<output_t, at::Half>::value,
float16,
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type >::type;
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type> ::type;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] {
const auto* indices_acc = indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();
Expand All @@ -230,7 +234,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
const int32_t D_end = D_offsets_acc[t + 1];
const int32_t D = D_end - D_start;
{% else %}
const int32_t D_start = offsets_acc[t * B] * adjusted_D;
const int32_t elems_D = (o_dtype == SparseType::INT4) ? at::divup(adjusted_D, 2) : adjusted_D;
const int32_t D_start = offsets_acc[t * B] * elems_D;
{% endif %}

const auto placement = static_cast<PlacementType>(weights_placements_ptr[t]);
Expand Down Expand Up @@ -266,8 +271,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
{% endif %}

const float* indice_weights_ptr = nullptr;
// int8 output only enabled for nobag case with ref impl
const bool nobag_op = {{ "false" if not nobag else "output_is_int8" }};
// int8/int4 output only enabled for nobag case
const bool nobag_op = {{ "false" if not nobag else "output_is_int8 || output_is_int4" }};
{% if weighted %}
indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr;
{% endif %}
Expand All @@ -278,7 +283,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
if use_base else ("GenerateEmbeddingSpMDMNBitWithStrides"
if use_nbit else "GenerateEmbeddingSpMDMFP8WithStrides")
%}
using fbgemm_out_t = {{ "base_fbgemm_out_t" if use_base else "other_fbgemm_out_t" }};
using fbgemm_out_t = {{ "base_fbgemm_out_t" if use_base or use_nbit else "other_fbgemm_out_t" }};
{% if use_nbit %}
const int output_bit_rate = output_is_int4 ? 4 : sizeof(fbgemm_out_t) * 8;
{% endif %}
// TODO: merge nobag int8 path with normal asmjit dispatch
{% if nobag %}
const index_t* offset_ptr = (output_is_int8)? offsets_begin_ptr: offsets_nobag_ptr;
Expand All @@ -299,7 +307,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
{% endif %}
>(
{% if use_nbit %}
/*bit_rate=*/bit_rate,
/*input_bit_rate=*/bit_rate,
{% endif %}
D,
{% if has_asmjit %}
Expand All @@ -324,6 +332,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
/*no_bag=*/nobag_op,
{% endif %}
/*is_bf16_out=*/output_is_bf16
{% if use_nbit %}
,/*no_bag=*/nobag_op,
/*output_bit_rate=*/output_bit_rate
{% endif %}
);
success = kernel(
{{ "B" if not nobag else "index_size"}},
Expand Down
19 changes: 19 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/quantize_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,23 @@ hfp8_to_float(uint8_t hfp8_val, int ebits, int exponent_bias) {
return val_out.F;
}

// Get the number of bytes of a row in a tensor with quantized nbit integers
inline int32_t nbit_elems_to_bytes(const at::Tensor& input) {
const auto input_sizes = input.sizes();
const int32_t ncols = input_sizes[1];
// at::kQUInt4x2 is the dtype for quantized int4 tensors and at::kQUInt2x4 is
// for quantized int2 tensors. QUIntMxN (M*N=8) means quantized M-bit integer
// with each byte holding N such elements.
// input_sizes[1] is the number of elements in each row, so we need to divide
// it by 2 or 4 for quint4x2 or quint2x4 respectively to get the number of
// bytes in each row.
if (input.dtype() == at::kQUInt2x4) {
return fbgemm_gpu::div_up(ncols, 4);
} else if (input.dtype() == at::kQUInt4x2) {
return fbgemm_gpu::div_up(ncols, 2);
} else {
return ncols;
}
}

} // namespace fbgemm_gpu
2 changes: 2 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@
at::ScalarType::BFloat16, at::BFloat16, __VA_ARGS__) \
PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Float, float, __VA_ARGS__) \
PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
PRIVATE_CASE_TYPE_OUTPUT2( \
at::ScalarType::QUInt4x2, uint8_t, __VA_ARGS__) \
default: \
AT_ERROR( \
#NAME, \
Expand Down
5 changes: 5 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ __builtin_ia32_serialize(void) {
#define DISPATCH_TO_CPU(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(function)))

#define DISPATCH_TO_QUANTIZED_CPU(name, function) \
m.impl( \
name, \
torch::dispatch(c10::DispatchKey::QuantizedCPU, TORCH_FN(function)))

#define DISPATCH_TO_META(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::Meta, TORCH_FN(function)))

Expand Down
4 changes: 4 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ using fint32 = union fint32 {
float F;
};

inline int64_t div_up(int64_t val, int64_t unit) {
return (val + unit - 1) / unit;
}

} // namespace fbgemm_gpu
63 changes: 62 additions & 1 deletion fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ Tensor _fusednbitrowwise_to_float_cpu(

const auto input_sizes = input.sizes();
const int64_t nrows = input_sizes[0];
const int32_t ncols = input_sizes[1];
// Here we want the number of bytes in a row
const int32_t ncols = nbit_elems_to_bytes(input);
const int32_t num_elem_per_byte = 8 / bit_rate;
const int32_t output_columns =
(ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;
Expand All @@ -149,6 +150,40 @@ Tensor _fusednbitrowwise_to_float_cpu(
return output;
}

Tensor _fusednbitrowwise_sbfront_to_float_cpu(
const Tensor& input,
const int64_t bit_rate) {
TENSOR_ON_CPU(input);
TENSOR_NDIM_EQUALS(input, 2);

const auto input_sizes = input.sizes();
const int64_t nrows = input_sizes[0];
// Here we want the number of bytes in a row
const int32_t ncols = nbit_elems_to_bytes(input);
const int32_t num_elem_per_byte = 8 / bit_rate;
const int32_t output_columns =
(ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;

Tensor output;
output = at::empty(
{nrows, output_columns}, // 4 = sizeof(float)
input.options().dtype(at::kFloat));

float* output_data = static_cast<float*>(
output.data_ptr()); // output.data_ptr<output_t>(); -> Yields
// unresolved data_ptr symbol.

fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<float>(
bit_rate,
input.data_ptr<uint8_t>(),
nrows,
ncols,
output_data,
/*scale_bias_last=*/false);

return output;
}

/// @ingroup quantize-data-cpu
///
Tensor& _fused8bitrowwise_to_float_cpu_out(
Expand Down Expand Up @@ -274,6 +309,24 @@ Tensor fusednbitrowwise_to_float_cpu(
return _fusednbitrowwise_to_float_cpu<float>(input, bit_rate);
}

/// @ingroup quantize-data-cpu
/// @brief Dequantize int4/int2 rows with scale and bias stored in the front
/// into float32.
/// @param input Tensor of int4/int2 rows with scale and bias stored in the
/// front.
/// @param bit_rate Bit rate of each element. Should be 4 or 2.
/// @return Tensor of float32, holding dequantized numbers.
///
/// Dequantize int4/int2 rows with scale and bias stored in the front into
/// float32. The input tensor should have torch.quint4x2 or torch.quint2x4 dtype
/// and QuantizedCPU backend. This operator is only recommended for testing
/// purpose because its kernel is reference implementation and not optimized.
Tensor fusednbitrowwise_sbfront_to_float_cpu(
const Tensor& input,
const int64_t bit_rate) {
return _fusednbitrowwise_sbfront_to_float_cpu(input, bit_rate);
}

/// @ingroup quantize-data-cpu
///
Tensor fusednbitrowwise_to_half_cpu(
Expand Down Expand Up @@ -466,6 +519,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfToFloat(Tensor input, int bit_rate) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfFrontToFloat(Tensor input, int bit_rate) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfToHalf(Tensor input, int bit_rate) -> Tensor");
m.def(
Expand All @@ -485,6 +540,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("dequantize_mx_cuda(Tensor input, int mx_group_size) -> Tensor");
}

TORCH_LIBRARY_IMPL(fbgemm, QuantizedCPU, m) {
DISPATCH_TO_QUANTIZED_CPU(
"FusedNBitRowwiseQuantizedSBHalfFrontToFloat",
fbgemm_gpu::fusednbitrowwise_sbfront_to_float_cpu);
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU(
"FloatToFused8BitRowwiseQuantized",
Expand Down
6 changes: 4 additions & 2 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ FBGEMM_API typename EmbeddingSpMDMKernelSignature<
OffsetType,
OutType>::Type
GenerateEmbeddingSpMDMNBitWithStrides(
int bit_rate,
const int input_bit_rate,
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
Expand All @@ -169,7 +169,9 @@ GenerateEmbeddingSpMDMNBitWithStrides(
std::int64_t output_stride = -1,
std::int64_t input_stride = -1,
bool scale_bias_last = true,
bool is_bf16_out = false);
const bool is_bf16_out = false,
const bool no_bag = false,
int output_bit_rate = -1);

/**
* @param output_stride If -1, output_stride is same as block_size
Expand Down
3 changes: 2 additions & 1 deletion include/fbgemm/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
OutputType* output,
bool scale_bias_last = true);

/**
* Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized.
Expand Down
35 changes: 35 additions & 0 deletions include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <algorithm>
#include <array>
#include <cassert>
#include <cmath>
#include <string>
#include <type_traits>
Expand Down Expand Up @@ -416,4 +417,38 @@ FBGEMM_API bool is_autovec_disabled();
FBGEMM_API bool is_autovec_forced();
FBGEMM_API bool is_asmjit_disabled();

/**
* @brief A function to check if the input parameter in the nbit CPU TBE kernel
* is valid.
*/
template <typename OutType>
void nbit_embedding_sanity_check(
// assertions are ignored in release mode, in which case these parameters
// will be unused
[[maybe_unused]] const int input_bit_rate,
[[maybe_unused]] const int output_bit_rate,
[[maybe_unused]] const bool no_bag) {
assert(
(input_bit_rate == 2 || input_bit_rate == 4) &&
"input_bit_rate must be 2 or 4");
if (std::is_same<OutType, uint8_t>::value) {
assert(
(no_bag && input_bit_rate == 4 && output_bit_rate == 4) &&
"we currently only support int4 to int4 for sequential TBE");
} else {
assert(
(output_bit_rate == 8 * sizeof(OutType)) &&
"output_bit_rate should be equal to 8 * sizeof(OutType)");
}
}

#define WARN_ONCE(...) \
do { \
static bool _warned = false; \
if (!_warned) { \
_warned = true; \
fprintf(stderr, __VA_ARGS__); \
} \
} while (0)

} // namespace fbgemm
Loading
Loading