From 53411e875f7ac0f817a180650078466c9e77feaf Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Mon, 30 Oct 2023 16:54:54 +0800 Subject: [PATCH] pass compile --- cmake/onnxruntime_mlas.cmake | 19 +++-- .../contrib_ops/cpu/matmul_nbits_cpu.cc | 6 +- .../core/graph/contrib_ops/contrib_defs.cc | 8 +- onnxruntime/core/mlas/inc/mlas_q4.h | 4 +- onnxruntime/core/mlas/lib/q4gemm.cpp | 9 +- .../test/contrib_ops/matmul_nbits_cpu.cc | 85 +++++++++++++++++++ 6 files changed, 113 insertions(+), 18 deletions(-) create mode 100644 onnxruntime/test/contrib_ops/matmul_nbits_cpu.cc diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 18e8c1ea0edb2..fdf00529974fa 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -2,7 +2,7 @@ # Licensed under the MIT License. set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib) - +set(MLAS_WITH_JBLAS ON) # # All hardware agnostic source files here # hardware specific files would cause trouble in @@ -44,6 +44,13 @@ endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) +function(add_jblas) + add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) + target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) + target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_JBLAS) + set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) +endfunction() + #TODO: set MASM flags properly function(setup_mlas_source_for_windows) @@ -197,8 +204,9 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/q4gemm_avx512.cpp ) endif() - add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) - target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) + if(MLAS_WITH_JBLAS) + add_jblas() + endif() else() target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp @@ -561,8 +569,9 @@ else() ) set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) - target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) + if(MLAS_WITH_JBLAS) + add_jblas() + endif() endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/contrib_ops/cpu/matmul_nbits_cpu.cc b/onnxruntime/contrib_ops/cpu/matmul_nbits_cpu.cc index f75279c0e0b71..0fb55f1213cb6 100644 --- a/onnxruntime/contrib_ops/cpu/matmul_nbits_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/matmul_nbits_cpu.cc @@ -24,12 +24,12 @@ class MatMulNBitsCPU final : public OpKernel { const auto t = info.GetAttrOrDefault("blk_quant_type", static_cast(1)); const auto c = info.GetAttrOrDefault("compute_type", static_cast(1)); blk_quant_type_ = t == 0 ? BlkQ4Sym : BlkQ4Zp8; - compute_type_ = c == 0 ? FP32 : INT8; + compute_type_ = c == 0 ? CompFp32 : CompInt8; } Status Compute(OpKernelContext* context) const override; MLAS_BLK_QUANT_TYPE blk_quant_type_{BlkQ4Zp8}; - BLK_QUANT_COMPUTE_TYPE compute_type_{INT8}; + BLK_QUANT_COMPUTE_TYPE compute_type_{CompInt8}; }; Status MatMulNBitsCPU::Compute(OpKernelContext* ctx) const { @@ -40,7 +40,7 @@ Status MatMulNBitsCPU::Compute(OpKernelContext* ctx) const { const Tensor* b = ctx->Input(1); const auto blob_shape = b->Shape(); ORT_ENFORCE(blob_shape.NumDimensions() == 1, "Second input of MatMulNBitsCPU must be a 1D blob!"); - const auto blob_len = blob_shape[0]; + //const auto blob_len = blob_shape[0]; const Tensor* bshape_tr = ctx->Input(2); TensorShape b_shape(bshape_tr->DataAsSpan()); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 846a0d0b3b9e2..8b9e0cb62b3c4 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2042,10 +2042,10 @@ no offset } y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT); - auto blk_quant_v = getAttribute(ctx, "blk_quant_type", 1); - auto compute_v = getAttribute(ctx, "compute_type", 1); - MLAS_BLK_QUANT_TYPE blk_quant_type = BlkQ4SymPerN; - BLK_QUANT_COMPUTE_TYPE compute_type = compute_v == 0 ? FP32 : INT8; + /* auto blk_quant_v = getAttribute(ctx, "blk_quant_type", 1); + auto compute_v = getAttribute(ctx, "compute_type", 1);*/ + /*MLAS_BLK_QUANT_TYPE blk_quant_type = BlkQ4SymPerN; + BLK_QUANT_COMPUTE_TYPE compute_type = compute_v == 0 ? CompFp32 : CompInt8;*/ //matmulQ4ShapeInference(ctx, 0, 1, 2, blk_quant_type); })); diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index f6a2227234c11..224c10d941246 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -41,8 +41,8 @@ typedef enum { * @brief Define compute types of block quantization */ typedef enum { - FP32 = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */ - INT8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */ + CompFp32 = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */ + CompInt8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */ } BLK_QUANT_COMPUTE_TYPE; diff --git a/onnxruntime/core/mlas/lib/q4gemm.cpp b/onnxruntime/core/mlas/lib/q4gemm.cpp index acd88661b2a67..dbb5253cd2972 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm.cpp @@ -17,7 +17,8 @@ Module Name: --*/ #include "q4gemm.h" -#if !defined(__APPLE__) + +#ifdef MLAS_JBLAS #include "jblas/jit_blas_weight_compression.h" #endif @@ -139,7 +140,7 @@ MlasQ4GemmBatchDriver(MLAS_BLK_QUANT_TYPE QType, }); } -#if !defined(__APPLE__) +#ifdef MLAS_JBLAS template using WeiS4ClipFp32PerN = jblas::prologue::weight_comp::gemm_kblcok::WeightS4ClipScaleFp32PerN; @@ -225,7 +226,7 @@ MlasQ4GemmBatch(MLAS_BLK_QUANT_TYPE QType, MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); } - +#ifdef MLAS_JBLAS void MLASCALL JblasQ4GemmBatch(BLK_QUANT_COMPUTE_TYPE CType, MLAS_BLK_QUANT_TYPE QType, @@ -240,7 +241,7 @@ JblasQ4GemmBatch(BLK_QUANT_COMPUTE_TYPE CType, return JblasQ4GemmBatchDriver(CType, M, N, K, BatchN, DataParams, ThreadPool); } } - +#endif void MLASCALL MlasQ8Q4GemmBatch(MLAS_BLK_QUANT_TYPE QType, diff --git a/onnxruntime/test/contrib_ops/matmul_nbits_cpu.cc b/onnxruntime/test/contrib_ops/matmul_nbits_cpu.cc new file mode 100644 index 0000000000000..e67938668b724 --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_nbits_cpu.cc @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef ORT_MINIMAL_BUILD + +#include "core/common/span_utils.h" +#include "core/framework/tensor.h" +#include "core/mlas/inc/mlas_q4.h" +#include "core/session/inference_session.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" +#include "core/util/qmath.h" + +#include +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +namespace onnxruntime { +namespace test { + +TEST(MatMulNBitsCPU, MatMul2DSymPerN) { + // (100 x 52) X (52 x 288) + constexpr int64_t M = 100; + constexpr int64_t N = 288; + constexpr int64_t K = 52; + + const auto buf_size = MlasQ4GemmPackBSize(BlkQ4Sym, (size_t)N, (size_t)K); + if (buf_size == 0) { + GTEST_SKIP(); // operation not supported on this hardware platform yet. + } + + OpTester test("MatMulNBitsCPU", 1, kMSDomain); + test.AddAttribute("blk_quant_type", BlkQ4SymPerN); + test.AddAttribute("compute_type", 1); + + std::vector input0_vals(M * K); + float fv = -135.f; + for (auto& f : input0_vals) { + f = fv / 128; + fv++; + if (fv > 135.f) { + fv = -135.f; + } + } + + std::vector input1_f_vals(N * K); + int v = -2; + for (size_t i = 0; i < N * K; i++) { + if (v == 0 || v == -3 || v == 3) v++; + input1_f_vals[i] = (float)v; + if (++v >= 8) { + v = -8; + } + } + std::vector input1_vals(buf_size); + MlasQ4GemmPackB(BlkQ4SymPerN, input1_vals.data(), input1_f_vals.data(), (size_t)N, (size_t)K, (size_t)N); + + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_vals[m * K + k] * input1_f_vals[k * N + n]; + } + expected_vals[m * N + n] = sum; + } + } + + test.AddInput("A", {M, K}, input0_vals, false); + test.AddInput("B", {(int64_t)input1_vals.size()}, input1_vals, true); + test.AddInput("B_shape", {(int64_t)2}, {(int64_t)K, (int64_t)N}, true); + + test.AddOutput("Y", {M, N}, expected_vals); + + test.Run(); +} +} // namespace test +} // namespace onnxruntime + +#endif // ORT_MINIMAL_BUILD