From a47254eaef87a4572078d3165f4b9d1c4b4c52b1 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga <adlizarraga@microsoft.com> Date: Tue, 24 Sep 2024 21:02:17 -0700 Subject: [PATCH] Remove empty (DQ -> Q -> graph output) sequence in TransposeOptimizer (#22172) ### Description Updates the TransposeOptimizer to also remove empty (DQ -> Q) sequences that occur at a graph output. An empty DQ->Q sequence results from a Transpose being optimized out. Consider the following example model: ![image](https://github.com/user-attachments/assets/4e7bc4eb-ea8a-463b-9672-c4ec5ef779b2) The TransposeOptimizer removes the final Transpose and leaves an empty DQ->Q->output_0 sequence. This PR ensures that the final DQ->Q is also removed. ### Motivation and Context Models with quantized output can run on QNN EP. The inference latency of a customer model is impacted by the unnecessary DQ->Q sequence at the output. --------- Co-authored-by: Scott McKay <skottmckay@gmail.com> --- .../onnx_transpose_optimization.cc | 63 +++++++++- .../optimizer/transpose_optimizer_test.cc | 84 +++++++++++++ ...se_optimizer_empty_dq_q_at_output_model.py | 117 ++++++++++++++++++ ..._optimizer_empty_dq_q_at_graph_output.onnx | Bin 0 -> 1410 bytes 4 files changed, 258 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/test/testdata/make_transpose_optimizer_empty_dq_q_at_output_model.py create mode 100644 onnxruntime/test/testdata/transpose_optimizer_empty_dq_q_at_graph_output.onnx diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 5d689a9d933e8..470838d36ec1c 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -2749,7 +2749,9 @@ static bool CanModifyNode(const OptimizerCtx& ctx, const api::NodeRef& node) { /// <summary> /// Try to remove empty DQ -> Q pair that results from moving a Transpose downstream or a Transpose being canceled out. -/// (DQ -> Q -> consumer node) => consumer node +/// Handles the following scenarios: +/// - (DQ -> Q -> consumer node) => consumer node +/// - (parent node -> DQ -> Q -> graph output) => parent node -> graph output /// </summary> /// <param name="ctx">Optimizer context</param> /// <param name="q_node">QuantizeLinear node</param> @@ -2764,12 +2766,27 @@ static bool TryRemoveEmptyDQQ(OptimizerCtx& ctx, api::NodeRef& q_node) { } auto& dq_node = *input_node; - std::unique_ptr<api::NodeRef> single_consumer_node; - // remove empty DQ -> Q before a consumer node if the DQ and Q have matching types, scale and zp. - if (OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, single_consumer_node) && - OutputValueHasSingleConsumerNode(ctx.graph, q_node, 0, single_consumer_node) && - CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) { + // DQ should have a single consumer (the Q) + std::unique_ptr<api::NodeRef> dq_consumer_node; + if (!OutputValueHasSingleConsumerNode(ctx.graph, dq_node, 0, dq_consumer_node)) { + return false; + } + + // The DQ and Q should have matching types, scale and zp. + if (!CheckQDQNodePairMatch(ctx.graph, dq_node, q_node)) { + return false; + } + + std::string_view q_output = q_node.Outputs()[0]; + auto q_consumers = ctx.graph.GetValueConsumers(q_output); + const size_t num_q_consumers = q_consumers->nodes.size(); + const bool q_has_single_consumer = q_consumers->comprehensive && (num_q_consumers == 1); + + // (DQ -> Q -> consumer node) => consumer node + if (q_has_single_consumer) { + std::unique_ptr<api::NodeRef> single_consumer_node = std::move(q_consumers->nodes[0]); + // connect Q consumer to DQ input for (size_t j_idx = 0, j_end = single_consumer_node->Inputs().size(); j_idx < j_end; ++j_idx) { if (single_consumer_node->Inputs()[j_idx] == q_node.Outputs()[0]) { @@ -2787,6 +2804,40 @@ static bool TryRemoveEmptyDQQ(OptimizerCtx& ctx, api::NodeRef& q_node) { return true; } + // (parent node -> DQ -> Q -> graph output) => (parent node -> graph output) + if (num_q_consumers == 0 && ctx.graph.IsGraphOutput(q_output)) { + // Get the DQ's parent node. + std::string_view dq_input = dq_node.Inputs()[0]; + auto dq_parent_node = ctx.graph.GetNodeProducingOutput(dq_input); + if (!dq_parent_node) { + return false; // Don't handle DQ that consumes a graph input. + } + + // Find index of output from DQ's parent node + auto dq_parent_outputs = dq_parent_node->Outputs(); + size_t dq_parent_output_index = 0; + for (dq_parent_output_index = 0; dq_parent_output_index < dq_parent_outputs.size(); ++dq_parent_output_index) { + if (dq_parent_outputs[dq_parent_output_index] == dq_input) break; + } + + // The DQ's parent should only have a single consumer (i.e., the DQ itself). + std::unique_ptr<api::NodeRef> dq_parent_consumer; + if (!OutputValueHasSingleConsumerNode(ctx.graph, *dq_parent_node, dq_parent_output_index, dq_parent_consumer)) { + return false; + } + + // Move Q's output to come out of DQ's parent node so the graph output value name is maintained. + dq_node.SetInput(0, ""); // Disconnect DQ from its parent first. + ctx.graph.MoveOutput(q_node, 0, *dq_parent_node, dq_parent_output_index); + + // Disconnect Q and remove both DQ and Q from the graph. + q_node.SetInput(0, ""); + ctx.graph.RemoveNode(dq_node); + ctx.graph.RemoveNode(q_node); + + return true; + } + return false; } diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 8d90e48db97c1..35ba1a3369597 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -4964,6 +4964,90 @@ TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerAxisDQUnsqueezeTranspose) { testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>())); } +// Test that the TransposeOptimizer's qdq-fixup pass converts the sequence (Op -> DQ -> Q -> GRAPH_OUTPUT) to +// (Op -> GRAPH_OUTPUT). +TEST(TransposeOptimizerTests, RemoveEmptyDQQAtGraphOutput) { + auto model_uri = ORT_TSTR("testdata/transpose_optimizer_empty_dq_q_at_graph_output.onnx"); + + RandomValueGenerator random{123}; + std::vector<int64_t> input_dims{1, 3, 4, 4}; + std::vector<float> input0_data = random.Gaussian<float>(input_dims, 0.0f, 1.0f); + + auto allocators = TestCPUExecutionProvider()->CreatePreferredAllocators(); + OrtValue input0; + CreateMLValue<float>(allocators[0], input_dims, input0_data, &input0); + + NameMLValMap feeds{{"input0", input0}}; + + std::vector<std::string> output_names{"output0"}; + std::vector<OrtValue> fetches_orig; + std::vector<OrtValue> fetches; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + so.graph_optimization_level = TransformerLevel::Default; // off + + // get results with no modifications to the model + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); + } + + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + + Graph& graph = session.GetMutableGraph(); + CPUAllocator allocator; + + namespace alias_oto = onnx_transpose_optimization; + auto api_graph = MakeApiGraph(graph, + TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + /*new_node_ep*/ nullptr); + + alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph); + ASSERT_EQ(result.error_msg, std::nullopt); + ASSERT_TRUE(result.graph_modified); + ASSERT_TRUE(graph.GraphResolveNeeded()); + ASSERT_STATUS_OK(graph.Resolve()); + + // Use this hack to save model for viewing if needed + // ASSERT_STATUS_OK(Model::Save(const_cast<Model&>(session.GetModel()), + // ToPathString("updated_model_empty_dqq_graph_output.onnx"))); + + std::map<std::string, int> op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 0) << "2 pre-existing Transposes at the I/O cancel. "; + + // Check that the graph ends in the sequence (Mul -> Q -> GRAPH_OUTPUT) + Node* mul_node = nullptr; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Mul") { + mul_node = &node; + break; + } + } + + // Mul should be followed by a Q node. + ASSERT_TRUE(mul_node != nullptr); + const auto& last_q_node = *(mul_node->OutputNodesBegin()); + EXPECT_EQ(last_q_node.OpType(), "QuantizeLinear"); + + // The Q node should generate the graph's output. + const std::string& q_out_name = last_q_node.OutputDefs()[0]->Name(); + const std::string& graph_out_name = graph.GetOutputs()[0]->Name(); + EXPECT_EQ(q_out_name, graph_out_name); + + // Run optimized model. + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get<Tensor>().DataAsSpan<uint8_t>(), + testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<uint8_t>())); +} + // Tests the in-place unsqueeze and transpose of a constant consumed by a per-axis DQ. TEST(TransposeOptimizerTests, InPlaceUnsqueezeTransposePerAxisDQ) { // Model contains a Mul with a constant/broadcastable/per-axis DQ input[1]. diff --git a/onnxruntime/test/testdata/make_transpose_optimizer_empty_dq_q_at_output_model.py b/onnxruntime/test/testdata/make_transpose_optimizer_empty_dq_q_at_output_model.py new file mode 100644 index 0000000000000..3666f2299de46 --- /dev/null +++ b/onnxruntime/test/testdata/make_transpose_optimizer_empty_dq_q_at_output_model.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import onnx + + +def make_model(model_path: str): + """ + Creates a QDQ model with a (DQ -> Transpose -> Q -> GRAPH OUTPUT) sequence. The Transpose is optimized out + and the TransposeOptimizer should also remove the empty (DQ -> Q) sequence. + """ + input0_shape = (1, 3, 4, 4) + + inputs = [onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, input0_shape)] + outputs = [onnx.helper.make_tensor_value_info("output0", onnx.TensorProto.UINT8, None)] + + mul_weight_scale_data = np.array(1.0, dtype=np.float32) + mul_weight_zp_data = np.array(0, dtype=np.int8) + + initializers = [ + onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "scale_1"), + onnx.numpy_helper.from_array(np.array(128, dtype=np.uint8), "zp_128"), + onnx.numpy_helper.from_array(np.array(1.0 / 255.0, dtype=np.float32), "scale_inv_255"), + onnx.numpy_helper.from_array(np.array(0, dtype=np.uint8), "zp_0"), + onnx.numpy_helper.from_array(mul_weight_scale_data, "mul_weight_scale"), + onnx.numpy_helper.from_array(mul_weight_zp_data, "mul_weight_zp"), + ] + nodes = [] + + # Transpose to channel-last + tp0_node = onnx.helper.make_node("Transpose", ["input0"], ["tp0_out"], name="tp0_node", perm=(0, 2, 3, 1)) + nodes.append(tp0_node) + + # Q_0 + q0_node = onnx.helper.make_node("QuantizeLinear", ["tp0_out", "scale_1", "zp_128"], ["q0_out"], name="q0_node") + nodes.append(q0_node) + + # DQ_0 + dq0_node = onnx.helper.make_node("DequantizeLinear", ["q0_out", "scale_1", "zp_128"], ["dq0_out"], name="dq0_node") + nodes.append(dq0_node) + + # Sigmoid + sigmoid_node = onnx.helper.make_node("Sigmoid", ["dq0_out"], ["sigmoid_out"], name="sigmoid_node") + nodes.append(sigmoid_node) + + # Q_1 + q1_node = onnx.helper.make_node( + "QuantizeLinear", ["sigmoid_out", "scale_inv_255", "zp_0"], ["q1_out"], name="q1_node" + ) + nodes.append(q1_node) + + # DQ_1 + dq1_node = onnx.helper.make_node( + "DequantizeLinear", ["q1_out", "scale_inv_255", "zp_0"], ["dq1_out"], name="dq1_node" + ) + nodes.append(dq1_node) + + # DQ for mul input[1] + mul_weight_i8_data = np.array([1, 2, 3], dtype=np.int8) + mul_weight = onnx.numpy_helper.from_array(mul_weight_i8_data, "mul_weight") + initializers.append(mul_weight) + + nodes.append( + onnx.helper.make_node( + "DequantizeLinear", + ["mul_weight", "mul_weight_scale", "mul_weight_zp"], + ["mul_input_1"], + name="dq_mul_input_1", + ) + ) + + # Mul + mul_node = onnx.helper.make_node("Mul", ["dq1_out", "mul_input_1"], ["mul_out"], name="mul_node") + nodes.append(mul_node) + + # Q_2 + q2_node = onnx.helper.make_node("QuantizeLinear", ["mul_out", "scale_inv_255", "zp_0"], ["q2_out"], name="q2_node") + nodes.append(q2_node) + + # DQ_2 + dq2_node = onnx.helper.make_node( + "DequantizeLinear", ["q2_out", "scale_inv_255", "zp_0"], ["dq2_out"], name="dq2_node" + ) + nodes.append(dq2_node) + + # Transpose to channel-first + tp1_node = onnx.helper.make_node("Transpose", ["dq2_out"], ["tp1_out"], name="tp1_node", perm=(0, 3, 1, 2)) + nodes.append(tp1_node) + + # Q_3 to graph output + nodes.append( + onnx.helper.make_node("QuantizeLinear", ["tp1_out", "scale_inv_255", "zp_0"], ["output0"], name="q3_node") + ) + + graph = onnx.helper.make_graph( + nodes, + "transpose_opt_empty_dqq_graph_output", + inputs, + outputs, + initializer=initializers, + ) + opset_imports = [ + onnx.helper.make_opsetid("", 19), + ] + qdq_model = onnx.helper.make_model(graph, opset_imports=opset_imports) + + print("[INFO]: Running onnx.checker on qdq model") + qdq_model = onnx.shape_inference.infer_shapes(qdq_model) + onnx.checker.check_model(qdq_model, True) + + print(f"[INFO]: Saving {model_path}") + onnx.save_model(qdq_model, model_path) + + +if __name__ == "__main__": + make_model("transpose_optimizer_empty_dq_q_at_graph_output.onnx") diff --git a/onnxruntime/test/testdata/transpose_optimizer_empty_dq_q_at_graph_output.onnx b/onnxruntime/test/testdata/transpose_optimizer_empty_dq_q_at_graph_output.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e5a0675c61cae8ec2ef73c6287a44bad80834ed0 GIT binary patch literal 1410 zcmaJ=O>fgc5cQX`W<o006c9N<1s6+HNl+zdC2AQ?790?q5YoyJD|MClBe6l`6#fbS zk}EUob#2#94wiTA_ujsld27POKY%wdh~l)!CpOE|2~UcAY|<i5=HcnRFY6%Qq{$|9 zj=@O7^~zOT&DCA?ugc&ODhja8k6;<{DGatLpUz&}gRInI_!d$}pNb&Pqiy&hio;+H zZ$U!3TsD_vlNN$)y$`cC|0Pr~d$@@ft0bBW^Y^Qb;IYr;FL(z{6pTv0QT&t7&d<R> z*Aw(S6%q{Jy`DYvEFk@SG*k+pZ>iAr{S9DMEcvf6T3qK~)oU&+!Km?No7zL#iUINI z_-LN-#{RUcenwp>EjF1pBbt~kEktqp$6^UrQ2psIX^_oiOJ@69A|U;IbW;joOQb^A z^aYs~Fv_58AuWW@-V9;_U5zs6QU(vr1dFD0k6HGrw*vclUS*soY0kq{n*ZkWEaQuH zkX}>u$ajt{)vH488;Y{Kaz>Ws$+-LK&UuW8cC38k&u(|=3=t;+=BtV`MpS3$Qt-sm zz4rKt@zB)qdm}Wyu4?+%r?>a^5uBK+ZGmR0_~`f;^n8B?OcrkwMC(Fs5F#(VfYd5N zfz)pBoboo_5vdU<A}{P-kr%Dow--idiCG>#tA_L$egduJ1DMzya;sCh@*W(tP)4!c XvpeP1kUk5dt$i+UhvONbSP%XK@zP@+ literal 0 HcmV?d00001