From 4c07742421372e23292c0cafd514d8df580fd608 Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Fri, 17 Nov 2023 10:20:24 +0800 Subject: [PATCH] resolve matmul_nbits.cc --- .../cpu/quantization/matmul_nbits.cc | 37 +++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index e0a1835831a65..75fb2372c3103 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -9,7 +9,6 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas_q4.h" namespace onnxruntime { namespace contrib { @@ -26,6 +25,13 @@ class MatMulNBits final : public OpKernel { "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); info.GetAttrOrDefault("accuracy_level", &accuracy_level_, 0); is_asym_ = info.GetInputCount() >= 4; + const Tensor* tensor_B = nullptr; + const Tensor* tensor_scale = nullptr; + const Tensor* tensor_zero_point = nullptr; + bool get_B = info.TryGetConstantInput(0, &tensor_B); + bool get_scale = info.TryGetConstantInput(1, &tensor_scale); + bool get_zero_point = info.TryGetConstantInput(2, &tensor_zero_point); + all_constant_ = get_B && get_scale && get_zero_point; } Status Compute(OpKernelContext* context) const override; @@ -46,6 +52,7 @@ class MatMulNBits final : public OpKernel { IAllocatorUniquePtr packed_b_; size_t packed_b_size_; bool is_asym_; + bool all_constant_; int64_t accuracy_level_; }; @@ -53,15 +60,22 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; - auto compt_type = static_cast(accuracy_level_); + if (!all_constant_) { + return; + } if (MlasNBitsGemmPackBSupport(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type)) { - // TODO use threadpool here + auto compt_type = static_cast(accuracy_level_); + // better to use threadpool here, LLM weight will consume a lot of time MLAS_THREADPOOL* pool = NULL; if (input_idx == 1) { auto qptr = tensor.Data(); packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), is_asym_, false, compt_type, pool); + if (packed_b_ == nullptr) { + return; + } + MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, false, compt_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -70,7 +84,11 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 2) { auto sptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), is_asym_, !is_asym_, compt_type, pool); + if (packed_b_ == nullptr) { + return; + } + MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, !is_asym_, compt_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -79,7 +97,11 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 3) { auto zptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), is_asym_, is_asym_, compt_type, pool); + if (packed_b_ == nullptr) { + return; + } + MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, is_asym_, compt_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -140,7 +162,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { AllocatorPtr allocator; auto status = ctx->GetTempSpaceAllocator(&allocator); ORT_RETURN_IF_ERROR(status); - auto ws_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K) * M); // workspace for activation process(dynamic quantization and others) + // workspace for activation process(dynamic quantization and others) + auto ws_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K) * std::max(M, 32)); for (size_t i = 0; i < max_len; i++) { gemm_params[i].A = a_data + helper.LeftOffsets()[i]; gemm_params[i].lda = lda;