Skip to content

Commit

Permalink
pass compile
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Oct 30, 2023
1 parent 127b30b commit 53411e8
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 18 deletions.
19 changes: 14 additions & 5 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.

set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib)

set(MLAS_WITH_JBLAS ON)
#
# All hardware agnostic source files here
# hardware specific files would cause trouble in
Expand Down Expand Up @@ -44,6 +44,13 @@ endif()

set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)

function(add_jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_JBLAS)
set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF)
endfunction()

#TODO: set MASM flags properly
function(setup_mlas_source_for_windows)

Expand Down Expand Up @@ -197,8 +204,9 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
)
endif()
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
if(MLAS_WITH_JBLAS)
add_jblas()
endif()
else()
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
Expand Down Expand Up @@ -561,8 +569,9 @@ else()
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
if(MLAS_WITH_JBLAS)
add_jblas()
endif()
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cpu/matmul_nbits_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ class MatMulNBitsCPU final : public OpKernel {
const auto t = info.GetAttrOrDefault<int64_t>("blk_quant_type", static_cast<int64_t>(1));
const auto c = info.GetAttrOrDefault<int64_t>("compute_type", static_cast<int64_t>(1));
blk_quant_type_ = t == 0 ? BlkQ4Sym : BlkQ4Zp8;
compute_type_ = c == 0 ? FP32 : INT8;
compute_type_ = c == 0 ? CompFp32 : CompInt8;
}

Status Compute(OpKernelContext* context) const override;
MLAS_BLK_QUANT_TYPE blk_quant_type_{BlkQ4Zp8};
BLK_QUANT_COMPUTE_TYPE compute_type_{INT8};
BLK_QUANT_COMPUTE_TYPE compute_type_{CompInt8};
};

Status MatMulNBitsCPU::Compute(OpKernelContext* ctx) const {
Expand All @@ -40,7 +40,7 @@ Status MatMulNBitsCPU::Compute(OpKernelContext* ctx) const {
const Tensor* b = ctx->Input<Tensor>(1);
const auto blob_shape = b->Shape();
ORT_ENFORCE(blob_shape.NumDimensions() == 1, "Second input of MatMulNBitsCPU must be a 1D blob!");
const auto blob_len = blob_shape[0];
//const auto blob_len = blob_shape[0];

const Tensor* bshape_tr = ctx->Input<Tensor>(2);
TensorShape b_shape(bshape_tr->DataAsSpan<int64_t>());
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2042,10 +2042,10 @@ no offset
}

y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT);
auto blk_quant_v = getAttribute(ctx, "blk_quant_type", 1);
auto compute_v = getAttribute(ctx, "compute_type", 1);
MLAS_BLK_QUANT_TYPE blk_quant_type = BlkQ4SymPerN;
BLK_QUANT_COMPUTE_TYPE compute_type = compute_v == 0 ? FP32 : INT8;
/* auto blk_quant_v = getAttribute(ctx, "blk_quant_type", 1);
auto compute_v = getAttribute(ctx, "compute_type", 1);*/
/*MLAS_BLK_QUANT_TYPE blk_quant_type = BlkQ4SymPerN;
BLK_QUANT_COMPUTE_TYPE compute_type = compute_v == 0 ? CompFp32 : CompInt8;*/

//matmulQ4ShapeInference(ctx, 0, 1, 2, blk_quant_type);
}));
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ typedef enum {
* @brief Define compute types of block quantization
*/
typedef enum {
FP32 = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */
INT8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */
CompFp32 = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */
CompInt8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */
} BLK_QUANT_COMPUTE_TYPE;


Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/core/mlas/lib/q4gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ Module Name:
--*/

#include "q4gemm.h"
#if !defined(__APPLE__)

#ifdef MLAS_JBLAS
#include "jblas/jit_blas_weight_compression.h"
#endif

Expand Down Expand Up @@ -139,7 +140,7 @@ MlasQ4GemmBatchDriver(MLAS_BLK_QUANT_TYPE QType,
});
}

#if !defined(__APPLE__)
#ifdef MLAS_JBLAS
template <class T, JBLAS_ISA ISA>
using WeiS4ClipFp32PerN =
jblas::prologue::weight_comp::gemm_kblcok::WeightS4ClipScaleFp32PerN<T, ISA>;
Expand Down Expand Up @@ -225,7 +226,7 @@ MlasQ4GemmBatch(MLAS_BLK_QUANT_TYPE QType,
MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool);
}


#ifdef MLAS_JBLAS
void MLASCALL
JblasQ4GemmBatch(BLK_QUANT_COMPUTE_TYPE CType,
MLAS_BLK_QUANT_TYPE QType,
Expand All @@ -240,7 +241,7 @@ JblasQ4GemmBatch(BLK_QUANT_COMPUTE_TYPE CType,
return JblasQ4GemmBatchDriver(CType, M, N, K, BatchN, DataParams, ThreadPool);
}
}

#endif

void MLASCALL
MlasQ8Q4GemmBatch(MLAS_BLK_QUANT_TYPE QType,
Expand Down
85 changes: 85 additions & 0 deletions onnxruntime/test/contrib_ops/matmul_nbits_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifndef ORT_MINIMAL_BUILD

#include "core/common/span_utils.h"
#include "core/framework/tensor.h"
#include "core/mlas/inc/mlas_q4.h"
#include "core/session/inference_session.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/framework/test_utils.h"
#include "test/optimizer/graph_transform_test_builder.h"
#include "test/providers/provider_test_utils.h"
#include "test/util/include/default_providers.h"
#include "core/util/qmath.h"

#include <chrono>
#include <random>

#include "gtest/gtest.h"
#include "gmock/gmock.h"

namespace onnxruntime {
namespace test {

TEST(MatMulNBitsCPU, MatMul2DSymPerN) {
// (100 x 52) X (52 x 288)
constexpr int64_t M = 100;
constexpr int64_t N = 288;
constexpr int64_t K = 52;

const auto buf_size = MlasQ4GemmPackBSize(BlkQ4Sym, (size_t)N, (size_t)K);
if (buf_size == 0) {
GTEST_SKIP(); // operation not supported on this hardware platform yet.
}

OpTester test("MatMulNBitsCPU", 1, kMSDomain);
test.AddAttribute<int64_t>("blk_quant_type", BlkQ4SymPerN);
test.AddAttribute<int64_t>("compute_type", 1);

std::vector<float> input0_vals(M * K);
float fv = -135.f;
for (auto& f : input0_vals) {
f = fv / 128;
fv++;
if (fv > 135.f) {
fv = -135.f;
}
}

std::vector<float> input1_f_vals(N * K);
int v = -2;
for (size_t i = 0; i < N * K; i++) {
if (v == 0 || v == -3 || v == 3) v++;
input1_f_vals[i] = (float)v;
if (++v >= 8) {
v = -8;
}
}
std::vector<uint8_t> input1_vals(buf_size);
MlasQ4GemmPackB(BlkQ4SymPerN, input1_vals.data(), input1_f_vals.data(), (size_t)N, (size_t)K, (size_t)N);

std::vector<float> expected_vals(M * N);
for (int64_t m = 0; m < M; m++) {
for (int64_t n = 0; n < N; n++) {
float sum = 0.0f;
for (int64_t k = 0; k < K; k++) {
sum += input0_vals[m * K + k] * input1_f_vals[k * N + n];
}
expected_vals[m * N + n] = sum;
}
}

test.AddInput<float>("A", {M, K}, input0_vals, false);
test.AddInput<uint8_t>("B", {(int64_t)input1_vals.size()}, input1_vals, true);
test.AddInput<int64_t>("B_shape", {(int64_t)2}, {(int64_t)K, (int64_t)N}, true);

test.AddOutput<float>("Y", {M, N}, expected_vals);

test.Run();
}
} // namespace test
} // namespace onnxruntime

#endif // ORT_MINIMAL_BUILD

0 comments on commit 53411e8

Please sign in to comment.