From 009adb61a3c3d7d52eb1ec4c1afa5fa43df51859 Mon Sep 17 00:00:00 2001 From: "Luo, Yu" Date: Tue, 9 Jan 2024 17:35:30 +0800 Subject: [PATCH] move neural_speed gemms to contrib_ops --- cmake/CMakeLists.txt | 4 +- cmake/external/neural_speed.cmake | 15 +- cmake/onnxruntime_mlas.cmake | 12 - cmake/onnxruntime_providers_cpu.cmake | 12 + .../cpu/quantization}/bestla_defs.h | 18 +- .../cpu/quantization/bestla_gemm.cc | 462 ++++++++++++++++ .../cpu/quantization/bestla_gemm.h | 75 +++ .../cpu/quantization/matmul_nbits.cc | 46 +- onnxruntime/core/mlas/inc/mlas_qnbit.h | 141 ----- onnxruntime/core/mlas/lib/bestla_gemm.cpp | 505 ------------------ onnxruntime/core/mlas/lib/bestla_gemm.h | 61 --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 127 ----- .../test/contrib_ops/matmul_4bits_test.cc | 47 +- .../test/mlas/bench/bench_sqnbitgemm.cpp | 54 -- 14 files changed, 624 insertions(+), 955 deletions(-) rename onnxruntime/{core/mlas/lib => contrib_ops/cpu/quantization}/bestla_defs.h (79%) create mode 100644 onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.cc create mode 100644 onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.h delete mode 100644 onnxruntime/core/mlas/lib/bestla_gemm.cpp delete mode 100644 onnxruntime/core/mlas/lib/bestla_gemm.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index f223a8a8b076f..2e2e9c25210d6 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1177,8 +1177,8 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() - -if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD) +set(USE_NEURAL_SPEED FALSE) +if (onnxruntime_USE_NEURAL_SPEED) include(neural_speed) endif() diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake index ca5cdd95f39ad..ba48929be769f 100644 --- a/cmake/external/neural_speed.cmake +++ b/cmake/external/neural_speed.cmake @@ -1,7 +1,6 @@ -set(BTLA_URL https://github.com/intel/neural-speed.git) -set(BTLA_TAG 368ccbd2823e7ecef862d09e7b2385e6b2553081) # bestla v0.1 +set(NEURAL_SPEED_URL https://github.com/intel/neural-speed.git) +set(NEURAL_SPEED_TAG 18720b319d6921c28e59cc9e003e50cee9a85fcc) # kernel-only release v0.2 -set(USE_NEURAL_SPEED FALSE) if (onnxruntime_USE_NEURAL_SPEED) if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") set(USE_NEURAL_SPEED TRUE) @@ -10,11 +9,11 @@ if (onnxruntime_USE_NEURAL_SPEED) endif() if(USE_NEURAL_SPEED) FetchContent_Declare( - bestla - GIT_REPOSITORY ${BTLA_URL} - GIT_TAG ${BTLA_TAG} + neural_speed + GIT_REPOSITORY ${NEURAL_SPEED_URL} + GIT_TAG ${NEURAL_SPEED_TAG} ) - FetchContent_MakeAvailable(bestla) - add_compile_definitions(MLAS_NEURAL_SPEED) + FetchContent_MakeAvailable(neural_speed) + add_compile_definitions(ORT_NEURAL_SPEED) endif() endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 7b60459ca5884..d12546be65cbf 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -45,14 +45,6 @@ endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) -function(add_neural_speed) - target_link_libraries(onnxruntime_mlas PRIVATE bestla::bestla) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/bestla_gemm.cpp - ) - set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) -endfunction() - #TODO: set MASM flags properly function(setup_mlas_source_for_windows) @@ -611,10 +603,6 @@ else() target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) endif() -if(USE_NEURAL_SPEED) - add_neural_speed() -endif() - foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR}) onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index f60faa4d39116..e5daaff3ede0f 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -60,6 +60,13 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" ) endif() + if(NOT USE_NEURAL_SPEED) + list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/bestla_defs.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/bestla_gemm.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/bestla_gemm.h" + ) + endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) @@ -144,6 +151,11 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL) target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical") endif() +if(USE_NEURAL_SPEED) + target_link_libraries(onnxruntime_providers PRIVATE bestla::bestla) + set_target_properties(onnxruntime_providers PROPERTIES COMPILE_WARNING_AS_ERROR OFF) # ignore warnings inside neural-speed +endif() + if (MSVC) target_compile_options(onnxruntime_providers PRIVATE "/bigobj") # if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) diff --git a/onnxruntime/core/mlas/lib/bestla_defs.h b/onnxruntime/contrib_ops/cpu/quantization/bestla_defs.h similarity index 79% rename from onnxruntime/core/mlas/lib/bestla_defs.h rename to onnxruntime/contrib_ops/cpu/quantization/bestla_defs.h index 0f6b7116c20d9..dc0a358bc1dd7 100644 --- a/onnxruntime/core/mlas/lib/bestla_defs.h +++ b/onnxruntime/contrib_ops/cpu/quantization/bestla_defs.h @@ -11,8 +11,7 @@ Licensed under the MIT License. #include "bestla/bestla_prologue_a.h" #include "bestla/bestla_wrapper.h" -namespace bestla -{ +namespace bestla { using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>; using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>; @@ -33,14 +32,13 @@ using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger; template using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat; -class ORTThreading : public parallel::IThreading -{ - public: - ORTThreading(void* tp); - void parallel_for(const parallel::thread_func& func) const override; - void set_threads(int nthreads) override { assert(0); } - void sync() const override { assert(0); } - void* mTp; +class ORTThreading : public parallel::IThreading { + public: + explicit ORTThreading(void* tp); + void parallel_for(const parallel::thread_func& func) const override; + void set_threads(int nthreads) override { assert(0); } + void sync() const override { assert(0); } + void* mTp; }; } // namespace bestla diff --git a/onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.cc new file mode 100644 index 0000000000000..2953f8c157db7 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.cc @@ -0,0 +1,462 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + bestla_gemm.cpp + +Abstract: + + Currently only support Q4 gemm. +--*/ + +#include "contrib_ops/cpu/quantization/bestla_defs.h" +#include "contrib_ops/cpu/quantization/bestla_gemm.h" +#include "core/platform/threadpool.h" + +using ThreadPool = onnxruntime::concurrency::ThreadPool; +namespace bestla { +ORTThreading::ORTThreading(void* tp) + : IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) { +} + +void ORTThreading::parallel_for(const parallel::thread_func& func) const { + ThreadPool::TrySimpleParallelFor(reinterpret_cast(mTp), mThreadNum, [&](ptrdiff_t tid) { + func(static_cast(tid)); + }); +} + +template +static void +NSSQ4GemmCompF32( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, + float* C, + const size_t ldc, + int8_t* WorkSpace, + parallel::IThreading* th) { + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + if (M <= 16) { + using Parallel = parallel::gemm::SchedulerKBlock; + using Launcher = wrapper::gemm::LauncherKBlock< + GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationKBlockBaseF32, + prologue_b::gemm::WeightKBlockNInteger, epilogue::gemm::CompFp32BlockEpilogue, + epilogue::gemm::AccumulatorWriteBackFp32>; + static Launcher kernel; + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + if (B->IsAsym()) { + reduceA.assign(WorkSpace); + ORTThreading single(nullptr); + kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single); + } + typename Launcher::BEpiParam blkargs{ + B->template SPtr(), B->SDtype(), B->CStep(), B->template ZPtr(), + reduceA.template RPtr(), reduceA.lda}; + + typename Launcher::Param args{gp, {A, lda_, &reduceA}, {B}, blkargs, {C, ldc_}}; + parallel::GemmRun(kernel, args, th); + } else { + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = wrapper::gemm::LauncherBase< + GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationBase, prologue_b::gemm::WeightKBlockNInteger, + epilogue::gemm::AccumulatorWriteBackFp32>; + static Launcher kernel; + typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_}}; + parallel::GemmRun(kernel, args, th); + } +} + +template +static void +NSSQ4GemmCompInt8( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, + float* C, + const size_t ldc, + int8_t* WorkSpace, + parallel::IThreading* th) { + using Parallel = parallel::gemm::SchedulerKBlockS; + using Launcher = wrapper::gemm::LauncherIntKBlock< + GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationF32KBlockQuantize, + prologue_b::gemm::WeightKBlockNInteger, epilogue::gemm::AccumulatorWriteBackFp32>; + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + static Launcher kernel; + auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym()); + quanA.assign(WorkSpace); + if (M <= 16) { + ORTThreading single(nullptr); + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); + } else { + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); + } + utils::GemmProblem gp(1, M, N, K, B->mBlockSize); + typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_}}; + parallel::GemmRun(kernel, args, th); +} + +template +static size_t +NSSQ4GemmCompF32WorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, + float* C, + const size_t ldc) { + auto M_ = static_cast(M); + auto K_ = static_cast(K); + (void)(N); + (void)(lda); + (void)(ldc); + if (M <= 16) { + using ProA = prologue_a::gemm::ActivationKBlockBaseF32; + static ProA proA; + if (B->IsAsym()) { + auto reduceA = proA.createStorage(M_, K_, B->mBlockSize); + return reduceA.mSize; + } + return 0; + } else { + using ProA = prologue_a::gemm::ActivationBase; + return 0; + } + return 0; +} + +template +static size_t +NSSQ4GemmCompInt8WorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, + float* C, + const size_t ldc) { + (void)(N); + (void)(lda); + (void)(ldc); + using ProA = prologue_a::gemm::ActivationF32KBlockQuantize; + static ProA proA; + auto quanA = + proA.createStorage(static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->IsAsym()); + return quanA.mSize; +} + +} // namespace bestla + +using namespace bestla; + +bool NSSQ4GemmBatchDriver( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + int8_t* WorkSpace, + void* ThreadPool) { + GetCPUDevice(); + bestla::ORTThreading orth(ThreadPool); + bool processed = true; + bestla::utils::timer tm; + tm.start(); + for (size_t i = 0; i < BatchN; i++) { + auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + auto NTile = + gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F()) { + bestla::NSSQ4GemmCompF32( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth); + } else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2()) { + bestla::NSSQ4GemmCompF32( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8()) { + bestla::NSSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth); + } else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI()) { + bestla::NSSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth); + } else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI()) { + bestla::NSSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth); + } + } + } + } else { + processed = false; + break; + } + } + if (N == 4096 && K == 14336) { + printf("%f\n", tm.stop()); + } + return processed; +} + +size_t +NSSQ4GemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + GetCPUDevice(); + size_t size = 0; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto NTile = + gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + size = std::max( + NSSQ4GemmCompF32WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { + size = std::max( + NSSQ4GemmCompF32WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8()) { + size = std::max( + NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI()) { + size = std::max( + NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI()) { + size = std::max( + NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } + } + } + } + } + return size; +} + +template +static size_t +NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) { + static T proB; + auto stor = proB.createStorage( + static_cast(N), static_cast(K), static_cast(block_size), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + BTLA_DTYPE::BF16, isAsym); + // TODO(Yu) support more scale dtype + return stor.mSize; +} + +size_t +NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, int64_t CompType) { + GetCPUDevice(); + if (K % BlkSize != 0) { + return 0; + } + // from low precision to high precision + switch (CompType) { + case 4: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + } + case 3: + case 2: + case 1: + case 0: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + break; + default: + return 0; + } + return 0; +} + +template +static void +NSQ4GemmPackBImpl( + void* PackedBuf, + size_t BlkSize, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + bool IsAsym, + bool lastCall, + size_t ldb, + void* ThreadPool) { + static T proB; + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto stor = proB.createStorage( + N_, K_, static_cast(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, IsAsym); + stor.assign(reinterpret_cast(PackedBuf)); + ORTThreading orth(ThreadPool); + proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); + if (lastCall) { + proB.reduceWeight(&stor, &orth); + } +} + +bool NSQ4GemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + bool isAsym, + bool lastCall, + int64_t CompType, + void* ThreadPool) { + GetCPUDevice(); + // explicit statement fall through. + switch (CompType) { + case 4: // int8 + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + } + case 3: // bf16 + case 2: // fp16 + case 1: // fp32 + case 0: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + default: + return false; + } + return false; +} + +bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); + auto uptr = std::unique_ptr(ptr); + ORTThreading orth(ThreadPool); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto ldb_ = static_cast(ldb); + GetCPUDevice(); + if (ptr) { + auto NTile = + gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto wptr = reinterpret_cast(ptr); + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8()) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI()) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI()) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + } + return true; + } + return false; +} diff --git a/onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.h b/onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.h new file mode 100644 index 0000000000000..924b78f4fb05a --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/bestla_gemm.h @@ -0,0 +1,75 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + bestla_gemm.h + +Abstract: + + Currently only support Q4 gemm. +--*/ + +#pragma once + +#include +#include + +/** + * @brief Data parameters for NBits GEMM routine + * C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * All except C are [in] parameters + */ +struct NS_SQNBITS_GEMM_DATA_PACKED_PARAMS { + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (packed nbits blob)*/ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ +}; + +size_t +NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, int64_t accuracy_level); + +bool +NSQ4GemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + bool isAsym, + bool lastCall, + int64_t CompType, + void* ThreadPool +); + +bool +NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb + , void* ThreadPool); + +bool +NSSQ4GemmBatchDriver( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + int8_t* WorkSpace, + void* ThreadPool); + +size_t +NSSQ4GemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index a9703dc68dd26..ddf1be917c775 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -9,6 +9,9 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#ifdef ORT_NEURAL_SPEED +#include "bestla_gemm.h" +#endif namespace onnxruntime { namespace contrib { @@ -64,16 +67,19 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat if (!all_constant_) { return Status::OK(); } - auto compt_type = static_cast(accuracy_level_); +#ifdef ORT_NEURAL_SPEED MLAS_THREADPOOL* pool = NULL; + if (nbits_ != 4) { + return Status::OK(); + } if (input_idx == 1) { - packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type); + packed_b_size_ = NSQ4GemmPackBSize(N_, K_, block_size_, is_asym_, accuracy_level_); if (packed_b_size_ == 0) return Status::OK(); auto qptr = tensor.Data(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); std::memset(packed_b_.get(), 0, packed_b_size_); - MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, false, compt_type, pool); + NSQ4GemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, + is_asym_, false, accuracy_level_, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -82,8 +88,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 2 && packed_b_ != nullptr) { auto sptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, !is_asym_, compt_type, pool); + NSQ4GemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, + is_asym_, !is_asym_, accuracy_level_, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -92,21 +98,27 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 3 && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, is_asym_, compt_type, pool); + NSQ4GemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, + is_asym_, is_asym_, accuracy_level_, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); } is_packed = true; } - +#else + (void)(alloc); + (void)(tensor); + (void)(input_idx); + (void)(prepacked_weights); +#endif return Status::OK(); } Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; +#ifdef ORT_NEURAL_SPEED // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; @@ -120,6 +132,10 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } +#else + (void)(prepacked_buffers); + (void)(input_idx); +#endif return Status::OK(); } @@ -128,7 +144,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); - +#ifdef ORT_NEURAL_SPEED if (packed_b_.get()) { TensorShape b_shape({static_cast(N_), static_cast(K_)}); @@ -147,7 +163,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t N = static_cast(helper.N()); const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - std::vector gemm_params(max_len); + std::vector gemm_params(max_len); AllocatorPtr allocator; auto status = ctx->GetTempSpaceAllocator(&allocator); ORT_RETURN_IF_ERROR(status); @@ -158,14 +174,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { gemm_params[i].C = y_data + helper.OutputOffsets()[i]; gemm_params[i].ldc = N; } - auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); + auto ws_size = NSSQ4GemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); // workspace for activation process(dynamic quantization and others) auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); - MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), - thread_pool); + NSSQ4GemmBatchDriver(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), + thread_pool); return Status::OK(); } - +#endif const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 1e83dd1cec400..9620dd42d1da9 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -77,144 +77,3 @@ MlasIsSQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen ); - -/** - * @brief Define compute types of block quantization - */ -typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32 = 1, /*!< input fp32, accumulator fp32 */ - CompFp16 = 2, /*!< input fp16, accumulator fp16 */ - CompBf16 = 3, /*!< input bf16, accumulator fp32 */ - CompInt8 = 4 /*!< input int8, accumulator int32 */ -} MLAS_SQNBIT_COMPUTE_TYPE; - -/** - * @brief Data parameters for NBits GEMM routine - * C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * All except C are [in] parameters - */ -struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS { - const float* A = nullptr; /**< address of A (float32 matrix)*/ - const void* B = nullptr; /**< address of B (packed nbits blob)*/ - float* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldc = 0; /**< leading dimension of C*/ -}; - -/** - * @brief Compute the byte size of the parameter combination - * - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @return size of the packing buffer, 0 if the operation is not yet supported. - */ -size_t MLASCALL -MlasNBitsGemmPackBSize( - size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type -); - -/** - * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. - * - * @param PackedBuf packed data buffer - * @param QData quantized data buffer - * @param Scale scale pointer - * @param Zp zero point pointer - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization (default 4) - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor - * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where - * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up - * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale - * (is_asym is false) and Zp(is_asym is true). - * @param thread_pool - */ -void MLASCALL -MlasNBitsGemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t block_size, - int nbits, - bool is_asym, - bool last_call, - MLAS_SQNBIT_COMPUTE_TYPE comp_type, - MLAS_THREADPOOL* thread_pool -); - -/** - * @brief Unpack and dequantize to fp32 - * - * @param FpData unpacked float32 data - * @param PackedBuf quantized and packed data - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param thread_pool - */ -void MLASCALL -MlasNBitsGemmUnPackB( - float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool -); - -/** - * @brief Get the workspace size required by computation. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @return Workspace size in bytes - */ -size_t MLASCALL -MlasSQNBitsGemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -); - -/** - * @brief Batched GEMM: C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] WorkSpace temporary buffer - * @param[in] ThreadPool - * @return - */ -void MLASCALL -MlasSQNBitsGemmBatchPackedB( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - void* WorkSpace, - MLAS_THREADPOOL* ThreadPool = nullptr -); diff --git a/onnxruntime/core/mlas/lib/bestla_gemm.cpp b/onnxruntime/core/mlas/lib/bestla_gemm.cpp deleted file mode 100644 index 92fd445bab757..0000000000000 --- a/onnxruntime/core/mlas/lib/bestla_gemm.cpp +++ /dev/null @@ -1,505 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - bestla_gemm.cpp - -Abstract: - - Currently only support Q4 gemm. ---*/ - -#include "bestla_gemm.h" - -#include "bestla_defs.h" -#include "mlasi.h" - -namespace bestla -{ -ORTThreading::ORTThreading(void* tp) - : IThreading(MLAS_THREADPOOL::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) -{ -} - -void -ORTThreading::parallel_for(const parallel::thread_func& func) const -{ - MlasTrySimpleParallel(reinterpret_cast(mTp), mThreadNum, [&](ptrdiff_t tid) { - func(static_cast(tid)); - }); -} - -template -static void -BTLASQ4GemmCompF32( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, - float* C, - const size_t ldc, - int8_t* WorkSpace, - parallel::IThreading* th -) -{ - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); - if (M <= 16) { - using Parallel = parallel::gemm::SchedulerKBlock; - using Launcher = wrapper::gemm::LauncherKBlock< - GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationKBlockBaseF32, - prologue_b::gemm::WeightKBlockNInteger, epilogue::gemm::CompFp32BlockEpilogue, - epilogue::gemm::AccumulatorWriteBackFp32>; - static Launcher kernel; - auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); - if (B->IsAsym()) { - reduceA.assign(WorkSpace); - ORTThreading single(nullptr); - kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single); - } - typename Launcher::BEpiParam blkargs{ - B->template SPtr(), B->SDtype(), B->CStep(), B->template ZPtr(), - reduceA.template RPtr(), reduceA.lda}; - - typename Launcher::Param args{gp, {A, lda_, &reduceA}, {B}, blkargs, {C, ldc_}}; - parallel::GemmRun(kernel, args, th); - } else { - using Parallel = parallel::gemm::SchedulerBase; - using Launcher = wrapper::gemm::LauncherBase< - GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationBase, prologue_b::gemm::WeightKBlockNInteger, - epilogue::gemm::AccumulatorWriteBackFp32>; - static Launcher kernel; - typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_}}; - parallel::GemmRun(kernel, args, th); - } -} - -template -static void -BTLASQ4GemmCompInt8( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, - float* C, - const size_t ldc, - int8_t* WorkSpace, - parallel::IThreading* th -) -{ - using Parallel = parallel::gemm::SchedulerKBlockS; - using Launcher = wrapper::gemm::LauncherIntKBlock< - GemmCore_T::ISA, GemmCore_T, prologue_a::gemm::ActivationF32KBlockQuantize, - prologue_b::gemm::WeightKBlockNInteger, epilogue::gemm::AccumulatorWriteBackFp32>; - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - static Launcher kernel; - auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym()); - quanA.assign(WorkSpace); - if (M <= 16) { - ORTThreading single(nullptr); - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); - } else { - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); - } - utils::GemmProblem gp(1, M, N, K, B->mBlockSize); - typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_}}; - parallel::GemmRun(kernel, args, th); -} - -template -static size_t -BTLASQ4GemmCompF32WorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, - float* C, - const size_t ldc -) -{ - auto M_ = static_cast(M); - auto K_ = static_cast(K); - (void)(N); - (void)(lda); - (void)(ldc); - if (M <= 16) { - using ProA = prologue_a::gemm::ActivationKBlockBaseF32; - static ProA proA; - if (B->IsAsym()) { - auto reduceA = proA.createStorage(M_, K_, B->mBlockSize); - return reduceA.mSize; - } - return 0; - } else { - using ProA = prologue_a::gemm::ActivationBase; - return 0; - } - return 0; -} - -template -static size_t -BTLASQ4GemmCompInt8WorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, - float* C, - const size_t ldc -) -{ - (void)(N); - (void)(lda); - (void)(ldc); - using ProA = prologue_a::gemm::ActivationF32KBlockQuantize; - static ProA proA; - auto quanA = - proA.createStorage(static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->IsAsym()); - return quanA.mSize; -} - -} // namespace bestla - -using namespace bestla; - -bool -BTLASQ4GemmBatchDriver( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool -) -{ - GetCPUDevice(); - bestla::ORTThreading orth(ThreadPool); - bool processed = true; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - auto NTile = - gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); - auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); - auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); - auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); - if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { - auto kptr = reinterpret_cast(ptr); - if (btype == gemm::CompType::tFP32 && PackRow == 1) { - if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F()) { - bestla::BTLASQ4GemmCompF32( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2()) { - bestla::BTLASQ4GemmCompF32( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - if (btype == gemm::CompType::tS8 && PackRow == 4) { - if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8()) { - bestla::BTLASQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI()) { - bestla::BTLASQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI()) { - bestla::BTLASQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - } - } else { - processed = false; - break; - } - } - return processed; -} - -size_t -BTLASQ4GemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -) -{ - GetCPUDevice(); - size_t size = 0; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { - auto kptr = reinterpret_cast(ptr); - auto NTile = - gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); - auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); - auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); - auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); - if (btype == gemm::CompType::tFP32 && PackRow == 1) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - size = std::max( - BTLASQ4GemmCompF32WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - size = std::max( - BTLASQ4GemmCompF32WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - if (btype == gemm::CompType::tS8 && PackRow == 4) { - if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8()) { - size = std::max( - BTLASQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI()) { - size = std::max( - BTLASQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI()) { - size = std::max( - BTLASQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - } - } - } - return size; -} - -template -static size_t -BTLAQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) -{ - static T proB; - auto stor = proB.createStorage( - static_cast(N), static_cast(K), static_cast(block_size), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, - BTLA_DTYPE::BF16, isAsym - ); - // TODO(Yu) support more scale dtype - return stor.mSize; -} - -size_t -BTLAQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType) -{ - GetCPUDevice(); - if (K % BlkSize != 0) { - return 0; - } - // from low precision to high precision - switch (CompType) { - case CompInt8: - if (!isAsym) { // asym int8 is not optimized, so fall through to others. - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - return BTLAQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - return BTLAQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - return BTLAQ4BuSize>(BlkSize, N, K, isAsym); - } - } - case CompBf16: - case CompFp16: - case CompFp32: - case CompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - return BTLAQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - return BTLAQ4BuSize>(BlkSize, N, K, isAsym); - } - break; - default: - return 0; - } - return 0; -} - -template -static void -BTLAQ4GemmPackBImpl( - void* PackedBuf, - size_t BlkSize, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - bool IsAsym, - bool lastCall, - size_t ldb, - MLAS_THREADPOOL* ThreadPool -) -{ - static T proB; - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto stor = proB.createStorage( - N_, K_, static_cast(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, IsAsym - ); - stor.assign(reinterpret_cast(PackedBuf)); - ORTThreading orth(ThreadPool); - proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); - if (lastCall) { - proB.reduceWeight(&stor, &orth); - } -} - -bool -BTLAQ4GemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -) -{ - GetCPUDevice(); - // explicit statement fall through. - switch (CompType) { - case CompInt8: - if (!isAsym) { // asym int8 is not optimized, so fall through to others. - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - BTLAQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - BTLAQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - BTLAQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - } - case CompBf16: - case CompFp16: - case CompFp32: - case CompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - BTLAQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - BTLAQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - default: - return false; - } - return false; -} - -bool -BTLAQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) -{ - auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); - auto uptr = std::unique_ptr(ptr); - ORTThreading orth(ThreadPool); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto ldb_ = static_cast(ldb); - GetCPUDevice(); - if (ptr) { - auto NTile = - gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); - auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); - auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); - auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); - if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { - auto wptr = reinterpret_cast(ptr); - if (btype == gemm::CompType::tFP32 && PackRow == 1) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } - } - if (btype == gemm::CompType::tS8 && PackRow == 4) { - if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8()) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI()) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI()) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } - } - } - return true; - } - return false; -} diff --git a/onnxruntime/core/mlas/lib/bestla_gemm.h b/onnxruntime/core/mlas/lib/bestla_gemm.h deleted file mode 100644 index 6eaa350341739..0000000000000 --- a/onnxruntime/core/mlas/lib/bestla_gemm.h +++ /dev/null @@ -1,61 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - bestla_gemm.h - -Abstract: - - Currently only support Q4 gemm. ---*/ - -#pragma once - -#include "mlas_qnbit.h" - -size_t -BTLAQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType); - -bool -BTLAQ4GemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -); - -bool -BTLAQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb - , MLAS_THREADPOOL* ThreadPool); - -bool -BTLASQ4GemmBatchDriver( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool -); - -size_t -BTLASQ4GemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 9697a1a73463e..f964b1affec31 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -15,9 +15,6 @@ Module Name: --*/ #include "sqnbitgemm.h" -#ifdef MLAS_NEURAL_SPEED -#include "bestla_gemm.h" -#endif namespace { @@ -145,127 +142,3 @@ MlasIsSQNBitGemmAvailable( return true; } - -size_t MLASCALL -MlasNBitsGemmPackBSize( - size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType -) -{ -#ifdef MLAS_NEURAL_SPEED - if (nbits == 4) { - auto jsize = BTLAQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); - if (jsize) { - return jsize; - } - } -#endif - (void)(N); - (void)(K); - (void)(BlkSize); - (void)(nbits); - (void)(isAsym); - (void)(CompType); - return 0; -} - -void MLASCALL -MlasNBitsGemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - int nbits, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -) -{ -#ifdef MLAS_NEURAL_SPEED - if (nbits == 4) { - if (BTLAQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { - return; - } - } -#endif - (void)(PackedBuf); - (void)(QData); - (void)(Scale); - (void)(Zp); - (void)(N); - (void)(K); - (void)(ldb); - (void)(BlkSize); - (void)(nbits); - (void)(isAsym); - (void)(lastCall); - (void)(CompType); - (void)(ThreadPool); -} - -void MLASCALL -MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) -{ -#ifdef MLAS_NEURAL_SPEED - if (BTLAQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { - return; - } -#endif - (void)(FpData); - (void)(PackedBuf); - (void)(N); - (void)(K); - (void)(ldb); - (void)(ThreadPool); -} - -size_t MLASCALL -MlasSQNBitsGemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -) -{ -#ifdef MLAS_NEURAL_SPEED - return BTLASQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); -#endif - (void)(M); - (void)(N); - (void)(K); - (void)(BatchN); - (void)(DataParams); - return 0; -} - -void MLASCALL -MlasSQNBitsGemmBatchPackedB( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - void* WorkSpace, - MLAS_THREADPOOL* ThreadPool -) -{ - GetMlasPlatform(); -#ifdef MLAS_NEURAL_SPEED - if (BTLASQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { - // PackedWeight is created by bestla - return; - } -#endif - (void)(M); - (void)(N); - (void)(K); - (void)(BatchN); - (void)(DataParams); - (void)(WorkSpace); - (void)(ThreadPool); -} diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 3814e65cf24a7..5562a20d077c1 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -63,7 +63,7 @@ void QuantizeDequantize(std::vector& raw_vals, tp.get()); } -void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, MLAS_SQNBIT_COMPUTE_TYPE comp_type, +void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t comp_type, bool has_zeropoint, bool use_float16, float fp16_abs_error = 0.02f) { RandomValueGenerator random{1234}; std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); @@ -134,7 +134,7 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, MLAS_SQNBIT_CO } test.AddOutput("Y", {M, N}, expected_vals); - if (comp_type == CompInt8) { + if (comp_type == 4) { test.SetOutputAbsErr("Y", 0.1f); } @@ -147,10 +147,17 @@ TEST(MatMulNBits, Float32) { for (auto N : {1, 2, 32, 288}) { for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { - for (auto comp : {CompUndef, CompFp32, CompInt8}) { - RunTest(M, N, K, block_size, comp, false, false); - RunTest(M, N, K, block_size, comp, true, false); +#ifdef ORT_NEURAL_SPEED + for (auto accuracy_level : {0, 1, 4}) { + RunTest(M, N, K, block_size, accuracy_level, false, false); + RunTest(M, N, K, block_size, accuracy_level, true, false); } +#else + for (auto accuracy_level : {0}) { + RunTest(M, N, K, block_size, accuracy_level, false, false); + RunTest(M, N, K, block_size, accuracy_level, true, false); + } +#endif } } } @@ -163,8 +170,8 @@ TEST(MatMulNBits, Float16) { for (auto N : {1, 2, 32, 288}) { for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { - RunTest(M, N, K, block_size, CompUndef, false, true); - RunTest(M, N, K, block_size, CompUndef, true, true); + RunTest(M, N, K, block_size, 0, false, true); + RunTest(M, N, K, block_size, 0, true, true); } } } @@ -174,9 +181,9 @@ TEST(MatMulNBits, Float16) { TEST(MatMulNBits, Float16Large) { for (auto block_size : {16, 32, 64, 128}) { for (auto symmetric : {false, true}) { - RunTest(1, 4096, 4096, block_size, CompUndef, symmetric, true, 0.05f); - RunTest(1, 4096, 11008, block_size, CompUndef, symmetric, true, 0.05f); - RunTest(1, 11008, 4096, block_size, CompUndef, symmetric, true, 0.05f); + RunTest(1, 4096, 4096, block_size, 0, symmetric, true, 0.05f); + RunTest(1, 4096, 11008, block_size, 0, symmetric, true, 0.05f); + RunTest(1, 11008, 4096, block_size, 0, symmetric, true, 0.05f); } } } @@ -184,11 +191,11 @@ TEST(MatMulNBits, Float16Large) { #endif void RunSharedPrepackedWeightsTest(int64_t M, int64_t N, int64_t K, int block_size, bool is_asym, - MLAS_SQNBIT_COMPUTE_TYPE acc_lvl) { + int64_t acc_lvl) { // (M x K) X (K x N) OpTester test("MatMulNBits", 1, kMSDomain); - test.AddAttribute("accuracy_level", int64_t(acc_lvl)); + test.AddAttribute("accuracy_level", acc_lvl); test.AddAttribute("block_size", int64_t(block_size)); test.AddAttribute("bits", QBits); test.AddAttribute("N", N); @@ -268,7 +275,7 @@ void RunSharedPrepackedWeightsTest(int64_t M, int64_t N, int64_t K, int block_si test.AddInput("zero_points", {N, static_cast(kblks / 2)}, input3_vals, true); } test.AddOutput("Y", {M, N}, expected_vals, false); - if (acc_lvl == CompInt8) { + if (acc_lvl == 4) { test.SetOutputAbsErr("Y", 0.1f); } @@ -341,14 +348,14 @@ void RunSharedPrepackedWeightsTest(int64_t M, int64_t N, int64_t K, int block_si } } -#ifdef MLAS_NEURAL_SPEED +#ifdef ORT_NEURAL_SPEED TEST(MatMulNBits, SharedPrepackedWeights) { - RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, true, CompFp32); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, false, CompFp32); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, CompFp32); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, CompInt8); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 1024, false, CompInt8); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 4096, false, CompInt8); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, true, 1); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, false, 1); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 1); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 4); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 1024, false, 4); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 4096, false, 4); } #endif } // namespace test diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index e21dee7284156..2f2635dab0512 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -84,57 +84,3 @@ BENCHMARK(SQNBITGEMM<4, 128, false>)->Apply(GemmSizeProducts)->UseRealTime(); BENCHMARK(SQNBITGEMM<4, 128, true>)->Apply(GemmSizeProducts)->UseRealTime(); BENCHMARK(SQNBITGEMM<4, 256, false>)->Apply(GemmSizeProducts)->UseRealTime(); BENCHMARK(SQNBITGEMM<4, 256, true>)->Apply(GemmSizeProducts)->UseRealTime(); - -#ifdef MLAS_NEURAL_SPEED -void Q4GEMM_BTLA(benchmark::State& state, int block_size, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE cmp_type) { - if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); - if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); - if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); - if (state.range(3) <= 0) throw std::invalid_argument("Threads must greater than 0!"); - - const size_t M = static_cast(state.range(0)); - const size_t N = static_cast(state.range(1)); - const size_t K = static_cast(state.range(2)); - const size_t threads = static_cast(state.range(3)); - block_size = block_size == -1 ? static_cast(K) : block_size; - const size_t pack_b_size = MlasNBitsGemmPackBSize(N, K, block_size, 4, is_asym, cmp_type); - - OrtThreadPoolParams tpo; - tpo.thread_pool_size = static_cast(threads); - tpo.auto_set_affinity = true; - std::unique_ptr tp(onnxruntime::concurrency::CreateThreadPool( - &onnxruntime::Env::Default(), tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - - auto A1 = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); - auto B1 = RandomVectorUniform(static_cast(N * K / 2), 0, 255); - auto blk_num = static_cast((K + block_size - 1) / block_size); - auto B_scale = RandomVectorUniform(static_cast(N * blk_num), 0.003f, 0.005f); - std::vector C1(static_cast(M * N)); - auto B_zp = RandomVectorUniform(static_cast(N * blk_num / 2), 0, 255); - - std::vector B1_packed(pack_b_size); - MlasNBitsGemmPackB(B1_packed.data(), B1.data(), B_scale.data(), is_asym ? B_zp.data() : nullptr, N, K, K, block_size, - 4, is_asym, true, cmp_type, tp.get()); - - MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS params1; - params1.A = A1.data(); - params1.lda = K; - params1.C = C1.data(); - params1.ldc = N; - params1.B = B1_packed.data(); - std::vector workspace(static_cast(M <= 32 ? 32 : M) * K * 4); - MlasSQNBitsGemmBatchPackedB(M, N, K, 1, ¶ms1, workspace.data(), tp.get()); - - for (auto _ : state) { - MlasSQNBitsGemmBatchPackedB(M, N, K, 1, ¶ms1, workspace.data(), tp.get()); - } -} - -BENCHMARK_CAPTURE(Q4GEMM_BTLA, Q4B32SymInt8, 32, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q4GEMM_BTLA, Q4B128SymInt8, 128, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q4GEMM_BTLA, Q4PerNSymInt8, -1, false, CompInt8)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q4GEMM_BTLA, Q4B32SymFp32, 32, false, CompFp32)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q4GEMM_BTLA, Q4B128SymFp32, 128, false, CompFp32)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q4GEMM_BTLA, Q4PerNSymFp32, -1, false, CompFp32)->Apply(GemmSizeProducts)->UseRealTime(); -BENCHMARK_CAPTURE(Q4GEMM_BTLA, Q4B32AsymFp32, 32, true, CompFp32)->Apply(GemmSizeProducts)->UseRealTime(); -#endif