Skip to content

Commit

Permalink
Enable int4 to int4 CPU STBE in fbgemm_gpu TBE API
Browse files Browse the repository at this point in the history
Summary: Enable int4 to int4 sequential CPU TBE in codegen template so that fbgemm_gpu's `IntNBitTableBatchedEmbeddingBagsCodegen` could support it

Differential Revision: D61305978
  • Loading branch information
excelle08 authored and facebook-github-bot committed Aug 15, 2024
1 parent 9467a99 commit 6453fb6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{

Tensor output;
SparseType o_dtype = static_cast<SparseType>(output_dtype);
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16);
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT4);
bool output_is_bf16 = o_dtype == SparseType::BF16;
bool output_is_int8 = o_dtype == SparseType::INT8;
bool output_is_int4 = o_dtype == SparseType::INT4;
{% if not nobag %}
const int kINT8QparamsBytes = 8;
int64_t total_adjusted_D = total_D;
Expand All @@ -178,10 +179,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
}
output = at::empty({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
{% else %}
const int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout
constexpr int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout
constexpr int kINT4QparamsElems = 8; // scale + bias takes 4 bytes which are 8 int4 elements
int64_t adjusted_D = D;
if (o_dtype == SparseType::INT8) {
adjusted_D += kINT8QparamsBytes;
} else if (o_dtype == SparseType::INT4) {
adjusted_D += kINT4QparamsElems;
}
output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));

Expand Down Expand Up @@ -212,7 +216,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
using other_fbgemm_out_t = typename std::conditional<
std::is_same<output_t, at::Half>::value,
float16,
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type >::type;
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type> ::type;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] {
const auto* indices_acc = indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();
Expand All @@ -230,7 +234,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
const int32_t D_end = D_offsets_acc[t + 1];
const int32_t D = D_end - D_start;
{% else %}
const int32_t D_start = offsets_acc[t * B] * adjusted_D;
const int32_t elems_D = (o_dtype == SparseType::INT4) ? at::divup(adjusted_D, 2) : adjusted_D;
const int32_t D_start = offsets_acc[t * B] * elems_D;
{% endif %}

const auto placement = static_cast<PlacementType>(weights_placements_ptr[t]);
Expand Down Expand Up @@ -266,8 +271,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
{% endif %}

const float* indice_weights_ptr = nullptr;
// int8 output only enabled for nobag case with ref impl
const bool nobag_op = {{ "false" if not nobag else "output_is_int8" }};
// int8/int4 output only enabled for nobag case
const bool nobag_op = {{ "false" if not nobag else "output_is_int8 || output_is_int4" }};
{% if weighted %}
indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr;
{% endif %}
Expand All @@ -278,7 +283,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
if use_base else ("GenerateEmbeddingSpMDMNBitWithStrides"
if use_nbit else "GenerateEmbeddingSpMDMFP8WithStrides")
%}
using fbgemm_out_t = {{ "base_fbgemm_out_t" if use_base else "other_fbgemm_out_t" }};
using fbgemm_out_t = {{ "base_fbgemm_out_t" if use_base or use_nbit else "other_fbgemm_out_t" }};
{% if use_nbit %}
const int output_bit_rate = output_is_int4 ? 4 : sizeof(fbgemm_out_t) * 8;
{% endif %}
// TODO: merge nobag int8 path with normal asmjit dispatch
{% if nobag %}
const index_t* offset_ptr = (output_is_int8)? offsets_begin_ptr: offsets_nobag_ptr;
Expand All @@ -299,7 +307,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
{% endif %}
>(
{% if use_nbit %}
/*bit_rate=*/bit_rate,
/*input_bit_rate=*/bit_rate,
{% endif %}
D,
{% if has_asmjit %}
Expand All @@ -324,6 +332,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
/*no_bag=*/nobag_op,
{% endif %}
/*is_bf16_out=*/output_is_bf16
{% if use_nbit %}
,/*no_bag=*/nobag_op,
/*output_bit_rate=*/output_bit_rate
{% endif %}
);
success = kernel(
{{ "B" if not nobag else "index_size"}},
Expand Down
2 changes: 2 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@
at::ScalarType::BFloat16, at::BFloat16, __VA_ARGS__) \
PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Float, float, __VA_ARGS__) \
PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
PRIVATE_CASE_TYPE_OUTPUT2( \
at::ScalarType::QUInt4x2, uint8_t, __VA_ARGS__) \
default: \
AT_ERROR( \
#NAME, \
Expand Down
2 changes: 1 addition & 1 deletion include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ GenerateEmbeddingSpMDMNBitWithStrides(
bool scale_bias_last = true,
const bool is_bf16_out = false,
const bool no_bag = false,
const bool output_bit_rate = 8 * sizeof(OutType));
const int output_bit_rate = 8 * sizeof(OutType));

/**
* @param output_stride If -1, output_stride is same as block_size
Expand Down

0 comments on commit 6453fb6

Please sign in to comment.