-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use gemm to replace matmul + add (#234)
* matmul add fusion * add shape check on Gemm input C * walk around the issue with RemoveNode * update the version support * If MatMul has shape [K] * [K, N], update it to [1, K] * [K, N], so that it can work for Gemm * Fuse Gemm+Activation into FusedGemm * test * revert the change which fuse the matmul with shape [K]*[K, N] to Gemm as shape [1, K]*[K, N], this may cause runtime failure, as the we can't change input data shape. * revert the change which change the shape for Matmul from [K]*[K, N] to [1, K]*[K, N]. It enables fuse Matmul + Add to Gemm, but the issue is the data is not aware of this, so the data shape is still [K]*[K, N] and cause runtime issue. * 1. Fix build issue for CUDA 2. Update Gemm so that we can fuse Matmul [K] * [K, N] + Add [1, N] into Gemm with shape [1,K] * [K, N] + [1, N] * Fix build issue * Fuse the activation node even it connects the output * resolve the merge conflicts * Add test model for Gemm+Activation fusion
- Loading branch information
Showing
24 changed files
with
475 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "fused_gemm.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( | ||
FusedGemm, | ||
1, | ||
float, | ||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), | ||
FusedGemm<float, float, float, float>); | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/cpu/math/gemm.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
template <typename T_X, | ||
typename T_W, | ||
typename T_B, | ||
typename T_Y> | ||
class FusedGemm : public Gemm<T_X, T_W, T_B, T_Y> { | ||
public: | ||
FusedGemm(const OpKernelInfo& info) : Gemm<T_X, T_W, T_B, T_Y>(info) { | ||
Gemm<T_X, T_W, T_B, T_Y>::activation_ = info.GetAttrOrDefault<std::string>("activation", ""); | ||
Gemm<T_X, T_W, T_B, T_Y>::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f); | ||
} | ||
|
||
Status Compute(OpKernelContext* context) const override { | ||
return Gemm<T_X, T_W, T_B, T_Y>::Compute(context); | ||
} | ||
}; | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/graph/initializer.h" | ||
#include "core/graph/gemm_activation_fusion.h" | ||
#include "core/graph/graph_utils.h" | ||
#include <deque> | ||
|
||
using namespace onnx; | ||
using namespace ::onnxruntime::common; | ||
namespace onnxruntime { | ||
|
||
namespace { | ||
bool IsFusableActivation(const Node& node) { | ||
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6); | ||
} | ||
|
||
void HandleActivationNodeEdges(Graph& g, const Node& act, Node& fused_gemm) { | ||
Node::EdgeSet output_edges; | ||
for (auto it = act.OutputEdgesBegin(); it != act.OutputEdgesEnd(); ++it) { | ||
output_edges.insert(*it); | ||
} | ||
|
||
//remove output edge of activation | ||
//connect fused_gemm node and nodes after activation nodes | ||
for (auto& output_edge : output_edges) { | ||
NodeIndex dst_node_index = output_edge.GetNode().Index(); | ||
int src_arg_index = output_edge.GetSrcArgIndex(); | ||
int dst_arg_index = output_edge.GetDstArgIndex(); | ||
g.RemoveEdge(act.Index(), dst_node_index, src_arg_index, dst_arg_index); | ||
g.AddEdge(fused_gemm.Index(), dst_node_index, 0, dst_arg_index); | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
Status GemmActivationFusion::Apply(Graph& graph, bool& modified) const { | ||
GraphViewer graph_viewer(graph); | ||
const auto& order = graph_viewer.GetNodesInTopologicalOrder(); | ||
|
||
std::deque<onnxruntime::NodeIndex> removed_nodes; | ||
for (auto index : order) { | ||
auto node = graph.GetNode(index); | ||
if (!(utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 7) || utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 9)) || node->GetOutputEdgesCount() != 1) { | ||
continue; | ||
} | ||
const Node& next_node = *(node->OutputNodesBegin()); | ||
if (!IsFusableActivation(next_node)) { | ||
continue; | ||
} | ||
|
||
Node* gemm_node = node; | ||
const Node& act_node = next_node; | ||
|
||
Node& fused_gemm = graph.AddNode(graph.GenerateNodeName("fused " + gemm_node->Name()), "FusedGemm", | ||
"fused Gemm " + gemm_node->Name() + "with activation " + act_node.OpType(), | ||
gemm_node->MutableInputDefs(), | ||
graph.IsNodeOutputsInGraphOutputs(next_node) ? const_cast<Node&>(act_node).MutableOutputDefs() : gemm_node->MutableOutputDefs(), | ||
&gemm_node->GetAttributes(), | ||
"com.microsoft"); | ||
|
||
//Add a new attribute to specify the activation type | ||
fused_gemm.AddAttribute("activation", act_node.OpType()); | ||
|
||
//Add optional attributes for activations | ||
if (act_node.OpType() == "LeakyRelu") { | ||
const NodeAttributes attrs = act_node.GetAttributes(); | ||
for (auto it = attrs.begin(); it != attrs.end(); ++it) { | ||
fused_gemm.AddAttribute("leaky_relu_" + it->first, it->second); | ||
} | ||
} | ||
|
||
if (!graph.IsNodeOutputsInGraphOutputs(next_node)) { | ||
HandleActivationNodeEdges(graph, act_node, fused_gemm); | ||
|
||
// Replace the input of the node following activation node | ||
const NodeArg* act_output_def = act_node.OutputDefs()[0]; | ||
NodeArg* fused_gemm_output_def = fused_gemm.MutableOutputDefs()[0]; | ||
for (auto it = act_node.OutputNodesBegin(); it != act_node.OutputNodesEnd(); ++it) { | ||
auto output_node = graph.GetNode((*it).Index()); | ||
if (!output_node) { | ||
return Status(ONNXRUNTIME, INVALID_ARGUMENT); | ||
} | ||
|
||
auto& input_defs = output_node->MutableInputDefs(); | ||
for (auto& def : input_defs) { | ||
if (def == act_output_def) { | ||
def = fused_gemm_output_def; | ||
} | ||
} | ||
} | ||
} | ||
|
||
removed_nodes.push_front(gemm_node->Index()); | ||
removed_nodes.push_front(act_node.Index()); | ||
} | ||
|
||
for (auto node : removed_nodes) { | ||
graph.RemoveNode(node); | ||
} | ||
|
||
if (!removed_nodes.empty()) { | ||
modified = true; | ||
ORT_RETURN_IF_ERROR(graph.Resolve()); | ||
} | ||
return Status::OK(); | ||
} | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/graph/graph_transformer.h" | ||
|
||
namespace onnxruntime { | ||
|
||
class GemmActivationFusion : public onnxruntime::GraphTransformer { | ||
public: | ||
GemmActivationFusion() noexcept : onnxruntime::GraphTransformer("GemmActivationFusion", "Fusing Activation into Gemm") {} | ||
Status Apply(onnxruntime::Graph& graph, bool& modified) const override; | ||
}; | ||
|
||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/graph/initializer.h" | ||
#include "core/graph/matmul_add_fusion.h" | ||
#include "core/graph/graph_utils.h" | ||
#include <deque> | ||
|
||
using namespace onnx; | ||
using namespace ::onnxruntime::common; | ||
namespace onnxruntime { | ||
|
||
Status MatMulAddFusion::Apply(Graph& graph, bool& modified) const { | ||
GraphViewer graph_viewer(graph); | ||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); | ||
std::deque<onnxruntime::NodeIndex> removed_nodes; | ||
|
||
for (auto node_index : node_topology_list) { | ||
auto node = graph.GetNode(node_index); | ||
if (nullptr == node || | ||
!(utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 1) || utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 9)) || | ||
node->GetOutputEdgesCount() != 1) { | ||
continue; | ||
} | ||
|
||
auto next_node_itr = node->OutputNodesBegin(); | ||
if (next_node_itr == node->OutputNodesEnd()) { | ||
continue; | ||
} | ||
|
||
const Node& next_node = (*next_node_itr); | ||
if (!utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7)) { | ||
continue; | ||
} | ||
|
||
Node* matmul_node = node; | ||
Node& add_node = const_cast<Node&>(next_node); | ||
std::vector<NodeArg> input_args, output_args; | ||
auto matmul_input_defs = matmul_node->MutableInputDefs(); | ||
auto add_input_defs = add_node.MutableInputDefs(); | ||
|
||
// Gemm only support float, so the inputs of MatMul | ||
auto matmul_type = matmul_input_defs[0]->Type(); | ||
auto add_type = add_input_defs[0]->Type(); | ||
if ((*matmul_type) != "tensor(float)" || (*add_type) != "tensor(float)") { | ||
continue; | ||
} | ||
|
||
// Gemm only support Matrix, need to check the shape of MatMul and Add | ||
auto matmul_a_shape = matmul_input_defs[0]->Shape(); | ||
auto matmul_b_shape = matmul_input_defs[1]->Shape(); | ||
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape ) { | ||
continue; | ||
} else if (1 == matmul_a_shape->dim_size() && 2 == matmul_b_shape->dim_size()) { | ||
// MatMul has shape [K] * [K, N], reset it to [1, K] * [K, N], so that it can work for Gemm | ||
auto mutable_matmul_a_shape = const_cast<onnx::TensorShapeProto*>(matmul_a_shape); | ||
auto dim_0 = mutable_matmul_a_shape->mutable_dim(0); | ||
auto dim_1 = (const_cast<onnx::TensorShapeProto*>(matmul_a_shape))->add_dim(); | ||
(*dim_1) = (*dim_0); | ||
dim_0->set_dim_value(1); | ||
} if (2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) { | ||
// Gemm only support Matrix | ||
continue; | ||
} | ||
|
||
auto matmul_output_name = matmul_node->OutputDefs()[0]->Name(); | ||
auto gemm_input_defs = matmul_input_defs; | ||
if (matmul_output_name == add_input_defs[0]->Name()) { | ||
// matmul output as Add_A, should use Add_B as input C for gemm | ||
// Gemm only support unidirectional broadcast on C | ||
if (add_input_defs[1]->Shape()->dim_size() > 2) { | ||
continue; | ||
} | ||
gemm_input_defs.push_back(add_input_defs[1]); | ||
} else { | ||
// matmul output as Add_B, should use Add_A as input C for gemm | ||
// Gemm only support unidirectional broadcast on C | ||
if (add_input_defs[0]->Shape()->dim_size() > 2) { | ||
continue; | ||
} | ||
gemm_input_defs.push_back(add_input_defs[0]); | ||
} | ||
|
||
graph.AddNode(graph.GenerateNodeName("gemm"), | ||
"Gemm", | ||
"fused Matmul and Add " + add_node.OpType(), | ||
gemm_input_defs, | ||
add_node.MutableOutputDefs()); | ||
|
||
removed_nodes.push_front(matmul_node->Index()); | ||
removed_nodes.push_front(add_node.Index()); | ||
} | ||
|
||
// Have to remove node in reversed order for now to walk around the issue in RemoveNode | ||
for (auto it = removed_nodes.begin(); it != removed_nodes.end(); ++it) { | ||
graph.RemoveNode(*it); | ||
} | ||
|
||
if (!removed_nodes.empty()) { | ||
modified = true; | ||
ORT_RETURN_IF_ERROR(graph.Resolve()); | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
} // namespace onnxruntime |
Oops, something went wrong.