Skip to content

Commit

Permalink
Fuse Clip->Q to Q (microsoft#10434)
Browse files Browse the repository at this point in the history
* Fuse Clip->Q to Q

* Remove unused variable argmax_node

* Remove braces around scalar initializer

* Move GetClipConstantMinMax under ORT_MINIMAL_BUILD

* Consider epsilon so we can fuse more cases
  • Loading branch information
yihonglyu authored Feb 3, 2022
1 parent 97b8f6f commit a405658
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 65 deletions.
65 changes: 2 additions & 63 deletions onnxruntime/core/optimizer/conv_activation_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,73 +5,12 @@
#include "core/graph/graph_utils.h"
#include "core/optimizer/conv_activation_fusion.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"

using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {

namespace {
// get min/max values from Clip if they are constant. Returns false if mutable and cannot be used
static bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& min, float& max) {
min = std::numeric_limits<float>::lowest();
max = std::numeric_limits<float>::max();

// Clip opset 6 has min and max as attributes. they're inputs from opset 11 on.
bool min_max_are_attributes = graph_utils::IsSupportedOptypeVersionAndDomain(node, "Clip", {6});
bool min_max_are_constant_values = true;

if (min_max_are_attributes) {
min = graph_utils::GetNodeAttribute(node, "min")->f();
max = graph_utils::GetNodeAttribute(node, "max")->f();
} else {
// update min/max if provided via a constant initializer
// return true if value is default or coming from a constant initializer and update 'value'
// return false if value is mutable
auto update_if_constant_value = [&graph](const Node& node, size_t input_idx, float& value) {
const auto& input_defs = node.InputDefs();
const NodeArg* input = (input_defs.size() > input_idx) ? input_defs[input_idx] : nullptr;

if (input == nullptr || !input->Exists()) {
// optional input not specified so using default value
return true;
}

bool is_constant = true;
const ONNX_NAMESPACE::TensorProto* initializer = graph_utils::GetConstantInitializer(graph, input->Name());
if (initializer) {
Initializer i(*initializer, graph.ModelPath());
switch (initializer->data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
value = *i.data<float>();
break;
// double isn't currently supported
//case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
// value = static_cast<float>(*i.data<double>());
// break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
value = math::halfToFloat(i.data<MLFloat16>()->val);
break;
default:
ORT_THROW("Unexpected data type for Clip input of ", initializer->data_type());
}
} else {
is_constant = false;
}

return is_constant;
};

// 'min' is input 1, 'max' is input 2. both are optional.
// if the input is constant, 'min' or 'max' is updated by the call to get_if_constant_value
min_max_are_constant_values = update_if_constant_value(node, 1, min) &&
update_if_constant_value(node, 2, max);
}

return min_max_are_constant_values;
}

} // namespace

Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
Expand Down Expand Up @@ -173,7 +112,7 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "alpha")->f());
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13})) {
float min, max;
if (GetClipConstantMinMax(graph, next_node, min, max)) {
if (optimizer_utils::GetClipConstantMinMax(graph, next_node, min, max)) {
activation_params.push_back(min);
activation_params.push_back(max);
} else {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "core/optimizer/nhwc_transformer.h"
#include "core/optimizer/noop_elimination.h"
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/qdq_transformer/clip_quantizelinear.h"
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
#include "core/optimizer/qdq_transformer/qdq_s8_to_u8.h"
#include "core/optimizer/qdq_transformer/relu_quantizelinear.h"
Expand Down Expand Up @@ -99,6 +100,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
rules.push_back(std::make_unique<ConvAddFusion>());
rules.push_back(std::make_unique<ConvMulFusion>());
rules.push_back(std::make_unique<ConvBNFusion>());
rules.push_back(std::make_unique<ClipQuantFusion>());
rules.push_back(std::make_unique<ReluQuantFusion>());
break;

Expand Down
107 changes: 107 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/optimizer/initializer.h"
#include "core/optimizer/qdq_transformer/clip_quantizelinear.h"
#include "core/optimizer/utils.h"
#include "core/graph/graph_utils.h"

using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {

static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& lower, float& upper) {
const auto& input_defs = node.InputDefs();

constexpr size_t input_cnt_required = 3;
if (input_defs.size() != input_cnt_required) {
return false;
}

constexpr size_t s_idx = 1;
const NodeArg* s_input = input_defs[s_idx];

const ONNX_NAMESPACE::TensorProto* s_tensor_proto = graph_utils::GetConstantInitializer(graph, s_input->Name());
if (!s_tensor_proto) {
return false;
}

Initializer s_initializer(*s_tensor_proto, graph.ModelPath());
if (s_initializer.dims().size() != 0 ||
s_initializer.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
return false;
}
const float scale = s_initializer.data<float>()[0];

constexpr size_t zp_idx = 2;
const NodeArg* zp_input = input_defs[zp_idx];

const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = graph_utils::GetConstantInitializer(graph, zp_input->Name());
if (!zp_tensor_proto) {
return false;
}

Initializer zp_initializer(*zp_tensor_proto, graph.ModelPath());
if (zp_initializer.dims().size() != 0) {
return false;
}

switch (zp_initializer.data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
const int8_t zero_point = zp_initializer.data<int8_t>()[0];
lower = scale * (-128 - zero_point);
upper = scale * (127 - zero_point);
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
const uint8_t zero_point = zp_initializer.data<uint8_t>()[0];
lower = scale * (0 - zero_point);
upper = scale * (255 - zero_point);
break;
}
default:
ORT_THROW("Unexpected data type for QuantizeLinear input y_zero_point of ", zp_initializer.data_type());
}
return true;
}

bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Clip", {1, 6, 11, 12, 13}) ||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
return false;
}

// if Clip is followed by QuantizeLinear, it can be fused into QuantizeLinear potentially
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "QuantizeLinear", {10, 13})) {
return false;
}

return true;
}

Status ClipQuantFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
float min, max;
if (!optimizer_utils::GetClipConstantMinMax(graph, node, min, max)) {
return Status::OK();
}

const Node& q_node = *graph.GetNode(node.OutputNodesBegin()->Index());

float lower, upper;
if (!GetQConstantLowerUpper(graph, q_node, lower, upper)) {
return Status::OK();
}

constexpr float epsilon = std::numeric_limits<float>::epsilon();
if (epsilon < min - lower || epsilon < upper - max) {
return Status::OK();
}

if (graph_utils::RemoveNode(graph, node)) {
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
}

return Status::OK();
}
} // namespace onnxruntime
29 changes: 29 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/rewrite_rule.h"

namespace onnxruntime {

/**
@Class ClipQuantFusion
Rewrite rule that fuses Clip into followed QuantizeLinear
*/
class ClipQuantFusion : public RewriteRule {
public:
ClipQuantFusion() noexcept : RewriteRule("ClipQuantRewrite") {}

std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Clip"};
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;

Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};

} // namespace onnxruntime
62 changes: 60 additions & 2 deletions onnxruntime/core/optimizer/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,67 @@ bool IsOperationDeterministic(const std::string& domain, const std::string& op)
if (domain.compare(kOnnxDomain) == 0) {
auto iter = std::find(kOnnxDomainNonDeterministicOps.begin(), kOnnxDomainNonDeterministicOps.end(), op);
return iter == kOnnxDomainNonDeterministicOps.end();
}
}
// Unknown domain. Assume the op is not deterministic.
return false;
return false;
}

bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& min, float& max) {
min = std::numeric_limits<float>::lowest();
max = std::numeric_limits<float>::max();

// Clip opset 1 and 6 has min and max as attributes. they're inputs from opset 11 on.
bool min_max_are_attributes = graph_utils::IsSupportedOptypeVersionAndDomain(node, "Clip", {1, 6});
bool min_max_are_constant_values = true;

if (min_max_are_attributes) {
min = graph_utils::GetNodeAttribute(node, "min")->f();
max = graph_utils::GetNodeAttribute(node, "max")->f();
} else {
// update min/max if provided via a constant initializer
// return true if value is default or coming from a constant initializer and update 'value'
// return false if value is mutable
auto update_if_constant_value = [&graph](const Node& node, size_t input_idx, float& value) {
const auto& input_defs = node.InputDefs();
const NodeArg* input = (input_defs.size() > input_idx) ? input_defs[input_idx] : nullptr;

if (input == nullptr || !input->Exists()) {
// optional input not specified so using default value
return true;
}

bool is_constant = true;
const ONNX_NAMESPACE::TensorProto* initializer = graph_utils::GetConstantInitializer(graph, input->Name());
if (initializer) {
Initializer i(*initializer, graph.ModelPath());
switch (initializer->data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
value = *i.data<float>();
break;
// double isn't currently supported
//case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
// value = static_cast<float>(*i.data<double>());
// break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
value = math::halfToFloat(i.data<MLFloat16>()->val);
break;
default:
ORT_THROW("Unexpected data type for Clip input of ", initializer->data_type());
}
} else {
is_constant = false;
}

return is_constant;
};

// 'min' is input 1, 'max' is input 2. both are optional.
// if the input is constant, 'min' or 'max' is updated by the call to get_if_constant_value
min_max_are_constant_values = update_if_constant_value(node, 1, min) &&
update_if_constant_value(node, 2, max);
}

return min_max_are_constant_values;
}

#endif // #if !defined(ORT_MINIMAL_BUILD)
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/optimizer/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ bool CheckOutputEdges(const Graph& graph, const Node& node, size_t expected_outp

bool IsOperationDeterministic(const std::string& domain, const std::string& op);

/** Get min/max values from Clip if they are constant.
@returns false if mutable and cannot be used.
*/
bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& min, float& max);

#endif // !#if !defined(ORT_MINIMAL_BUILD)

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand Down
Loading

0 comments on commit a405658

Please sign in to comment.