diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp index 2d706b59df..aa410c95da 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp @@ -229,9 +229,6 @@ for (const auto t : c10::irange(T)) { // default to 1 byte alignment for CPU TBE const int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment); - // NOTE: currently we only support bf16 output when input is int4 or int2 - TORCH_CHECK(o_dtype != SparseType::BF16 || (o_dtype == SparseType::BF16 && (weight_ty == SparseType::INT4 || weight_ty == SparseType::INT2))); - int tt; for (tt = t + 1; tt < T && weights_offsets_acc[tt] == weights_offsets_acc[t]; ++tt); size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes; @@ -268,10 +265,13 @@ for (const auto t : c10::irange(T)) { {% endif %} /*input_stride=*/D_bytes / sizeof(float), {% if not nobag %} - /*scale_bias_last=*/false); + /*scale_bias_last=*/false, + /*no_bag=*/false, + /*is_bf16_out=*/output_is_bf16); {% else %} /*scale_bias_last=*/false, - /*no_bag=*/true); + /*no_bag=*/true, + /*is_bf16_out=*/output_is_bf16); {% endif %} success = kernel( {% if not nobag %} @@ -301,10 +301,13 @@ for (const auto t : c10::irange(T)) { {% endif %} /*input_stride=*/D_bytes / sizeof(float16), {% if not nobag %} - /*scale_bias_last=*/false); + /*scale_bias_last=*/false, + /*no_bag=*/false, + /*is_bf16_out=*/output_is_bf16); {% else %} /*scale_bias_last=*/false, - /*no_bag=*/true); + /*no_bag=*/true, + /*is_bf16_out=*/output_is_bf16); {% endif %} success = kernel( {% if not nobag %} @@ -333,7 +336,8 @@ for (const auto t : c10::irange(T)) { {% endif %} /*input_stride=*/D_bytes / sizeof(uint8_t), /*exponent_bits=*/fp8_exponent_bits, - /*exponent_bias=*/fp8_exponent_bias); + /*exponent_bias=*/fp8_exponent_bias, + /*is_bf16_out=*/output_is_bf16); success = kernel( B, index_size, @@ -358,10 +362,13 @@ for (const auto t : c10::irange(T)) { {% endif %} /*input_stride=*/D_bytes / sizeof(uint8_t), {% if not nobag %} - /*scale_bias_last=*/false); + /*scale_bias_last=*/false, + /*no_bag=*/false, + /*is_bf16_out=*/output_is_bf16); {% else %} /*scale_bias_last=*/false, - /*no_bag=*/true); + /*no_bag=*/true, + /*is_bf16_out=*/output_is_bf16); {% endif %} success = kernel( {% if not nobag %} diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 9146499f8c..f2b7b97568 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -4357,7 +4357,7 @@ def test_nbit_forward_cpu( ) @given( - nbit_weights_ty=st.sampled_from([SparseType.INT4, SparseType.INT2]), + nbit_weights_ty=get_nbit_weights_ty(), use_array_for_index_remapping=st.booleans(), do_pruning=st.booleans(), ) diff --git a/include/fbgemm/FbgemmEmbedding.h b/include/fbgemm/FbgemmEmbedding.h index e46adac395..96ed38e5eb 100644 --- a/include/fbgemm/FbgemmEmbedding.h +++ b/include/fbgemm/FbgemmEmbedding.h @@ -80,7 +80,8 @@ GenerateEmbeddingSpMDM( int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true, - bool isbf16 = false); + bool is_bf16_out = false, + bool is_br16_in = false); /** * @param output_stride If -1, output_stride is same as block_size @@ -112,7 +113,8 @@ GenerateEmbeddingSpMDMWithStrides( std::int64_t input_stride = -1, bool scale_bias_last = true, bool no_bag = false, - bool isbf16 = false); + bool is_bf16_out = false, + bool is_bf16_in = false); /** * @tparam IndexType can be int32_t or int64_t @@ -195,7 +197,8 @@ GenerateEmbeddingSpMDMFP8WithStrides( std::int64_t output_stride = -1, std::int64_t input_stride = -1, int exponent_bits = 4, - int exponent_bias = 7); + int exponent_bias = 7, + bool is_bf16_out = false); template < typename InType, diff --git a/src/EmbeddingSpMDM.cc b/src/EmbeddingSpMDM.cc index a5b1139282..2977471bf1 100644 --- a/src/EmbeddingSpMDM.cc +++ b/src/EmbeddingSpMDM.cc @@ -6,6 +6,9 @@ * LICENSE file in the root directory of this source tree. */ +#include +#include +#include #define FBGEMM_EXPORTS #include "fbgemm/FbgemmEmbedding.h" @@ -106,7 +109,8 @@ class GenEmbeddingSpMDMLookup { int output_stride, int input_stride, bool scale_bias_last, - bool isbf16); + bool is_bf16_out, + bool is_bf16_in); private: static asmjit::JitRuntime& runtime() { @@ -123,7 +127,7 @@ class GenEmbeddingSpMDMLookup { // positional weights, normalize by lenths, prefetch distance, use_offsets, // output_stride, input_stride, and scale_bias_last static CodeCache< - std::tuple, + std::tuple, typename ReturnFunctionSignature< inType, indxType, @@ -160,7 +164,7 @@ template < bool ROWWISE_SPARSE, bool THREAD_LOCAL> CodeCache< - std::tuple, + std::tuple, typename ReturnFunctionSignature< inType, indxType, @@ -209,19 +213,20 @@ GenEmbeddingSpMDMLookup< int output_stride, int input_stride, bool scale_bias_last, - bool isbf16) { - std::tuple kernelSig = - std::make_tuple( - block_size, - has_weight, - is_weight_positional, - normalize_by_lengths, - prefetch, - use_offsets, - output_stride, - input_stride, - scale_bias_last, - isbf16); + bool is_bf16_out, + bool is_bf16_in) { + auto kernelSig = std::make_tuple( + block_size, + has_weight, + is_weight_positional, + normalize_by_lengths, + prefetch, + use_offsets, + output_stride, + input_stride, + scale_bias_last, + is_bf16_out, + is_bf16_in); return codeCache_.getOrCreate( kernelSig, @@ -231,12 +236,11 @@ GenEmbeddingSpMDMLookup< offsetType, outType, ROWWISE_SPARSE>::jit_embedding_kernel { - bool is8bit = std::is_same::value; - bool is16bit = std::is_same::value; - bool is16bitout = std::is_same::value; - bool isbf16out = isbf16; - bool isfp16 = is16bit && !isbf16; - bool isfp16out = is16bitout && !isbf16out; + bool is_8bit_in = std::is_same::value; + bool is_16bit_in = std::is_same::value; + bool is_16bit_out = std::is_same::value; + bool is_fp16_in = is_16bit_in && !is_bf16_in; + bool is_fp16_out = is_16bit_out && !is_bf16_out; // TODO: Make this tunable int pref_dist = prefetch; @@ -248,16 +252,16 @@ GenEmbeddingSpMDMLookup< x86::Emitter* a = assembler.as(); #if defined(FBGEMM_LOG_CODE) std::string filename = "embeddinglookup"; - if (is8bit) { + if (is_8bit_in) { filename += "_8bit"; - } else if (isfp16) { + } else if (is_fp16_in) { filename += "_fp16"; - } else if (isbf16) { + } else if (is_bf16_in) { filename += "_bf16"; } - if (isbf16out) { + if (is_bf16_out) { filename += "_bf16_out"; - } else if (isfp16out) { + } else if (is_fp16_out) { filename += "_fp16_out"; } filename += "_emd_dim_" + std::to_string(block_size); @@ -422,7 +426,7 @@ GenEmbeddingSpMDMLookup< x86::Xmm mask_fp16_vreg; // mask for loading fp16 in avx2 vec_reg_t ones_vreg; // 2^15 for bf16_2_fp32_rn - if (is8bit) { + if (is_8bit_in) { // We need 2 vec registers for 1. scale 2. bias --unroll_factor; scale_vreg = vec_reg_t(unroll_factor); @@ -430,7 +434,7 @@ GenEmbeddingSpMDMLookup< bias_vreg = vec_reg_t(unroll_factor); } - if (isbf16out) { + if (is_bf16_out) { --unroll_factor; ones_vreg = vec_reg_t(unroll_factor); a->mov(scratchReg2_, 1 << 15); @@ -438,7 +442,8 @@ GenEmbeddingSpMDMLookup< a->vpbroadcastd(ones_vreg, ones_vreg.xmm()); } - if (is8bit || is16bit || (remainder && instSet == inst_set_t::avx2)) { + if (is_8bit_in || is_16bit_in || + (remainder && instSet == inst_set_t::avx2)) { --unroll_factor; src_vreg = vec_reg_t(unroll_factor); } @@ -452,7 +457,7 @@ GenEmbeddingSpMDMLookup< // AVX512 doesn't need to use vector register for masking --unroll_factor; mask_vreg = x86::ymm(unroll_factor); - if (remainder > 1 && (is16bit || isbf16out || isfp16out)) { + if (remainder > 1 && (is_16bit_in || is_bf16_out || is_fp16_out)) { --unroll_factor; mask_fp16_vreg = x86::xmm(unroll_factor); } @@ -469,7 +474,7 @@ GenEmbeddingSpMDMLookup< mask_vreg, x86::ymmword_ptr( scratchReg1_, (vlen - remainder) % vlen * sizeof(int32_t))); - if (is16bit || isbf16out || isfp16out) { + if (is_16bit_in || is_bf16_out || is_fp16_out) { if (remainder > 1) { a->vmovups( mask_fp16_vreg, @@ -680,7 +685,7 @@ GenEmbeddingSpMDMLookup< // broadcast the scale x86::Mem scale_src, bias_src; constexpr unsigned int CACHE_LINE_LEN = 64; - if (is8bit) { + if (is_8bit_in) { if (scale_bias_last) { scale_src = x86::dword_ptr( input, scratchReg1_, 0, block_size * sizeof(uint8_t)); @@ -711,14 +716,14 @@ GenEmbeddingSpMDMLookup< } } - if (has_weight && is8bit) { + if (has_weight && is_8bit_in) { a->vmulps(scale_vreg, scale_vreg, w_vreg); a->vmulps(bias_vreg, bias_vreg, w_vreg); } // The main computation int src_addr_offset = - is8bit && !scale_bias_last ? 2 * sizeof(uint16_t) : 0; + is_8bit_in && !scale_bias_last ? 2 * sizeof(uint16_t) : 0; for (int v = 0; v < cur_unroll_factor; ++v) { constexpr int BYTES_PER_VLOAD = vlen * sizeof(inType); auto src_addr = x86::dword_ptr( @@ -730,7 +735,7 @@ GenEmbeddingSpMDMLookup< // For 8bit SLS convert usigned 8-bit to 32bit int, then to float // multiply with scale and then add with bias - if (is8bit) { + if (is_8bit_in) { if (remainder && vec_idx + v == num_vec_regs_per_block - 1 && instSet == inst_set_t::avx512) { a->k(x86::k(1)).z().vpmovzxbd(src_vreg, src_addr); @@ -743,7 +748,7 @@ GenEmbeddingSpMDMLookup< a->vcvtdq2ps(src_vreg, src_vreg); a->vaddps(out_vreg, out_vreg, bias_vreg); a->vfmadd231ps(out_vreg, src_vreg, scale_vreg); - } else if (is16bit) { + } else if (is_16bit_in) { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (instSet == inst_set_t::avx2) { if (remainder % 2 == 0) { @@ -772,18 +777,18 @@ GenEmbeddingSpMDMLookup< a->vmovups(src_vreg.xmm(), x86::xmmword_ptr(x86::rsp)); } // remainder > 1 } // remainder % 2 - if (isfp16) { + if (is_fp16_in) { a->vcvtph2ps(src_vreg.ymm(), src_vreg.xmm()); - } else if (isbf16) { + } else if (is_bf16_in) { // bf16 a->vpmovzxwd(src_vreg.ymm(), src_vreg.xmm()); a->vpslld(src_vreg.ymm(), src_vreg.ymm(), 16); } } else { // avx512 - if (isfp16) { + if (is_fp16_in) { a->k(x86::k(1)).z().vcvtph2ps(src_vreg, src_addr); - } else if (isbf16) { + } else if (is_bf16_in) { // bf16 a->k(x86::k(1)).z().vpmovzxwd(src_vreg, src_addr); a->k(x86::k(1)).z().vpslld(src_vreg, src_vreg, 16); @@ -791,9 +796,9 @@ GenEmbeddingSpMDMLookup< } } else { // no remainder - if (isfp16) { + if (is_fp16_in) { a->vcvtph2ps(src_vreg, src_addr); - } else if (isbf16) { + } else if (is_bf16_in) { // bf16 a->vpmovzxwd(src_vreg, src_addr); a->vpslld(src_vreg, src_vreg, 16); @@ -869,9 +874,9 @@ GenEmbeddingSpMDMLookup< // fp16/bf16 output if (instSet == inst_set_t::avx2) { // round nearest with no exception - if (isfp16out) { + if (is_fp16_out) { a->vcvtps2ph(out_vreg.xmm(), out_vreg, 8); - } else if (isbf16out) { + } else if (is_bf16_out) { a->vpaddd(out_vreg, out_vreg, ones_vreg); a->vpsrld(out_vreg, out_vreg, 16); a->vpackusdw(out_vreg, out_vreg, out_vreg); @@ -899,18 +904,18 @@ GenEmbeddingSpMDMLookup< } } else { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { - if (isfp16out) { + if (is_fp16_out) { a->k(x86::k(1)).vcvtps2ph(dst_addr, out_vreg, 8); - } else if (isbf16out) { + } else if (is_bf16_out) { // bf16 a->k(x86::k(1)).vpaddd(out_vreg, out_vreg, ones_vreg); a->k(x86::k(1)).vpsrld(out_vreg, out_vreg, 16); a->k(x86::k(1)).vpmovdw(dst_addr, out_vreg); } } else { - if (isfp16out) { + if (is_fp16_out) { a->vcvtps2ph(dst_addr, out_vreg, 8); - } else if (isbf16out) { + } else if (is_bf16_out) { // bf16 a->vpaddd(out_vreg, out_vreg, ones_vreg); a->vpsrld(out_vreg, out_vreg, 16); @@ -970,7 +975,7 @@ GenEmbeddingSpMDMLookup< a->bind(exit); if (remainder && instSet == inst_set_t::avx2 && - (is16bit || isbf16out || isfp16out)) { + (is_16bit_in || is_bf16_out || is_fp16_out)) { a->lea(x86::rsp, x86::ymmword_ptr(x86::rsp, vlen * sizeof(int32_t))); } @@ -1022,7 +1027,8 @@ typename EmbeddingSpMDMKernelSignature:: int64_t input_stride /*=-1*/, bool scale_bias_last /*=true*/, bool no_bag /*=false*/, - bool isbf16 /*=false*/) { + bool is_bf16_out /*=false*/, + bool is_bf16_in /*=false*/) { if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } @@ -1067,7 +1073,8 @@ typename EmbeddingSpMDMKernelSignature:: input_stride, scale_bias_last, no_bag, - isbf16); + is_bf16_out, + is_bf16_in); }; } @@ -1097,7 +1104,7 @@ typename EmbeddingSpMDMKernelSignature:: reinterpret_cast(out), is_weight_positional, use_offsets, - isbf16); + is_bf16_out); }; } else if (isZmm(isa)) { static GenEmbeddingSpMDMLookup< @@ -1119,7 +1126,8 @@ typename EmbeddingSpMDMKernelSignature:: output_stride, input_stride, scale_bias_last, - isbf16); + is_bf16_out, + is_bf16_in); return [=](int64_t output_size, int64_t index_size, int64_t data_size, @@ -1159,7 +1167,8 @@ typename EmbeddingSpMDMKernelSignature:: output_stride, input_stride, scale_bias_last, - isbf16); + is_bf16_out, + is_bf16_in); return [=](int64_t output_size, int64_t index_size, int64_t data_size, @@ -1209,7 +1218,8 @@ typename EmbeddingSpMDMKernelSignature:: input_stride, scale_bias_last, no_bag, - isbf16); + is_bf16_out, + is_bf16_in); }; #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 } @@ -1231,7 +1241,8 @@ typename EmbeddingSpMDMKernelSignature:: int prefetch, bool is_weight_positional, bool use_offsets, - bool isbf16) { + bool is_bf16_out, + bool is_bf16_in) { return GenerateEmbeddingSpMDMWithStrides< inType, indxType, @@ -1248,7 +1259,8 @@ typename EmbeddingSpMDMKernelSignature:: /*input_stride=*/-1, /*scale_bias_last=*/true, /*no_bag=*/false, - isbf16); + is_bf16_out, + is_bf16_in); } template @@ -1262,7 +1274,8 @@ typename EmbeddingSpMDMKernelSignature:: int64_t output_stride /*=-1*/, int64_t input_stride /*=-1*/, int exponent_bits, - int exponent_bias) { + int exponent_bias, + bool is_bf16_out) { if (output_stride == -1) { output_stride = block_size; } @@ -1294,7 +1307,8 @@ typename EmbeddingSpMDMKernelSignature:: output_stride, input_stride, exponent_bits, - exponent_bias); + exponent_bias, + is_bf16_out); }; } @@ -1338,7 +1352,8 @@ GenerateEmbeddingSpMDMRowWiseSparse( /*output_stride=*/block_size, input_stride, /*scale_bias_last=*/true, - /*isbf16=*/false); + /*is_bf16_out=*/false, + /*is_bf16_in=*/false); return [=](int64_t output_size, int64_t index_size, int64_t uncompressed_data_size, @@ -1379,7 +1394,8 @@ GenerateEmbeddingSpMDMRowWiseSparse( /*output_stride=*/block_size, input_stride, /*scale_bias_last=*/true, - /*isbf16=*/false); + /*is_bf16_out=*/false, + /*is_bf16_in=*/false); return [=](int64_t output_size, int64_t index_size, int64_t uncompressed_data_size, @@ -1457,7 +1473,8 @@ GenerateEmbeddingSpMDMRowWiseSparse( int64_t input_stride, \ bool scale_bias_last, \ bool no_bag, \ - bool isbf16); + bool is_bf16_out, \ + bool is_bf16_in); #define INSTANTIATE_SPMDMFP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \ @@ -1473,7 +1490,8 @@ GenerateEmbeddingSpMDMRowWiseSparse( int64_t output_stride, \ int64_t input_stride, \ int exponent_bits, \ - int exponent_bias); + int exponent_bias, \ + bool is_bf16_out); #define INSTANTIATE_SPMDM_NOSTRIDE_BASE( \ IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, THREAD_LOCAL) \ @@ -1494,7 +1512,8 @@ GenerateEmbeddingSpMDMRowWiseSparse( int prefetch, \ bool is_weight_positional, \ bool use_offsets, \ - bool isbf16); + bool is_bf16_out, \ + bool is_bf16_in); #define INSTANTIATE_SPMDM_ROWWISE_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \ template FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< \ diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 585939b114..45715e42cc 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -1208,7 +1208,8 @@ bool EmbeddingSpMDM_ref( int64_t input_stride /*=-1*/, bool scale_bias_last /*=true*/, bool no_bag /*=false*/, - bool is_bf16_out /*=false*/) { + bool is_bf16_out /*=false*/, + bool is_bf16_in /*=false*/) { bool is8bit = is_same::value; if (output_stride == -1) { output_stride = block_size; @@ -1265,8 +1266,7 @@ bool EmbeddingSpMDM_ref( buf[j] + bias); } for (int j = 0; j < block_size; ++j) { - out[j] = is_same::value ? cpu_float2half_rn(buf[j]) - : buf[j]; + out[j] = convert_from_float_ref(buf[j], is_bf16_out); } out += output_stride; } // m @@ -1322,8 +1322,7 @@ bool EmbeddingSpMDM_ref( } } for (int j = 0; j < block_size; ++j) { - out[j] = is_same::value ? cpu_float2half_rn(buf[j]) - : buf[j]; + out[j] = convert_from_float_ref(buf[j], is_bf16_out); } out += output_stride; } @@ -1349,7 +1348,7 @@ bool EmbeddingSpMDM_ref( for (int j = 0; j < block_size; ++j) { const InType* inptr = input + input_stride * idx + j; buf[j] = - std::fma(w, convert_to_float_ref(*inptr, is_bf16_out), buf[j]); + std::fma(w, convert_to_float_ref(*inptr, is_bf16_in), buf[j]); } for (int j = 0; j < block_size; ++j) { out[j] = convert_from_float_ref(buf[j], is_bf16_out); @@ -1382,7 +1381,7 @@ bool EmbeddingSpMDM_ref( for (int j = 0; j < block_size; ++j) { const InType* inptr = input + input_stride * idx + j; buf[j] = - std::fma(w, convert_to_float_ref(*inptr, is_bf16_out), buf[j]); + std::fma(w, convert_to_float_ref(*inptr, is_bf16_in), buf[j]); } ++current; @@ -1506,7 +1505,8 @@ bool EmbeddingSpMDMFP8_ref( int64_t output_stride, int64_t input_stride, int exponent_bits, - int exponent_bias) { + int exponent_bias, + bool is_bf16_out /*=false*/) { if (output_stride == -1) { output_stride = block_size; } @@ -1555,8 +1555,9 @@ bool EmbeddingSpMDMFP8_ref( } } for (int j = 0; j < block_size; ++j) { - out[j] = - is_same::value ? cpu_float2half_rn(buf[j]) : buf[j]; + out[j] = is_same::value + ? convert_from_float_ref(buf[j], is_bf16_out) + : buf[j]; } out += output_stride; } @@ -2052,7 +2053,8 @@ template FBGEMM_API void transposeConvWeights( int64_t output_stride, \ bool scale_bias_last, \ bool no_bag, \ - bool is_bf16_out); + bool is_bf16_out, \ + bool is_bf16_in); #define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \ INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \ @@ -2124,7 +2126,8 @@ INSTANTIATE_SPMDM_INDEX_T(std::uint8_t) int64_t output_stride, \ int64_t input_stride, \ int exponent_bits, \ - int exponent_bias); + int exponent_bias, \ + bool is_bf16_out); #define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ diff --git a/src/RefImplementations.h b/src/RefImplementations.h index ae63d6a414..f01aa57d5a 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -238,7 +238,8 @@ FBGEMM_API bool EmbeddingSpMDM_ref( std::int64_t input_stride = -1, bool scale_bias_last = true, bool no_bag = false, - bool is_bf16_out = false); + bool is_bf16_out = false, + bool is_bf16_in = false); template < typename IndexType = std::int64_t, @@ -283,7 +284,8 @@ bool EmbeddingSpMDMFP8_ref( int64_t output_stride = -1, int64_t input_stride = -1, int exponent_bits = 4, - int exponent_bias = 7); + int exponent_bias = 7, + bool is_bf16_out = false); template < typename InType = std::uint8_t, diff --git a/test/EmbeddingSpMDM8BitTest.cc b/test/EmbeddingSpMDM8BitTest.cc index db3dfea107..b658d480cc 100644 --- a/test/EmbeddingSpMDM8BitTest.cc +++ b/test/EmbeddingSpMDM8BitTest.cc @@ -50,9 +50,11 @@ vector prefetch_distances{0, 16, 1000000}; namespace { -class Fused8BitRowwiseEmbeddingLookupTest - : public testing::TestWithParam< - tuple> {}; +class Fused8BitRowwiseEmbeddingLookupTest : public testing::TestWithParam> {}; }; // namespace INSTANTIATE_TEST_CASE_P( @@ -68,7 +70,8 @@ INSTANTIATE_TEST_CASE_P( NONE, EMPTY_INDICES, OUT_OF_BOUND_INDICES, - UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM))); + UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM), + ::testing::Values(FLOAT, FLOAT16, BFLOAT16))); TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) { vector> inputs(GetInputs_()); @@ -80,19 +83,19 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) { bool isOffset64b = bool_dist(generator); bool normalize_by_lengths = bool_dist(generator); bool use_offsets = bool_dist(generator); - bool is_output_float = bool_dist(generator); bool scale_bias_last = bool_dist(generator); int prefetch; EmbeddingSpMDMWeightChoice weight_choice; EmbeddingSpMDMCornerCase corner_case; - tie(prefetch, weight_choice, corner_case) = GetParam(); + EmbeddingSpMDMDtypeChoice out_type; + tie(prefetch, weight_choice, corner_case, out_type) = GetParam(); bool is_wt_positional = weight_choice == POSITIONAL_WEIGHTED; bool use_weight = weight_choice != UNWEIGHTED; if (corner_case != NONE || weight_choice == POSITIONAL_WEIGHTED) { // Check corner case only for subset of tests. - if (normalize_by_lengths || !is_output_float || !scale_bias_last) { + if (normalize_by_lengths || out_type == FLOAT || !scale_bias_last) { return; } } @@ -161,12 +164,14 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) { int output_size_wo_sentries = batch_size * embedding_dim; vector output_ref(output_size_wo_sentries + num_sentries); vector output(output_ref.size()); - vector output_ref_fp16(output.size()), output_fp16(output.size()); + vector output_ref_16b(output.size()), output_16b(output.size()); for (size_t i = output_size_wo_sentries; i < output.size(); ++i) { output_ref[i] = sentry_value; output[i] = sentry_value; - output_ref_fp16[i] = cpu_float2half_rn(sentry_value); - output_fp16[i] = cpu_float2half_rn(sentry_value); + output_ref_16b[i] = + convert_from_float_ref(sentry_value, out_type == BFLOAT16); + output_16b[i] = + convert_from_float_ref(sentry_value, out_type == BFLOAT16); } bool success, success_ref; @@ -194,7 +199,9 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) { use_offsets, \ /*output_stride=*/-1, \ /*input_stride=*/-1, \ - scale_bias_last); \ + scale_bias_last, \ + /*is_bf16_out=*/out_type == BFLOAT16, \ + /*is_bf16_in=*/false); \ \ auto kernel = GenerateEmbeddingSpMDMWithStrides< \ uint8_t, \ @@ -209,7 +216,9 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) { use_offsets, \ /*output_stride=*/-1, \ /*input_stride=*/-1, \ - scale_bias_last); \ + scale_bias_last, \ + /*is_bf16_out=*/out_type == BFLOAT16, \ + /*is_bf16_in=*/false); \ success = kernel( \ batch_size, \ lengths_sum, \ @@ -221,7 +230,7 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) { output.data()); #define TEST_OUT_TYPE(indices, offsets_or_lengths, IndexType, OffsetType) \ - if (is_output_float) { \ + if (out_type == FLOAT) { \ TEST_BASE( \ indices, \ offsets_or_lengths, \ @@ -234,8 +243,8 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) { TEST_BASE( \ indices, \ offsets_or_lengths, \ - output_ref_fp16, \ - output_fp16, \ + output_ref_16b, \ + output_16b, \ IndexType, \ OffsetType, \ float16); \ @@ -267,10 +276,12 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) { } if (success) { for (size_t i = 0; i < output.size(); ++i) { - float actual = - is_output_float ? output[i] : cpu_half2float(output_fp16[i]); - float expected = is_output_float ? output_ref[i] - : cpu_half2float(output_ref_fp16[i]); + float actual = out_type == FLOAT + ? output[i] + : convert_to_float_ref(output_16b[i], out_type == BFLOAT16); + float expected = out_type == FLOAT + ? output_ref[i] + : convert_to_float_ref(output_ref_16b[i], out_type == BFLOAT16); EXPECT_EQ(actual, expected) << "results differ at (" << i << ") reference: " << expected << ", FBGEMM: " << actual << " emb dim :" << embedding_dim; @@ -278,11 +289,13 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) { for (int offset = output_size_wo_sentries; offset < output_size_wo_sentries + num_sentries; ++offset) { - float actual = is_output_float ? output[offset] - : cpu_half2float(output_fp16[offset]); - float expected = is_output_float + float actual = out_type == FLOAT + ? output[offset] + : convert_to_float_ref(output_16b[offset], out_type == BFLOAT16); + float expected = out_type == FLOAT ? output_ref[offset] - : cpu_half2float(output_ref_fp16[offset]); + : convert_to_float_ref( + output_ref_16b[offset], out_type == BFLOAT16); EXPECT_EQ(actual, expected) << "results differ at (" << offset << ") reference: " << expected << ", FBGEMM: " << actual << " emb dim :" << embedding_dim; @@ -301,17 +314,17 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) { bool isOffset64b = bool_dist(generator); bool normalize_by_lengths = bool_dist(generator); bool use_offsets = bool_dist(generator); - bool is_output_float = bool_dist(generator); bool scale_bias_last = bool_dist(generator); int prefetch; EmbeddingSpMDMWeightChoice weight_choice; EmbeddingSpMDMCornerCase corner_case; - tie(prefetch, weight_choice, corner_case) = GetParam(); + EmbeddingSpMDMDtypeChoice out_type; + tie(prefetch, weight_choice, corner_case, out_type) = GetParam(); bool is_wt_positional = weight_choice == POSITIONAL_WEIGHTED; bool use_weight = weight_choice != UNWEIGHTED; - if (!is_output_float || !scale_bias_last) { + if (out_type != FLOAT || !scale_bias_last) { return; } diff --git a/test/EmbeddingSpMDMTest.cc b/test/EmbeddingSpMDMTest.cc index 735f121981..1141e5e7be 100644 --- a/test/EmbeddingSpMDMTest.cc +++ b/test/EmbeddingSpMDMTest.cc @@ -133,10 +133,6 @@ TEST_P(EmbeddingSpMDMTest, basicTest) { bool is_output_float = out_type == FLOAT; bool is_output_bfloat16 = out_type == BFLOAT16; - if (isBf16 ^ is_output_bfloat16) { - // only support both in and out are bf16 now - return; - } if (corner_case != NONE || is_wt_positional) { // Check corner case only for subset of tests. if (isFp16 || normalize_by_lengths || use_output_input_stride || @@ -149,6 +145,22 @@ TEST_P(EmbeddingSpMDMTest, basicTest) { return; } + if (in_type == FLOAT) { + printf("in_type == FLOAT, "); + } else if (in_type == FLOAT16) { + printf("in_type == FLOAT16, "); + } else if (in_type == BFLOAT16) { + printf("in_type == BFLOAT16, "); + } + + if (out_type == FLOAT) { + printf("out_type == FLOAT.\n"); + } else if (out_type == FLOAT16) { + printf("out_type == FLOAT16.\n"); + } else if (out_type == BFLOAT16) { + printf("out_type == BFLOAT16.\n"); + } + for (auto input : inputs) { int batch_size = input[0]; int num_rows = input[1]; @@ -255,6 +267,7 @@ TEST_P(EmbeddingSpMDMTest, basicTest) { input_stride, \ true, \ false, \ + is_output_bfloat16, \ isBf16); \ \ auto kernel = GenerateEmbeddingSpMDMWithStrides< \ @@ -273,6 +286,7 @@ TEST_P(EmbeddingSpMDMTest, basicTest) { input_stride, \ true, \ false, \ + is_output_bfloat16, \ isBf16); \ success = kernel( \ batch_size, \