Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix1 #1

Merged
merged 5 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 41 additions & 51 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@
const size_t nbits_;
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_;
size_t packed_b_size_;
bool is_asym_;
bool all_constant_;
int64_t accuracy_level_;
size_t packed_b_size_{0};
bool is_asym_{false};
bool all_constant_{false};
int64_t accuracy_level_{0};
};

Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
Expand All @@ -65,57 +65,48 @@
return Status::OK();
}
auto compt_type = static_cast<MLAS_COMPUTE_TYPE>(accuracy_level_);
if (MlasIsNBitGemmAvailable(N_, K_, block_size_, static_cast<int>(nbits_), is_asym_, compt_type)) {
// better to use threadpool here, LLM weight will consume a lot of time
MLAS_THREADPOOL* pool = NULL;
if (input_idx == 1) {
auto qptr = tensor.Data<uint8_t>();
packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast<int>(nbits_), is_asym_, compt_type);
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
if (packed_b_ == nullptr) {
return Status::OK();
}
MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, false, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
MLAS_THREADPOOL* pool = NULL;
if (input_idx == 1) {
packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast<int>(nbits_), is_asym_, compt_type);
if (packed_b_size_ == 0) return Status::OK();
auto qptr = tensor.Data<uint8_t>();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
if (packed_b_ == nullptr) {
return Status::OK();
}
MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, false, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
if (input_idx == 2) {
auto sptr = tensor.Data<float>();
if (packed_b_ == nullptr) {
return Status::OK();
}
MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, !is_asym_, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
is_packed = true;
}
if (input_idx == 2 && packed_b_ != nullptr) {
auto sptr = tensor.Data<float>();
MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, !is_asym_, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
if (input_idx == 3) {
auto zptr = tensor.Data<uint8_t>();
if (packed_b_ == nullptr) {
return Status::OK();
}
MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, is_asym_, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
is_packed = true;
}
if (input_idx == 3 && packed_b_ != nullptr) {
auto zptr = tensor.Data<uint8_t>();
MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, is_asym_, compt_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}

return Status::OK();
}

Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
int input_idx,
Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
// Pack three tensors into one buffer
Expand Down Expand Up @@ -149,8 +140,7 @@
Tensor* y = ctx->Output(0, helper.OutputShape());

// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0)
return Status::OK();
if (y->Shape().Size() == 0) return Status::OK();

auto* y_data = y->MutableData<float>();

Expand All @@ -159,7 +149,7 @@
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);
std::vector<MLAS_NBITS_GEMM_DATA_SIMPLE_PARAMS> gemm_params(max_len);
std::vector<MLAS_NBITS_GEMM_DATA_PACKED_PARAMS> gemm_params(max_len);
AllocatorPtr allocator;
auto status = ctx->GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
Expand All @@ -172,7 +162,7 @@
gemm_params[i].C = y_data + helper.OutputOffsets()[i];
gemm_params[i].ldc = N;
}
MlasNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), (int8_t*)ws_ptr.get(), thread_pool);
MlasNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), reinterpret_cast<int8_t*>(ws_ptr.get()), thread_pool);

Check warning on line 165 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc#L165

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:165:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,4 @@
int rows,
int columns,
MLAS_THREADPOOL* thread_pool
);
);

Check warning on line 360 in onnxruntime/core/mlas/inc/mlas_q4.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/inc/mlas_q4.h#L360

Closing ) should be moved to the previous line [whitespace/parens] [2]
Raw output
onnxruntime/core/mlas/inc/mlas_q4.h:360:  Closing ) should be moved to the previous line  [whitespace/parens] [2]
31 changes: 8 additions & 23 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,35 +90,20 @@ typedef enum {
} MLAS_COMPUTE_TYPE;

/**
* @brief Data parameters for Q4 GEMM routine
* C = A * B + Bias
* A must be a float32 matrix
* B must be a quantized and packed int4 blob
* @brief Data parameters for NBits GEMM routine
* C = A * B
* A, C must be a float32 matrix
* B must be a packed nbits blob
* All except C are [in] parameters
*/
struct MLAS_NBITS_GEMM_DATA_SIMPLE_PARAMS {
struct MLAS_NBITS_GEMM_DATA_PACKED_PARAMS {
const float* A = nullptr; /**< address of A (float32 matrix)*/
const void* B = nullptr; /**< address of B (quantized and packed int4 blob)*/
const void* B = nullptr; /**< address of B (packed nbits blob)*/
float* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldc = 0; /**< leading dimension of C*/
};

/**
* @brief Check if the parameter combination is supported by the runtime device.
*
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param block_size size of the block to quantize, elements from the same block share the same
* scale and zero point
* @param nbits number of bits used for weight quantization (default 4)
* @param is_asym flag for asymmetric quantization
* @param comp_type specify input data type and accumulator data type
* @return support flag, true if the combination is supported.
*/
bool MLASCALL
MlasIsNBitGemmAvailable(size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type);

/**
* @brief Compute the byte size of the parameter combination
*
Expand Down Expand Up @@ -204,7 +189,7 @@ MlasNBitsGemmBatchPackedB(
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_NBITS_GEMM_DATA_SIMPLE_PARAMS* DataParams,
const MLAS_NBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
int8_t* WorkSpace,
MLAS_THREADPOOL* ThreadPool = nullptr
);
);
7 changes: 3 additions & 4 deletions onnxruntime/core/mlas/lib/jblas_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
#include "jblas/jit_blas_prologue_b.h"
#include "jblas/jit_blas_wrapper.h"

namespace jblas
{
namespace jblas {
template <class GemmCore_T>
using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock<
GemmCore_T::ISA,
Expand Down Expand Up @@ -40,12 +39,12 @@
using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>;
using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>;

class ORTThreading : public jblas::parallel::IThreading
{
class ORTThreading : public jblas::parallel::IThreading {
public:
ORTThreading(void* tp);
void parallel_for(const jblas::parallel::thread_func& func) override;
virtual void set_threads(int nthreads) override { assert(0); }
virtual void sync() override { assert(0); }

Check warning on line 47 in onnxruntime/core/mlas/lib/jblas_defs.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/jblas_defs.h#L47

"virtual" is redundant since function is already declared as "override" [readability/inheritance] [4]
Raw output
onnxruntime/core/mlas/lib/jblas_defs.h:47:  "virtual" is redundant since function is already declared as "override"  [readability/inheritance] [4]
void* mTp;
};

Expand Down
16 changes: 7 additions & 9 deletions onnxruntime/core/mlas/lib/jblas_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ JblasQ4GemmCompF32(
kernel.mProA.reduce({A, K}, &reduceA, M, K, &single);
}
typename Launcher::BEpiParam blkargs{
B->template SPtr<int8_t>(), B->mScaT, B->mCStep, B->template ZPtr<int8_t>(),
B->template SPtr<int8_t>(), B->mScaT, B->mCStep, B->template ZPtr<int8_t>(),
reduceA.template get<float>(), reduceA.lda};

typename Launcher::Param args{M, N, K, B->mBlockSize, {A, K}, {B}, blkargs, {C, N}};
Expand Down Expand Up @@ -121,7 +121,7 @@ JblasQ4GemmBatchDriver(
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_NBITS_GEMM_DATA_SIMPLE_PARAMS* DataParams,
const MLAS_NBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
int8_t* WorkSpace,
MLAS_THREADPOOL* ThreadPool
)
Expand Down Expand Up @@ -222,6 +222,7 @@ size_t
JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPUTE_TYPE CompType)
{
GetCPUDevice();
// from low precision to high precision
switch (CompType) {
case CompInt8:
if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) {
Expand All @@ -233,7 +234,8 @@ JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPU
if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) {
return JblasQ4BuSize<tLauncher_Int8_S4_F32F32<tAVX_VNNI>>(int(BlkSize), N, K, isAsym);
}
break;
case CompBf16:
case CompFp16:
case CompFp32:
if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
return JblasQ4BuSize<tLauncher_Int8_S4_F32F32<tAVX512F>>(int(BlkSize), N, K, isAsym);
Expand All @@ -242,8 +244,6 @@ JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPU
return JblasQ4BuSize<tLauncher_Int8_S4_F32F32<tAVX2>>(int(BlkSize), N, K, isAsym);
}
break;
case CompBf16:
case CompFp16:
default:
return 0;
}
Expand Down Expand Up @@ -315,7 +315,8 @@ JblasQ4GemmPackB(
);
return true;
}
break;
case CompBf16:
case CompFp16:
case CompFp32:
if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
JblaNBitsGemmPackB<tLauncher_Fp32_S4_F32F32<tAVX512F>>(
Expand All @@ -329,9 +330,6 @@ JblasQ4GemmPackB(
);
return true;
}
break;
case CompBf16:
case CompFp16:
default:
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/jblas_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ JblasQ4GemmBatchDriver(
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_NBITS_GEMM_DATA_SIMPLE_PARAMS* DataParams,
const MLAS_NBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
int8_t* WorkSpace,
MLAS_THREADPOOL* ThreadPool
);
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/q4_dq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1059,4 +1059,4 @@
int rows,
int columns,
MLAS_THREADPOOL* thread_pool
);
);

Check warning on line 1062 in onnxruntime/core/mlas/lib/q4_dq.cpp

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/q4_dq.cpp#L1062

Closing ) should be moved to the previous line [whitespace/parens] [2]
Raw output
onnxruntime/core/mlas/lib/q4_dq.cpp:1062:  Closing ) should be moved to the previous line  [whitespace/parens] [2]
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/q4gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,4 @@ MlasQ8Q4GemmBatch(
)
{
MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool);
}
}
8 changes: 1 addition & 7 deletions onnxruntime/core/mlas/lib/sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,6 @@ MlasNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsy
return 0;
}

bool MLASCALL
MlasIsNBitGemmAvailable(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_COMPUTE_TYPE CompType)
{
return MlasNBitsGemmPackBSize(N, K, BlkSize, nbits, isAsym, CompType) > 0;
}

void MLASCALL
MlasNBitsGemmPackB(
void* PackedBuf,
Expand Down Expand Up @@ -233,7 +227,7 @@ MlasNBitsGemmBatchPackedB(
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_NBITS_GEMM_DATA_SIMPLE_PARAMS* DataParams,
const MLAS_NBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
int8_t* WorkSpace,
MLAS_THREADPOOL* ThreadPool
)
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ class JitBase : protected Xbyak::CodeGenerator {
jb(".maskflag");
cmp(_tmp, 0);
jl(".zeroflag");
uint64_t allmask = ((uint64_t)1 << N) - 1;
uint64_t allmask = (static_cast<uint64_t>(1) << N) - 1;
if (N == 64) {
allmask = (uint64_t)-1;
allmask = static_cast<uint64_t>(-1);
}
mov(_tmp, allmask);
kmovq(_msk, _tmp);
Expand Down Expand Up @@ -256,19 +256,19 @@ class JitAmxtile : protected JitAvx512f {
// Configure C tiles
int t = 0;
for (; t < CNum; ++t) {
tc.rows[t] = uint8_t(TILE_M);
tc.colb[t] = uint16_t(TILE_N * 4);
tc.rows[t] = static_cast<uint8_t>(TILE_M);
tc.colb[t] = static_cast<uint16_t>(TILE_N * 4);
}
// Configure A tiles
for (; t < CNum + ANum; ++t) {
tc.rows[t] = uint8_t(TILE_M);
tc.colb[t] = uint16_t(TILE_K * elesize);
tc.rows[t] = static_cast<uint8_t>(TILE_M);
tc.colb[t] = static_cast<uint16_t>(TILE_K * elesize);
}
// Configure B tile. B effectively has 64 rows and 16 columns.
int kpack = 4 / elesize;
for (; t < CNum + ANum + BNum; ++t) {
tc.rows[t] = uint8_t(TILE_K / kpack);
tc.colb[t] = uint16_t(TILE_N * 4);
tc.rows[t] = static_cast<uint8_t>(TILE_K / kpack);
tc.colb[t] = static_cast<uint16_t>(TILE_N * 4);
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,4 @@ enum class JBLAS_PROLOGUEB_IDS : uint32_t {
WeightKBlockF4,
KBlockEnd,
End,
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,4 @@ class CpuBase {
int mNumThreads;
};
} // namespace device

} // namespace jblas
Loading
Loading