Skip to content

Commit

Permalink
fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Jul 8, 2024
1 parent c2d891b commit 69e49b3
Showing 3 changed files with 24 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -339,35 +339,35 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
// to what we need. But it does not handle external data.
Initializer weight_src(*weight_tensor_proto, graph.ModelPath());
Initializer scale_src(*scale_tensor_proto, graph.ModelPath());
std::optional<Initializer> zp_src;
std::optional<std::unique_ptr<Initializer>> zp_src_ptr;
Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName(weight_arg->Name() + "_T"),
std::vector<int64_t>{N, quant_num, blob_bytes});
Initializer scale_dst(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(scale_src.data_type()),
graph.GenerateNodeArgName(scale_arg->Name() + "_T"),
std::vector<int64_t>{N * quant_num});
std::optional<Initializer> zp_dst;
std::optional<std::unique_ptr<Initializer>> zp_dst_ptr;

if (zp_tensor_proto) {
zp_src.emplace(Initializer(*zp_tensor_proto, graph.ModelPath()));
zp_dst.emplace(Initializer(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName(zp_arg->Name() + "_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)}));
zp_src_ptr.emplace(std::make_unique<Initializer>(*zp_tensor_proto, graph.ModelPath()));
zp_dst_ptr.emplace(std::make_unique<Initializer>(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName(zp_arg->Name() + "_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)}));
} else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
zp_dst.emplace(Initializer(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)}));
zp_dst_ptr.emplace(std::make_unique<Initializer>(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)}));
}

if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) {
MlasQDQTransposeBlockwiseQuantized<float, 4, true>(
weight_src.DataAsByteSpan().data(),
scale_src.data<float>(),
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<float>(),
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
@@ -377,10 +377,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
MlasQDQTransposeBlockwiseQuantized<float, 4, false>(
weight_src.DataAsByteSpan().data(),
scale_src.data<float>(),
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<float>(),
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
@@ -392,10 +392,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
MlasQDQTransposeBlockwiseQuantized<MLFloat16, 4, true>(
weight_src.DataAsByteSpan().data(),
scale_src.data<MLFloat16>(),
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<MLFloat16>(),
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
@@ -406,10 +406,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
MlasQDQTransposeBlockwiseQuantized<MLFloat16, 4, false>(
weight_src.DataAsByteSpan().data(),
scale_src.data<MLFloat16>(),
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<MLFloat16>(),
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
@@ -420,15 +420,15 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,

ONNX_NAMESPACE::TensorProto weight_T_tp;
ONNX_NAMESPACE::TensorProto scale_T_tp;
std::unique_ptr<ONNX_NAMESPACE::TensorProto> zp_T_tp_ptr = nullptr;
std::optional<std::unique_ptr<ONNX_NAMESPACE::TensorProto>> zp_T_tp_ptr;

// TODO(fajin): external_data to memory location to avoid arena allocation
// https://github.com/microsoft/onnxruntime/pull/12465
weight_dst.ToProto(weight_T_tp);
scale_dst.ToProto(scale_T_tp);
if (zp_dst) {
if (zp_dst_ptr) {
zp_T_tp_ptr = std::make_unique<ONNX_NAMESPACE::TensorProto>();
zp_dst->ToProto(*zp_T_tp_ptr);
zp_dst_ptr.value()->ToProto(*zp_T_tp_ptr.value());
}

auto& input_defs = replacement_node.MutableInputDefs();
@@ -438,10 +438,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
replacement_node.MutableInputArgsCount().push_back(1);

if (zp_T_tp_ptr) {
input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr));
input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr.value()));
replacement_node.MutableInputArgsCount().push_back(1);
}

return Status::OK();
}

Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
#include <vector>

#include "core/optimizer/selectors_actions/actions.h"
#include "core/platform/threadpool.h"

namespace onnxruntime {

Original file line number Diff line number Diff line change
@@ -421,7 +421,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer,
ORT_UNUSED_PARAMETER(q_nodes);
const auto& graph = graph_viewer.GetGraph();

// MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output
// MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output
if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) {
return false;
}

0 comments on commit 69e49b3

Please sign in to comment.