Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable bf16 output in TBE CPU kernel for other input types #1851

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,6 @@ for (const auto t : c10::irange(T)) {
// default to 1 byte alignment for CPU TBE
const int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment);

// NOTE: currently we only support bf16 output when input is int4 or int2
TORCH_CHECK(o_dtype != SparseType::BF16 || (o_dtype == SparseType::BF16 && (weight_ty == SparseType::INT4 || weight_ty == SparseType::INT2)));

int tt;
for (tt = t + 1; tt < T && weights_offsets_acc[tt] == weights_offsets_acc[t]; ++tt);
size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes;
Expand Down Expand Up @@ -268,10 +265,13 @@ for (const auto t : c10::irange(T)) {
{% endif %}
/*input_stride=*/D_bytes / sizeof(float),
{% if not nobag %}
/*scale_bias_last=*/false);
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true);
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
Expand Down Expand Up @@ -301,10 +301,13 @@ for (const auto t : c10::irange(T)) {
{% endif %}
/*input_stride=*/D_bytes / sizeof(float16),
{% if not nobag %}
/*scale_bias_last=*/false);
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true);
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
Expand Down Expand Up @@ -333,7 +336,8 @@ for (const auto t : c10::irange(T)) {
{% endif %}
/*input_stride=*/D_bytes / sizeof(uint8_t),
/*exponent_bits=*/fp8_exponent_bits,
/*exponent_bias=*/fp8_exponent_bias);
/*exponent_bias=*/fp8_exponent_bias,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
B,
index_size,
Expand All @@ -358,10 +362,13 @@ for (const auto t : c10::irange(T)) {
{% endif %}
/*input_stride=*/D_bytes / sizeof(uint8_t),
{% if not nobag %}
/*scale_bias_last=*/false);
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true);
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4357,7 +4357,7 @@ def test_nbit_forward_cpu(
)

@given(
nbit_weights_ty=st.sampled_from([SparseType.INT4, SparseType.INT2]),
nbit_weights_ty=get_nbit_weights_ty(),
use_array_for_index_remapping=st.booleans(),
do_pruning=st.booleans(),
)
Expand Down
9 changes: 6 additions & 3 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ GenerateEmbeddingSpMDM(
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true,
bool isbf16 = false);
bool is_bf16_out = false,
bool is_bf16_in = false);

/**
* @param output_stride If -1, output_stride is same as block_size
Expand Down Expand Up @@ -112,7 +113,8 @@ GenerateEmbeddingSpMDMWithStrides(
std::int64_t input_stride = -1,
bool scale_bias_last = true,
bool no_bag = false,
bool isbf16 = false);
bool is_bf16_out = false,
bool is_bf16_in = false);

/**
* @tparam IndexType can be int32_t or int64_t
Expand Down Expand Up @@ -195,7 +197,8 @@ GenerateEmbeddingSpMDMFP8WithStrides(
std::int64_t output_stride = -1,
std::int64_t input_stride = -1,
int exponent_bits = 4,
int exponent_bias = 7);
int exponent_bias = 7,
bool is_bf16_out = false);

template <
typename InType,
Expand Down
Loading