diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index e82219a0aff64..2384c72efe0a4 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -87,6 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) +option(onnxruntime_USE_JLAS "Build MLAS with JBLAS support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index a8b465f4e1c1d..b4200768e2b0a 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -2,7 +2,6 @@ # Licensed under the MIT License. set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib) -set(MLAS_WITH_JBLAS ON) # # All hardware agnostic source files here # hardware specific files would cause trouble in @@ -47,7 +46,7 @@ set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) function(add_jblas) add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) - target_compile_definitions(onnxruntime_mlas PUBLIC MLAS_JBLAS) + target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_JBLAS) set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) endfunction() @@ -204,7 +203,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/q4gemm_avx512.cpp ) endif() - if(MLAS_WITH_JBLAS) + if(onnxruntime_USE_JLAS) add_jblas() endif() else() @@ -569,7 +568,7 @@ else() ) set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - if(MLAS_WITH_JBLAS) + if(onnxruntime_USE_JLAS) add_jblas() endif() endif() diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index fdc3fa01fdf1c..67215ea7ae5e3 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -51,29 +51,29 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; - - if (accuracy_level_ > 0 && nbits_ == 4) { + auto compt_type = static_cast(accuracy_level_); + if (MlasNBitsGemmPackBSupport(N_, K_, block_size_, nbits_, is_asym_, compt_type)) { // TODO use threadpool here MLAS_THREADPOOL* pool = NULL; if (input_idx == 1) { auto qptr = tensor.Data(); - packed_b_size_ = MlasJblasQ4GemmPackBSize(N_, K_, block_size_, is_asym_, static_cast(accuracy_level_)); + packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, nbits_, is_asym_, compt_type); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasJblasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, is_asym_, false, static_cast(accuracy_level_), pool); + MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits_, is_asym_, false, compt_type, pool); prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); is_packed = true; } if (input_idx == 2) { auto sptr = tensor.Data(); - MlasJblasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, is_asym_, !is_asym_, static_cast(accuracy_level_), pool); + MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits_, is_asym_, !is_asym_, compt_type, pool); prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); is_packed = true; } if (input_idx == 3) { auto zptr = tensor.Data(); - MlasJblasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, is_asym_, is_asym_, static_cast(accuracy_level_), pool); + MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits_, is_asym_, is_asym_, compt_type, pool); prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); is_packed = true; @@ -87,7 +87,7 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; - + // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); @@ -140,7 +140,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { gemm_params[i].C = y_data + helper.OutputOffsets()[i]; gemm_params[i].ldc = N; } - MlasJblasQ4GemmBatch(M, N, K, max_len, gemm_params.data(), (int8_t*)ws_ptr.get(), thread_pool); + MlasNBitsGemmBatch(M, N, K, max_len, gemm_params.data(), (int8_t*)ws_ptr.get(), thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 61ff05efa689f..bef8657b8e87d 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -30,10 +30,10 @@ Module Name: * @brief Define types of block quantization */ typedef enum { - BlkQ4Sym = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */ - BlkQ4Zp8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */ - BlkQ4Sym64 = 2, /*!< int4 Symmetric Block Quantization, 64 values per block*/ - BlkQ4Sym128 = 4, /*!< int4 Symmetric Block Quantization, 128 values per block*/ + BlkQ4Sym = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */ + BlkQ4Zp8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */ + BlkQ4Sym64 = 2, /*!< int4 Symmetric Block Quantization, 64 values per block*/ + BlkQ4Sym128 = 4, /*!< int4 Symmetric Block Quantization, 128 values per block*/ } MLAS_BLK_QUANT_TYPE; /** @@ -323,7 +323,10 @@ MlasDequantizeBlockwise(ElementT* dst, int columns, MLAS_THREADPOOL* thread_pool); -#ifdef MLAS_JBLAS +bool MLASCALL +MlasNBitsGemmPackBSupport( + size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_COMPUTE_TYPE CompType); + /** * @brief Computes the number of bytes required to pack and int4-quantize * a weight matrix @@ -333,8 +336,8 @@ MlasDequantizeBlockwise(ElementT* dst, * @return size of the packing buffer, 0 if the operation is not yet supported. */ size_t MLASCALL -MlasJblasQ4GemmPackBSize( - size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPUTE_TYPE CompType); +MlasNBitsGemmPackBSize( + size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_COMPUTE_TYPE CompType); /** * @brief Prepack and Quantize fp32 weight tensor to int4 blocks @@ -347,40 +350,19 @@ MlasJblasQ4GemmPackBSize( * @param ldb leading dimension of B */ void MLASCALL -MlasJblasQ4GemmPackB(void* PackedBuf, - const float* FpData, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - MLAS_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool); - -/** - * @brief Prepack and Quantize fp32 weight tensor to int4 blocks - * - * @param QType type of block quantization - * @param PackedBuf destination buffer - * @param FpData the pointer to fp32 matrix - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - */ -void MLASCALL -MlasJblasNBitsGemmPackB(void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool); - +MlasNBitsGemmPackB(void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + int nbits, + bool isAsym, + bool lastCall, + MLAS_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool); /** * @brief Unpack and dequantize from int4 to fp32, reverse operation of * MlasQ4GemmPackB @@ -392,12 +374,12 @@ MlasJblasNBitsGemmPackB(void* PackedBuf, * @param ldb leading dimension of B */ void MLASCALL -MlasJblasQ4GemmUnPackB(float* FpData, - const void* PackedBuf, - size_t N, - size_t K, - size_t ldb, - MLAS_THREADPOOL* ThreadPool); +MlasNBitsGemmUnPackB(float* FpData, + const void* PackedBuf, + size_t N, + size_t K, + size_t ldb, + MLAS_THREADPOOL* ThreadPool); /** * @brief Calculate the buffer size needed for int8 block quantize @@ -407,13 +389,11 @@ MlasJblasQ4GemmUnPackB(float* FpData, * @return buffer size (in bytes) needed, 0 if not yet supported on current * hardware */ - void MLASCALL -MlasJblasQ4GemmBatch(const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool = nullptr); -#endif +MlasNBitsGemmBatch(const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, + int8_t* WorkSpace, + MLAS_THREADPOOL* ThreadPool = nullptr); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index ab7c2af3d25a7..a9a6cf07ed9cf 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -18,7 +18,6 @@ Module Name: --*/ #ifdef MLAS_JBLAS -#include "../framework/allocator.h" #include "mlas_jblas_defs.h" using namespace jblas; #endif @@ -32,228 +31,6 @@ BlkQ4BufSize(size_t N, size_t K) return N * KBlocks * T::BlobSize; } -#ifdef MLAS_JBLAS - -template -static size_t -JblasQ4BuSize(T& launcher, int block_size, size_t N, size_t K, bool isAsym) -{ - auto stor = launcher.mProB.createStorage(N, K, block_size, JBLAS_DTYPE::S4_CLIP, - JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, isAsym); - // TODO(Yu) support more S4 quant type, scale dtype - return stor.mSize; -} - -size_t MLASCALL -MlasJblasQ4GemmPackBSize( - size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPUTE_TYPE CompType) -{ - GetCPUDevice(); - switch (CompType) { - case CompInt8: - if (_cd->AVX512_VNNI()) { - return JblasQ4BuSize(JblasAvx512VnniS4Fp32Fp32, int(BlkSize), N, K, isAsym); - } - if (_cd->AVX_VNNI()) { - return JblasQ4BuSize(JblasAvxVnniS4Fp32Fp32, int(BlkSize), N, K, isAsym); - } - break; - case CompFp32: - if (_cd->AVX512F()) { - return JblasQ4BuSize(JblasAvx512fS4Fp32Fp32, int(BlkSize), N, K, isAsym); - } - if (_cd->AVX2()) { - return JblasQ4BuSize(JblasAvx2S4Fp32Fp32, int(BlkSize), N, K, isAsym); - } - break; - case CompBf16: - case CompFp16: - default: - break; - } -} - -template -void -JblaQ4GemmPackB(T& JblasKernel, - int BlkSize, - void* PackedBuf, - const float* FpData, - int N, - int K, - bool IsAsym, - int ldb, - MLAS_THREADPOOL* ThreadPool) -{ - auto stor = JblasKernel.mProB.createStorage(N, K, BlkSize, JBLAS_DTYPE::S4_CLIP, - JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym); - stor.assign((int8_t*)PackedBuf); - ORTThreading orth(ThreadPool); - JblasKernel.mProB.packWeight(N, K, FpData, ldb, &stor, &orth); -} - -void MLASCALL -MlasJblasQ4GemmPackB(void* PackedBuf, - const float* FpData, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - MLAS_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool) -{ - GetCPUDevice(); - switch (CompType) { - case CompInt8: - if (_cd->AVX512_VNNI()) { - return JblaQ4GemmPackB(JblasAvxVnniS4Fp32Fp32, int(BlkSize), PackedBuf, FpData, - int(N), int(K), isAsym, int(ldb), ThreadPool); - } - if (_cd->AVX_VNNI()) { - return JblaQ4GemmPackB(JblasAvxVnniS4Fp32Fp32, int(BlkSize), PackedBuf, FpData, - int(N), int(K), isAsym, int(ldb), ThreadPool); - } - break; - case CompFp32: - if (_cd->AVX512F()) { - return JblaQ4GemmPackB(JblasAvx512fS4Fp32Fp32, int(BlkSize), PackedBuf, FpData, - int(N), int(K), isAsym, int(ldb), ThreadPool); - } - if (_cd->AVX2()) { - return JblaQ4GemmPackB(JblasAvx2S4Fp32Fp32, int(BlkSize), PackedBuf, FpData, int(N), - int(K), isAsym, int(ldb), ThreadPool); - } - break; - case CompBf16: - case CompFp16: - default: - break; - } -} - -template -void -JblaNBitsGemmPackB(T& JblasKernel, - void* PackedBuf, - int BlkSize, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - int N, - int K, - bool IsAsym, - bool lastCall, - int ldb, - MLAS_THREADPOOL* ThreadPool) -{ - auto stor = JblasKernel.mProB.createStorage(N, K, BlkSize, JBLAS_DTYPE::S4_CLIP, - JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym); - stor.assign((int8_t*)PackedBuf); - ORTThreading orth(ThreadPool); - JblasKernel.mProB.packNbitsWeight(N, K, IsAsym, QData, ldb, Scale, Zp, &stor, &orth); - if (lastCall) { - JblasKernel.mProB.reduceWeight(&stor, &orth); - } -} - -void MLASCALL -MlasJblasNBitsGemmPackB(void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool) -{ - GetCPUDevice(); - switch (CompType) { - case CompInt8: - if (_cd->AVX512_VNNI()) { - return JblaNBitsGemmPackB(JblasAvx512VnniS4Fp32Fp32, PackedBuf, int(BlkSize), QData, - Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb), - ThreadPool); - } - if (_cd->AVX_VNNI()) { - return JblaNBitsGemmPackB(JblasAvxVnniS4Fp32Fp32, PackedBuf, int(BlkSize), QData, - Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb), - ThreadPool); - } - break; - case CompFp32: - if (_cd->AVX512F()) { - return JblaNBitsGemmPackB(JblasAvx512fS4Fp32Fp32, PackedBuf, int(BlkSize), QData, - Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb), - ThreadPool); - } - if (_cd->AVX2()) { - return JblaNBitsGemmPackB(JblasAvx2S4Fp32Fp32, PackedBuf, int(BlkSize), QData, - Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb), - ThreadPool); - } - break; - case CompBf16: - case CompFp16: - default: - break; - } -} - -void MLASCALL -MlasJblasQ4GemmUnPackB(float* FpData, - const void* PackedBuf, - size_t N, - size_t K, - size_t ldb, - MLAS_THREADPOOL* ThreadPool) -{ - auto ptr = - jblas::storage::gemm::PackedWeightParser::deserialBuffer(const_cast(PackedBuf)); - ORTThreading orth(ThreadPool); - GetCPUDevice(); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto NTile = - jblas::gemm::CoreAttr::get_mask_val(ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, - jblas::gemm::CoreAttr::NTILE_SHIFT); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT); - if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) { - if (NTile == 48 && _cd->AVX512F()) { - JblasAvx512fS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), - &orth); - return; - } - if (NTile == 24 && _cd->AVX2()) { - JblasAvx2S4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), - &orth); - return; - } - } - if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == 48 && _cd->AVX512_VNNI()) { - JblasAvx512VnniS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, - int(ldb), &orth); - return; - } - if (NTile == 24 && _cd->AVX_VNNI()) { - JblasAvxVnniS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), - &orth); - return; - } - } - } - delete ptr; - } -} - -#endif - size_t MLASCALL MlasQ4GemmPackBSize(MLAS_BLK_QUANT_TYPE QType, size_t N, size_t K) { @@ -1095,3 +872,271 @@ MlasDequantizeBlockwise(float* dst, int rows, int columns, MLAS_THREADPOOL* thread_pool); + +#ifdef MLAS_JBLAS + +template +static size_t +JblasQ4BuSize(int block_size, size_t N, size_t K, bool isAsym) +{ + static T launcher; + auto stor = launcher.mProB.createStorage(N, K, block_size, JBLAS_DTYPE::S4_CLIP, + JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, isAsym); + // TODO(Yu) support more S4 quant type, scale dtype + return stor.mSize; +} + +static size_t +JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPUTE_TYPE CompType) +{ + GetCPUDevice(); + switch (CompType) { + case CompInt8: + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { + return JblasQ4BuSize>(int(BlkSize), N, K, + isAsym); + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { + return JblasQ4BuSize>(int(BlkSize), N, K, + isAsym); + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { + return JblasQ4BuSize>(int(BlkSize), N, K, + isAsym); + } + break; + case CompFp32: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + return JblasQ4BuSize>(int(BlkSize), N, K, + isAsym); + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + return JblasQ4BuSize>(int(BlkSize), N, K, isAsym); + } + break; + case CompBf16: + case CompFp16: + default: + return 0; + } +} + +template +void +JblaNBitsGemmPackB(void* PackedBuf, + int BlkSize, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + int N, + int K, + bool IsAsym, + bool lastCall, + int ldb, + MLAS_THREADPOOL* ThreadPool) +{ + static T JblasKernel; + auto stor = JblasKernel.mProB.createStorage(N, K, BlkSize, JBLAS_DTYPE::S4_CLIP, + JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym); + stor.assign((int8_t*)PackedBuf); + ORTThreading orth(ThreadPool); + JblasKernel.mProB.packNbitsWeight(N, K, IsAsym, QData, ldb, Scale, Zp, &stor, &orth); + if (lastCall) { + JblasKernel.mProB.reduceWeight(&stor, &orth); + } +} + +static bool +JblasQ4GemmPackB(void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + bool isAsym, + bool lastCall, + MLAS_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool) +{ + GetCPUDevice(); + switch (CompType) { + case CompInt8: + if (_cd->AMX_INT8()) { + JblaNBitsGemmPackB>( + PackedBuf, int(BlkSize), QData, Scale, Zp, int(N), int(K), isAsym, lastCall, + int(ldb), ThreadPool); + return true; + } + if (_cd->AVX512_VNNI()) { + JblaNBitsGemmPackB>( + PackedBuf, int(BlkSize), QData, Scale, Zp, int(N), int(K), isAsym, lastCall, + int(ldb), ThreadPool); + return true; + } + if (_cd->AVX_VNNI()) { + JblaNBitsGemmPackB>( + PackedBuf, int(BlkSize), QData, Scale, Zp, int(N), int(K), isAsym, lastCall, + int(ldb), ThreadPool); + return true; + } + break; + case CompFp32: + if (_cd->AVX512F()) { + JblaNBitsGemmPackB>( + PackedBuf, int(BlkSize), QData, Scale, Zp, int(N), int(K), isAsym, lastCall, + int(ldb), ThreadPool); + return true; + } + if (_cd->AVX2()) { + JblaNBitsGemmPackB>( + PackedBuf, int(BlkSize), QData, Scale, Zp, int(N), int(K), isAsym, lastCall, + int(ldb), ThreadPool); + return true; + } + break; + case CompBf16: + case CompFp16: + default: + return false; + } +} + +static bool +JblasQ4GemmUnPackB(float* FpData, + const void* PackedBuf, + size_t N, + size_t K, + size_t ldb, + MLAS_THREADPOOL* ThreadPool) +{ + auto ptr = + jblas::storage::gemm::PackedWeightParser::deserialBuffer(const_cast(PackedBuf)); + ORTThreading orth(ThreadPool); + GetCPUDevice(); + if (ptr) { + if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { + auto NTile = + jblas::gemm::CoreAttr::get_mask_val(ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, + jblas::gemm::CoreAttr::NTILE_SHIFT); + auto CType = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT); + if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), &orth); + goto __END; + } + if (NTile == tAVX2::NTILE && _cd->AVX2()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), &orth); + goto __END; + } + } + if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { + static jblas::prologue_b::gemm::WeightKBlockS4 + proB; + proB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), &orth); + goto __END; + } + if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { + static jblas::prologue_b::gemm::WeightKBlockS4 + proB; + proB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), &orth); + goto __END; + } + if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), &orth); + goto __END; + } + } + if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_SS_INT32)) { + if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { + static jblas::prologue_b::gemm::WeightKBlockS4 + proB; + proB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), &orth); + goto __END; + } + } + } + __END: + delete ptr; + return true; + } + return false; +} + +#endif + +/** + * @brief Computes the number of bytes required to pack and int4-quantize + * a weight matrix + * @param QType type of block quantization + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @return size of the packing buffer, 0 if the operation is not yet supported. + */ +size_t MLASCALL +MlasNBitsGemmPackBSize( + size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_COMPUTE_TYPE CompType) +{ +#ifdef MLAS_JBLAS + if (nbits == 4) { + auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); + if (jsize) { + return jsize; + } + } +#endif + return 0; +} + +bool MLASCALL +MlasNBitsGemmPackBSupport( + size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_COMPUTE_TYPE CompType) +{ + return MlasNBitsGemmPackBSize(N, K, BlkSize, nbits, isAsym, CompType) > 0; +} + +void MLASCALL +MlasNBitsGemmPackB(void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + int nbits, + bool isAsym, + bool lastCall, + MLAS_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool) +{ +#ifdef MLAS_JBLAS + if (nbits == 4) { + if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, + CompType, ThreadPool)) { + return; + } + } +#endif +} + +void MLASCALL +MlasNBitsGemmUnPackB(float* FpData, + const void* PackedBuf, + size_t N, + size_t K, + size_t ldb, + MLAS_THREADPOOL* ThreadPool) +{ +#ifdef MLAS_JBLAS + if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { + return; + } +#endif +} diff --git a/onnxruntime/core/mlas/lib/q4gemm.cpp b/onnxruntime/core/mlas/lib/q4gemm.cpp index ca22fe984bff1..6acd3eeeb3f1e 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm.cpp @@ -140,6 +140,30 @@ MlasQ4GemmBatchDriver(MLAS_BLK_QUANT_TYPE QType, }); } +void MLASCALL +MlasQ4GemmBatch(MLAS_BLK_QUANT_TYPE QType, + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool) +{ + MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); +} + +void MLASCALL +MlasQ8Q4GemmBatch(MLAS_BLK_QUANT_TYPE QType, + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool) +{ + MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); +} + #ifdef MLAS_JBLAS jblas::ORTThreading::ORTThreading(void* tp) @@ -169,7 +193,7 @@ JblasQ4GemmCompF32(const int M, { if (M <= 32) { using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = JBLAS_FP32_S4_F32F32; + using Launcher = tLauncher_Fp32_S4_F32F32; static Launcher kernel; auto reduceA = kernel.mProA.createStorage(M, K, B->mBlockSize); if (B->mIsAsym) { @@ -211,7 +235,7 @@ JblasQ4GemmCompInt8(const int M, jblas::parallel::IThreading* th) { using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = JBLAS_INT8_S4_F32F32; + using Launcher = tLauncher_Int8_S4_F32F32; static Launcher kernel; auto quanA = kernel.mProA.createStorage(M, K, B->mBlockSize, B->mIsAsym); @@ -236,7 +260,7 @@ JblasQ4GemmCompInt8(const int M, jblas::parallel::GemmKBlockRun(kernel, args, th); } -void +static bool JblasQ4GemmBatchDriver(const size_t M, const size_t N, const size_t K, @@ -247,6 +271,7 @@ JblasQ4GemmBatchDriver(const size_t M, { GetCPUDevice(); ORTThreading orth(ThreadPool); + bool processed = true; for (size_t i = 0; i < BatchN; i++) { auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer( const_cast(DataParams[i].B)); @@ -261,76 +286,79 @@ JblasQ4GemmBatchDriver(const size_t M, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT); if (CType == uint32_t(gemm::CompType::COMP_FP32)) { - if (NTile == 48 && _cd->AVX512F()) { - JblasQ4GemmCompF32>( + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + JblasQ4GemmCompF32( M, N, K, DataParams[i].A, DataParams[i].lda, (jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); - return; + goto __END; + } + if (NTile == tAVX2::NTILE && _cd->AVX2()) { + JblasQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, + (jblas::storage::gemm::StorageWeightKBlockS4*)ptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + goto __END; } - if (NTile == 24 && _cd->AVX2()) { - JblasQ4GemmCompF32>( + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { + JblasQ4GemmCompInt8( M, N, K, DataParams[i].A, DataParams[i].lda, (jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); - return; + goto __END; } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == 48 && _cd->AVX512_VNNI()) { - JblasQ4GemmCompInt8>( + if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { + JblasQ4GemmCompInt8( M, N, K, DataParams[i].A, DataParams[i].lda, (jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); - return; + goto __END; } - if (NTile == 24 && _cd->AVX_VNNI()) { - JblasQ4GemmCompInt8>( + if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { + JblasQ4GemmCompInt8( M, N, K, DataParams[i].A, DataParams[i].lda, (jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); - return; + goto __END; + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { + if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { + JblasQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, + (jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C, + DataParams[i].ldc, WorkSpace, &orth); + goto __END; } } } + __END: delete ptr; + } else { + processed = false; + break; } } + return processed; } -void MLASCALL -MlasJblasQ4GemmBatch(const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool) -{ - JblasQ4GemmBatchDriver(M, N, K, BatchN, DataParams, WorkSpace, ThreadPool); -} #endif void MLASCALL -MlasQ4GemmBatch(MLAS_BLK_QUANT_TYPE QType, - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool) -{ - MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); -} - -void MLASCALL -MlasQ8Q4GemmBatch(MLAS_BLK_QUANT_TYPE QType, - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool) +MlasNBitsGemmBatch(const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, + int8_t* WorkSpace, + MLAS_THREADPOOL* ThreadPool) { - MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); +#ifdef MLAS_JBLAS + if (JblasQ4GemmBatchDriver(M, N, K, BatchN, DataParams, WorkSpace, ThreadPool)) { + // PackedWeight is created by jblas + return; + } +#endif } diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h index 92d1c1d0c1db6..58629d09b7a54 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h @@ -423,16 +423,21 @@ class SchedulerKBlock : public Scheduler2D { size_t maxN = (mL1Size - sizeA) / (mBlock[0] * mEleSize[2] * 2 + refk * mEleSize[1]); return maxN; }; + auto getMaxK = [&](size_t refN) { + size_t sizeC = refN * mEleSize[2] * mBlock[0] * 2; + size_t maxK = (mL1Size - sizeC) / (mBlock[0] * mEleSize[0] + refN * mEleSize[1]); + return maxK; + }; auto maxN = getMaxN(startK); if (maxN <= mThdSize[1]) { mBlock[1] = int(maxN); mBlock[1] = utils::padto_le(mBlock[1], mStep[1]); mBlock[2] = int(startK); } else { - mBlock[2] = int(startK * 2); - mBlock[1] = int(getMaxN(mBlock[2])); - mBlock[1] = utils::padto_le(mBlock[1], mStep[1]); - mBlock[1] = std::min(mBlock[1], mThdSize[1]); + mBlock[1] = mThdSize[1]; + mBlock[2] = getMaxK(mBlock[1]); + mBlock[2] = utils::padto_le(mBlock[2], mStep[2]); + mBlock[2] = std::min(mKBlock, mBlock[2]); } } size_t mL2Size = 0, mL1Size = 0, mL2Use = 0; diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_a.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_a.h index f5aa76c46ff64..b006e0b410cd8 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_a.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_a.h @@ -111,7 +111,8 @@ class ActivationKBlockQuantize { inline QParam createStorage(int m, int k, int kblock, bool hasreduce) { QParam tmp; int kpad = utils::padto(k, _GemmCore_T::KTILE); - tmp.resize(m, kpad, kblock == -1 ? kpad : kblock, JBLAS_DTYPE::U8, JBLAS_DTYPE::F32, JBLAS_DTYPE::U8, + int mpad = utils::padto(m, _GemmCore_T::MTILE); + tmp.resize(mpad, kpad, kblock == -1 ? kpad : kblock, JBLAS_DTYPE::U8, JBLAS_DTYPE::F32, JBLAS_DTYPE::U8, JBLAS_DTYPE::F32, std::is_same_v, hasreduce); return tmp; } diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_b.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_b.h index 0b43e11679062..e75d787603877 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_b.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_b.h @@ -274,20 +274,43 @@ class WeightKBlockS8 { } } } - if (stor->mIsAsym && zero_points) { + } + }); + } else if (stor->mScaT == JBLAS_DTYPE::BF16) { + threading->parallel_for([&](int tidx) { + parallel::ThreadProblem2D thdp{tidx}; + _para.getIndex(thdp); + if (thdp.valid) { + if (scales) { for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) { if (i < rawnk_scale) { for (size_t j = 0; j < N; j++) { - stor->template ZPtr()[i * stor->mNPad + j] = zero_points[j * rawnk_scale + i]; + stor->template SPtr()[i * stor->mNPad + j] = utils::bf16(scales[j * rawnk_scale + i]); } } else { - std::memset(stor->template ZPtr() + i * stor->mNPad, 0, stor->mNPad * sizeof(zero_points[0])); + std::memset(stor->template SPtr() + i * stor->mNPad, 0, stor->mNPad * sizeof(utils::bf16)); } } } } }); } + if (stor->mIsAsym && zero_points) + threading->parallel_for([&](int tidx) { + parallel::ThreadProblem2D thdp{tidx}; + _para.getIndex(thdp); + if (thdp.valid) { + for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) { + if (i < rawnk_scale) { + for (size_t j = 0; j < N; j++) { + stor->template ZPtr()[i * stor->mNPad + j] = zero_points[j * rawnk_scale + i]; + } + } else { + std::memset(stor->template ZPtr() + i * stor->mNPad, 0, stor->mNPad * sizeof(zero_points[0])); + } + } + } + }); } virtual void packQWeight(const int N, const int K, const int8_t* B, const int ldb, const float* scales, @@ -315,7 +338,7 @@ class WeightKBlockS8 { utils::afree(deq); } } - template + template void reduce(const int N, const int K, const int KBlock, const float* B, const int ldb, RED_T* rptr, const int ldr, parallel::IThreading* threading) { parallel::Scheduler2D _para({threading->num_threads(), K, N, KBlock, 16}); diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/mlas_jblas_defs.h b/onnxruntime/core/mlas/lib/x86_64/jblas/mlas_jblas_defs.h index b0d7d1f61e618..daeeb147a22ea 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/mlas_jblas_defs.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/mlas_jblas_defs.h @@ -17,26 +17,26 @@ namespace jblas { template -using JBLAS_FP32_S4_F32F32 = +using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock; template -using JBLAS_INT8_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< +using tLauncher_Int8_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationF32KBlockQuantize, jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::CompInt8BlockEpilogue, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; -using Jblas_Fp32_AVX512F_F32F32 = JBLAS_FP32_S4_F32F32>; -using Jblas_Fp32_AVX2_F32F32 = JBLAS_FP32_S4_F32F32>; -using Jblas_Int8_AVX512F_F32F32 = JBLAS_INT8_S4_F32F32>; -using Jblas_Int8_AVX2_F32F32 = JBLAS_INT8_S4_F32F32>; -static Jblas_Fp32_AVX512F_F32F32 JblasAvx512fS4Fp32Fp32; -static Jblas_Fp32_AVX2_F32F32 JblasAvx2S4Fp32Fp32; -static Jblas_Int8_AVX512F_F32F32 JblasAvx512VnniS4Fp32Fp32; -static Jblas_Int8_AVX2_F32F32 JblasAvxVnniS4Fp32Fp32; +using tAVX512F = jblas::gemm::SCoreRowNAvx512f<48, 8>; +using tAMX_BF16 = jblas::gemm::HCoreRowNAmxbf16<64, 16>; +using tAVX512_FP16 = jblas::gemm::HCoreRowNAvx512fp16<96, 8>; +using tAVX_VNNI = jblas::gemm::ICoreRowNAvxvnni<48, 2>; +using tAVX512_VNNI = jblas::gemm::ICoreRowNAvx512vnni<48, 8>; +using tAMX_INT8_US = jblas::gemm::ICoreRowNAmxint8<64, 16>; +using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>; +using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>; class ORTThreading : public jblas::parallel::IThreading { public: diff --git a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp index f1cfdaba2c2a4..26c40e24690dc 100644 --- a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp @@ -62,7 +62,7 @@ void Q4GEMM_Jblas(benchmark::State& state, int block_size, bool is_asym, MLAS_CO const size_t N = static_cast(state.range(1)); const size_t K = static_cast(state.range(2)); const size_t threads = static_cast(state.range(3)); - const size_t pack_b_size = MlasJblasQ4GemmPackBSize(N, K, block_size, is_asym, cmp_type); + const size_t pack_b_size = MlasNBitsGemmPackBSize(N, K, block_size, 4, is_asym, cmp_type); OrtThreadPoolParams tpo; tpo.thread_pool_size = int(threads); @@ -72,11 +72,14 @@ void Q4GEMM_Jblas(benchmark::State& state, int block_size, bool is_asym, MLAS_CO tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); auto A1 = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); - auto B1 = RandomVectorUniform(static_cast(N * K), -1.0f, 1.0f); + auto B1 = RandomVectorUniform(static_cast(N * K / 2), 0, 255); + auto blk_num = (K + block_size - 1) / block_size; + auto B_scale = RandomVectorUniform(static_cast(N * blk_num), 0.003f, 0.005f); std::vector C1(static_cast(M * N)); + auto B_zp = RandomVectorUniform(static_cast(N * blk_num / 2), 0, 255); std::vector B1_packed(pack_b_size); - MlasJblasQ4GemmPackB(B1_packed.data(), B1.data(), N, K, N, block_size, is_asym, cmp_type, tp.get()); + MlasNBitsGemmPackB(B1_packed.data(), B1.data(), B_scale.data(), is_asym ? B_zp.data() : nullptr, N, K, N, block_size, 4, is_asym, true, cmp_type, tp.get()); MLAS_Q4_GEMM_DATA_PARAMS params1; params1.A = A1.data(); @@ -87,10 +90,10 @@ void Q4GEMM_Jblas(benchmark::State& state, int block_size, bool is_asym, MLAS_CO params1.B = B1_packed.data(); params1.OutputProcessor = nullptr; std::vector workspace(size_t(M) * K * 4); - MlasJblasQ4GemmBatch(M, N, K, 1, ¶ms1, workspace.data(), tp.get()); + MlasNBitsGemmBatch(M, N, K, 1, ¶ms1, workspace.data(), tp.get()); for (auto _ : state) { - MlasJblasQ4GemmBatch(M, N, K, 1, ¶ms1, workspace.data(), tp.get()); + MlasNBitsGemmBatch(M, N, K, 1, ¶ms1, workspace.data(), tp.get()); } }