diff --git a/bestla/bestla/bestla.h b/bestla/bestla/bestla.h index 7942eb6b7..06b28ec11 100644 --- a/bestla/bestla/bestla.h +++ b/bestla/bestla/bestla.h @@ -67,7 +67,6 @@ enum class BTLA_DTYPE : uint32_t { U8 = EleBits8 | TypeInt | SubType1, S3_CLIP = EleBits3 | TypeInt, S4_CLIP = EleBits4 | TypeInt, - S4_FULLRANGE = EleBits4 | TypeInt | SubType1, F4_E2M1 = EleBits4 | TypeFloat, F4_BNB = EleBits4 | TypeFloat | SubType1, F4_NF4 = EleBits4 | TypeFloat | SubType2, diff --git a/bestla/bestla/bestla_device.h b/bestla/bestla/bestla_device.h index ca161fd6d..5d09cbe1a 100644 --- a/bestla/bestla/bestla_device.h +++ b/bestla/bestla/bestla_device.h @@ -328,8 +328,8 @@ class CpuDevice { mHybrid = false; } } - numcores = P_core.size() + E_core.size(); - numthreads = P_core.size() + E_core.size() + SMT_core.size(); + numcores = static_cast(P_core.size() + E_core.size()); + numthreads = static_cast(P_core.size() + E_core.size() + SMT_core.size()); { // set PE @@ -515,7 +515,7 @@ class CpuRuntime { } else { mL1Cache_P = mL1Cache; mL2Cache_P = mL2Cache; - P_core_num = _cd->getPcoreNum(); + P_core_num = static_cast(_cd->getPcoreNum()); E_core_num = thread - P_core_num; } mL1Cache_E = _cd->getL1CacheSize_E(); diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index 7a996ae05..d20524446 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -137,23 +137,25 @@ class StdThreading : public IThreading { } inline void sync(int tidx, int idx = 0) override { - flag[idx].fetch_sub(1); - if (cr->mHybrid) { - Timer_T tm; - tm.start(); - while (true) { - if (flag[idx].load() == 0) - break; - else - _mm_pause(); - } - thread_time[tidx] -= int(tm.stop()); - } else { - while (true) { - if (flag[idx].load() == 0) - break; - else - _mm_pause(); + if (mThreadNum > 1) { + flag[idx].fetch_sub(1); + if (cr->mHybrid) { + Timer_T tm; + tm.start(); + while (true) { + if (flag[idx].load() == 0) + break; + else + _mm_pause(); + } + thread_time[tidx] -= int(tm.stop()); + } else { + while (true) { + if (flag[idx].load() == 0) + break; + else + _mm_pause(); + } } } } diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 99f3ccc90..94d80ac41 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -625,7 +625,7 @@ class WeightKBlockNInteger { auto wptr = _param.packedW; if (wptr->mDType == BTLA_DTYPE::S8) { return getQ8Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S4_CLIP || wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { + } else if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { return getQ4Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { return getQ3Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); @@ -674,7 +674,7 @@ class WeightKBlockNInteger { kernel::wrapper::Dq8GetScale::template forward( aptr + internal_k_offset * wptr->CStep() + n_offset, *dstptr, utils::updiv(k_size, wptr->mBlockSize), n_size, internal_k_offset * wptr->mN + n_offset, wptr->mDqBlockSize, dq_offset_idx, wptr->DQPtr(), - wptr->CStep(), n_size, false); + wptr->CStep(), n_size, false, wptr->mN); } return BTLA_CODE::Success; } @@ -713,11 +713,6 @@ class WeightKBlockNInteger { wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); } else if (wptr->mDType == BTLA_DTYPE::S8) { kernel::wrapper::DecompressKBlockS8S8Fp::template forward( wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, @@ -761,14 +756,6 @@ class WeightKBlockNInteger { *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { - kernel::wrapper::DecompressKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, - zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); } else if (wptr->mDType == BTLA_DTYPE::S8) { kernel::wrapper::DecompressKBlockS8Fp<_T, _GemmCore_T::PACK_ROW>::template forward( wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, @@ -802,14 +789,6 @@ class WeightKBlockNInteger { *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { - kernel::wrapper::DecompressKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, - zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); } else if (wptr->mDType == BTLA_DTYPE::S8) { kernel::wrapper::DecompressKBlockS8Fp<_T, _GemmCore_T::PACK_ROW>::template forward( wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, @@ -873,14 +852,10 @@ class WeightKBlockNInteger { auto KPad = wptr->mKPad; auto bptr = wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2; int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; + assert(wptr->mDType == BTLA_DTYPE::S4_CLIP); for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressKBlockS4S8::template forward( - bptr + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize); - } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { - kernel::wrapper::DecompressKBlockS4S8::template forward( - bptr + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize); - } + kernel::wrapper::DecompressKBlockS4S8::template forward( + bptr + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize); } *dststep = k_size; return BTLA_CODE::Success; @@ -916,9 +891,6 @@ class WeightKBlockNInteger { if (quant_dtype == BTLA_DTYPE::S8) { kernel::wrapper::QuantizeSignIntRowBlock::forward(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); - } else if (quant_dtype == BTLA_DTYPE::S4_FULLRANGE) { - kernel::wrapper::QuantizeSignIntRowBlock::forward( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); } else if (quant_dtype == BTLA_DTYPE::S4_CLIP) { kernel::wrapper::QuantizeSignIntRowBlock::forward( srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); @@ -932,7 +904,7 @@ class WeightKBlockNInteger { static inline BTLA_CODE doCompress(const int8_t* srcptr, void* dstptr, int row, int col, int ld_src, int ld_dst, BTLA_DTYPE quant_dtype) { - if (quant_dtype == BTLA_DTYPE::S4_CLIP || quant_dtype == BTLA_DTYPE::S4_FULLRANGE) { + if (quant_dtype == BTLA_DTYPE::S4_CLIP) { return kernel::wrapper::CompressS8S4::forward(srcptr, reinterpret_cast(dstptr), row, col, ld_src, ld_dst); } else if (quant_dtype == BTLA_DTYPE::F4_BNB || quant_dtype == BTLA_DTYPE::F4_NF4 || @@ -1051,7 +1023,7 @@ class WeightKBlockNFloat : public WeightKBlockNInteger<_GemmCore_T, ISA_T> { auto internal_n_offset = n_offset + i; auto internal_k_offset = k_offset / _GemmCore_T::PACK_ROW; auto internal_kblock = wptr->mBlockSize / _GemmCore_T::PACK_ROW; - auto dq_offset_idx = wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1; + auto dq_offset_idx = static_cast(wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1); if (wptr->mDType == BTLA_DTYPE::F4_NF4) { kernel::wrapper::DecompressDqKBlockF4Fp<_DST_T, _GemmCore_T::PACK_ROW>::template forward( diff --git a/bestla/bestla/bestla_utils.h b/bestla/bestla/bestla_utils.h index 2bfc58dc6..0878e1a9f 100644 --- a/bestla/bestla/bestla_utils.h +++ b/bestla/bestla/bestla_utils.h @@ -90,9 +90,7 @@ #define CompileAMXINT8() (CompileAMX()) #endif -#if CompileBF16() || CompileFP16() #include -#endif namespace bestla { namespace utils { @@ -157,6 +155,18 @@ struct f8 { x = v; return *this; } + + inline float tofloat() const { + int32_t r = x + 127; + uint32_t tmp = bit_cast(r & 0xff); + tmp <<= 23; + return bit_cast(tmp); + } + + inline float mul(float src) const { + auto scale = tofloat(); + return src * scale; + } }; struct fp16 { @@ -326,8 +336,6 @@ inline const char* bestla_dtype_str(BTLA_DTYPE dtype) { return "unsigned_int8"; case BTLA_DTYPE::S4_CLIP: return "int4_clip"; - case BTLA_DTYPE::S4_FULLRANGE: - return "int4_fullrange"; case BTLA_DTYPE::F4_E2M1: return "fp4_e2m1"; case BTLA_DTYPE::F4_BNB: @@ -697,13 +705,13 @@ inline bool isFastExp() { } } // namespace utils -static float fp4_bnb_dequant_fp32_LUT[] = { +static float fp4_bnb_dequant_fp32_LUT alignas(64)[] = { 0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f, 0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f, -1.f * 0.00000000f, -1.f * 5.208333333e-03f, -1.f * 0.66666667f, -1.f * 1.00000000f, -1.f * 0.33333333f, -1.f * 0.50000000f, -1.f * 0.16666667f, -1.f * 0.25000000f}; -static float fp4_e2m1_dequant_fp32_LUT[] = { +static float fp4_e2m1_dequant_fp32_LUT alignas(64)[] = { 0.f, 0.010416666666666666f, 0.16666666666666666f, @@ -722,27 +730,27 @@ static float fp4_e2m1_dequant_fp32_LUT[] = { -1.f * 1.f, }; -static float nf4_dequant_fp32_LUT[] = {0.f, - -0.6961928009986877f, - -0.5250730514526367f, - -0.39491748809814453f, - -0.28444138169288635f, - -0.18477343022823334f, - -0.09105003625154495f, - -1.f, - 0.07958029955625534f, - 0.16093020141124725f, - 0.24611230194568634f, - 0.33791524171829224f, - 0.44070982933044434f, - 0.5626170039176941f, - 0.7229568362236023f, - 1.0f}; +static float nf4_dequant_fp32_LUT alignas(64)[] = {0.f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + -1.f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f}; // 8bit dynamic-tree-quantization map from bitsandbytes double-quant implementation. // For more details pls refer // (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] -static float dq8_bnb_LUT[] = { +static float dq8_bnb_LUT alignas(64)[] = { -0.99297f, -0.97891f, -0.96484f, -0.95078f, -0.93672f, -0.92266f, -0.90859f, -0.89453f, -0.88047f, -0.86641f, -0.85234f, -0.83828f, -0.82422f, -0.81016f, -0.79609f, -0.78203f, -0.76797f, -0.75391f, -0.73984f, -0.72578f, -0.71172f, -0.69766f, -0.68359f, -0.66953f, -0.65547f, -0.64141f, -0.62734f, -0.61328f, -0.59922f, -0.58516f, diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index de577ae64..db895ea74 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -28,23 +28,42 @@ namespace avx2 { #else #endif -static uint8_t shuffle_map[] = {0x00, 0x01, 0x02, 0x03, 0xff, 0xff, 0xff, 0xff, - 0x04, 0x05, 0x06, 0x07, 0xff, 0xff, 0xff, 0xff}; +template +static inline __m256i unpack_4bits_avx2(void* srcptr, __m256i mask) { + auto raw_data = _mm_loadu_si128(reinterpret_cast<__m128i*>(srcptr)); + auto ymm0 = _mm256_cvtepu8_epi16(raw_data); + auto ymm1 = _mm256_slli_epi16(ymm0, 8); + ymm0 = _mm256_slli_epi16(ymm0, 4); + ymm0 = _mm256_or_si256(ymm0, ymm1); + ymm0 = _mm256_and_si256(ymm0, mask); + if constexpr (LowBits) { + ymm0 = _mm256_srli_epi16(ymm0, 4); + } + return ymm0; +} -template -static inline __m128i unpack_4bits_sse(void* srcptr) { - auto shuffle_v = _mm_loadu_si128(reinterpret_cast<__m128i*>(shuffle_map)); - auto raw_data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr)); - auto xmm0 = _mm_shuffle_epi8(raw_data, shuffle_v); - auto xmm1 = _mm_srli_epi32(xmm0, 0x04); - auto and_helper = _mm_set1_epi8(0x0f); - xmm0 = _mm_and_si128(xmm0, and_helper); - xmm1 = _mm_and_si128(xmm1, and_helper); - auto xmm2 = _mm_unpacklo_epi8(xmm0, xmm1); - auto xmm3 = _mm_unpackhi_epi8(xmm0, xmm1); - xmm2 = _mm_unpacklo_epi64(xmm2, xmm3); - if constexpr (S4_T != BTLA_DTYPE::S4_FULLRANGE) xmm2 = _mm_slli_epi32(xmm2, 4); - return xmm2; +template +static inline void convert_s4_s8_N_avx2(int8_t* dstptr, int8_t* srcptr, __m256i mask) { + static_assert(N % 2 == 0); + static_assert(N <= 64); + if constexpr (N == 32) { + auto dst0 = unpack_4bits_avx2(srcptr, mask); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); + } else if constexpr (N > 32) { + auto dst0 = unpack_4bits_avx2(srcptr, mask); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr), dst0); + int8_t temp[32]; + memcpy(temp, srcptr + 16, (N - 32) / 2); + dst0 = unpack_4bits_avx2(temp, mask); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); + memcpy(dstptr + 32, temp, (N - 32)); + } else { + int8_t temp[32]; + memcpy(temp, srcptr, N / 2); + auto dst0 = unpack_4bits_avx2(temp, mask); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(temp), dst0); + memcpy(dstptr, temp, N); + } } inline __m256 ymm_cvt_bf16_fp32(__m128i vbf16) { @@ -70,16 +89,6 @@ inline __m128i ymm_cvt_fp32_bf16(__m256 vfp32) { return ymm_cvtepi32_epi16(_mm256_bsrli_epi128(_mm256_castps_si256(vfp32), 2)); } -template -static inline void convert_s4_s8_16_sse(int8_t* dstptr, int8_t* srcptr) { - auto dst0 = unpack_4bits_sse(srcptr); - if constexpr (S4_T == BTLA_DTYPE::S4_FULLRANGE) { - auto s8 = _mm_set1_epi8(8); - dst0 = _mm_sub_epi8(dst0, s8); - } - _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr), dst0); -} - template static inline void convert_s8_fp_v8(T* dstptr, int8_t* srcptr) { auto xmm = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr)); @@ -93,11 +102,6 @@ static inline void convert_s8_fp_v8(T* dstptr, int8_t* srcptr) { } } -static inline void fp4_pad_4bit(int8_t* dstptr, int8_t* srcptr) { - auto dst0 = unpack_4bits_sse(srcptr); - _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr), dst0); -} - template static inline void dequant_s8_N_avx2(float* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps = nullptr) { static_assert(N % 8 == 0); @@ -113,8 +117,8 @@ static inline void dequant_s8_N_avx2(float* dstptr, int8_t* srcptr, __m256* vsca } inline BTLA_CODE dq8_get_fp_scale(uint8_t* src, float* dst, int row, int col, int scale_offset, int dq_blk, - int dq_offset_idx, float* dq_scale, int src_stride, int dst_stride, - bool zeropadding) { + int dq_offset_idx, float* dq_scale, int src_stride, int dst_stride, bool zeropadding, + int mN) { auto head_proc_num = utils::updiv(scale_offset, 8) * 8 - scale_offset; auto ymm_dq_offset = _mm256_set1_ps(dq_scale[dq_offset_idx]); @@ -136,10 +140,10 @@ inline BTLA_CODE dq8_get_fp_scale(uint8_t* src, float* dst, int row, int col, in for (int i = 0; i < row; i++) { if (head_proc_num > col) { - get_fp_scale_ref(col, scale_offset, src + i * src_stride, dst + i * dst_stride); + get_fp_scale_ref(col, scale_offset + i * mN, src + i * src_stride, dst + i * dst_stride); } else { - get_fp_scale_ref(head_proc_num, scale_offset, src + i * src_stride, dst + i * dst_stride); - auto scale_offset_iter = scale_offset + head_proc_num; + get_fp_scale_ref(head_proc_num, scale_offset + i * mN, src + i * src_stride, dst + i * dst_stride); + auto scale_offset_iter = scale_offset + i * mN + head_proc_num; uint8_t* src_iter_ptr = src + head_proc_num; float* dst_iter_ptr = dst + head_proc_num; auto body_loop = (col - head_proc_num) / 8; @@ -367,10 +371,10 @@ static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); if (col == ld_src) { size_t elesize = static_cast(row) * col; - size_t ele16 = utils::padto_le(elesize, 16); + size_t velt = utils::padto_le(elesize, 32); size_t i = 0; - for (; i < ele16; i += 16) { - convert_s4_s8_16_sse(dstptr + i, reinterpret_cast(srcptr + i / 2)); + for (; i < velt; i += 32) { + convert_s4_s8_N_avx2<32, S4_T>(dstptr + i, reinterpret_cast(srcptr + i / 2), vmask); } for (; i < elesize; i += 2) { auto tmp = srcptr[i / 2]; @@ -389,13 +393,16 @@ inline BTLA_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); if (col == ld_src) { size_t elesize = static_cast(row) * col; - size_t ele16 = utils::padto_le(elesize, 16); + + size_t velt = utils::padto_le(elesize, 32); size_t i = 0; - assert(tmpsize >= 16); - for (; i < ele16; i += 16) { - convert_s4_s8_16_sse(tmp, reinterpret_cast(srcptr + i / 2)); + assert(tmpsize >= 32); + for (; i < velt; i += 32) { + convert_s4_s8_N_avx2<32, S4_T>(tmp, reinterpret_cast(srcptr + i / 2), vmask); convert_s8_fp_v8(dstptr + i, tmp); convert_s8_fp_v8(dstptr + i + 8, tmp + 8); + convert_s8_fp_v8(dstptr + i + 16, tmp + 16); + convert_s8_fp_v8(dstptr + i + 24, tmp + 24); } for (; i < elesize; i += 2) { auto tmp = srcptr[i / 2]; @@ -460,7 +467,7 @@ inline BTLA_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int auto fp_v = ref::f8_to_fp32(srcptr[i * ld_src + j], src_f8_type); if constexpr (WITH_SCALE) { if constexpr (std::is_same_v<_S_T, utils::f8>) { - dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x); + dstptr[i * ld_dst + j] = sptr[j / _PACK_ROW].mul(fp_v); } else if constexpr (std::is_same_v<_S_T, float>) { dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW]; } @@ -524,32 +531,30 @@ static inline BTLA_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* sr if constexpr (!std::is_same_v) { dstptr[i * dststep + j] += alpha[j] * srcptr[i * srcstep + j]; } else { - dstptr[i * dststep + j] += std::pow(2, alpha[j].x) * srcptr[i * srcstep + j]; + dstptr[i * dststep + j] += alpha[j].mul(srcptr[i * srcstep + j]); } } } return BTLA_CODE::Success; } -template -static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps) { +template +static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales, __m256 vLutL, __m256 vLutH) { static_assert(N % 8 == 0); - float* LUT; - static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, - "Unsupported F4 type"); - if constexpr (F4_T == BTLA_DTYPE::F4_BNB) { - LUT = fp4_bnb_dequant_fp32_LUT; - } else if constexpr (F4_T == BTLA_DTYPE::F4_NF4) { - LUT = nf4_dequant_fp32_LUT; - } else if constexpr (F4_T == BTLA_DTYPE::F4_E2M1) { - LUT = fp4_e2m1_dequant_fp32_LUT; - } int constexpr VLoop = N / 8; + auto v7 = _mm256_set1_epi32(7); + auto v8 = _mm256_set1_epi32(8); for (int iv = 0; iv < VLoop; iv++) { auto idx = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr + iv * 8)); auto pad_idx = _mm256_cvtepu8_epi32(idx); - auto fp32_dq_v = _mm256_i32gather_ps(LUT, pad_idx, 4); - fp32_dq_v = _mm256_mul_ps(fp32_dq_v, vscales[iv]); + auto mskgt8 = _mm256_cmpgt_epi32(pad_idx, v7); + auto fp32_dq_v0 = _mm256_permutevar8x32_ps(vLutL, pad_idx); + pad_idx = _mm256_sub_epi32(pad_idx, v8); + auto fp32_dq_v1 = _mm256_permutevar8x32_ps(vLutH, pad_idx); + auto fp32_dq_v = _mm256_blendv_ps(fp32_dq_v0, fp32_dq_v1, _mm256_castsi256_ps(mskgt8)); + if constexpr (MULS_T) { + fp32_dq_v = _mm256_mul_ps(fp32_dq_v, vscales[iv]); + } if constexpr (std::is_same_v<_DST_T, float>) { _mm256_storeu_ps(dstptr + iv * 8, fp32_dq_v); } else if constexpr (std::is_same_v<_DST_T, utils::bf16>) { @@ -559,9 +564,11 @@ static inline void dequant_f4_N(_DST_T* dstptr, int8_t* srcptr, __m256* vscales, } } -template -static inline void unpack_f4_N(_DST_T* dstptr, int8_t* srcptr) { - static_assert(N % 8 == 0); +template +inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, int8_t* tmp, size_t tmpsize) { + uint32_t mask = 0xf0f0f0f0; + auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); float* LUT; static_assert(F4_T == BTLA_DTYPE::F4_BNB || F4_T == BTLA_DTYPE::F4_NF4 || F4_T == BTLA_DTYPE::F4_E2M1, "Unsupported F4 type"); @@ -572,33 +579,16 @@ static inline void unpack_f4_N(_DST_T* dstptr, int8_t* srcptr) { } else if constexpr (F4_T == BTLA_DTYPE::F4_E2M1) { LUT = fp4_e2m1_dequant_fp32_LUT; } - int constexpr VLoop = N / 8; - for (int iv = 0; iv < VLoop; iv++) { - auto idx = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr + iv * 8)); - auto pad_idx = _mm256_cvtepu8_epi32(idx); - auto fp32_dq_v = _mm256_i32gather_ps(LUT, pad_idx, 4); - if constexpr (std::is_same_v<_DST_T, float>) { - _mm256_storeu_ps(dstptr + iv * 8, fp32_dq_v); - } else if constexpr (std::is_same_v<_DST_T, utils::bf16>) { - auto bf16v = ymm_cvt_fp32_bf16(fp32_dq_v); - _mm_storeu_si128(reinterpret_cast<__m128i*>(dstptr + iv * 8), bf16v); - } - } -} - -template -inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dstptr, int row, int col, int ld_src, - int ld_dst, int8_t* tmp, size_t tmpsize) { - uint32_t mask = 0xf0f0f0f0; - auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + auto vLutL = _mm256_loadu_ps(LUT); + auto vLutH = _mm256_loadu_ps(LUT + 8); if (col == ld_src) { size_t elesize = static_cast(row) * col; - size_t ele16 = utils::padto_le(elesize, 16); + size_t velt = utils::padto_le(elesize, 32); size_t i = 0; - assert(tmpsize >= 16); - for (; i < ele16; i += 16) { - fp4_pad_4bit(tmp, reinterpret_cast(srcptr + i / 2)); - unpack_f4_N<16, DST_T, F4_T>(dstptr + i, tmp); + assert(tmpsize >= 32); + for (; i < velt; i += 32) { + convert_s4_s8_N_avx2<32, F4_T>(tmp, reinterpret_cast(srcptr + i / 2), vmask); + dequant_f4_N<32, DST_T, F4_T, false>(dstptr + i, tmp, nullptr, vLutL, vLutH); } for (; i < elesize; i += 2) { auto tmp = srcptr[i / 2]; @@ -610,13 +600,26 @@ inline BTLA_CODE decompress_kblock_f4_fp_noscale(utils::f4x2* srcptr, DST_T* dst return BTLA_CODE::Success; } -template -static inline BTLA_CODE decompress_kblock_bit4_packrow1( - utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, void (*dequantize)(_DST_T*, int8_t*, __m256*, __m256i*), - void (*pad_bit4_16)(int8_t*, int8_t*), void (*pad_bit4_8)(int8_t*, int8_t*), int8_t* tmpbuf, size_t tmpsize) { +template +static inline BTLA_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, + int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, int8_t* tmpbuf, + size_t tmpsize) { uint32_t mask = 0xf0f0f0f0; auto vmask = _mm256_set1_epi32(*reinterpret_cast(&mask)); + float* LUT = nullptr; + if constexpr (QT_T == BTLA_DTYPE::F4_BNB) { + LUT = fp4_bnb_dequant_fp32_LUT; + } else if constexpr (QT_T == BTLA_DTYPE::F4_NF4) { + LUT = nf4_dequant_fp32_LUT; + } else if constexpr (QT_T == BTLA_DTYPE::F4_E2M1) { + LUT = fp4_e2m1_dequant_fp32_LUT; + } + __m256 vLutL, vLutH; + if (LUT) { + vLutL = _mm256_loadu_ps(LUT); + vLutH = _mm256_loadu_ps(LUT + 8); + } int constexpr NReg = _NCOL / 8; assert(col == _NCOL); assert(ld_src == _NCOL); @@ -625,13 +628,21 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1( __m256i vzps[NReg]; int constexpr UnrollRow = 4; assert(kblock % UnrollRow == 0); - int constexpr Loop16 = _NCOL * UnrollRow / 16; + int constexpr NTile = 32; + int constexpr Loop32 = _NCOL * UnrollRow / NTile; assert(tmpsize >= (_NCOL * UnrollRow)); int row0 = kblock - k_offset % kblock; row0 = row0 == kblock ? 0 : row0; row0 = row0 > row ? row : row0; int row1 = row - row0; int irow = 0; + auto dequantize = [&](_DST_T* dstptr, int8_t* srcptr, __m256* vscales, __m256i* vzps = nullptr) { + if constexpr (QT_T == BTLA_DTYPE::S4_CLIP) { + dequant_s8_N_avx2<_NCOL, _IS_SYM>(dstptr, srcptr, vscales, vzps); + } else { + dequant_f4_N<_NCOL, _DST_T, QT_T, true>(dstptr, srcptr, vscales, vLutL, vLutH); + } + }; if (row0) { int rowpad4 = utils::padto_le(row0, UnrollRow); for (int iv = 0; iv < NReg; iv++) { @@ -643,19 +654,15 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1( } } for (; irow < rowpad4; irow += UnrollRow) { - for (int iter16 = 0; iter16 < Loop16; iter16++) - pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8 * iter16)); + for (int iter16 = 0; iter16 < Loop32; iter16++) + convert_s4_s8_N_avx2( + tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + irow * ld_src / 2 + NTile / 2 * iter16), vmask); for (int iterr = 0; iterr < UnrollRow; iterr++) dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps); } for (; irow < row0; irow++) { - if constexpr (_NCOL == 24) { - pad_bit4_16(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2)); - pad_bit4_8(tmpbuf + 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8)); - } else { - for (int iter16 = 0; iter16 < 3; iter16++) - pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8 * iter16)); - } + convert_s4_s8_N_avx2<_NCOL, QT_T>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2), vmask); + dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps); } } @@ -671,8 +678,10 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1( } } for (int irr = 0; irr < kblock; irr += UnrollRow) { - for (int iter16 = 0; iter16 < Loop16; iter16++) - pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + 8 * iter16)); + for (int iter16 = 0; iter16 < Loop32; iter16++) + convert_s4_s8_N_avx2( + tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + NTile / 2 * iter16), + vmask); for (int iterr = 0; iterr < UnrollRow; iterr++) dequantize(dstptr + (irow + irr + iterr) * ld_src, tmpbuf + iterr * _NCOL, vscales, vzps); } @@ -689,55 +698,82 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow1( auto rowre = row - irow; int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow; for (; irow < rowpad4; irow += UnrollRow) { - for (int iter16 = 0; iter16 < Loop16; iter16++) - pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8 * iter16)); + for (int iter16 = 0; iter16 < Loop32; iter16++) + convert_s4_s8_N_avx2( + tmpbuf + iter16 * NTile, reinterpret_cast(srcptr + irow * ld_src / 2 + NTile / 2 * iter16), vmask); for (int iterr = 0; iterr < UnrollRow; iterr++) dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps); } for (; irow < row; irow++) { - if constexpr (_NCOL == 24) { - pad_bit4_16(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2)); - pad_bit4_8(tmpbuf + 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8)); - } else { - for (int iter16 = 0; iter16 < 3; iter16++) - pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8 * iter16)); - } + convert_s4_s8_N_avx2<_NCOL, QT_T>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2), vmask); dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps); } } return BTLA_CODE::Success; } -template +template static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad, - void (*dequantize)(_DST_T*, int8_t*, __m256*, __m256i*), - void (*pad_bit4)(int8_t*, int8_t*), int8_t* tmp, + int k_offset, int kblock, int NPad, int8_t* tmp, size_t tmpsize) { return BTLA_CODE::NotSupport; } +template +static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, + int ld_dst, _ST* scales, int8_t* zero_points, int k_offset, int kblock, + int NPad, int8_t* tmp, size_t tmpsize) { + auto ret = BTLA_CODE::NotSupport; + if constexpr (_PACK_ROW == 1 && std::is_same_v<_DST_T, float> && std::is_same_v<_ST, float>) { + if (zero_points == nullptr) { + if (col == 24) { + ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, k_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } else if (col == 48) { + ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, k_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } else { + assert(0); + } + + } else { + if (col == 24) { + ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, k_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } else if (col == 48) { + ret = decompress_kblock_bit4_packrow1(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, k_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); + } else { + assert(0); + } + } + } + return ret; +} + template static inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int k_offset, int kblock, int NPad, int8_t* tmp, size_t tmpsize) { if constexpr (_PACK_ROW == 1) { if (col == 24) { - return decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, - &dequant_f4_N<24, _DST_T, _F4_T>, fp4_pad_4bit, &ref::convert_s4_s8_8<_F4_T>, tmp, tmpsize); + return decompress_kblock_bit4_packrow1<_F4_T, true, 24, _ST, _DST_T>( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, tmp, tmpsize); } if (col == 48) { - return decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, - &dequant_f4_N<48, _DST_T, _F4_T>, fp4_pad_4bit, &ref::convert_s4_s8_8<_F4_T>, tmp, tmpsize); + return decompress_kblock_bit4_packrow1<_F4_T, true, 48, _ST, _DST_T>( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad, tmp, tmpsize); } } else if constexpr (_PACK_ROW == 2) { - return decompress_kblock_bit4_packrow2(srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, - k_offset, kblock, NPad, &dequant_f4_N<64, _DST_T, _F4_T>, - fp4_pad_4bit, tmp, tmpsize); + return decompress_kblock_bit4_packrow2<_F4_T, true, _ST, _DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + nullptr, k_offset, kblock, NPad, tmp, tmpsize); } + assert(0); return BTLA_CODE::NotSupport; } diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index 9a3a9f738..c3c8a7dfa 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -64,28 +64,16 @@ static inline __m512i unpack_4bits(__m256i v4bits, __m512i vmask) { return zmm1; } -template -static inline void convert_s4_s8(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int LoadMask) { +static inline void convert_s4_s8_highbits(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int LoadMask) { auto ymm = _mm256_maskz_loadu_epi32(__mmask8(LoadMask), reinterpret_cast(srcptr)); auto zmm = unpack_4bits(ymm, vmask); - if constexpr (S4_T == BTLA_DTYPE::S4_FULLRANGE) { - zmm = _mm512_srli_epi32(zmm, 4); - auto s8 = _mm512_set1_epi8(8); - zmm = _mm512_sub_epi8(zmm, s8); - } _mm512_mask_storeu_epi64(dstptr, __mmask8(LoadMask), zmm); } -template -static inline void convert_s4_s8_v32(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int LoadMask) { +static inline void convert_s4_s8_highbits_v32(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int LoadMask) { auto xmm = _mm_maskz_loadu_epi32(__mmask8(LoadMask), reinterpret_cast(srcptr)); auto ymm = _mm256_castsi128_si256(xmm); auto zmm = unpack_4bits(ymm, vmask); - if constexpr (S4_T == BTLA_DTYPE::S4_FULLRANGE) { - zmm = _mm512_srli_epi32(zmm, 4); - auto s8 = _mm512_set1_epi8(8); - zmm = _mm512_sub_epi8(zmm, s8); - } auto ymm_out = _mm512_castsi512_si256(zmm); _mm256_mask_storeu_epi64(dstptr, __mmask8(LoadMask), ymm_out); } @@ -103,7 +91,7 @@ static inline void convert_s8_fp_v16(T* dstptr, int8_t* srcptr) { } } -constexpr void (*pad_fp4)(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int) = &convert_s4_s8; +constexpr void (*pad_fp4)(int8_t* dstptr, int8_t* srcptr, __m512i vmask, int) = &convert_s4_s8_highbits; template static inline void dequant_s8_N(_DST_T* dstptr, int8_t* srcptr, __m512* vscales, __m512i* vzps = nullptr) { @@ -390,13 +378,11 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _ } for (; irow < row0; irow++) { + convert_s4_s8_highbits(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, + LoadMask64); if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { - convert_s4_s8(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), - zmm_mask, LoadMask64); dequant_f4_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); } else { - convert_s4_s8<_SRCT>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, - LoadMask64); dequant_s8_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); } } @@ -416,13 +402,11 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _ } for (int irr = 0; irr < kblock; irr += 1) { + convert_s4_s8_highbits(tmpbuf, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + icol / 2), + zmm_mask, LoadMask64); if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { - convert_s4_s8( - tmpbuf, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + icol / 2), zmm_mask, LoadMask64); dequant_f4_N(dstptr + (irow + irr) * ld_dst + icol, tmpbuf, vscales, vzps); } else { - convert_s4_s8<_SRCT>(tmpbuf, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + icol / 2), - zmm_mask, LoadMask64); dequant_s8_N(dstptr + (irow + irr) * ld_dst + icol, tmpbuf, vscales, vzps); } } @@ -440,13 +424,11 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _ } } for (; irow < row; irow++) { + convert_s4_s8_highbits(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, + LoadMask64); if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { - convert_s4_s8(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), - zmm_mask, LoadMask64); dequant_f4_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); } else { - convert_s4_s8<_SRCT>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, - LoadMask64); dequant_s8_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); } } @@ -479,18 +461,13 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _ } for (; irow < row0; irow++) { + convert_s4_s8_highbits(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, + LoadMask64); + convert_s4_s8_highbits_v32(tmpbuf + 64, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2 + 32), + zmm_mask, LoadMask64); if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { - convert_s4_s8(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), - zmm_mask, LoadMask64); - convert_s4_s8_v32( - tmpbuf + 64, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2 + 32), zmm_mask, - LoadMask64); dequant_f4_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); } else { - convert_s4_s8<_SRCT>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, - LoadMask64); - convert_s4_s8_v32<_SRCT>(tmpbuf + 64, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2 + 32), - zmm_mask, LoadMask64); dequant_s8_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); } } @@ -510,19 +487,14 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _ } for (int irr = 0; irr < kblock; irr += 1) { - if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { - convert_s4_s8( - tmpbuf, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + icol / 2), zmm_mask, LoadMask64); - convert_s4_s8_v32( - tmpbuf + 64, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + icol / 2 + 32), zmm_mask, - LoadMask64); - dequant_f4_N(dstptr + (irow + irr) * ld_dst + icol, tmpbuf, vscales, vzps); - } else { - convert_s4_s8<_SRCT>(tmpbuf, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + icol / 2), + convert_s4_s8_highbits(tmpbuf, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + icol / 2), zmm_mask, LoadMask64); - convert_s4_s8_v32<_SRCT>(tmpbuf + 64, + convert_s4_s8_highbits_v32(tmpbuf + 64, reinterpret_cast(srcptr + (irow + irr) * ld_src / 2 + icol / 2 + 32), zmm_mask, LoadMask64); + if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { + dequant_f4_N(dstptr + (irow + irr) * ld_dst + icol, tmpbuf, vscales, vzps); + } else { dequant_s8_N(dstptr + (irow + irr) * ld_dst + icol, tmpbuf, vscales, vzps); } } @@ -540,17 +512,13 @@ static inline BTLA_CODE decompress_kblock_bit4_packrow2(utils::bit4x2* srcptr, _ } } for (; irow < row; irow++) { + convert_s4_s8_highbits(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, + LoadMask64); + convert_s4_s8_highbits_v32(tmpbuf + 64, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2 + 32), + zmm_mask, LoadMask64); if constexpr (SRC_TYPE == BTLA_DTYPE::TypeFloat) { - convert_s4_s8(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), - zmm_mask, LoadMask64); - convert_s4_s8_v32( - tmpbuf + 64, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2 + 32), zmm_mask, LoadMask64); dequant_f4_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); } else { - convert_s4_s8<_SRCT>(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2), zmm_mask, - LoadMask64); - convert_s4_s8_v32<_SRCT>(tmpbuf + 64, reinterpret_cast(srcptr + irow * ld_src / 2 + icol / 2 + 32), - zmm_mask, LoadMask64); dequant_s8_N(dstptr + irow * ld_dst + icol, tmpbuf, vscales, vzps); } } @@ -627,11 +595,11 @@ static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* d if (zero_points == nullptr) { return decompress_kblock_bit4_packrow1<_ST, _DST_T, true>( srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, - &dequant_s8_N<48, _DST_T, true>, &convert_s4_s8, tmp, tmpsize); + &dequant_s8_N<48, _DST_T, true>, &convert_s4_s8_highbits, tmp, tmpsize); } else { return decompress_kblock_bit4_packrow1<_ST, _DST_T, false>( srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, - &dequant_s8_N<48, _DST_T, false>, &convert_s4_s8, tmp, tmpsize); + &dequant_s8_N<48, _DST_T, false>, &convert_s4_s8_highbits, tmp, tmpsize); } } else if constexpr (_PACK_ROW == 2) { if (zero_points == nullptr) { @@ -673,8 +641,8 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 zmm1 = _mm512_sllv_epi32(zmm1, zmm_shift); // int3_clip => int8 zmm2 = _mm512_sllv_epi32(zmm2, zmm_shift); // int3_clip => int8 - _mm512_storeu_epi8((__m512i*)dst, zmm1); - _mm512_storeu_epi8((__m512i*)(dst + 64), zmm2); + _mm512_storeu_si512((__m512i*)dst, zmm1); + _mm512_storeu_si512((__m512i*)(dst + 64), zmm2); }; assert(head_ignore_num % 8 == 0); @@ -792,14 +760,14 @@ static inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, size_t i = 0; constexpr int LoadMask64 = (1 << (64 / 8)) - 1; for (; i < ele256; i += 256) { - convert_s4_s8(dstptr + i + 0, reinterpret_cast(srcptr + i / 2 + 0), zmm_mask, LoadMask64); - convert_s4_s8(dstptr + i + 64, reinterpret_cast(srcptr + i / 2 + 32), zmm_mask, LoadMask64); - convert_s4_s8(dstptr + i + 128, reinterpret_cast(srcptr + i / 2 + 64), zmm_mask, LoadMask64); - convert_s4_s8(dstptr + i + 192, reinterpret_cast(srcptr + i / 2 + 96), zmm_mask, LoadMask64); + convert_s4_s8_highbits(dstptr + i + 0, reinterpret_cast(srcptr + i / 2 + 0), zmm_mask, LoadMask64); + convert_s4_s8_highbits(dstptr + i + 64, reinterpret_cast(srcptr + i / 2 + 32), zmm_mask, LoadMask64); + convert_s4_s8_highbits(dstptr + i + 128, reinterpret_cast(srcptr + i / 2 + 64), zmm_mask, LoadMask64); + convert_s4_s8_highbits(dstptr + i + 192, reinterpret_cast(srcptr + i / 2 + 96), zmm_mask, LoadMask64); } if (i + 64 <= ele64) { for (; i < ele64; i += 64) { - convert_s4_s8(dstptr + i, reinterpret_cast(srcptr + i / 2), zmm_mask, LoadMask64); + convert_s4_s8_highbits(dstptr + i, reinterpret_cast(srcptr + i / 2), zmm_mask, LoadMask64); } } for (; i < elesize; i += 2) { @@ -862,6 +830,64 @@ static inline BTLA_CODE quantize_f32_sign_int_rowblock_sym(const float* srcptr, } return BTLA_CODE::Success; } +template +static inline BTLA_CODE quantize_f32_sign_int_rowblock_sym_auto(const float* srcptr, int8_t* dstptr, int row, int col, + int ld_src, int ld_dst, float* scales, int blocksize) { + int constexpr VLen = 16; + int col16 = utils::padto_le(col, VLen); + int i = 0; + auto align_row = row / blocksize * blocksize; + for (; i < col16; i += VLen) { + int j = 0; + float tmp_min[VLen]; + float tmp_max[VLen]; + float tmp_abs[VLen]; + auto simd_process_block = [&](int size) { + __m512 vscale; + __m512 vmaxval = _mm512_set1_ps(std::numeric_limits::min()); + __m512 vminval = _mm512_set1_ps(std::numeric_limits::max()); + __m512 vabsval = _mm512_set1_ps(0.f); + for (size_t ij = 0; ij < size; ij++) { + auto vsrc = _mm512_loadu_ps(&srcptr[(j + ij) * ld_src + i]); + vmaxval = _mm512_max_ps(vmaxval, vsrc); + vminval = _mm512_min_ps(vminval, vsrc); + vsrc = _mm512_abs_ps(vsrc); + vabsval = _mm512_max_ps(vabsval, vsrc); + } + _mm512_storeu_ps(tmp_min, vminval); + _mm512_storeu_ps(tmp_max, vmaxval); + _mm512_storeu_ps(tmp_abs, vabsval); + auto constexpr NBits = utils::bestla_dtype_bits(QDT_T); + int constexpr FullValue = 1 << (NBits - 1); + int constexpr GenValue = FullValue - 1; + for (int iv = 0; iv < VLen; iv++) { + int NVal = GenValue; + auto sum = tmp_max[iv] + tmp_min[iv]; + if (abs(sum) >= tmp_abs[iv] / FullValue) { + NVal = sum > 0.f ? -FullValue : FullValue; + } + NVal = NVal << (8 - NBits); + tmp_abs[iv] = NVal; + } + auto vmag = _mm512_loadu_ps(tmp_abs); + vscale = _mm512_div_ps(vabsval, vmag); + auto vrscale = _mm512_div_ps(vmag, vabsval); + _mm512_storeu_ps(&scales[j / blocksize * ld_dst + i], vscale); + for (size_t ij = 0; ij < size; ij++) { + auto vsrc = _mm512_loadu_ps(&srcptr[(j + ij) * ld_src + i]); + vsrc = _mm512_mul_ps(vsrc, vrscale); + auto vdsrc = _mm512_cvtps_epi32(vsrc); + auto vbsrc = _mm512_cvtepi32_epi8(vdsrc); + _mm_storeu_si128(reinterpret_cast<__m128i*>(&dstptr[(j + ij) * ld_dst + i]), vbsrc); + } + }; + for (; j < align_row; j += blocksize) simd_process_block(blocksize); + if (j < row) simd_process_block(row - align_row); + } + kernel::ref::quantize_f32_sign_int_rowblock(srcptr + i, dstptr + i, row, col - i, ld_src, ld_dst, scales + i, + nullptr, blocksize); + return BTLA_CODE::Success; +} static inline BTLA_CODE quantize_f32_sign_int_rowblock_asym(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, float* scales, int8_t* zero_points, @@ -930,12 +956,17 @@ static inline BTLA_CODE quantize_f32_sign_int_rowblock_asym(const float* srcptr, return BTLA_CODE::Success; } -template +template static inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, float* scales, int8_t* zero_points, int blocksize) { if (zero_points == nullptr) - return quantize_f32_sign_int_rowblock_sym(srcptr, dstptr, row, col, ld_src, ld_dst, scales, blocksize); + if constexpr (QDT_T == BTLA_DTYPE::S4_CLIP || QDT_T == BTLA_DTYPE::S3_CLIP) { + return quantize_f32_sign_int_rowblock_sym_auto(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + blocksize); + } else { + return quantize_f32_sign_int_rowblock_sym(srcptr, dstptr, row, col, ld_src, ld_dst, scales, blocksize); + } else return quantize_f32_sign_int_rowblock_asym(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, blocksize); @@ -1325,8 +1356,8 @@ static inline BTLA_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* s } inline BTLA_CODE dq8_get_fp_scale(uint8_t* src, float* dst, int row, int col, int scale_offset, int dq_blk, - int dq_offset_idx, float* dq_scale, int src_stride, int dst_stride, - bool zeropadding) { + int dq_offset_idx, float* dq_scale, int src_stride, int dst_stride, bool zeropadding, + int mN) { auto head_proc_num = utils::updiv(scale_offset, 16) * 16 - scale_offset; auto zmm_dq_offset = _mm512_set1_ps(dq_scale[dq_offset_idx]); @@ -1344,13 +1375,13 @@ inline BTLA_CODE dq8_get_fp_scale(uint8_t* src, float* dst, int row, int col, in for (int i = 0; i < row; i++) { if (head_proc_num > col) { auto mask = _cvtu32_mask16(0xffff >> (16 - col)); - get_fp_scale(col, mask, scale_offset, src + i * src_stride, dst + i * dst_stride); + get_fp_scale(col, mask, scale_offset + i * mN, src + i * src_stride, dst + i * dst_stride); } else { // TODO(zhe): consider head_proc_num==0 case. auto head_mask = _cvtu32_mask16(0xffff >> (16 - head_proc_num)); auto body_mask = _cvtu32_mask16(0xffff); - get_fp_scale(head_proc_num, head_mask, scale_offset, src + i * src_stride, dst + i * dst_stride); - auto scale_offset_iter = scale_offset + head_proc_num; + get_fp_scale(head_proc_num, head_mask, scale_offset + i * mN, src + i * src_stride, dst + i * dst_stride); + auto scale_offset_iter = scale_offset + i * mN + head_proc_num; uint8_t* src_iter_ptr = src + head_proc_num; float* dst_iter_ptr = dst + head_proc_num; auto body_loop = (col - head_proc_num) / 16; @@ -1419,17 +1450,17 @@ inline BTLA_CODE decompress_kblock_s4_s8fp(utils::int4x2* srcptr, _DST_T* dstptr size_t i = 0; constexpr int LoadMask64 = (1 << (64 / 8)) - 1; for (; i < ele256; i += 256) { - convert_s4_s8(tmp + 0, reinterpret_cast(srcptr + i / 2 + 0), zmm_mask, LoadMask64); - convert_s4_s8(tmp + 64, reinterpret_cast(srcptr + i / 2 + 32), zmm_mask, LoadMask64); - convert_s4_s8(tmp + 128, reinterpret_cast(srcptr + i / 2 + 64), zmm_mask, LoadMask64); - convert_s4_s8(tmp + 192, reinterpret_cast(srcptr + i / 2 + 96), zmm_mask, LoadMask64); + convert_s4_s8_highbits(tmp + 0, reinterpret_cast(srcptr + i / 2 + 0), zmm_mask, LoadMask64); + convert_s4_s8_highbits(tmp + 64, reinterpret_cast(srcptr + i / 2 + 32), zmm_mask, LoadMask64); + convert_s4_s8_highbits(tmp + 128, reinterpret_cast(srcptr + i / 2 + 64), zmm_mask, LoadMask64); + convert_s4_s8_highbits(tmp + 192, reinterpret_cast(srcptr + i / 2 + 96), zmm_mask, LoadMask64); for (size_t j = 0; j < 256; j += 16) { convert_s8_fp_v16(dstptr + i + j, tmp + j); } } if (i + 64 <= ele64) { for (; i < ele64; i += 64) { - convert_s4_s8(tmp, reinterpret_cast(srcptr + i / 2), zmm_mask, LoadMask64); + convert_s4_s8_highbits(tmp, reinterpret_cast(srcptr + i / 2), zmm_mask, LoadMask64); for (size_t j = 0; j < 64; j += 16) { convert_s8_fp_v16(dstptr + i + j, tmp + j); } @@ -1497,7 +1528,7 @@ static inline BTLA_CODE accum_alphaN_f32_f32(const SCA_T* alpha, const float* sr if constexpr (!std::is_same_v) { dstptr[i * dststep + j] += static_cast(alpha[j]) * srcptr[i * srcstep + j]; } else { - dstptr[i * dststep + j] += std::pow(2, alpha[j].x) * srcptr[i * srcstep + j]; + dstptr[i * dststep + j] += alpha[j].mul(srcptr[i * srcstep + j]); } } } diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index 18d876226..ba4460119 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -181,7 +181,14 @@ static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, i static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int row, int col, int ld_src, int ld_dst) { assert(col % 128 == 0); - + auto round3bit = [](int8_t src) { + int32_t dst = src; + dst = dst >= 0 ? dst + 16 : dst - 16; + dst = dst / 32; + dst = dst > 3 ? 3 : dst; + dst = dst < -4 ? -4 : dst; + return static_cast(dst); + }; auto bit2_interleave = [&](int8_t* src, int8_t* dst) { for (int i = 0; i < 128 / 4; i++) { dst[4 * i] = src[i]; @@ -191,30 +198,36 @@ static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x } }; + int8_t round_buf[128]; int8_t interleave_buf[128]; for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 128) { - bit2_interleave(const_cast(srcptr + i * ld_src + j), interleave_buf); + for (int k = 0; k < 128; k++) { + round_buf[k] = round3bit(const_cast(srcptr + i * ld_src + j + k)[0]) << 5; + } + bit2_interleave(round_buf, interleave_buf); for (int k = 0; k < 32; k++) { bit2ptr[i * ld_dst / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5; bit2ptr[i * ld_dst / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5; bit2ptr[i * ld_dst / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5; bit2ptr[i * ld_dst / 4 + j / 4 + k].d = interleave_buf[4 * k + 3] >> 5; } + for (int k = j; k < j + 128; k += 8) { + bit1ptr[i * ld_dst / 8 + k / 8].a = round_buf[k - j] >> 7; + bit1ptr[i * ld_dst / 8 + k / 8].b = round_buf[k - j + 1] >> 7; + bit1ptr[i * ld_dst / 8 + k / 8].c = round_buf[k - j + 2] >> 7; + bit1ptr[i * ld_dst / 8 + k / 8].d = round_buf[k - j + 3] >> 7; + bit1ptr[i * ld_dst / 8 + k / 8].e = round_buf[k - j + 4] >> 7; + bit1ptr[i * ld_dst / 8 + k / 8].f = round_buf[k - j + 5] >> 7; + bit1ptr[i * ld_dst / 8 + k / 8].g = round_buf[k - j + 6] >> 7; + bit1ptr[i * ld_dst / 8 + k / 8].h = round_buf[k - j + 7] >> 7; + } } } // store 1 bit without interleave as mask. for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 8) { - bit1ptr[i * ld_dst / 8 + j / 8].a = srcptr[i * ld_src + j] >> 7; - bit1ptr[i * ld_dst / 8 + j / 8].b = srcptr[i * ld_src + j + 1] >> 7; - bit1ptr[i * ld_dst / 8 + j / 8].c = srcptr[i * ld_src + j + 2] >> 7; - bit1ptr[i * ld_dst / 8 + j / 8].d = srcptr[i * ld_src + j + 3] >> 7; - bit1ptr[i * ld_dst / 8 + j / 8].e = srcptr[i * ld_src + j + 4] >> 7; - bit1ptr[i * ld_dst / 8 + j / 8].f = srcptr[i * ld_src + j + 5] >> 7; - bit1ptr[i * ld_dst / 8 + j / 8].g = srcptr[i * ld_src + j + 6] >> 7; - bit1ptr[i * ld_dst / 8 + j / 8].h = srcptr[i * ld_src + j + 7] >> 7; } } return BTLA_CODE::Success; @@ -236,17 +249,8 @@ static inline BTLA_CODE decompress_s4_f32(utils::int4x2* srcptr, float* dstptr, template inline int8_t get_s8(int8_t v) { - switch (S4_T) { - case BTLA_DTYPE::S4_CLIP: - return v << 4; - case BTLA_DTYPE::S4_FULLRANGE: - v &= 0x0f; - return v - 8; - default: - assert(false); - break; - } - return static_cast(0); + static_assert(S4_T == BTLA_DTYPE::S4_CLIP); + return v << 4; } template @@ -290,14 +294,6 @@ inline void convert_s4_s8_8_lowbits(int8_t* dstptr, int8_t* srcptr) { dstptr[7] = static_cast(tmp); } -template <> -inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { - convert_s4_s8_8_lowbits(dstptr, srcptr); - for (size_t i = 0; i < 8; i++) { - dstptr[i] -= 8; - } -} - template <> inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { convert_s4_s8_8_lowbits(dstptr, srcptr); @@ -315,6 +311,7 @@ inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) template inline BTLA_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst) { + static_assert(S4_T == BTLA_DTYPE::S4_CLIP); for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 2) { auto tmp = srcptr[i * ld_src / 2 + j / 2]; @@ -864,7 +861,7 @@ static inline BTLA_CODE get2d_e8m0_scale(const void* srcptr, void* dstptr, int r return BTLA_CODE::Success; } -template +template inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, float* scales, int8_t* zero_points, int blocksize) { int raw_blocksize = blocksize; @@ -883,24 +880,6 @@ inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dst dstptr[(j + ij) * ld_dst + i] = utils::cast(srcptr[(j + ij) * ld_src + i] * rscale); } }; - auto s4_fullrange_calc_store_scale_and_quantv_sym = [&](int blocksize) { - float amax = 0.f, max = 0.f; - for (size_t ij = 0; ij < blocksize; ij++) { - auto v = srcptr[(j + ij) * ld_src + i]; - if (amax < std::abs(v)) { - amax = std::abs(v); - max = v; - } - } - float scale = max / -8.f; - float rscale = scale != 0.f ? 1.f / scale : 0.f; - scales[j / raw_blocksize * ld_dst + i] = scale; - for (size_t ij = 0; ij < blocksize; ij++) { - auto quant_v = srcptr[(j + ij) * ld_src + i] * rscale; - int8_t x = std::min(static_cast(15), static_cast(quant_v + 8.5f)); - dstptr[(j + ij) * ld_dst + i] = x << 4; - } - }; auto s8_calc_store_scale_and_quantv_asym = [&](int blocksize) { float maxval = 0.f; float minval = 0.f; @@ -918,45 +897,47 @@ inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dst dstptr[(j + ij) * ld_dst + i] = utils::cast((srcptr[(j + ij) * ld_src + i] - fmedium) * rscale); } }; - auto s4_fullrange_calc_store_scale_and_quantv_asym = [&](int blocksize) { - float maxval = 0.f; - float minval = 0.f; + auto sNauto_calc_store_scale_and_quantv_sym = [&](int blocksize) { + auto constexpr NBits = utils::bestla_dtype_bits(QDT_T); + int constexpr FullValue = 1 << (NBits - 1); + int constexpr GenValue = FullValue - 1; + float maxval = std::numeric_limits::min(); + float minval = std::numeric_limits::max(); + float absmax = 0; for (size_t ij = 0; ij < blocksize; ij++) { - auto v = srcptr[(j + ij) * ld_src + i]; - maxval = std::max(maxval, v); - minval = std::min(minval, v); + maxval = std::max(maxval, srcptr[(j + ij) * ld_src + i]); + minval = std::min(minval, srcptr[(j + ij) * ld_src + i]); + absmax = std::max(absmax, std::abs(srcptr[(j + ij) * ld_src + i])); + } + int NVal = GenValue; + auto sum = maxval + minval; + if (abs(sum) >= absmax / FullValue) { + NVal = sum > 0.f ? -FullValue : FullValue; } - float max = std::abs(maxval) < std::abs(minval) ? minval - maxval : maxval - minval; - float scale = max / -16.f; - float rscale = scale != 0.f ? 1.f / scale : 0.f; + NVal = NVal << (8 - NBits); + float scale = absmax / NVal; + float rscale = 1.f / scale; scales[j / raw_blocksize * ld_dst + i] = scale; - float fmedium = (maxval + minval) / 2; - ; - int8_t bzp = utils::cast((0.f - fmedium) * rscale); - zero_points[j / raw_blocksize * ld_dst + i] = bzp; for (size_t ij = 0; ij < blocksize; ij++) { - auto quant_v = (srcptr[(j + ij) * ld_src + i] - fmedium) * rscale; - int8_t x = std::min(static_cast(15), static_cast(quant_v + 8.5f)); - dstptr[(j + ij) * ld_dst + i] = x << 4; + dstptr[(j + ij) * ld_dst + i] = utils::cast(srcptr[(j + ij) * ld_src + i] * rscale); } }; auto dispatch_calc = [&](int blocksize) { - switch (S4_T) { + switch (QDT_T) { case BTLA_DTYPE::S8: - case BTLA_DTYPE::S4_CLIP: - case BTLA_DTYPE::S3_CLIP: if (zero_points == nullptr) { s8_calc_store_scale_and_quantv_sym(blocksize); } else { s8_calc_store_scale_and_quantv_asym(blocksize); } break; - case BTLA_DTYPE::S4_FULLRANGE: + case BTLA_DTYPE::S3_CLIP: + case BTLA_DTYPE::S4_CLIP: if (zero_points == nullptr) { - s4_fullrange_calc_store_scale_and_quantv_sym(blocksize); + sNauto_calc_store_scale_and_quantv_sym(blocksize); } else { - s4_fullrange_calc_store_scale_and_quantv_asym(blocksize); + s8_calc_store_scale_and_quantv_asym(blocksize); } break; default: @@ -1252,11 +1233,11 @@ inline BTLA_CODE dq8_bnb_double_quant(float* scale, size_t scale_size, int dq_bl } inline BTLA_CODE dq8_get_fp_scale(uint8_t* src, float* dst, int row, int col, int scale_offset, int dq_blk, - int dq_offset_idx, float* dq_scale, int src_stride, int dst_stride, - bool zeropadding) { + int dq_offset_idx, float* dq_scale, int src_stride, int dst_stride, bool zeropadding, + int mN) { for (int i = 0; i < row; i++) { for (int j = 0; j < col; j++) { - auto dq_s_idx = (scale_offset + j) / dq_blk; + auto dq_s_idx = (i * mN + scale_offset + j) / dq_blk; dst[i * dst_stride + j] = dq8_bnb_LUT[src[i * src_stride + j]] * dq_scale[dq_s_idx] + dq_scale[dq_offset_idx]; } } diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index 5b7c7f5b8..e4a416f00 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -241,21 +241,21 @@ class Dq8GetScale { public: template static BTLA_CODE forward(uint8_t* src, float* dst, int row, int col, int scale_offset, int dq_blk, int dq_offset_idx, - float* dq_scale, int src_stride, int dst_stride, bool zeropadding) { + float* dq_scale, int src_stride, int dst_stride, bool zeropadding, int mN) { #if CompileAVX512F() if (ISA_T >= BTLA_ISA::AVX512F) { return kernel::avx512f::dq8_get_fp_scale(src, dst, row, col, scale_offset, dq_blk, dq_offset_idx, dq_scale, - src_stride, dst_stride, zeropadding); + src_stride, dst_stride, zeropadding, mN); } #endif #if CompileAVX2() if (ISA_T >= BTLA_ISA::AVX2) { return kernel::avx2::dq8_get_fp_scale(src, dst, row, col, scale_offset, dq_blk, dq_offset_idx, dq_scale, - src_stride, dst_stride, zeropadding); + src_stride, dst_stride, zeropadding, mN); } #endif return kernel::ref::dq8_get_fp_scale(src, dst, row, col, scale_offset, dq_blk, dq_offset_idx, dq_scale, src_stride, - dst_stride, zeropadding); + dst_stride, zeropadding, mN); } }; @@ -297,18 +297,17 @@ class Transpose2D { class QuantizeSignIntRowBlock { public: - template + template static inline BTLA_CODE forward(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, float* scales, int8_t* zero_points, int blocksize) { #if CompileAVX512F() - if constexpr (utils::isa_base::avx512f && - S4_T != BTLA_DTYPE::S4_FULLRANGE) { // TODO(zhe): support simd version s4_fullrange quantization. - return avx512f::quantize_f32_sign_int_rowblock(srcptr, dstptr, row, col, ld_src, ld_dst, scales, - zero_points, blocksize); + if constexpr (utils::isa_base::avx512f) { + return avx512f::quantize_f32_sign_int_rowblock(srcptr, dstptr, row, col, ld_src, ld_dst, scales, + zero_points, blocksize); } #endif - return ref::quantize_f32_sign_int_rowblock(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, - blocksize); + return ref::quantize_f32_sign_int_rowblock(srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, + blocksize); } }; @@ -422,35 +421,10 @@ class DecompressKBlockS4Fp { #endif #if CompileAVX2() // AVX2 device only focus on fp32 data and layout - if constexpr (utils::isa_base::avx2 && std::is_same_v<_SCA_T, float> && std::is_same_v<_DST_T, float> && - _PACK_ROW == 1) { - if (zero_points == nullptr) { - if (col == 24) { - ret = avx2::decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, - &avx2::dequant_s8_N_avx2<24, true>, &avx2::convert_s4_s8_16_sse, &ref::convert_s4_s8_8, - reinterpret_cast(tmp), tmpsize); - } else if (col == 48) { - ret = avx2::decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, - &avx2::dequant_s8_N_avx2<48, true>, &avx2::convert_s4_s8_16_sse, &ref::convert_s4_s8_8, - reinterpret_cast(tmp), tmpsize); - } - - } else { - if (col == 24) { - ret = avx2::decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, - &avx2::dequant_s8_N_avx2<24, false>, &avx2::convert_s4_s8_16_sse, &ref::convert_s4_s8_8, - reinterpret_cast(tmp), tmpsize); - } else if (col == 48) { - ret = avx2::decompress_kblock_bit4_packrow1( - srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, k_offset, kblock, NPad, - &avx2::dequant_s8_N_avx2<48, false>, &avx2::convert_s4_s8_16_sse, &ref::convert_s4_s8_8, - reinterpret_cast(tmp), tmpsize); - } - } - + if constexpr (utils::isa_base::avx2) { + ret = avx2::decompress_kblock_s4_fp(srcptr, dstptr, row, col, ld_src, ld_dst, + scales, zero_points, k_offset, kblock, NPad, + reinterpret_cast(tmp), tmpsize); if (ret == BTLA_CODE::Success) return ret; } #endif diff --git a/bestla/bestla/ut/bestla_gemm.cpp b/bestla/bestla/ut/bestla_gemm.cpp index 9a686126f..697d1bd70 100644 --- a/bestla/bestla/ut/bestla_gemm.cpp +++ b/bestla/bestla/ut/bestla_gemm.cpp @@ -170,18 +170,12 @@ class UT_GEMM_AVX2 { ut_48(1, 48, 3); ut_48(1, 144, 3); - -#ifdef JBLAS_UT_BENCHMARK - benchmark_all(388, 192, 512, 1024); - benchmark_all(512, 192, 768, 1024); - benchmark_all(512, 384, 512, 1024); -#endif } void ut_24(int m, int n, int k) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::SCoreRowNAvx2<24>; - Core gemm; + static Core gemm; if (n % Core::NTILE != 0) { return; } @@ -203,7 +197,7 @@ class UT_GEMM_AVX2 { void ut_48(int m, int n, int k) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::SCoreRowNAvx2<48>; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -221,43 +215,8 @@ class UT_GEMM_AVX2 { gemm.forward(A.data(), B.data(), C.data(), m, n, k, k * 4, k * 4, n * 4, 0, cache, CacheSize); ut::buffer_error(RefC.data(), C.data(), RefC.size(), 0.001f); } - - template - void benchmark(int m, int n, int k, int batch, float* A, float* B, float* C) { - LOG_T log; - CORE_T core; - for (size_t i = 0; i < batch; i++) { - log.start(); - for (size_t im = 0; im < m; im += CORE_T::MTILE) { - auto im_re = remainsize(im, m, CORE_T::MTILE); - core.forward(A + i * m * k + im * k, B + i * n * k, C + i * m * n + im * n, im_re, n, k, k * 4, k * 4, n * 4, 0, - cache, CacheSize); - } - if (log.stop()) { - printf("tar%d %d %s\n", CORE_T::NTILE, CORE_T::MTILE, log.get_log_str()); - } - } - } - void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { - printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); - avector A(m * k * batch), B(k * n * batch), C(m * n * batch, 0.f), RefC(m * n * batch, 0.f); - using LOG = timer_statistics_logger<500>; - using Core8 = gemm::SCoreRowNAvx2<8>; - using Core16 = gemm::SCoreRowNAvx2<16>; - using Core24 = gemm::SCoreRowNAvx2<24>; - using Core24_B = gemm::SCoreRowNAvx2<24, 4>; - using Core32 = gemm::SCoreRowNAvx2<32>; - using Core48 = gemm::SCoreRowNAvx2<48>; - - benchmark(m, n, k, batch, A.data(), B.data(), C.data()); - benchmark(m, n, k, batch, A.data(), B.data(), C.data()); - benchmark(m, n, k, batch, A.data(), B.data(), C.data()); - benchmark(m, n, k, batch, A.data(), B.data(), C.data()); - benchmark(m, n, k, batch, A.data(), B.data(), C.data()); - benchmark(m, n, k, batch, A.data(), B.data(), C.data()); - } }; -#ifdef JBLAS_UT_GEMM +#ifdef BTLA_UT_GEMM static UT_GEMM_AVX2 sUT_GEMM_AVX2; #endif @@ -271,19 +230,12 @@ class UT_GEMM_AVX512F { ut_48(1, 48, 3); ut_48(1, 144, 3); - ut(1024, 144, 154); - -#ifdef JBLAS_UT_BENCHMARK - benchmark_all(388, 192, 512, 1024); - benchmark_all(512, 192, 768, 1024); - benchmark_all(512, 384, 512, 1024); -#endif } void ut_32(int m, int n, int k) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::SCoreRowNAvx512f<32>; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -302,7 +254,7 @@ class UT_GEMM_AVX512F { void ut_48(int m, int n, int k) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::SCoreRowNAvx512f<48>; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -317,62 +269,8 @@ class UT_GEMM_AVX512F { gemm.forward(A.data(), B.data(), C.data(), m, n, k, k * 4, k * 4, n * 4, 0, cache, CacheSize); ut::buffer_error(RefC.data(), C.data(), RefC.size(), 0.001f); } - - template - void benchmark(int m, int n, int k, int batch, float* A, float* B, float* C, float timems) { - LOG_T log; - CORE_T core; - utils::timer tm; - tm.start(); - while (tm.stop() < timems) { - for (size_t i = 0; i < batch; i++) { - log.start(); - for (int im = 0; im < m; im += CORE_T::MTILE) { - auto im_re = remainsize(im, m, CORE_T::MTILE); - core.forward(A + i * m * k + im * k, B + i * n * k, C + i * m * n + im * n, im_re, n, k, k * 4, k * 4, n * 4, - 0, cache, CacheSize); - } - if (log.stop()) { - printf("tar%d %d %s\n", CORE_T::NTILE, CORE_T::MTILE, log.get_log_str()); - } - } - } - } - - void ut(size_t m, size_t n, size_t k) { - printf("%s %d %d %d\n", __FUNCTION__, int(m), int(n), int(k)); - avector A(m * k), B(k * n), C(m * n, 0.f), RefC(m * n, 0.f); - using Core48_B = gemm::SCoreRowNAvx512f<48, 8>; - ref_fp32(A.data(), B.data(), C.data(), m, n, k, k * sizeof(A[0]), k * sizeof(B[0]), - n * sizeof(C[0]), 0); - Core48_B core; - for (size_t im = 0; im < m; im += Core48_B::MTILE) { - auto im_re = remainsize(im, m, Core48_B::MTILE); - core.forward(A.data(), B.data(), C.data(), im_re, n, k, k * sizeof(A[0]), k * sizeof(B[0]), n * sizeof(C[0]), 0, - cache, CacheSize); - } - buffer_error(RefC.data(), C.data(), RefC.size()); - } - - void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { - printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); - avector A(m * k * batch), B(k * n * batch), C(m * n * batch, 0.f), RefC(m * n * batch, 0.f); - using LOG = timer_statistics_logger<100>; - using Core16 = gemm::SCoreRowNAvx512f<16>; - using Core32 = gemm::SCoreRowNAvx512f<32>; - using Core48 = gemm::SCoreRowNAvx512f<48>; - using Core48_B = gemm::SCoreRowNAvx512f<48, 8>; - using Core64 = gemm::SCoreRowNAvx512f<64>; - - float testtime = 500.f; - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - } }; -#ifdef JBLAS_UT_GEMM +#ifdef BTLA_UT_GEMM static UT_GEMM_AVX512F sUT_GEMM_AVX512F; #endif @@ -386,19 +284,13 @@ class UT_GEMM_AVX512VNNI { ut<48, 0>(4, 96, 12); ut<48, 8>(4, 96, 12); - -#ifdef JBLAS_UT_BENCHMARK - benchmark_all(388, 192, 512, 1024); - benchmark_all(512, 192, 768, 1024); - benchmark_all(512, 384, 512, 1024); -#endif } template void ut(int m, int n, int k) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::ICoreRowNAvx512vnni; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -417,50 +309,8 @@ class UT_GEMM_AVX512VNNI { CacheSize); ut::buffer_error(RefC.data(), C.data(), RefC.size(), 1); } - - template - void benchmark(int m, int n, int k, int batch, uint8_t* A, int8_t* B, int32_t* C, float timems) { - LOG_T log; - CORE_T core; - utils::timer tm; - tm.start(); - while (tm.stop() < timems) { - for (size_t i = 0; i < batch; i++) { - log.start(); - for (int im = 0; im < m; im += CORE_T::MTILE) { - auto im_re = remainsize(im, m, CORE_T::MTILE); - core.forward(A + i * m * k + im * k, B + i * n * k, C + i * m * n + im * n, im_re, n, k, k * 1, k * 1, n * 4, - 0, cache, CacheSize); - } - if (log.stop()) { - printf("tar%d %d %s\n", CORE_T::NTILE, CORE_T::MTILE, log.get_log_str()); - } - } - } - } - - void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { - printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); - avector A(m * k * batch); - avector B(k * n * batch); - avector C(m * n * batch, 0), RefC(m * n * batch, 0); - fill_buffer_randn(A.data(), A.size(), (uint8_t)0, (uint8_t)255); - fill_buffer_randn(B.data(), B.size(), (int8_t)-127, (int8_t)127); - using LOG = timer_statistics_logger<100>; - using Core16 = gemm::ICoreRowNAvx512vnni<16>; - using Core32 = gemm::ICoreRowNAvx512vnni<32>; - using Core48 = gemm::ICoreRowNAvx512vnni<48>; - using Core48_B = gemm::ICoreRowNAvx512vnni<48, 8>; - using Core64 = gemm::ICoreRowNAvx512vnni<64>; - float testtime = 500.f; - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - } }; -#ifdef JBLAS_UT_GEMM +#ifdef BTLA_UT_GEMM static UT_GEMM_AVX512VNNI sUT_GEMM_AVX512VNNI; #endif @@ -480,7 +330,7 @@ class UT_GEMM_AVX512VNNI_KBLOCK { void ut(int m, int n, int k, int kblock) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::ICoreRowNAvx512vnniKBlock; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -565,7 +415,7 @@ class UT_GEMM_AVX512VNNI_KBLOCK { ut::buffer_error(RefC.data(), C.data(), RefC.size(), 0.001f); } }; -#ifdef JBLAS_UT_GEMM +#ifdef BTLA_UT_GEMM static UT_GEMM_AVX512VNNI_KBLOCK sUT_GEMM_AVX512VNNI_KBLOCK; #endif @@ -588,7 +438,7 @@ class UT_GEMM_AVXVNNI_KBLOCK { void ut(int m, int n, int k, int kblock) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::ICoreRowNAvxvnniKBlock; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -626,7 +476,7 @@ class UT_GEMM_AVXVNNI_KBLOCK { ut::buffer_error(RefC.data(), C.data(), RefC.size(), 0.001f); } }; -#ifdef JBLAS_UT_GEMM +#ifdef BTLA_UT_GEMM static UT_GEMM_AVXVNNI_KBLOCK sUT_GEMM_AVXVNNI_KBLOCK; #endif @@ -646,7 +496,7 @@ class UT_GEMM_AMXINT8_KBLOCK { void ut_splitblock(int m, int n, int k, int kblock) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::ICoreRowNAmxint8KBlock; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -684,7 +534,7 @@ class UT_GEMM_AMXINT8_KBLOCK { ut::buffer_error(RefC.data(), C.data(), RefC.size(), 0.001f); } }; -#ifdef JBLAS_UT_GEMM +#ifdef BTLA_UT_GEMM static UT_GEMM_AMXINT8_KBLOCK sUT_GEMM_AMXINT8_KBLOCK; #endif @@ -696,18 +546,13 @@ class UT_GEMM_AVXVNNI { ut<24>(4, 48, 12); ut<48>(2, 96, 12); -#ifdef JBLAS_UT_BENCHMARK - benchmark_all(388, 192, 512, 1024); - benchmark_all(512, 192, 768, 1024); - benchmark_all(512, 384, 512, 1024); -#endif } template void ut(int m, int n, int k) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::ICoreRowNAvxvnni; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -726,51 +571,8 @@ class UT_GEMM_AVXVNNI { CacheSize); ut::buffer_error(RefC.data(), C.data(), RefC.size(), 1); } - - template - void benchmark(int m, int n, int k, int batch, uint8_t* A, int8_t* B, int32_t* C, float timems) { - LOG_T log; - CORE_T core; - utils::timer tm; - tm.start(); - while (tm.stop() < timems) { - for (size_t i = 0; i < batch; i++) { - log.start(); - for (int im = 0; im < m; im += CORE_T::MTILE) { - auto im_re = remainsize(im, m, CORE_T::MTILE); - core.forward(A + i * m * k + im * k, B + i * n * k, C + i * m * n + im * n, im_re, n, k, k * 1, k * 1, n * 4, - 0, cache, CacheSize); - } - if (log.stop()) { - printf("tar%d %d %s\n", CORE_T::NTILE, CORE_T::MTILE, log.get_log_str()); - } - } - } - } - - void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { - printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); - avector A(m * k * batch); - avector B(k * n * batch); - avector C(m * n * batch, 0), RefC(m * n * batch, 0); - fill_buffer_randn(A.data(), A.size(), (uint8_t)0, (uint8_t)255); - fill_buffer_randn(B.data(), B.size(), (int8_t)-127, (int8_t)127); - using LOG = timer_statistics_logger<100>; - utils::timer tm; - using Core16 = gemm::ICoreRowNAvxvnni<16>; - using Core24 = gemm::ICoreRowNAvxvnni<24>; - using Core24_B = gemm::ICoreRowNAvxvnni<24, 4>; - using Core32 = gemm::ICoreRowNAvxvnni<32>; - using Core48 = gemm::ICoreRowNAvxvnni<48>; - float testtime = 500.f; - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - } }; -#ifdef JBLAS_UT_GEMM +#ifdef BTLA_UT_GEMM static UT_GEMM_AVXVNNI sUT_GEMM_AVXVNNI; #endif @@ -781,19 +583,13 @@ class UT_GEMM_AVX512FP16 { CheckISA(AVX512_FP16); ut<32, 0>(4, 64, 3); ut<64, 0>(4, 128, 3); - -#ifdef JBLAS_UT_BENCHMARK - benchmark_all(388, 192, 512, 1024); - benchmark_all(512, 192, 768, 1024); - benchmark_all(512, 384, 512, 1024); -#endif } template void ut(int m, int n, int k) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::HCoreRowNAvx512fp16; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -808,54 +604,10 @@ class UT_GEMM_AVX512FP16 { ref_fp16(matAbf16.data(), matBbf16.data(), refC.data(), m, n, k, k * 2, reordered_bstride, n * 2, 0); gemm.forward(matAbf16.data(), matBbf16.data(), matC.data(), m, n, k, k * sizeof(fp16), k * sizeof(fp16), n * sizeof(fp16), 0, cache, CacheSize); - ut::buffer_error(refC.data(), matC.data(), refC.size(), fp16(0.00001f * k)); - } - - template - void benchmark(int m, int n, int k, int batch, utils::fp16* A, utils::fp16* B, utils::fp16* C, float timems) { - LOG_T log; - CORE_T core; - utils::timer tm; - tm.start(); - while (tm.stop() < timems) { - for (size_t i = 0; i < batch; i++) { - log.start(); - for (int im = 0; im < m; im += CORE_T::MTILE) { - auto im_re = remainsize(im, m, CORE_T::MTILE); - core.forward(A + i * m * k + im * k, B + i * n * k, C + i * m * n + im * n, im_re, n, k, k, k, n, 0, cache, - CacheSize); - } - if (log.stop()) { - printf("tar%d %d %s\n", CORE_T::NTILE, CORE_T::MTILE, log.get_log_str()); - } - } - } - } - - void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { - printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); - avector A(m * k * batch); - avector B(k * n * batch); - avector C(m * n * batch, utils::fp16(0.f)), RefC(m * n * batch, utils::fp16(0.f)); - fill_buffer_randn(A.data(), A.size(), (utils::fp16)-0.5f, (utils::fp16)0.5f); - fill_buffer_randn(B.data(), B.size(), (utils::fp16)-0.5f, (utils::fp16)0.5f); - using LOG = timer_statistics_logger<100>; - using Core32 = gemm::HCoreRowNAvx512fp16<32>; - using Core64 = gemm::HCoreRowNAvx512fp16<64>; - using Core96 = gemm::HCoreRowNAvx512fp16<96>; - using Core96_B = gemm::HCoreRowNAvx512fp16<96, 8>; - using Core128 = gemm::HCoreRowNAvx512fp16<128>; - - float testtime = 500.f; - - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); + ut::buffer_error(refC.data(), matC.data(), refC.size(), fp16(FP16_ERR)); } }; -#ifdef JBLAS_UT_GEMM +#ifdef BTLA_UT_GEMM static UT_GEMM_AVX512FP16 sUT_GEMM_AVX512FP16; #endif @@ -867,19 +619,13 @@ class UT_GEMM_AVX512BF16 { ut<48, 0>(4, 96, 6); ut<48, 8>(4, 96, 6); ut<64, 0>(4, 128, 6); - -#ifdef JBLAS_UT_BENCHMARK - benchmark_all(388, 192, 512, 1024); - benchmark_all(512, 192, 768, 1024); - benchmark_all(512, 384, 512, 1024); -#endif } template void ut(int m, int n, int k) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::HCoreRowNAvx512bf16; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -896,49 +642,8 @@ class UT_GEMM_AVX512BF16 { n * sizeof(float), 0, cache, CacheSize); ut::buffer_error(refC.data(), matC.data(), refC.size(), 0.001f); } - - template - void benchmark(int m, int n, int k, int batch, utils::bf16* A, utils::bf16* B, float* C, float timems) { - LOG_T log; - CORE_T core; - utils::timer tm; - tm.start(); - while (tm.stop() < timems) { - for (size_t i = 0; i < batch; i++) { - log.start(); - for (int im = 0; im < m; im += CORE_T::MTILE) { - auto im_re = remainsize(im, m, CORE_T::MTILE); - core.forward(A + i * m * k + im * k, B + i * n * k, C + i * m * n + im * n, im_re, n, k, k * 2, k * 2, n * 4, - 0, cache, CacheSize); - } - if (log.stop()) { - printf("tar%d %d %s\n", CORE_T::NTILE, CORE_T::MTILE, log.get_log_str()); - } - } - } - } - - void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { - printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); - avector A(m * k * batch); - avector B(k * n * batch); - avector C(m * n * batch, utils::bf16(0.f)), RefC(m * n * batch, utils::bf16(0.f)); - fill_buffer_randn(A.data(), A.size(), (utils::bf16)-0.5f, (utils::bf16)0.5f); - fill_buffer_randn(B.data(), B.size(), (utils::bf16)-0.5f, (utils::bf16)0.5f); - using LOG = timer_statistics_logger<100>; - using Core32 = gemm::HCoreRowNAvx512bf16<32>; - using Core48 = gemm::HCoreRowNAvx512bf16<48>; - using Core48_B = gemm::HCoreRowNAvx512bf16<48, 8>; - using Core64 = gemm::HCoreRowNAvx512bf16<64>; - - float testtime = 500.f; - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - } }; -#ifdef JBLAS_UT_GEMM +#ifdef BTLA_UT_GEMM static UT_GEMM_AVX512BF16 sUT_GEMM_AVX512BF16; #endif @@ -952,18 +657,13 @@ class UT_GEMM_AMXBF16 { ut<32, 32>(4, 96, 96); ut<48, 0>(4, 96, 96); ut<64, 16>(4, 128, 96); -#ifdef JBLAS_UT_BENCHMARK - benchmakr_all(384, 192, 512, 1024); - benchmakr_all(384, 192, 768, 1024); - benchmakr_all(768, 384, 512, 1024); -#endif } template void ut(int m, int n, int k) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::HCoreRowNAmxbf16; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -982,48 +682,8 @@ class UT_GEMM_AMXBF16 { n * sizeof(float), 0, cache, CacheSize); ut::buffer_error(refC.data(), matC.data(), m * n, 0.001f); } - - template - void benchmark(int m, int n, int k, int batch, utils::bf16* A, utils::bf16* B, float* C, float timems) { - LOG_T log; - CORE_T core; - core.configure(m, n, k); - - utils::timer tm; - tm.start(); - while (tm.stop() < timems) { - for (size_t i = 0; i < batch; i++) { - log.start(); - for (int im = 0; im < m; im += CORE_T::MTILE) { - auto im_re = remainsize(im, m, CORE_T::MTILE); - core.forward(A + i * m * k + im * k, B + i * n * k, C + i * m * n + im * n, im_re, n, k, k * 2, k * 2, n * 4, - 0, cache, CacheSize); - } - if (log.stop()) { - printf("tar%d %d %s\n", CORE_T::NTILE, CORE_T::MTILE, log.get_log_str()); - } - } - } - } - - void benchmakr_all(size_t m, size_t n, size_t k, size_t batch) { - printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); - avector A(m * k * batch); - avector B(k * n * batch); - avector C(m * n * batch, utils::bf16(0.f)), RefC(m * n * batch, utils::bf16(0.f)); - fill_buffer_randn(A.data(), A.size(), (utils::bf16)-0.5f, (utils::bf16)0.5f); - fill_buffer_randn(B.data(), B.size(), (utils::bf16)-0.5f, (utils::bf16)0.5f); - using LOG = timer_statistics_logger<100>; - using Core32 = gemm::HCoreRowNAmxbf16<32, 32>; - using Core48 = gemm::HCoreRowNAmxbf16<48, 16>; - using Core64 = gemm::HCoreRowNAmxbf16<64, 16>; - float testtime = 500.f; - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - } }; -#ifdef JBLAS_UT_GEMM +#ifdef BTLA_UT_GEMM static UT_GEMM_AMXBF16 sUT_GEMM_AMXBF16; #endif @@ -1040,18 +700,13 @@ class UT_GEMM_AMXINT8 { ut<32, 32>(32, 64, 64 * 3); ut<64, 16>(16, 128, 64 * 3); -#ifdef JBLAS_UT_BENCHMARK - benchmark_all(384, 192, 512, 1024); - benchmark_all(384, 192, 768, 1024); - benchmark_all(768, 384, 512, 1024); -#endif } template void ut(int m, int n, int k) { printf("Test Case: %d %d %d\n", m, n, k); using Core = gemm::ICoreRowNAmxint8; - Core gemm; + static Core gemm; if (n % Core::Code::NTILE != 0) { return; } @@ -1070,48 +725,8 @@ class UT_GEMM_AMXINT8 { CacheSize); ut::buffer_error(RefC.data(), C.data(), m * n, 1); } - - template - void benchmark(int m, int n, int k, int batch, uint8_t* A, int8_t* B, int32_t* C, float timems) { - LOG_T log; - CORE_T core; - core.configure(m, n, k); - utils::timer tm; - tm.start(); - while (tm.stop() < timems) { - for (size_t i = 0; i < batch; i++) { - log.start(); - for (int im = 0; im < m; im += CORE_T::MTILE) { - auto im_re = remainsize(im, m, CORE_T::MTILE); - core.forward(A + i * m * k + im * k, B + i * n * k, C + i * m * n + im * n, im_re, n, k, k * 1, k * 1, n * 4, - 0, cache, CacheSize); - } - if (log.stop()) { - printf("tar%d %d %s\n", CORE_T::NTILE, CORE_T::MTILE, log.get_log_str()); - } - } - } - } - - void benchmark_all(size_t m, size_t n, size_t k, size_t batch) { - printf("%s %d %d %d %d\n", __FUNCTION__, int(m), int(n), int(k), int(batch)); - avector A(m * k * batch); - avector B(k * n * batch); - avector C(m * n * batch, 0), RefC(m * n * batch, 0); - fill_buffer_randn(A.data(), A.size(), (uint8_t)0, (uint8_t)255); - fill_buffer_randn(B.data(), B.size(), (int8_t)-127, (int8_t)127); - using LOG = timer_statistics_logger<100>; - using Core32 = gemm::ICoreRowNAmxint8<32, 32>; - using Core48 = gemm::ICoreRowNAmxint8<48, 16>; - using Core64 = gemm::ICoreRowNAmxint8<64, 16>; - - float testtime = 500.f; - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - benchmark(m, n, k, batch, A.data(), B.data(), C.data(), testtime); - } }; -#ifdef JBLAS_UT_GEMM +#ifdef BTLA_UT_GEMM static UT_GEMM_AMXINT8 sUT_GEMM_AMXINT8; #endif } // namespace ut diff --git a/bestla/bestla/ut/bestla_prologue_a.cpp b/bestla/bestla/ut/bestla_prologue_a.cpp index 5f7795479..3c79b5e8c 100644 --- a/bestla/bestla/ut/bestla_prologue_a.cpp +++ b/bestla/bestla/ut/bestla_prologue_a.cpp @@ -145,8 +145,8 @@ class UT_ActivationU8KBlockQuantize { } } buffer_error(redref.data(), reduce.data(), reduce.size(), INT8_ERR); - buffer_error(redqref.data(), reduce.data(), reduce.size(), FP32_ERR); - buffer_error(reduce.data(), quanAct.template RPtr(), reduce.size(), FP32_ERR); + buffer_error(redqref.data(), reduce.data(), reduce.size(), 0.01f); + buffer_error(reduce.data(), quanAct.template RPtr(), reduce.size(), 0.01f); } } }; @@ -293,4 +293,4 @@ static UT_ShuffleActivationKblock sUT_ShuffleActivationKblock; #endif } // namespace ut } // namespace bestla -#endif \ No newline at end of file +#endif diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 27170b33f..ba48e1bfa 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -173,15 +173,54 @@ class UT_BlockQunatize_F8 { static UT_BlockQunatize_F8 sUT_BlockQunatize_F8; #endif +class UT_BlockQunatize_S3S4 { + public: + UT_BlockQunatize_S3S4() { + UT_START(); + CheckISA(AVX512F); + ut(127, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut(4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut(4096, 4096, 128, BTLA_DTYPE::S3_CLIP); + ut(127, 4096, 32, BTLA_DTYPE::S4_CLIP); + ut(4096, 4096, 32, BTLA_DTYPE::S4_CLIP); + ut(4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + } + + void ut(int n, int k, int blocksize, BTLA_DTYPE QUANT_T) { + printf("%s DType %s: %d %d %d\n", __FUNCTION__, utils::bestla_dtype_str(QUANT_T), n, k, blocksize); + int ldb = n; + utils::aligned_vector raw(n * k); + ut::fill_buffer_randn(raw.data(), raw.size(), -0.5f, 0.5f); + + auto constexpr RuntimeISA = BTLA_ISA::AVX512F; + using PrologueB = prologue_b::gemm::WeightKBlockNInteger, RuntimeISA>; + PrologueB kernel; + auto ptr = kernel.createStorage(n, k, blocksize, QUANT_T, BTLA_DTYPE::F32, BTLA_DTYPE::F32, false); + avector buffer(ptr.mSize); + ptr.assign(buffer.data()); + kernel.packWeight(n, k, raw.data(), ldb, &ptr, UT_Threading::get()); + avector dequant(n * k, 0); + kernel.unpackWeight(n, k, &ptr, dequant.data(), n, UT_Threading::get()); + ut::buffer_error(raw.data(), dequant.data(), dequant.size(), 0.01f); + } +}; +#ifdef BTLA_UT_PROLOGUE_B +// no proper threshold for this UT +// static UT_BlockQunatize_S3S4 sUT_BlockQunatize_S3S4; +#endif + class UT_S3_WOQ { public: UT_S3_WOQ() { UT_START(); CheckISA(AVX512F); ut(1, 4096, 4096, 32, 56); + CheckISA(AVX512_VNNI); + ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 56); + CheckISA(AMX_BF16); ut(1, 4096, 4096, 32, 56); + CheckISA(AMX_INT8); ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 56); - ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 56); } template @@ -376,11 +415,9 @@ class UT_BlockQuantize_INT4 { CheckISA(AVX2); CheckISA(AVX512F); ut_2(4096, 4096, 128, BTLA_DTYPE::S4_CLIP, false); - ut_2(4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE, false); CheckISA(AVX512F); ut_512vnni(4096, 4096, 128, BTLA_DTYPE::S4_CLIP, false); ut_512vnni(4096, 4096, 128, BTLA_DTYPE::S4_CLIP, true); - ut_512vnni(4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE, false); } void ut_2(int n, int k, int blocksize, BTLA_DTYPE qtype, bool asym = false) { printf("Test Case: %d %d %d %s\n", n, k, blocksize, asym ? "asym" : "sym"); @@ -477,7 +514,6 @@ class UT_StorageMemCheck { UT_START(); CheckISA(AVX512F); ut_s4(4096, 4096, 128, BTLA_DTYPE::S4_CLIP); - ut_s4(4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, true); ut_f4(4096, 4096, 32, BTLA_DTYPE::F4_BNB); ut_f4(4096, 4096, 32, BTLA_DTYPE::F4_E2M1); } @@ -526,7 +562,6 @@ class UT_ShuffleIndices { UT_START(); CheckISA(AVX2); // ut_file(); - ut_s4(4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, true); ut_s4(4096, 4096, 128, BTLA_DTYPE::S4_CLIP); } @@ -646,16 +681,8 @@ class UT_CompFp32 { false); ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, false); - ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, BTLA_DTYPE::F32, - false); - ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE, BTLA_DTYPE::F32, - false); - ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE, BTLA_DTYPE::F32, - false); ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::BF16, false); - ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, BTLA_DTYPE::BF16, - false); CheckISA(AVX512F); ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, @@ -664,16 +691,8 @@ class UT_CompFp32 { false); ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, false); - ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, - BTLA_DTYPE::F32, false); - ut_int(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE, - BTLA_DTYPE::F32, false); - ut_int(2, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE, - BTLA_DTYPE::F32, false); ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::BF16, false); - ut_int(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE, - BTLA_DTYPE::BF16, false); } void ut_s8() { @@ -826,11 +845,7 @@ class UT_CompInt8 { ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); ut(2, 4096, 4096, -1, BTLA_DTYPE::S4_CLIP); - ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); - ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE); - ut(2, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE); ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); - ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); } if (_cd->AVX512_VNNI()) { @@ -841,11 +856,7 @@ class UT_CompInt8 { ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); ut(2, 4096, 4096, -1, BTLA_DTYPE::S4_CLIP); - ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); - ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE); - ut(2, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE); ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); - ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP, true); } @@ -853,11 +864,7 @@ class UT_CompInt8 { request_perm_xtile_data(); ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); ut(2, 4096, 4096, -1, BTLA_DTYPE::S4_CLIP); - ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE); - ut(2, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE); - ut(16, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE); ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); - ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE); ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP, true); ut_s8s8(2, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); ut_s8s8(2, 4096, 4096, -1, BTLA_DTYPE::S4_CLIP); @@ -1195,11 +1202,7 @@ class UT_CompBf16 { ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); ut(2, 4096, 4096, -1, BTLA_DTYPE::S4_CLIP); - ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); - ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE); - ut(2, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE); ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); - ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); } void ut_s8() { @@ -1483,11 +1486,7 @@ class UT_CompFp16 { ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); ut(2, 4096, 4096, -1, BTLA_DTYPE::S4_CLIP); - ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); - ut(2, 4096, 4096, 128, BTLA_DTYPE::S4_FULLRANGE); - ut(2, 4096, 4096, -1, BTLA_DTYPE::S4_FULLRANGE); ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_CLIP); - ut(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); } void ut_s8() { diff --git a/bestla/bestla/ut/bestla_ut.h b/bestla/bestla/ut/bestla_ut.h index b9e40f54e..dbab29ead 100644 --- a/bestla/bestla/ut/bestla_ut.h +++ b/bestla/bestla/ut/bestla_ut.h @@ -366,7 +366,7 @@ struct UT_GEMMData_Row_u8s8 { float _cmax = std::numeric_limits::min(); matCRef.resize(M * LDC); auto tmpsrcscale = alpha * matA.scales[0] * matB.scales[0]; -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int j = 0; j < M; j++) { for (int i = 0; i < N; i += 1) { int tmp = 0; @@ -385,14 +385,14 @@ struct UT_GEMMData_Row_u8s8 { matC.scales[0] = (_cmax - _cmin) / (255.f); matC.zeropoints[0] = int((0 - _cmin) / matC.scales[0]); auto tmpscale = 1.f / matC.scales[0]; -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int j = 0; j < M; j++) { for (int i = 0; i < N; i += 1) { matC.data()[j * LDC + i] = utils::cast(matCRef[j * LDC + i] * tmpscale + matC.zeropoints[0]); } } matCDequan.resize(matCRef.size()); -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int j = 0; j < M; j++) { for (int i = 0; i < N; i += 1) { matCDequan.data()[j * LDC + i] = ((int)matC.data()[j * LDC + i] - matC.zeropoints[0]) * matC.scales[0]; @@ -412,7 +412,7 @@ struct UT_GEMMData_Row_u8s8 { }; static inline void gemmref_u8s8s32(int m, int n, int k, uint8_t* A, int8_t* B, int32_t* C, int lda, int ldb, int ldc) { -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int j = 0; j < m; j++) { for (int i = 0; i < n; i += 1) { int tmp = 0; @@ -425,7 +425,7 @@ static inline void gemmref_u8s8s32(int m, int n, int k, uint8_t* A, int8_t* B, i } static inline void gemmref_s8s8s32(int m, int n, int k, int8_t* A, int8_t* B, int32_t* C, int lda, int ldb, int ldc) { -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int j = 0; j < m; j++) { for (int i = 0; i < n; i += 1) { int tmp = 0; @@ -441,7 +441,7 @@ static inline void kblockgemmref_u8zp_s8_f32(int m, int n, int k, int kblock, ui int8_t* B, float* scaleB, float* C, int lda, int ldsa, int ldb, int ldsb, int ldc) { int kblk = utils::padto_le(k, kblock); -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int j = 0; j < m; j++) { for (int i = 0; i < n; i += 1) { float tmp = 0.f; @@ -468,7 +468,7 @@ static inline void kblockgemmref_u8zp_s8_f32(int m, int n, int k, int kblock, ui static inline void kblockgemmref_u8zp_s8_f32(int m, int n, int k, int kblock, uint8_t* A, uint8_t* zpA, float* scaleA, int8_t* B, utils::bf16* scaleB, float* C, int lda, int ldsa, int ldb, int ldsb, int ldc) { -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int j = 0; j < m; j++) { for (int i = 0; i < n; i += 1) { float tmp = 0.f; @@ -507,7 +507,7 @@ struct UT_GEMMData_Row_bf16 { } void calc_ref(float alpha, float beta) { -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { auto tmp = 0.f; @@ -527,7 +527,7 @@ struct UT_GEMMData_Row_bf16 { }; static inline void gemmref_fp32fp32fp32(int m, int n, int k, float* A, float* B, float* C, int lda, int ldb, int ldc) { -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int j = 0; j < m; j++) { for (int i = 0; i < n; i += 1) { float tmp = 0; @@ -541,7 +541,7 @@ static inline void gemmref_fp32fp32fp32(int m, int n, int k, float* A, float* B, static inline void gemmref_bf16bf16fp32(int m, int n, int k, utils::bf16* A, utils::bf16* B, float* C, int lda, int ldb, int ldc) { -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int j = 0; j < m; j++) { for (int i = 0; i < n; i += 1) { float tmp = 0; @@ -555,7 +555,7 @@ static inline void gemmref_bf16bf16fp32(int m, int n, int k, utils::bf16* A, uti static inline void gemmref_fp16fp16fp16(int m, int n, int k, utils::fp16* A, utils::fp16* B, utils::fp16* C, int lda, int ldb, int ldc) { -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int j = 0; j < m; j++) { for (int i = 0; i < n; i += 1) { float tmp = 0; @@ -586,7 +586,7 @@ struct UT_GEMMData_Row_fp16 { } void calc_ref(float alpha, float beta) { -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { utils::fp16 tmp = utils::fp16(0.f); @@ -630,7 +630,7 @@ struct UT_GEMMData_Row_f32 { int ldc, int ldd, float alpha, float beta) { int NBlock = 128; #if 1 -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int i = 0; i < n; i += NBlock) { for (int j = 0; j < m; j++) { int remainn = i + NBlock <= n ? NBlock : n - i; @@ -645,7 +645,7 @@ struct UT_GEMMData_Row_f32 { } } #else -#pragma omp parallel for collapse(2) +#pragma omp parallel for for (int i = 0; i < n; i += 1) { for (int j = 0; j < m; j++) { auto tmp = 0.f; diff --git a/bestla/bestla/ut/kernel_intrin.cpp b/bestla/bestla/ut/kernel_intrin.cpp index e2cff3869..db848bd82 100644 --- a/bestla/bestla/ut/kernel_intrin.cpp +++ b/bestla/bestla/ut/kernel_intrin.cpp @@ -42,8 +42,40 @@ class UT_Avx512f_decompress_kblock_s4_fp { ut::buffer_error(ref_wei.data(), bf16_wei.data(), bf16_wei.size(), DST_T(BF16_ERR)); } }; -#ifdef BTLA_KERNEL_INTRIN +#ifdef BTLA_UT_KERNEL_INTRIN static UT_Avx512f_decompress_kblock_s4_fp sUT_Avx512f_decompress_kblock_s4_fp; #endif +class UT_avx2_decompress_s4_s8 { + public: + UT_avx2_decompress_s4_s8() { + UT_START(); + CheckISA(AVX2); + ut(32, 128); + ut(32, 96); + ut(32, 48); + } + + template + void ut(int row, int col) { + printf("Test Case %s_%s: %d %d\n", __FUNCTION__, bestla_dtype_str(S4_T), row, col); + std::vector s4_wei(row * col / 2); + std::vector s8_wei(col * row); + std::vector rev(col * row); + fill_buffer_randn(s8_wei.data(), s8_wei.size(), int8_t(-128), int8_t(127)); + + for (int i = 0; i < col * row; i += 2) { + s8_wei[i] = s8_wei[i] & 0xf0; + s8_wei[i + 1] = s8_wei[i + 1] & 0xf0; + s4_wei[i / 2].x = utils::int4x2::convert(s8_wei[i]); + s4_wei[i / 2].y = utils::int4x2::convert(s8_wei[i + 1]); + } + kernel::avx2::decompress_s4_s8(s4_wei.data(), rev.data(), row, col, col, col); + + ut::buffer_error(s8_wei.data(), rev.data(), rev.size(), int8_t(0)); + } +}; +#ifdef BTLA_UT_KERNEL_INTRIN +static UT_avx2_decompress_s4_s8 sUT_avx2_decompress_s4_s8; +#endif } // namespace ut } // namespace bestla diff --git a/bestla/bestla/ut/kernel_wrapper.cpp b/bestla/bestla/ut/kernel_wrapper.cpp index 26765c06c..37ba3d6f9 100644 --- a/bestla/bestla/ut/kernel_wrapper.cpp +++ b/bestla/bestla/ut/kernel_wrapper.cpp @@ -40,7 +40,11 @@ class UT_DecompressKBlockS4FP { kernel::wrapper::DecompressKBlockS4Fp::template forward( s4_wei.data(), ref_wei.data(), row, col, ld_src, ld_dst, scales.data(), asym ? zero_points.data() : nullptr, k_offset, kblock, NPad, cache, CacheSize); - ut::buffer_error(ref_wei.data(), bf16_wei.data(), bf16_wei.size(), DST_T(0.01f)); + DST_T thres = DST_T(0.01f); + if constexpr (std::is_same_v) { + thres = DST_T(BF16_ERR); + } + ut::buffer_error(ref_wei.data(), bf16_wei.data(), bf16_wei.size(), thres); } template @@ -68,7 +72,11 @@ class UT_DecompressKBlockS4FP { kernel::wrapper::DecompressKBlockS4Fp::template forward( s4_wei.data(), ref_wei.data(), row, col, ld_src, ld_dst, scales.data(), asym ? zero_points.data() : nullptr, k_offset, kblock, NPad, cache, CacheSize); - ut::buffer_error(ref_wei.data(), bf16_wei.data(), bf16_wei.size(), DST_T(0.01f)); + DST_T thres = DST_T(0.01f); + if constexpr (std::is_same_v) { + thres = DST_T(BF16_ERR); + } + ut::buffer_error(ref_wei.data(), bf16_wei.data(), bf16_wei.size(), thres); } }; #ifdef BTLA_UT_KERNEL_WRAPPER diff --git a/neural_speed/application/common.cpp b/neural_speed/application/common.cpp index b7e55cfdc..b9e7d9459 100644 --- a/neural_speed/application/common.cpp +++ b/neural_speed/application/common.cpp @@ -649,7 +649,7 @@ void quant_print_usage(int argc, char** argv, const quant_params& params) { fprintf(stderr, " --nthread number of threads to use (default: 1)\n"); fprintf(stderr, " --weight_dtype number of bits to use for quantization: int4/int8/fp8_e4m3/fp8_e5m2/" - "fp4_e2m1/nf4 (default: int4)\n"); + "fp4_e2m1/nf4/int3 (default: int4)\n"); fprintf(stderr, " --alg quantization algorithm to use: sym/asym (default: sym)\n"); fprintf(stderr, " --group_size group size: 32/128/-1 (per channel) (default: 32)\n"); fprintf(stderr, " --scale_dtype fp32/bf16/fp8 type for scales (default: fp32)\n"); diff --git a/neural_speed/core/layers/bestla_gemm.cpp b/neural_speed/core/layers/bestla_gemm.cpp index 471dc0e03..f6290d0c5 100644 --- a/neural_speed/core/layers/bestla_gemm.cpp +++ b/neural_speed/core/layers/bestla_gemm.cpp @@ -237,12 +237,14 @@ size_t BTLAGemmPackBSizeLocal(size_t N, size_t K, size_t BlkSize, BTLA_DTYPE Qua ScaleDtype, isAsym, shuffle_indice); } } + [[fallthrough]]; case NE_COMP_F16: case NE_COMP_BF16: if (_cd->AMX_BF16() && BlkSize % tAMX_BF16::KTILE == 0) { return BTLABuSize>(static_cast(BlkSize), N, K, QuantType, ScaleDtype, isAsym, shuffle_indice); } + [[fallthrough]]; case NE_COMP_F32: case NE_COMP_UNDEF: // currently only f32 activation if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { @@ -253,7 +255,7 @@ size_t BTLAGemmPackBSizeLocal(size_t N, size_t K, size_t BlkSize, BTLA_DTYPE Qua return BTLABuSize>(static_cast(BlkSize), N, K, QuantType, ScaleDtype, isAsym, shuffle_indice); } - break; + [[fallthrough]]; default: return 0; } @@ -262,22 +264,15 @@ size_t BTLAGemmPackBSizeLocal(size_t N, size_t K, size_t BlkSize, BTLA_DTYPE Qua size_t BTLAGemmPackBSize(size_t N, size_t K, size_t BlkSize, BTLA_DTYPE QuantType, BTLA_DTYPE ScaleDtype, bool isAsym, ne_comp_type CompType, int* shuffle_indice) { - switch (QuantType) { - case BTLA_DTYPE::S4_CLIP: - case BTLA_DTYPE::S3_CLIP: - case BTLA_DTYPE::S4_FULLRANGE: - case BTLA_DTYPE::S8: - return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, - isAsym, CompType, shuffle_indice); - case BTLA_DTYPE::F8_E4M3: - case BTLA_DTYPE::F8_E5M2: - case BTLA_DTYPE::F4_BNB: - case BTLA_DTYPE::F4_E2M1: - case BTLA_DTYPE::F4_NF4: - return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, isAsym, + auto qtype = utils::bestla_dtype_type(QuantType); + if (qtype == utils::bestla_dtype_type(BTLA_DTYPE::TypeInt)) { + return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, isAsym, CompType, shuffle_indice); - default: - return 0; + } else if (qtype == utils::bestla_dtype_type(BTLA_DTYPE::TypeFloat)) { + return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, isAsym, + CompType, shuffle_indice); + } else { + assert(0); } return 0; } @@ -332,6 +327,7 @@ bool BTLAGemmQuantPackBLocal(void* PackedBuf, const float* FpData, size_t N, siz return true; } } + [[fallthrough]]; case NE_COMP_F16: case NE_COMP_BF16: if (_cd->AMX_BF16() && BlkSize % tAMX_BF16::KTILE == 0) { @@ -340,6 +336,7 @@ bool BTLAGemmQuantPackBLocal(void* PackedBuf, const float* FpData, size_t N, siz ScaleDtype, isAsym, static_cast(ldb), isTrans, ThreadPool); return true; } + [[fallthrough]]; case NE_COMP_F32: case NE_COMP_UNDEF: // currently only f32 activation if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { @@ -354,6 +351,7 @@ bool BTLAGemmQuantPackBLocal(void* PackedBuf, const float* FpData, size_t N, siz ScaleDtype, isAsym, static_cast(ldb), isTrans, ThreadPool); return true; } + [[fallthrough]]; default: return false; } @@ -363,24 +361,17 @@ bool BTLAGemmQuantPackBLocal(void* PackedBuf, const float* FpData, size_t N, siz bool BTLAGemmQuantPackB(void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb, size_t BlkSize, BTLA_DTYPE QuantType, BTLA_DTYPE ScaleDtype, bool isAsym, ne_comp_type CompType, bool isTrans, void* ThreadPool) { - switch (QuantType) { - case BTLA_DTYPE::S4_CLIP: - case BTLA_DTYPE::S3_CLIP: - case BTLA_DTYPE::S4_FULLRANGE: - case BTLA_DTYPE::S8: - return BTLAGemmQuantPackBLocal( - PackedBuf, FpData, N, K, ldb, BlkSize, QuantType, ScaleDtype, isAsym, CompType, isTrans, ThreadPool); - case BTLA_DTYPE::F8_E5M2: - case BTLA_DTYPE::F8_E4M3: - case BTLA_DTYPE::F4_BNB: - case BTLA_DTYPE::F4_E2M1: - case BTLA_DTYPE::F4_NF4: - return BTLAGemmQuantPackBLocal( - PackedBuf, FpData, N, K, ldb, BlkSize, QuantType, ScaleDtype, isAsym, CompType, isTrans, ThreadPool); - default: - return false; + auto qtype = utils::bestla_dtype_type(QuantType); + if (qtype == utils::bestla_dtype_type(BTLA_DTYPE::TypeInt)) { + return BTLAGemmQuantPackBLocal( + PackedBuf, FpData, N, K, ldb, BlkSize, QuantType, ScaleDtype, isAsym, CompType, isTrans, ThreadPool); + } else if (qtype == utils::bestla_dtype_type(BTLA_DTYPE::TypeFloat)) { + return BTLAGemmQuantPackBLocal( + PackedBuf, FpData, N, K, ldb, BlkSize, QuantType, ScaleDtype, isAsym, CompType, isTrans, ThreadPool); + } else { + assert(0); + return false; } - return false; } template @@ -441,6 +432,7 @@ bool BTLAGemmPackBLocal(void* PackedBuf, const int8_t* QData, const float* Scale return true; } } + [[fallthrough]]; case NE_COMP_F16: case NE_COMP_BF16: if (_cd->AMX_BF16() && BlkSize % tAMX_BF16::KTILE == 0) { @@ -449,6 +441,7 @@ bool BTLAGemmPackBLocal(void* PackedBuf, const int8_t* QData, const float* Scale QuantType, ScaleDtype, isAsym, static_cast(ldb), shuffle_indice, ThreadPool); return true; } + [[fallthrough]]; case NE_COMP_F32: case NE_COMP_UNDEF: // currently only f32 activation if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { @@ -463,6 +456,7 @@ bool BTLAGemmPackBLocal(void* PackedBuf, const int8_t* QData, const float* Scale QuantType, ScaleDtype, isAsym, static_cast(ldb), shuffle_indice, ThreadPool); return true; } + [[fallthrough]]; default: return false; } @@ -566,23 +560,15 @@ bool BTLAGemmBatchDriver(const size_t M, const size_t N, const size_t K, const s size_t BTLAGemmPackBSize(size_t N, size_t K, size_t BlkSize, BTLA_DTYPE QuantType, BTLA_DTYPE ScaleDtype, bool isAsym, ne_comp_type CompType, int* shuffle_indice) { - switch (QuantType) { - case BTLA_DTYPE::S4_CLIP: - case BTLA_DTYPE::S3_CLIP: - case BTLA_DTYPE::S4_FULLRANGE: - case BTLA_DTYPE::S8: - return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, - isAsym, CompType, shuffle_indice); - case BTLA_DTYPE::F8_E4M3: - case BTLA_DTYPE::F8_E5M2: - case BTLA_DTYPE::F4_BNB: - case BTLA_DTYPE::F4_E2M1: - case BTLA_DTYPE::F4_NF4: - return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, isAsym, + auto qtype = utils::bestla_dtype_type(QuantType); + if (qtype == utils::bestla_dtype_type(BTLA_DTYPE::TypeInt)) { + return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, isAsym, CompType, shuffle_indice); - - default: - return 0; + } else if (qtype == utils::bestla_dtype_type(BTLA_DTYPE::TypeFloat)) { + return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, isAsym, + CompType, shuffle_indice); + } else { + assert(0); } return 0; } @@ -590,22 +576,15 @@ size_t BTLAGemmPackBSize(size_t N, size_t K, size_t BlkSize, BTLA_DTYPE QuantTyp bool BTLAGemmQuantPackB(void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb, size_t BlkSize, BTLA_DTYPE QuantType, BTLA_DTYPE ScaleDtype, bool isAsym, ne_comp_type CompType, bool isTrans, void* ThreadPool) { - switch (QuantType) { - case BTLA_DTYPE::S4_CLIP: - case BTLA_DTYPE::S3_CLIP: - case BTLA_DTYPE::S4_FULLRANGE: - case BTLA_DTYPE::S8: - return BTLAGemmQuantPackBLocal( - PackedBuf, FpData, N, K, ldb, BlkSize, QuantType, ScaleDtype, isAsym, CompType, isTrans, ThreadPool); - case BTLA_DTYPE::F8_E5M2: - case BTLA_DTYPE::F8_E4M3: - case BTLA_DTYPE::F4_BNB: - case BTLA_DTYPE::F4_E2M1: - case BTLA_DTYPE::F4_NF4: - return BTLAGemmQuantPackBLocal( - PackedBuf, FpData, N, K, ldb, BlkSize, QuantType, ScaleDtype, isAsym, CompType, isTrans, ThreadPool); - default: - return false; + auto qtype = utils::bestla_dtype_type(QuantType); + if (qtype == utils::bestla_dtype_type(BTLA_DTYPE::TypeInt)) { + return BTLAGemmQuantPackBLocal( + PackedBuf, FpData, N, K, ldb, BlkSize, QuantType, ScaleDtype, isAsym, CompType, isTrans, ThreadPool); + } else if (qtype == utils::bestla_dtype_type(BTLA_DTYPE::TypeFloat)) { + return BTLAGemmQuantPackBLocal( + PackedBuf, FpData, N, K, ldb, BlkSize, QuantType, ScaleDtype, isAsym, CompType, isTrans, ThreadPool); + } else { + assert(0); } return false; } @@ -613,17 +592,15 @@ bool BTLAGemmQuantPackB(void* PackedBuf, const float* FpData, size_t N, size_t K bool BTLAGemmPackB(void* PackedBuf, const int8_t* QData, const float* Scales, const int8_t* Zp, size_t N, size_t K, size_t ldb, size_t BlkSize, BTLA_DTYPE QuantType, BTLA_DTYPE ScaleDtype, bool isAsym, ne_comp_type CompType, int* shuffle_indice, void* ThreadPool) { - // only for integer quant - switch (QuantType) { - case BTLA_DTYPE::S3_CLIP: - case BTLA_DTYPE::S4_CLIP: - case BTLA_DTYPE::S4_FULLRANGE: - case BTLA_DTYPE::S8: - return BTLAGemmPackBLocal(PackedBuf, QData, Scales, Zp, N, K, ldb, - BlkSize, QuantType, ScaleDtype, isAsym, - CompType, shuffle_indice, ThreadPool); - default: - return false; + auto qtype = utils::bestla_dtype_type(QuantType); + if (qtype == utils::bestla_dtype_type(BTLA_DTYPE::TypeInt)) { + return BTLAGemmPackBLocal(PackedBuf, QData, Scales, Zp, N, K, ldb, BlkSize, + QuantType, ScaleDtype, isAsym, CompType, + shuffle_indice, ThreadPool); + } else if (qtype == utils::bestla_dtype_type(BTLA_DTYPE::TypeFloat)) { + assert(0); + } else { + assert(0); } return false; } diff --git a/neural_speed/core/layers/mha_dense_wrapper.h b/neural_speed/core/layers/mha_dense_wrapper.h index 9d416ce13..67d42fd3c 100644 --- a/neural_speed/core/layers/mha_dense_wrapper.h +++ b/neural_speed/core/layers/mha_dense_wrapper.h @@ -1190,7 +1190,7 @@ class weight_cvt_bf16_ntile48_t { for (int j = 0; j < n_size; j += 16) { const auto cur_src = src + i * 48 + j * 2; const auto cur_dst = dst + i * 48 + j; - const auto src_lo = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_loadu_epi16(cur_src), 16U)); + const auto src_lo = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_loadu_si512(cur_src), 16U)); const auto src_hi = _mm512_castsi512_ps(_mm512_maskz_loadu_epi16(mask_hi, cur_src)); _mm512_store_ps(cur_dst + 0, src_lo); _mm512_store_ps(cur_dst + 48, src_hi); diff --git a/neural_speed/models/model_utils/quant_utils.cpp b/neural_speed/models/model_utils/quant_utils.cpp index 05d5fc158..c25d9595a 100644 --- a/neural_speed/models/model_utils/quant_utils.cpp +++ b/neural_speed/models/model_utils/quant_utils.cpp @@ -410,7 +410,7 @@ void ne_common_quantize(const int nthread, const quant_params_internal& params, if (new_type == NE_TYPE_BTLA) { size_t k_ = tensor.ne.at(0); size_t n_ = tensor.ne.at(1); - printf("JBLAS "); + printf("BesTLA "); new_size = bestla_quantize(f32_data, work.addr, params, nthread, n_, k_); } else if (new_type >= NE_TYPE_Q4_0 && new_type < NE_TYPE_BTLA) { printf("GGML "); diff --git a/neural_speed/models/whisper/whisper.cpp b/neural_speed/models/whisper/whisper.cpp index f57b82b0d..1865094ac 100644 --- a/neural_speed/models/whisper/whisper.cpp +++ b/neural_speed/models/whisper/whisper.cpp @@ -458,7 +458,7 @@ struct whisper_state { whisper_kv_cache_t kv_cross; whisper_mel_t mel; - whisper_decoder_t decoders[WHISPER_MAX_DECODERS] = {}; + whisper_decoder_t decoders[WHISPER_MAX_DECODERS]; // memory buffers used by encode / decode contexts model_ctx_buffer buf_compute;