Skip to content

Commit

Permalink
pass pre-pack UT of matmul_nbits
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Nov 3, 2023
1 parent 347d9ff commit 2272174
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 57 deletions.
51 changes: 31 additions & 20 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class MatMulNBits final : public OpKernel {
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("bits", &nbits_));
info.GetAttrOrDefault<int64_t>("is_asym", &is_asym_, 1);
info.GetAttrOrDefault<int64_t>("compute_type", &comp_type_, 0);
info.GetAttrOrDefault<int64_t>("compute_type", &comp_type_, -1);
is_asym_ = info.GetInputCount() >= 4;
}

Status Compute(OpKernelContext* context) const override;
Expand All @@ -39,9 +39,8 @@ class MatMulNBits final : public OpKernel {
int64_t block_size_;
int64_t nbits_;
IAllocatorUniquePtr<void> packed_b_;
const uint8_t *qptr_, *zptr_;
const float* sptr_;
int64_t is_asym_;
size_t packed_b_size_;
bool is_asym_;
int64_t comp_type_;
};

Expand All @@ -51,24 +50,29 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
is_packed = false;

if (comp_type_ != -1 && nbits_ == 4) {
// only pack Matrix B
// TODO use threadpool here
MLAS_THREADPOOL* pool = NULL;
if (input_idx == 1) {
qptr_ = tensor.Data<uint8_t>();
auto qptr = tensor.Data<uint8_t>();
packed_b_size_ = MlasJblasQ4GemmPackBSize(N_, K_, block_size_, is_asym_, static_cast<MLAS_COMPUTE_TYPE>(comp_type_));
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasJblasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, is_asym_, static_cast<MLAS_COMPUTE_TYPE>(comp_type_), pool);
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
is_packed = true;
}
if (input_idx == 2) {
sptr_ = tensor.Data<float>();
auto sptr = tensor.Data<float>();
MlasJblasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, is_asym_, static_cast<MLAS_COMPUTE_TYPE>(comp_type_), pool);
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
is_packed = true;
}
if (input_idx == 3) {
zptr_ = tensor.Data<uint8_t>();
auto packed_b_size = MlasJblasQ4GemmPackBSize(N_, K_, block_size_, is_asym_, static_cast<MLAS_COMPUTE_TYPE>(comp_type_));
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size, true);
MlasJblasNBitsGemmPackB(packed_b_.get(), qptr_, sptr_, zptr_, N_, K_, K_, block_size_, is_asym_, static_cast<MLAS_COMPUTE_TYPE>(comp_type_), NULL);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size);
}
auto zptr = tensor.Data<uint8_t>();
MlasJblasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, is_asym_, static_cast<MLAS_COMPUTE_TYPE>(comp_type_), pool);
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
is_packed = true;
}
}
Expand All @@ -85,17 +89,21 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prep
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}

if (input_idx == 2) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
if (input_idx == 3) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
return Status::OK();
}

Status MatMulNBits::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();

const Tensor* a = ctx->Input<Tensor>(0);
const Tensor* b = ctx->Input<Tensor>(1);
const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const auto* a_data = a->Data<float>();

if (packed_b_.get()) {
Expand Down Expand Up @@ -129,6 +137,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
return Status::OK();
}

const Tensor* b = ctx->Input<Tensor>(1);
const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const uint8_t* b_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<float>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data<uint8_t>();
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3467,6 +3467,7 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
.Attr("K", "size of each input feature", AttributeProto::INT)
.Attr("N", "size of each output feature", AttributeProto::INT)
.Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT)
.Attr("compute_type", "compute type for compressed weight, can be: 0(fp32), 1(int8), 2(bf16), or 3(fp16) (default fp32)", AttributeProto::INT)
.Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1-dimensional data blob", "T2")
Expand Down
52 changes: 29 additions & 23 deletions onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,19 +263,26 @@ class WeightKBlockS8 {
parallel::ThreadProblem2D thdp{tidx};
_para.getIndex(thdp);
if (thdp.valid) {
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
if (i < rawnk_scale) {
for (size_t j = 0; j < N; j++) {
stor->template SPtr<float>()[i * stor->mNPad + j] = scales[j * rawnk_scale + i];
if (scales) {
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
if (i < rawnk_scale) {
for (size_t j = 0; j < N; j++) {
stor->template SPtr<float>()[i * stor->mNPad + j] = scales[j * rawnk_scale + i];
}
} else {
std::memset(stor->template SPtr<float>() + i * stor->mNPad, 0, stor->mNPad * sizeof(float));
}
if (stor->mIsAsym)
}
}
if (stor->mIsAsym && zero_points) {
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
if (i < rawnk_scale) {
for (size_t j = 0; j < N; j++) {
stor->template ZPtr<int8_t>()[i * stor->mNPad + j] = zero_points[j * rawnk_scale + i];
}
} else {
std::memset(stor->template SPtr<float>() + i * stor->mNPad, 0, stor->mNPad * sizeof(float));
if (stor->mIsAsym)
} else {
std::memset(stor->template ZPtr<int8_t>() + i * stor->mNPad, 0, stor->mNPad * sizeof(zero_points[0]));
}
}
}
}
Expand Down Expand Up @@ -493,28 +500,27 @@ class WeightKBlockS4 : public WeightKBlockS8<_GemmCore_T, ISA_T> {
virtual void packNbitsWeight(const int N, const int K, bool isasym, const uint8_t* B, const int ldb,
const float* scales, const uint8_t* zero_points, void* ptr,
parallel::IThreading* threading) {
if (B == nullptr || scales == nullptr) {
assert(0);
return;
}
if (isasym && zero_points == nullptr) {
assert(0);
return;
}
auto stor = reinterpret_cast<StorageWeight*>(ptr);
auto tmp = utils::amalloc<float>((size_t)stor->mKPad * stor->mNPad);
auto blks = utils::updiv(K, stor->mBlockSize);
auto tmpscales = (float*)tmp;
auto tmpzeropoints = (int8_t*)(tmpscales + N * blks);
for (size_t i = 0; i < N * blks; i += 2) {
tmpscales[i] = scales[i] / 16;
tmpscales[i + 1] = scales[i + 1] / 16;
auto tmpzp = *(zero_points + i / 2);
tmpzeropoints[i] = ((tmpzp & 0xf) - 8) << 4;
tmpzeropoints[i + 1] = (((tmpzp & 0xf0) >> 4) - 8) << 4;
if (scales) {
for (size_t i = 0; i < N * blks; i += 2) {
tmpscales[i] = scales[i] / 16;
tmpscales[i + 1] = scales[i + 1] / 16;
}
}
if (zero_points) {
for (size_t i = 0; i < N * blks; i += 2) {
auto tmpzp = *(zero_points + i / 2);
tmpzeropoints[i] = ((tmpzp & 0xf) - 8) << 4;
tmpzeropoints[i + 1] = (((tmpzp & 0xf0) >> 4) - 8) << 4;
}
}

WeightKBlockS8<_GemmCore_T, ISA_T>::setTransposeQuantCorrection(N, K, tmpzeropoints, tmpscales, ptr, threading);
WeightKBlockS8<_GemmCore_T, ISA_T>::setTransposeQuantCorrection(N, K, zero_points ? tmpzeropoints : nullptr,
scales ? tmpscales : nullptr, ptr, threading);
if (B) {
auto s8ptr = (int8_t*)tmp;
auto transposeunpackfunc_u4s4 = [&]() {
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,13 +580,13 @@ static inline JBLAS_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* ds
int ld_dst, _ST* scales, int k_offset, int kblock, int NPad,
int8_t* tmp, size_t tmpsize) {
if constexpr (_PACK_ROW == 1) {
return decompress_kblock_bit4_packrow1<true, _ST, _DST_T>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad,
&dequant_f4_N<48, _DST_T, _F4_T>, fp4_pad_4bit, tmp, tmpsize);
return decompress_kblock_bit4_packrow1<true, _ST, _DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr,
k_offset, kblock, NPad, &dequant_f4_N<48, _DST_T, _F4_T>,
fp4_pad_4bit, tmp, tmpsize);
} else if constexpr (_PACK_ROW == 2) {
return decompress_kblock_bit4_packrow2<true, _ST, _DST_T>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr, k_offset, kblock, NPad,
&dequant_f4_N<64, _DST_T, _F4_T>, fp4_pad_4bit, tmp, tmpsize);
return decompress_kblock_bit4_packrow2<true, _ST, _DST_T>(srcptr, dstptr, row, col, ld_src, ld_dst, scales, nullptr,
k_offset, kblock, NPad, &dequant_f4_N<64, _DST_T, _F4_T>,
fp4_pad_4bit, tmp, tmpsize);
}
return JblasNotSupport;
}
Expand Down
Loading

0 comments on commit 2272174

Please sign in to comment.