From 69e49b3444796417d48488efff597964a2aa087c Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Mon, 8 Jul 2024 11:30:54 -0700 Subject: [PATCH] fix build --- .../selectors_actions/qdq_actions.cc | 44 +++++++++---------- .../selectors_actions/qdq_actions.h | 1 + .../selectors_actions/qdq_selectors.cc | 2 +- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index b2f9c79b455f3..c1216d4c73ae9 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -339,24 +339,24 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, // to what we need. But it does not handle external data. Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); - std::optional zp_src; + std::optional> zp_src_ptr; Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, graph.GenerateNodeArgName(weight_arg->Name() + "_T"), std::vector{N, quant_num, blob_bytes}); Initializer scale_dst(static_cast(scale_src.data_type()), graph.GenerateNodeArgName(scale_arg->Name() + "_T"), std::vector{N * quant_num}); - std::optional zp_dst; + std::optional> zp_dst_ptr; if (zp_tensor_proto) { - zp_src.emplace(Initializer(*zp_tensor_proto, graph.ModelPath())); - zp_dst.emplace(Initializer(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(zp_arg->Name() + "_T"), - std::vector{N * ((quant_num + 1) / 2)})); + zp_src_ptr.emplace(std::make_unique(*zp_tensor_proto, graph.ModelPath())); + zp_dst_ptr.emplace(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(zp_arg->Name() + "_T"), + std::vector{N * ((quant_num + 1) / 2)})); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { - zp_dst.emplace(Initializer(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), - std::vector{N * ((quant_num + 1) / 2)})); + zp_dst_ptr.emplace(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), + std::vector{N * ((quant_num + 1) / 2)})); } if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { @@ -364,10 +364,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, true, static_cast(K), static_cast(N), @@ -377,10 +377,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, true, static_cast(K), static_cast(N), @@ -392,10 +392,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, true, static_cast(K), static_cast(N), @@ -406,10 +406,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst ? zp_dst->data() : nullptr, + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, true, static_cast(K), static_cast(N), @@ -420,15 +420,15 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, ONNX_NAMESPACE::TensorProto weight_T_tp; ONNX_NAMESPACE::TensorProto scale_T_tp; - std::unique_ptr zp_T_tp_ptr = nullptr; + std::optional> zp_T_tp_ptr; // TODO(fajin): external_data to memory location to avoid arena allocation // https://github.com/microsoft/onnxruntime/pull/12465 weight_dst.ToProto(weight_T_tp); scale_dst.ToProto(scale_T_tp); - if (zp_dst) { + if (zp_dst_ptr) { zp_T_tp_ptr = std::make_unique(); - zp_dst->ToProto(*zp_T_tp_ptr); + zp_dst_ptr.value()->ToProto(*zp_T_tp_ptr.value()); } auto& input_defs = replacement_node.MutableInputDefs(); @@ -438,10 +438,10 @@ Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, replacement_node.MutableInputArgsCount().push_back(1); if (zp_T_tp_ptr) { - input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr)); + input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr.value())); replacement_node.MutableInputArgsCount().push_back(1); } - + return Status::OK(); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index c73be519871cf..d80c3f9d183bf 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -8,6 +8,7 @@ #include #include "core/optimizer/selectors_actions/actions.h" +#include "core/platform/threadpool.h" namespace onnxruntime { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 96dc10a326692..692db4eb327b5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -421,7 +421,7 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, ORT_UNUSED_PARAMETER(q_nodes); const auto& graph = graph_viewer.GetGraph(); -// MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output + // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { return false; }