Skip to content

Commit

Permalink
add ut for mamtmul_nbits_cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Oct 31, 2023
1 parent 9482b24 commit db4bd31
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions onnxruntime/test/contrib_ops/matmul_nbits_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ TEST(MatMulNBitsCPU, MatMul2DSymPerN) {
constexpr int64_t M = 100;
constexpr int64_t N = 288;
constexpr int64_t K = 52;

const auto buf_size = MlasQ4GemmPackBSize(BlkQ4Sym, (size_t)N, (size_t)K);
constexpr int BlkSize = 32;
constexpr bool IsAsym = false;
constexpr MLAS_COMPUTE_TYPE CompType = CompFp32;
const auto buf_size = MlasJblasQ4GemmPackBSize((size_t)N, (size_t)K, BlkSize, IsAsym, CompType);
if (buf_size == 0) {
GTEST_SKIP(); // operation not supported on this hardware platform yet.
}
Expand Down Expand Up @@ -58,7 +60,7 @@ TEST(MatMulNBitsCPU, MatMul2DSymPerN) {
}
}
std::vector<uint8_t> input1_vals(buf_size);
MlasQ4GemmPackB(BlkQ4SymPerN, input1_vals.data(), input1_f_vals.data(), (size_t)N, (size_t)K, (size_t)N);
MlasJblasQ4GemmPackB(input1_vals.data(), input1_f_vals.data(), (size_t)N, (size_t)K, (size_t)N, BlkSize, IsAsym, CompType, NULL);

std::vector<float> expected_vals(M * N);
for (int64_t m = 0; m < M; m++) {
Expand Down

0 comments on commit db4bd31

Please sign in to comment.