diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 3fc640c6fb341..bca54a72fe8a7 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -51,10 +51,10 @@ class MatMulNBits final : public OpKernel { const size_t nbits_; const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; - size_t packed_b_size_; - bool is_asym_; - bool all_constant_; - int64_t accuracy_level_; + size_t packed_b_size_{0}; + bool is_asym_{false}; + bool all_constant_{false}; + int64_t accuracy_level_{0}; }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, @@ -65,57 +65,48 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat return Status::OK(); } auto compt_type = static_cast(accuracy_level_); - if (MlasIsNBitGemmAvailable(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type)) { - // better to use threadpool here, LLM weight will consume a lot of time - MLAS_THREADPOOL* pool = NULL; - if (input_idx == 1) { - auto qptr = tensor.Data(); - packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type); - packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - if (packed_b_ == nullptr) { - return Status::OK(); - } - MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, false, compt_type, pool); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } - is_packed = true; + MLAS_THREADPOOL* pool = NULL; + if (input_idx == 1) { + packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type); + if (packed_b_size_ == 0) return Status::OK(); + auto qptr = tensor.Data(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + if (packed_b_ == nullptr) { + return Status::OK(); + } + MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, false, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); } - if (input_idx == 2) { - auto sptr = tensor.Data(); - if (packed_b_ == nullptr) { - return Status::OK(); - } - MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, !is_asym_, compt_type, pool); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } - is_packed = true; + is_packed = true; + } + if (input_idx == 2 && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, !is_asym_, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); } - if (input_idx == 3) { - auto zptr = tensor.Data(); - if (packed_b_ == nullptr) { - return Status::OK(); - } - MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, is_asym_, compt_type, pool); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } - is_packed = true; + is_packed = true; + } + if (input_idx == 3 && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, is_asym_, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); } + is_packed = true; } return Status::OK(); } -Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, - int input_idx, +Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; // Pack three tensors into one buffer @@ -149,8 +140,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { Tensor* y = ctx->Output(0, helper.OutputShape()); // Bail out early if the output is going to be empty - if (y->Shape().Size() == 0) - return Status::OK(); + if (y->Shape().Size() == 0) return Status::OK(); auto* y_data = y->MutableData(); @@ -159,7 +149,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t N = static_cast(helper.N()); const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - std::vector gemm_params(max_len); + std::vector gemm_params(max_len); AllocatorPtr allocator; auto status = ctx->GetTempSpaceAllocator(&allocator); ORT_RETURN_IF_ERROR(status); @@ -172,7 +162,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { gemm_params[i].C = y_data + helper.OutputOffsets()[i]; gemm_params[i].ldc = N; } - MlasNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), (int8_t*)ws_ptr.get(), thread_pool); + MlasNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), reinterpret_cast(ws_ptr.get()), thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index c4fe1bb56c70c..316344ad8c214 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -357,4 +357,4 @@ MlasDequantizeBlockwise( int rows, int columns, MLAS_THREADPOOL* thread_pool - ); \ No newline at end of file + ); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index ca89d90cf8b15..f13bfda037633 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -90,35 +90,20 @@ typedef enum { } MLAS_COMPUTE_TYPE; /** - * @brief Data parameters for Q4 GEMM routine - * C = A * B + Bias - * A must be a float32 matrix - * B must be a quantized and packed int4 blob + * @brief Data parameters for NBits GEMM routine + * C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob * All except C are [in] parameters */ -struct MLAS_NBITS_GEMM_DATA_SIMPLE_PARAMS { +struct MLAS_NBITS_GEMM_DATA_PACKED_PARAMS { const float* A = nullptr; /**< address of A (float32 matrix)*/ - const void* B = nullptr; /**< address of B (quantized and packed int4 blob)*/ + const void* B = nullptr; /**< address of B (packed nbits blob)*/ float* C = nullptr; /**< address of result matrix */ size_t lda = 0; /**< leading dimension of A */ size_t ldc = 0; /**< leading dimension of C*/ }; -/** - * @brief Check if the parameter combination is supported by the runtime device. - * - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization (default 4) - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @return support flag, true if the combination is supported. - */ -bool MLASCALL -MlasIsNBitGemmAvailable(size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type); - /** * @brief Compute the byte size of the parameter combination * @@ -204,7 +189,7 @@ MlasNBitsGemmBatchPackedB( const size_t N, const size_t K, const size_t BatchN, - const MLAS_NBITS_GEMM_DATA_SIMPLE_PARAMS* DataParams, + const MLAS_NBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, MLAS_THREADPOOL* ThreadPool = nullptr -); \ No newline at end of file +); diff --git a/onnxruntime/core/mlas/lib/jblas_defs.h b/onnxruntime/core/mlas/lib/jblas_defs.h index 95d84dda61fa4..8f0dc38ec1d5b 100644 --- a/onnxruntime/core/mlas/lib/jblas_defs.h +++ b/onnxruntime/core/mlas/lib/jblas_defs.h @@ -11,8 +11,7 @@ Licensed under the MIT License. #include "jblas/jit_blas_prologue_b.h" #include "jblas/jit_blas_wrapper.h" -namespace jblas -{ +namespace jblas { template using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< GemmCore_T::ISA, @@ -40,12 +39,12 @@ using tAMX_INT8_US = jblas::gemm::ICoreRowNAmxint8<64, 16>; using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>; using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>; -class ORTThreading : public jblas::parallel::IThreading -{ +class ORTThreading : public jblas::parallel::IThreading { public: ORTThreading(void* tp); void parallel_for(const jblas::parallel::thread_func& func) override; virtual void set_threads(int nthreads) override { assert(0); } + virtual void sync() override { assert(0); } void* mTp; }; diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.cpp b/onnxruntime/core/mlas/lib/jblas_gemm.cpp index 95d866c64b918..d6b906aa331b0 100644 --- a/onnxruntime/core/mlas/lib/jblas_gemm.cpp +++ b/onnxruntime/core/mlas/lib/jblas_gemm.cpp @@ -57,7 +57,7 @@ JblasQ4GemmCompF32( kernel.mProA.reduce({A, K}, &reduceA, M, K, &single); } typename Launcher::BEpiParam blkargs{ - B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(), + B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(), reduceA.template get(), reduceA.lda}; typename Launcher::Param args{M, N, K, B->mBlockSize, {A, K}, {B}, blkargs, {C, N}}; @@ -121,7 +121,7 @@ JblasQ4GemmBatchDriver( const size_t N, const size_t K, const size_t BatchN, - const MLAS_NBITS_GEMM_DATA_SIMPLE_PARAMS* DataParams, + const MLAS_NBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, MLAS_THREADPOOL* ThreadPool ) @@ -222,6 +222,7 @@ size_t JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPUTE_TYPE CompType) { GetCPUDevice(); + // from low precision to high precision switch (CompType) { case CompInt8: if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { @@ -233,7 +234,8 @@ JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPU if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { return JblasQ4BuSize>(int(BlkSize), N, K, isAsym); } - break; + case CompBf16: + case CompFp16: case CompFp32: if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { return JblasQ4BuSize>(int(BlkSize), N, K, isAsym); @@ -242,8 +244,6 @@ JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPU return JblasQ4BuSize>(int(BlkSize), N, K, isAsym); } break; - case CompBf16: - case CompFp16: default: return 0; } @@ -315,7 +315,8 @@ JblasQ4GemmPackB( ); return true; } - break; + case CompBf16: + case CompFp16: case CompFp32: if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { JblaNBitsGemmPackB>( @@ -329,9 +330,6 @@ JblasQ4GemmPackB( ); return true; } - break; - case CompBf16: - case CompFp16: default: return false; } diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.h b/onnxruntime/core/mlas/lib/jblas_gemm.h index d84f815ee1c45..28c6dc729792c 100644 --- a/onnxruntime/core/mlas/lib/jblas_gemm.h +++ b/onnxruntime/core/mlas/lib/jblas_gemm.h @@ -45,7 +45,7 @@ JblasQ4GemmBatchDriver( const size_t N, const size_t K, const size_t BatchN, - const MLAS_NBITS_GEMM_DATA_SIMPLE_PARAMS* DataParams, + const MLAS_NBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, MLAS_THREADPOOL* ThreadPool ); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index aa560638decee..48d975a7fd26d 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -1059,4 +1059,4 @@ MlasDequantizeBlockwise( int rows, int columns, MLAS_THREADPOOL* thread_pool - ); \ No newline at end of file + ); diff --git a/onnxruntime/core/mlas/lib/q4gemm.cpp b/onnxruntime/core/mlas/lib/q4gemm.cpp index 289c8d0f3d985..a734f53432bb6 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm.cpp @@ -176,4 +176,4 @@ MlasQ8Q4GemmBatch( ) { MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); -} \ No newline at end of file +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index ad66013649c6b..03d20b426ec43 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -166,12 +166,6 @@ MlasNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsy return 0; } -bool MLASCALL -MlasIsNBitGemmAvailable(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_COMPUTE_TYPE CompType) -{ - return MlasNBitsGemmPackBSize(N, K, BlkSize, nbits, isAsym, CompType) > 0; -} - void MLASCALL MlasNBitsGemmPackB( void* PackedBuf, @@ -233,7 +227,7 @@ MlasNBitsGemmBatchPackedB( const size_t N, const size_t K, const size_t BatchN, - const MLAS_NBITS_GEMM_DATA_SIMPLE_PARAMS* DataParams, + const MLAS_NBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, MLAS_THREADPOOL* ThreadPool ) diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h index c2f08a0fc9770..143adb771760b 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h @@ -73,9 +73,9 @@ class JitBase : protected Xbyak::CodeGenerator { jb(".maskflag"); cmp(_tmp, 0); jl(".zeroflag"); - uint64_t allmask = ((uint64_t)1 << N) - 1; + uint64_t allmask = (static_cast(1) << N) - 1; if (N == 64) { - allmask = (uint64_t)-1; + allmask = static_cast(-1); } mov(_tmp, allmask); kmovq(_msk, _tmp); @@ -256,19 +256,19 @@ class JitAmxtile : protected JitAvx512f { // Configure C tiles int t = 0; for (; t < CNum; ++t) { - tc.rows[t] = uint8_t(TILE_M); - tc.colb[t] = uint16_t(TILE_N * 4); + tc.rows[t] = static_cast(TILE_M); + tc.colb[t] = static_cast(TILE_N * 4); } // Configure A tiles for (; t < CNum + ANum; ++t) { - tc.rows[t] = uint8_t(TILE_M); - tc.colb[t] = uint16_t(TILE_K * elesize); + tc.rows[t] = static_cast(TILE_M); + tc.colb[t] = static_cast(TILE_K * elesize); } // Configure B tile. B effectively has 64 rows and 16 columns. int kpack = 4 / elesize; for (; t < CNum + ANum + BNum; ++t) { - tc.rows[t] = uint8_t(TILE_K / kpack); - tc.colb[t] = uint16_t(TILE_N * 4); + tc.rows[t] = static_cast(TILE_K / kpack); + tc.colb[t] = static_cast(TILE_N * 4); } } }; diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h index 2ecb6630d3da0..8ecf3535c17f4 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h @@ -93,4 +93,4 @@ enum class JBLAS_PROLOGUEB_IDS : uint32_t { WeightKBlockF4, KBlockEnd, End, -}; \ No newline at end of file +}; diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h index bf3f45db18f62..5cac1080bc610 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h @@ -274,5 +274,4 @@ class CpuBase { int mNumThreads; }; } // namespace device - } // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h index 1cab348459a0c..ca82c9308a936 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h @@ -43,10 +43,10 @@ class AccumulatorWriteBack { static_assert(Valid, "fp32 to bf16 conversion only."); if constexpr (std::is_same::value) { return kernel::wrapper::Memcpy2DFp32CvtBf16::template forward( - (void*)cacheptr, (void*)cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); + const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); } else if constexpr (std::is_same, std::tuple>::value) { return kernel::wrapper::Memcpy2DFp16CvtFp32::template forward( - (void*)cacheptr, (void*)cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); + const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); } else if constexpr (sizeof(SType) == sizeof(DType)) { return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep, _param.ldc, _param.elt_const_v, ops...); @@ -132,19 +132,21 @@ class CompFp32BlockEpilogue { auto ret = JblasNotSupport; if (_param.scaledtype == JBLAS_DTYPE::F32) { ret = kernel::wrapper::CompFp32BlockScale::template forward( - (float*)_param.scales + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, cachestep, M, N); + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, + cachestep, M, N); assert(ret == JblasSuccess); if (_param.zps != nullptr) { ret = kernel::wrapper::RemoveZeroPointBias::forward_wei( dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset, - (float*)_param.scales + K_offset * _param.ldsb + N_offset, _param.ldra, + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, _param.ldra, _param.reduce + M_offset * _param.ldra + K_offset); } assert(ret == JblasSuccess); return ret; } else if (_param.scaledtype == JBLAS_DTYPE::BF16) { ret = kernel::wrapper::CompFp32BlockScale::template forward( - (utils::bf16*)_param.scales + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, cachestep, M, N); + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, + cachestep, M, N); if (_param.zps != nullptr) { assert(0); } @@ -202,31 +204,34 @@ class CompInt8BlockEpilogue { size_t ReduceBTmpSize = N * sizeof(float); assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); if (_param.scaleBdtype == JBLAS_DTYPE::BF16) { - auto scache = (float*)tmpcache; + auto scache = reinterpret_cast(tmpcache); ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - (utils::bf16*)_param.scalesB + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, false); + reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, + false); assert(ret == JblasSuccess); scab = scache; } else if (_param.scaleBdtype == JBLAS_DTYPE::F32) { - scab = (float*)_param.scalesB + N_offset + K_offset * _param.ldsb; + scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb; } float* redb = nullptr; if (_param.reduceB) { if (_param.reduceBdtype == JBLAS_DTYPE::BF16) { - auto rcache = (float*)((char*)tmpcache + ScaleBTmpSize); + auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - (utils::bf16*)_param.reduceB + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, false); + reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, + false); assert(ret == JblasSuccess); redb = rcache; } else if (_param.reduceBdtype == JBLAS_DTYPE::F32) { - redb = (float*)_param.reduceB + N_offset + K_offset * _param.ldsb; + redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb; } } - ret = kernel::wrapper::DequanS32Fp32::template forward(srcptr, cachestep, (float*)srcptr, cachestep, M, N, - _param.scalesA + M_offset * _param.ldsa + K_offset, - _param.ldsa, scab); + ret = kernel::wrapper::DequanS32Fp32::template forward( + srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N, + _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab); assert(ret == JblasSuccess); - ret = kernel::wrapper::AccumulateFp32::template forward((float*)srcptr, cachestep, dstptr, cachestep, M, N); + ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep, + dstptr, cachestep, M, N); assert(ret == JblasSuccess); if (_param.zpA == nullptr) { diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h index e024dcd45643e..364da9223940f 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h @@ -1946,6 +1946,518 @@ class AmxConfigure : protected jblas::xbyak::JitAmxtile { func_t mKernel = nullptr; }; +namespace kblock { +// optimize for kblock gemm, each block size in k dimension has dequant operation +// all accumulators use fp32 dtype. +template +class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { + public: + static int constexpr RegLen = 16, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1 - NRegs) / (NRegs * 2) : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_FP32; + typedef uint8_t AType; + typedef int8_t BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + uint8_t* zpA; + float* scaleA; + int ldsa; + float* scaleB; + float* reduceB; + int ldsb; + int k; + int n; + int kblock; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_iterkb; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_tmp4; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = NRegs; + CReg = 0; + CF32Reg = CReg + CRegCount; + BReg = CF32Reg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg < RegCount); + TmpRegCount = RegCount - TmpReg; + assert(TmpRegCount >= 1); + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_iterkb = st.t[12]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_tmp4 = st.t[11]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + xor_(reg_iterkb, reg_iterkb); + L(".kloop"); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); + } + } + xor_(reg_tmp2, reg_tmp2); + load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); + mov(reg_tmp, reg_tmp3); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kbloop", T_NEAR); + L(".unkbloop"); + generate_fma(_mtile, KUNROLL, reg_tmp1); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_tmp2, KUNROLL * KTILE); + cmp(reg_tmp2, reg_tmp); + jb(".unkbloop"); + cmp(reg_tmp, reg_tmp3); + jge(".kend", T_NEAR); + L(".kbloop"); + generate_fma(_mtile, 1, reg_tmp1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_tmp2, 1 * KTILE); + cmp(reg_tmp2, reg_tmp3); + jb(".kbloop"); + L(".kend"); + add(reg_iterk, reg_tmp2); + generate_f32_accumulate(_mtile); + generate_zp_correction(_mtile); + inc(reg_iterkb); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { + for (int kk = 0; kk < _ktile; kk++) { + lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void generate_f32_accumulate(int _mtile) { + load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + + mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); + load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_tmp2 + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(Xbyak::Zmm(TmpReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(TmpReg), Xbyak::Zmm(BReg + i)); + vmulps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(AReg)); + vaddps(Xbyak::Zmm(CF32Reg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); + } + } + } + + void generate_zp_correction(int _mtile) { + load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp1, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + auto& reg_redB = reg_tmp2; + + mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); + auto& reg_zpA = reg_tmp; + + mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); + auto& reg_scaleA = reg_tmp1; + + load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); + auto& reg_ldsa = reg_tmp3; + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_redB + i * VecBytes]); + } + + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(Xbyak::Zmm(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg)); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), zword_b[reg_scaleA]); + for (int j = 0; j < NRegs; j++) { + vmulps(Xbyak::Zmm(CReg + j), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + j)); + vsubps(Xbyak::Zmm(CF32Reg + i * NRegs + j), Xbyak::Zmm(CReg + j)); + } + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +} // namespace kblock } // namespace code template