Skip to content

Commit

Permalink
resolve matmul_nbits.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Nov 17, 2023
1 parent f5ada45 commit 4c07742
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "core/mlas/inc/mlas_q4.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#include "core/mlas/inc/mlas_q4.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -26,6 +25,13 @@ class MatMulNBits final : public OpKernel {
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
info.GetAttrOrDefault<int64_t>("accuracy_level", &accuracy_level_, 0);
is_asym_ = info.GetInputCount() >= 4;
const Tensor* tensor_B = nullptr;
const Tensor* tensor_scale = nullptr;
const Tensor* tensor_zero_point = nullptr;
bool get_B = info.TryGetConstantInput(0, &tensor_B);
bool get_scale = info.TryGetConstantInput(1, &tensor_scale);
bool get_zero_point = info.TryGetConstantInput(2, &tensor_zero_point);
all_constant_ = get_B && get_scale && get_zero_point;
}

Status Compute(OpKernelContext* context) const override;
Expand All @@ -46,22 +52,30 @@ class MatMulNBits final : public OpKernel {
IAllocatorUniquePtr<void> packed_b_;
size_t packed_b_size_;
bool is_asym_;
bool all_constant_;
int64_t accuracy_level_;
};

Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
auto compt_type = static_cast<MLAS_COMPUTE_TYPE>(accuracy_level_);
if (!all_constant_) {
return;
}
if (MlasNBitsGemmPackBSupport(N_, K_, block_size_, static_cast<int>(nbits_), is_asym_, compt_type)) {
// TODO use threadpool here
auto compt_type = static_cast<MLAS_COMPUTE_TYPE>(accuracy_level_);
// better to use threadpool here, LLM weight will consume a lot of time
MLAS_THREADPOOL* pool = NULL;
if (input_idx == 1) {
auto qptr = tensor.Data<uint8_t>();
packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast<int>(nbits_), is_asym_, compt_type);
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_), is_asym_, false, compt_type, pool);
if (packed_b_ == nullptr) {
return;
}
MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, false, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
Expand All @@ -70,7 +84,11 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
}
if (input_idx == 2) {
auto sptr = tensor.Data<float>();
MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_), is_asym_, !is_asym_, compt_type, pool);
if (packed_b_ == nullptr) {
return;
}
MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, !is_asym_, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
Expand All @@ -79,7 +97,11 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
}
if (input_idx == 3) {
auto zptr = tensor.Data<uint8_t>();
MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast<int>(nbits_), is_asym_, is_asym_, compt_type, pool);
if (packed_b_ == nullptr) {
return;
}
MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, is_asym_, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
Expand Down Expand Up @@ -140,7 +162,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
AllocatorPtr allocator;
auto status = ctx->GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
auto ws_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K) * M); // workspace for activation process(dynamic quantization and others)
// workspace for activation process(dynamic quantization and others)
auto ws_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K) * std::max(M, 32));
for (size_t i = 0; i < max_len; i++) {
gemm_params[i].A = a_data + helper.LeftOffsets()[i];
gemm_params[i].lda = lda;
Expand Down

0 comments on commit 4c07742

Please sign in to comment.