Skip to content

Commit

Permalink
add K dimension check
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Dec 5, 2023
1 parent d7601b5 commit e5aa4ec
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions onnxruntime/core/mlas/lib/jblas_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ JblasSQ4GemmCompF32(
kernel.mProA.reduce({A, K}, &reduceA, M, K, &single);
}
typename Launcher::BEpiParam blkargs{
B->template SPtr<int8_t>(), B->mScaT, B->mCStep, B->template ZPtr<int8_t>(),
B->template SPtr<int8_t>(), B->mScaT, B->mCStep, B->template ZPtr<int8_t>(),
reduceA.template get<float>(), reduceA.lda};

typename Launcher::Param args{M, N, K, B->mBlockSize, {A, K}, {B}, blkargs, {C, N}};
Expand Down Expand Up @@ -341,14 +341,17 @@ JblasQ4BuSize(int block_size, size_t N, size_t K, bool isAsym)
auto stor = launcher.mProB.createStorage(
N, K, block_size, JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, isAsym
);
// TODO(Yu) support more S4 quant type, scale dtype
// TODO(Yu) support more scale dtype
return stor.mSize;
}

size_t
JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType)
{
GetCPUDevice();
if (K % BlkSize != 0) {
return 0;
}
// from low precision to high precision
switch (CompType) {
case CompInt8:
Expand Down Expand Up @@ -423,7 +426,7 @@ JblasQ4GemmPackB(
)
{
GetCPUDevice();
// explicit statement fall through.
// explicit statement fall through.
switch (CompType) {
case CompInt8:
if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) {
Expand Down

0 comments on commit e5aa4ec

Please sign in to comment.