Skip to content

Commit

Permalink
add auto-dispatch of accuray_level
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Nov 21, 2023
1 parent 5cfec05 commit e5fc97c
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 14 deletions.
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ class MatMulNBits final : public OpKernel {
const size_t nbits_;
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_;
size_t packed_b_size_;
bool is_asym_;
bool all_constant_;
int64_t accuracy_level_;
size_t packed_b_size_{0};
bool is_asym_{false};
bool all_constant_{false};
int64_t accuracy_level_{0};
};

Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
Expand Down
12 changes: 5 additions & 7 deletions onnxruntime/core/mlas/lib/jblas_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ size_t
JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPUTE_TYPE CompType)
{
GetCPUDevice();
// from low precision to high precision
switch (CompType) {
case CompInt8:
if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) {
Expand All @@ -233,7 +234,8 @@ JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPU
if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) {
return JblasQ4BuSize<tLauncher_Int8_S4_F32F32<tAVX_VNNI>>(int(BlkSize), N, K, isAsym);
}
break;
case CompBf16:
case CompFp16:
case CompFp32:
if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
return JblasQ4BuSize<tLauncher_Int8_S4_F32F32<tAVX512F>>(int(BlkSize), N, K, isAsym);
Expand All @@ -242,8 +244,6 @@ JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPU
return JblasQ4BuSize<tLauncher_Int8_S4_F32F32<tAVX2>>(int(BlkSize), N, K, isAsym);
}
break;
case CompBf16:
case CompFp16:
default:
return 0;
}
Expand Down Expand Up @@ -315,7 +315,8 @@ JblasQ4GemmPackB(
);
return true;
}
break;
case CompBf16:
case CompFp16:
case CompFp32:
if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
JblaNBitsGemmPackB<tLauncher_Fp32_S4_F32F32<tAVX512F>>(
Expand All @@ -329,9 +330,6 @@ JblasQ4GemmPackB(
);
return true;
}
break;
case CompBf16:
case CompFp16:
default:
return false;
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -2471,7 +2471,7 @@ class CoreCodeBase {
static int constexpr KTILE = Code::KTILE;
static int constexpr PACK_ROW = Code::PackRow;
static int constexpr COMP = Code::COMPUTE;
static int constexpr PREFERRED_N = 144;
static int constexpr PREFERRED_N = NTILE * 3;
static JBLAS_ISA constexpr ISA = (JBLAS_ISA)Code::ISA;
static uint32_t constexpr ID = CoreAttr::make_core_id(NTILE, PACK_ROW, COMP, ISA);
void configure() { (void)(0); }
Expand All @@ -2497,7 +2497,7 @@ class CoreCodeBaseAMX {
static int constexpr KTILE = Code::KTILE;
static int constexpr PACK_ROW = Code::PackRow;
static int constexpr COMP = Code::COMPUTE;
static int constexpr PREFERRED_N = 144;
static int constexpr PREFERRED_N = NTILE * 3;
static JBLAS_ISA constexpr ISA = (JBLAS_ISA)Code::ISA;
static uint32_t constexpr ID = CoreAttr::make_core_id(_NTILE, PACK_ROW, COMP, ISA);
Xbyak::CodeGenerator cfgcode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ class SchedulerKBlock : public Scheduler2D {
}
}
mBlock[2] = utils::downdiv(mKBlock, scale);
}
mBlock[2] =utils::padto_le(mBlock[2],mStep[2]);

Check warning on line 406 in onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h#L406

Tab found; better to use spaces [whitespace/tab] [1]
Raw output
onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h:406:  Tab found; better to use spaces  [whitespace/tab] [1]

Check warning on line 406 in onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h#L406

Missing spaces around = [whitespace/operators] [4]
Raw output
onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h:406:  Missing spaces around =  [whitespace/operators] [4]

Check warning on line 406 in onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h#L406

Missing space after , [whitespace/comma] [3]
Raw output
onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h:406:  Missing space after ,  [whitespace/comma] [3]
}

Check warning on line 407 in onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h#L407

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h:407:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
size_t size_remain = mL2Size - mBlock[1] * mBlock[2] * mEleSize[1];
// MBlock*KBlock*ASize+MBlock*NBlock*CSize*2<=size_remain
int maxMBlock = static_cast<int>(size_remain / (mBlock[1] * mEleSize[2] * 2 + mBlock[2] * mEleSize[0]));
Expand Down

0 comments on commit e5fc97c

Please sign in to comment.