Skip to content

Commit

Permalink
Enable bf16 output in TBE CPU kernel for other input types (#1851)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1851

Enable bf16 output support in TBE CPU kernel when the input weight type is int8/fp8/fp16/fp32

Differential Revision: D47028021

fbshipit-source-id: 8721ba1aa097702ae6a0844d312929124ed3448e
  • Loading branch information
excelle08 authored and facebook-github-bot committed Jun 30, 2023
1 parent a7e7d3e commit bdb0f81
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 99 deletions.
27 changes: 17 additions & 10 deletions fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,6 @@ for (const auto t : c10::irange(T)) {
// default to 1 byte alignment for CPU TBE
const int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment);

// NOTE: currently we only support bf16 output when input is int4 or int2
TORCH_CHECK(o_dtype != SparseType::BF16 || (o_dtype == SparseType::BF16 && (weight_ty == SparseType::INT4 || weight_ty == SparseType::INT2)));

int tt;
for (tt = t + 1; tt < T && weights_offsets_acc[tt] == weights_offsets_acc[t]; ++tt);
size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes;
Expand Down Expand Up @@ -268,10 +265,13 @@ for (const auto t : c10::irange(T)) {
{% endif %}
/*input_stride=*/D_bytes / sizeof(float),
{% if not nobag %}
/*scale_bias_last=*/false);
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true);
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
Expand Down Expand Up @@ -301,10 +301,13 @@ for (const auto t : c10::irange(T)) {
{% endif %}
/*input_stride=*/D_bytes / sizeof(float16),
{% if not nobag %}
/*scale_bias_last=*/false);
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true);
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
Expand Down Expand Up @@ -333,7 +336,8 @@ for (const auto t : c10::irange(T)) {
{% endif %}
/*input_stride=*/D_bytes / sizeof(uint8_t),
/*exponent_bits=*/fp8_exponent_bits,
/*exponent_bias=*/fp8_exponent_bias);
/*exponent_bias=*/fp8_exponent_bias,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
B,
index_size,
Expand All @@ -358,10 +362,13 @@ for (const auto t : c10::irange(T)) {
{% endif %}
/*input_stride=*/D_bytes / sizeof(uint8_t),
{% if not nobag %}
/*scale_bias_last=*/false);
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true);
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4357,7 +4357,7 @@ def test_nbit_forward_cpu(
)

@given(
nbit_weights_ty=st.sampled_from([SparseType.INT4, SparseType.INT2]),
nbit_weights_ty=get_nbit_weights_ty(),
use_array_for_index_remapping=st.booleans(),
do_pruning=st.booleans(),
)
Expand Down
9 changes: 6 additions & 3 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ GenerateEmbeddingSpMDM(
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true,
bool isbf16 = false);
bool is_bf16_out = false,
bool is_br16_in = false);

/**
* @param output_stride If -1, output_stride is same as block_size
Expand Down Expand Up @@ -112,7 +113,8 @@ GenerateEmbeddingSpMDMWithStrides(
std::int64_t input_stride = -1,
bool scale_bias_last = true,
bool no_bag = false,
bool isbf16 = false);
bool is_bf16_out = false,
bool is_bf16_in = false);

/**
* @tparam IndexType can be int32_t or int64_t
Expand Down Expand Up @@ -195,7 +197,8 @@ GenerateEmbeddingSpMDMFP8WithStrides(
std::int64_t output_stride = -1,
std::int64_t input_stride = -1,
int exponent_bits = 4,
int exponent_bias = 7);
int exponent_bias = 7,
bool is_bf16_out = false);

template <
typename InType,
Expand Down
Loading

0 comments on commit bdb0f81

Please sign in to comment.