diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index dc5efeb5814c6..126abac47c181 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -99,13 +99,28 @@ MlasJblasQ4GemmPackB(void* PackedBuf, MLAS_COMPUTE_TYPE CompType, MLAS_THREADPOOL* ThreadPool) { + GetCPUDevice(); switch (CompType) { case CompInt8: - return JblaQ4GemmPackB(JblasAvxVnniS4Fp32Fp32, int(BlkSize), PackedBuf, FpData, int(N), - int(K), isAsym, int(ldb), ThreadPool); + if (_cd->AVX512_VNNI()) { + return JblaQ4GemmPackB(JblasAvxVnniS4Fp32Fp32, int(BlkSize), PackedBuf, FpData, + int(N), int(K), isAsym, int(ldb), ThreadPool); + } + if (_cd->AVX_VNNI()) { + return JblaQ4GemmPackB(JblasAvxVnniS4Fp32Fp32, int(BlkSize), PackedBuf, FpData, + int(N), int(K), isAsym, int(ldb), ThreadPool); + } + break; case CompFp32: - return JblaQ4GemmPackB(JblasAvx512fS4Fp32Fp32, int(BlkSize), PackedBuf, FpData, int(N), - int(K), isAsym, int(ldb), ThreadPool); + if (_cd->AVX512F()) { + return JblaQ4GemmPackB(JblasAvx512fS4Fp32Fp32, int(BlkSize), PackedBuf, FpData, + int(N), int(K), isAsym, int(ldb), ThreadPool); + } + if (_cd->AVX2()) { + return JblaQ4GemmPackB(JblasAvx2S4Fp32Fp32, int(BlkSize), PackedBuf, FpData, + int(N), int(K), isAsym, int(ldb), ThreadPool); + } + break; case CompBf16: case CompFp16: default: diff --git a/onnxruntime/core/mlas/lib/q4gemm.cpp b/onnxruntime/core/mlas/lib/q4gemm.cpp index 891c346c14466..7f2c8ae3a02e9 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm.cpp @@ -168,20 +168,34 @@ JblasQ4GemmCompF32(const int M, int8_t* WorkSpace, jblas::parallel::IThreading* th) { - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = JBLAS_FP32_S4_F32F32; - static Launcher kernel; - auto reduceA = kernel.mProA.createStorage(M, K, B->mBlockSize); - if (B->mIsAsym) { - reduceA.assign(WorkSpace); - kernel.mProA.reduce({A, K}, &reduceA, M, K, th); - } - typename Launcher::BEpiParam blkargs{ - B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(), - reduceA.template get(), reduceA.lda}; + if (M <= 32) { + using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = JBLAS_FP32_S4_F32F32; + static Launcher kernel; + auto reduceA = kernel.mProA.createStorage(M, K, B->mBlockSize); + if (B->mIsAsym) { + reduceA.assign(WorkSpace); + ORTThreading single(nullptr); + kernel.mProA.reduce({A, K}, &reduceA, M, K, &single); + } + typename Launcher::BEpiParam blkargs{ + B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(), + reduceA.template get(), reduceA.lda}; - typename Launcher::Param args{M, N, K, B->mBlockSize, {A, K}, {B}, blkargs, {C, N}}; - jblas::parallel::GemmKBlockRun(kernel, args, th); + typename Launcher::Param args{M, N, K, B->mBlockSize, {A, K}, {B}, blkargs, {C, N}}; + jblas::parallel::GemmKBlockRun(kernel, args, th); + } else { + using Parallel = jblas::parallel::gemm::SchedulerBase; + using Launcher = + jblas::wrapper::gemm::LauncherBase; + static Launcher kernel; + + typename Launcher::Param args{M, N, K, {A, K}, {B}, {C, N}}; + jblas::parallel::GemmBaseRun(kernel, args, th); + } } template @@ -203,7 +217,12 @@ JblasQ4GemmCompInt8(const int M, static Launcher kernel; auto quanA = kernel.mProA.createStorage(M, K, B->mBlockSize, B->mIsAsym); quanA.assign(WorkSpace); - kernel.mProA.quantize({A, K, &quanA}, M, K, th); + if (M <= 32) { + ORTThreading single(nullptr); + kernel.mProA.quantize({A, K, &quanA}, M, K, &single); + } else { + kernel.mProA.quantize({A, K, &quanA}, M, K, th); + } typename Launcher::Param args{ M, N,