Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable TBE CPU bf16 output support on Mac and Windows platforms #1839

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions src/EmbeddingSpMDMNBit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,25 +398,21 @@ GenEmbeddingSpMDMNBitLookup<
x86::Ymm mask_vreg; // mask for avx2
x86::Xmm mask2_vreg;
x86::Xmm mask_fp16_vreg;
#if !defined(__APPLE__) && !defined(_WIN32)
vec_reg_t ones_vreg;
#endif

// We need 2 vec registers for 1. scale 2. bias
--unroll_factor;
scale_vreg = vec_reg_t(unroll_factor);
--unroll_factor;
bias_vreg = vec_reg_t(unroll_factor);

#if !defined(__APPLE__) && !defined(_WIN32)
if (is_bf16_out) {
--unroll_factor;
ones_vreg = vec_reg_t(unroll_factor);
a->mov(scratchReg2_, 1 << 15);
a->vpinsrd(ones_vreg.xmm(), ones_vreg.xmm(), scratchReg2_, 0);
a->vpbroadcastd(ones_vreg, ones_vreg.xmm());
}
#endif

--unroll_factor;
src_vreg = vec_reg_t(unroll_factor);
Expand Down Expand Up @@ -883,19 +879,15 @@ GenEmbeddingSpMDMNBitLookup<
} else {
// 16-bit output
if (instSet == inst_set_t::avx2) {
#if !defined(__APPLE__) && !defined(_WIN32)
if (is_bf16_out) {
a->vpaddd(out_vreg, out_vreg, ones_vreg);
a->vpsrld(out_vreg, out_vreg, 16);
a->vpackusdw(out_vreg, out_vreg, out_vreg);
a->vpermq(out_vreg, out_vreg, 0xd8);
} else {
#endif
// round nearest with no exception
a->vcvtps2ph(out_vreg.xmm(), out_vreg, 8);
#if !defined(__APPLE__) && !defined(_WIN32)
}
#endif
if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
if (remainder > 1) {
a->vmaskmovps(dst_addr, mask_fp16_vreg, out_vreg.xmm());
Expand All @@ -918,31 +910,23 @@ GenEmbeddingSpMDMNBitLookup<
}
} else {
if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
#if !defined(__APPLE__) && !defined(_WIN32)
if (is_bf16_out) {
// bf16
a->k(x86::k(1)).vpaddd(out_vreg, out_vreg, ones_vreg);
a->k(x86::k(1)).vpsrld(out_vreg, out_vreg, 16);
a->k(x86::k(1)).vpmovdw(dst_addr, out_vreg);
} else {
#endif
a->k(x86::k(1)).vcvtps2ph(dst_addr, out_vreg, 8);
#if !defined(__APPLE__) && !defined(_WIN32)
}
#endif
} else {
#if !defined(__APPLE__) && !defined(_WIN32)
if (is_bf16_out) {
// bf16
a->vpaddd(out_vreg, out_vreg, ones_vreg);
a->vpsrld(out_vreg, out_vreg, 16);
a->vpmovdw(dst_addr, out_vreg);
} else {
#endif
a->vcvtps2ph(dst_addr, out_vreg, 8);
#if !defined(__APPLE__) && !defined(_WIN32)
}
#endif
}
}
}
Expand Down
73 changes: 53 additions & 20 deletions test/EmbeddingSpMDMNBitTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class FusedNBitRowwiseEmbeddingLookupTest : public testing::TestWithParam<tuple<
int,
int,
EmbeddingSpMDMWeightChoice,
EmbeddingSpMDMCornerCase>> {};
EmbeddingSpMDMCornerCase,
EmbeddingSpMDMDtypeChoice>> {};
}; // namespace

INSTANTIATE_TEST_CASE_P(
Expand All @@ -74,7 +75,8 @@ INSTANTIATE_TEST_CASE_P(
NONE,
EMPTY_INDICES,
OUT_OF_BOUND_INDICES,
UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM)));
UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM),
::testing::Values(FLOAT, FLOAT16, BFLOAT16)));

TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
vector<vector<int>> inputs(GetInputs_());
Expand All @@ -86,19 +88,20 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
bool isOffset64b = bool_dist(generator);
bool normalize_by_lengths = bool_dist(generator);
bool use_offsets = bool_dist(generator);
bool is_output_float = bool_dist(generator);
bool scale_bias_last = bool_dist(generator);
bool test_thread_local = bool_dist(generator);
int bit_rate, prefetch;
EmbeddingSpMDMWeightChoice weight_choice;
EmbeddingSpMDMCornerCase corner_case;
tie(bit_rate, prefetch, weight_choice, corner_case) = GetParam();
EmbeddingSpMDMDtypeChoice out_type;
tie(bit_rate, prefetch, weight_choice, corner_case, out_type) = GetParam();
bool is_wt_positional = weight_choice == POSITIONAL_WEIGHTED;
bool use_weight = weight_choice != UNWEIGHTED;
bool is_bf16_out = out_type == BFLOAT16;

if (corner_case != NONE || weight_choice == POSITIONAL_WEIGHTED) {
// Check corner case only for subset of tests.
if (normalize_by_lengths || !is_output_float || !scale_bias_last ||
if (normalize_by_lengths || out_type != FLOAT || !scale_bias_last ||
test_thread_local) {
return;
}
Expand Down Expand Up @@ -171,11 +174,14 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
vector<float> output_ref(output_size_wo_sentries + num_sentries);
vector<float> output(output_ref.size());
vector<float16> output_ref_fp16(output.size()), output_fp16(output.size());
vector<bfloat16> output_ref_bf16(output.size()), output_bf16(output.size());
for (size_t i = output_size_wo_sentries; i < output.size(); ++i) {
output_ref[i] = sentry_value;
output[i] = sentry_value;
output_ref_fp16[i] = cpu_float2half_rn(sentry_value);
output_fp16[i] = cpu_float2half_rn(sentry_value);
FloatToBfloat16_ref(&sentry_value, &output_ref_bf16[i], 1);
FloatToBfloat16_ref(&sentry_value, &output_bf16[i], 1);
}

bool success, success_ref;
Expand Down Expand Up @@ -205,7 +211,8 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
use_offsets, \
/*output_stride=*/-1, \
/*input_stride=*/-1, \
scale_bias_last); \
scale_bias_last, \
is_bf16_out); \
\
auto kernel = GenerateEmbeddingSpMDMNBitWithStrides< \
IndexType, \
Expand All @@ -221,7 +228,8 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
use_offsets, \
/*output_stride=*/-1, \
/*input_stride=*/-1, \
scale_bias_last); \
scale_bias_last, \
is_bf16_out); \
success = kernel( \
batch_size, \
lengths_sum, \
Expand Down Expand Up @@ -263,7 +271,7 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
}

#define TEST_OUT_TYPE(indices, offsets_or_lengths, IndexType, OffsetType) \
if (is_output_float) { \
if (out_type == FLOAT) { \
TEST_THREAD_LOCAL( \
indices, \
offsets_or_lengths, \
Expand All @@ -272,6 +280,15 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
IndexType, \
OffsetType, \
float); \
} else if (out_type == BFLOAT16) { \
TEST_THREAD_LOCAL( \
indices, \
offsets_or_lengths, \
output_ref_bf16, \
output_bf16, \
IndexType, \
OffsetType, \
bfloat16); \
} else { \
TEST_THREAD_LOCAL( \
indices, \
Expand Down Expand Up @@ -308,24 +325,40 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
corner_case == UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM) {
EXPECT_EQ(success, false);
}

auto get_actual = [&](int offset) {
if (out_type == FLOAT) {
return output[offset];
} else if (out_type == BFLOAT16) {
return cpu_bf162float(output[offset]);
} else {
return cpu_half2float(output[offset]);
}
};

auto get_expected = [&](int offset) {
if (out_type == FLOAT) {
return output_ref[offset];
} else if (out_type == BFLOAT16) {
return cpu_bf162float(output_ref[offset]);
} else {
return cpu_half2float(output_ref[offset]);
}
};

if (success) {
for (size_t i = 0; i < output.size(); ++i) {
float actual =
is_output_float ? output[i] : cpu_half2float(output_fp16[i]);
float expected = is_output_float ? output_ref[i]
: cpu_half2float(output_ref_fp16[i]);
float actual = get_actual(i);
float expected = get_expected(i);
EXPECT_EQ(actual, expected)
<< "results differ at (" << i << ") reference: " << expected
<< ", FBGEMM: " << actual << " emb dim :" << embedding_dim;
}
for (int offset = output_size_wo_sentries;
offset < output_size_wo_sentries + num_sentries;
++offset) {
float actual = is_output_float ? output[offset]
: cpu_half2float(output_fp16[offset]);
float expected = is_output_float
? output_ref[offset]
: cpu_half2float(output_ref_fp16[offset]);
float actual = get_actual(offset);
float expected = get_expected(offset);
EXPECT_EQ(actual, expected)
<< "results differ at (" << offset << ") reference: " << expected
<< ", FBGEMM: " << actual << " emb dim :" << embedding_dim;
Expand All @@ -344,17 +377,17 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
bool isOffset64b = bool_dist(generator);
bool normalize_by_lengths = bool_dist(generator);
bool use_offsets = bool_dist(generator);
bool is_output_float = bool_dist(generator);
bool scale_bias_last = bool_dist(generator);

int bit_rate, prefetch;
EmbeddingSpMDMWeightChoice weight_choice;
EmbeddingSpMDMCornerCase corner_case;
tie(bit_rate, prefetch, weight_choice, corner_case) = GetParam();
EmbeddingSpMDMDtypeChoice out_type;
tie(bit_rate, prefetch, weight_choice, corner_case, out_type) = GetParam();
bool is_wt_positional = weight_choice == POSITIONAL_WEIGHTED;
bool use_weight = weight_choice != UNWEIGHTED;

if (!is_output_float || !scale_bias_last) {
if (out_type != FLOAT || !scale_bias_last) {
return;
}

Expand Down