diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h index e32f8485b9161..a1607c9012187 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h @@ -393,7 +393,9 @@ class SchedulerKBlock : public Scheduler2D { int BlkNum = utils::updiv(mSize[2], mKBlock); int KSplitSize = utils::padto(utils::updiv(mSize[2], KSplitStage), mStep[2]); mBlock[1] = NRef < mThdSize[1] ? NRef : mThdSize[1]; - if (KSplitSize >= mKBlock) { + if (KSplitStage * mStep[2] >= mSize[2]) { + mBlock[2] = mSize[2]; + } else if (KSplitSize >= mKBlock) { mBlock[2] = mKBlock; } else { int scale = utils::downdiv(KSplitStage, BlkNum); @@ -403,7 +405,7 @@ class SchedulerKBlock : public Scheduler2D { } } mBlock[2] = utils::downdiv(mKBlock, scale); - mBlock[2] =utils::padto_le(mBlock[2],mStep[2]); + mBlock[2] = utils::padto_le(mBlock[2], mStep[2]); } size_t size_remain = mL2Size - mBlock[1] * mBlock[2] * mEleSize[1]; // MBlock*KBlock*ASize+MBlock*NBlock*CSize*2<=size_remain