Skip to content

Commit

Permalink
use gemm to replace matmul + add (#234)
Browse files Browse the repository at this point in the history
* matmul add fusion

* add shape check on Gemm input C

* walk around the issue with RemoveNode

* update the version support

* If MatMul has shape [K] * [K, N], update it to [1, K] * [K, N], so that it can work for Gemm

* Fuse Gemm+Activation into FusedGemm

* test

* revert the change which fuse the matmul with shape [K]*[K, N] to Gemm as shape [1, K]*[K, N], this may cause runtime failure, as the we can't change input data shape.

* revert the change which change the shape for Matmul from [K]*[K, N] to [1, K]*[K, N]. It enables fuse Matmul + Add to Gemm, but the issue is the data is not aware of this, so the data shape is still [K]*[K, N] and cause runtime issue.

* 1. Fix build issue for CUDA
2. Update Gemm so that we can fuse Matmul [K] * [K, N] + Add [1, N] into Gemm with shape [1,K] * [K, N] + [1, N]

* Fix build issue

* Fuse the activation node even it connects the output

* resolve the merge conflicts

* Add test model for Gemm+Activation fusion
  • Loading branch information
HectorSVC authored Jan 22, 2019
1 parent 8b55596 commit 647cc2d
Show file tree
Hide file tree
Showing 24 changed files with 475 additions and 26 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace contrib {
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Ngram);
Expand Down Expand Up @@ -38,6 +39,7 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) {

kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Ngram)>());
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/contrib_ops/cpu/fused_gemm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "fused_gemm.h"

namespace onnxruntime {
namespace contrib {
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
FusedGemm,
1,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedGemm<float, float, float, float>);
} // namespace contrib
} // namespace onnxruntime
26 changes: 26 additions & 0 deletions onnxruntime/contrib_ops/cpu/fused_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/cpu/math/gemm.h"

namespace onnxruntime {
namespace contrib {
template <typename T_X,
typename T_W,
typename T_B,
typename T_Y>
class FusedGemm : public Gemm<T_X, T_W, T_B, T_Y> {
public:
FusedGemm(const OpKernelInfo& info) : Gemm<T_X, T_W, T_B, T_Y>(info) {
Gemm<T_X, T_W, T_B, T_Y>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
Gemm<T_X, T_W, T_B, T_Y>::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f);
}

Status Compute(OpKernelContext* context) const override {
return Gemm<T_X, T_W, T_B, T_Y>::Compute(context);
}
};
} // namespace contrib
} // namespace onnxruntime
90 changes: 90 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,96 @@ activation.)DOC")
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true);
});

ONNX_CONTRIB_OPERATOR_SCHEMA(FusedGemm)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(R"DOC(
The FusedGemm operator schema is the same as Gemm besides it includes attributes
activation and leaky_relu_alpha.)DOC")
.Input(
0,
"A",
"Input tensor A. "
"The shape of A should be (M, K) if transA is 0, "
"or (K, M) if transA is non-zero.",
"T")
.Input(
1,
"B",
"Input tensor B. "
"The shape of B should be (K, N) if transB is 0, "
"or (N, K) if transB is non-zero.",
"T")
.Input(
2,
"C",
"Input tensor C. "
"The shape of C should be unidirectional broadcastable to (M, N).",
"T")
.Output(0, "Y", "Output tensor of shape (M, N).", "T")
.TypeConstraint(
"T",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int32)",
"tensor(int64)"},
"Constrain input and output types to float/int tensors.")
.Attr(
"transA",
"Whether A should be transposed",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"transB",
"Whether B should be transposed",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"alpha",
"Scalar multiplier for the product of input tensors A * B.",
AttributeProto::FLOAT,
1.0f)
.Attr(
"beta",
"Scalar multiplier for input tensor C.",
AttributeProto::FLOAT,
1.0f)
.Attr(
"activation",
"",
AttributeProto::STRING,
OPTIONAL)
.Attr(
"leaky_relu_alpha",
"",
AttributeProto::FLOAT,
OPTIONAL)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (hasNInputShapes(ctx, 2)) {
auto transAAttr = ctx.getAttribute("transA");
bool transA =
transAAttr ? static_cast<int>(transAAttr->i()) != 0 : false;
auto transBAttr = ctx.getAttribute("transB");
bool transB =
transBAttr ? static_cast<int>(transBAttr->i()) != 0 : false;
auto& first_input_shape = getInputShape(ctx, 0);
auto& second_input_shape = getInputShape(ctx, 1);
if (first_input_shape.dim_size() != 2)
fail_shape_inference("First input does not have rank 2");
if (second_input_shape.dim_size() != 2)
fail_shape_inference("Second input does not have rank 2");
updateOutputShape(
ctx,
0,
{first_input_shape.dim(transA ? 1 : 0),
second_input_shape.dim(transB ? 0 : 1)});
}
});

ONNX_CONTRIB_OPERATOR_SCHEMA(ExpandDims)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down
108 changes: 108 additions & 0 deletions onnxruntime/core/graph/gemm_activation_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/graph/initializer.h"
#include "core/graph/gemm_activation_fusion.h"
#include "core/graph/graph_utils.h"
#include <deque>

using namespace onnx;
using namespace ::onnxruntime::common;
namespace onnxruntime {

namespace {
bool IsFusableActivation(const Node& node) {
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6);
}

void HandleActivationNodeEdges(Graph& g, const Node& act, Node& fused_gemm) {
Node::EdgeSet output_edges;
for (auto it = act.OutputEdgesBegin(); it != act.OutputEdgesEnd(); ++it) {
output_edges.insert(*it);
}

//remove output edge of activation
//connect fused_gemm node and nodes after activation nodes
for (auto& output_edge : output_edges) {
NodeIndex dst_node_index = output_edge.GetNode().Index();
int src_arg_index = output_edge.GetSrcArgIndex();
int dst_arg_index = output_edge.GetDstArgIndex();
g.RemoveEdge(act.Index(), dst_node_index, src_arg_index, dst_arg_index);
g.AddEdge(fused_gemm.Index(), dst_node_index, 0, dst_arg_index);
}
}

} // namespace

Status GemmActivationFusion::Apply(Graph& graph, bool& modified) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();

std::deque<onnxruntime::NodeIndex> removed_nodes;
for (auto index : order) {
auto node = graph.GetNode(index);
if (!(utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 7) || utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 9)) || node->GetOutputEdgesCount() != 1) {
continue;
}
const Node& next_node = *(node->OutputNodesBegin());
if (!IsFusableActivation(next_node)) {
continue;
}

Node* gemm_node = node;
const Node& act_node = next_node;

Node& fused_gemm = graph.AddNode(graph.GenerateNodeName("fused " + gemm_node->Name()), "FusedGemm",
"fused Gemm " + gemm_node->Name() + "with activation " + act_node.OpType(),
gemm_node->MutableInputDefs(),
graph.IsNodeOutputsInGraphOutputs(next_node) ? const_cast<Node&>(act_node).MutableOutputDefs() : gemm_node->MutableOutputDefs(),
&gemm_node->GetAttributes(),
"com.microsoft");

//Add a new attribute to specify the activation type
fused_gemm.AddAttribute("activation", act_node.OpType());

//Add optional attributes for activations
if (act_node.OpType() == "LeakyRelu") {
const NodeAttributes attrs = act_node.GetAttributes();
for (auto it = attrs.begin(); it != attrs.end(); ++it) {
fused_gemm.AddAttribute("leaky_relu_" + it->first, it->second);
}
}

if (!graph.IsNodeOutputsInGraphOutputs(next_node)) {
HandleActivationNodeEdges(graph, act_node, fused_gemm);

// Replace the input of the node following activation node
const NodeArg* act_output_def = act_node.OutputDefs()[0];
NodeArg* fused_gemm_output_def = fused_gemm.MutableOutputDefs()[0];
for (auto it = act_node.OutputNodesBegin(); it != act_node.OutputNodesEnd(); ++it) {
auto output_node = graph.GetNode((*it).Index());
if (!output_node) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
}

auto& input_defs = output_node->MutableInputDefs();
for (auto& def : input_defs) {
if (def == act_output_def) {
def = fused_gemm_output_def;
}
}
}
}

removed_nodes.push_front(gemm_node->Index());
removed_nodes.push_front(act_node.Index());
}

for (auto node : removed_nodes) {
graph.RemoveNode(node);
}

if (!removed_nodes.empty()) {
modified = true;
ORT_RETURN_IF_ERROR(graph.Resolve());
}
return Status::OK();
}
} // namespace onnxruntime
16 changes: 16 additions & 0 deletions onnxruntime/core/graph/gemm_activation_fusion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/graph/graph_transformer.h"

namespace onnxruntime {

class GemmActivationFusion : public onnxruntime::GraphTransformer {
public:
GemmActivationFusion() noexcept : onnxruntime::GraphTransformer("GemmActivationFusion", "Fusing Activation into Gemm") {}
Status Apply(onnxruntime::Graph& graph, bool& modified) const override;
};

} // namespace onnxruntime
106 changes: 106 additions & 0 deletions onnxruntime/core/graph/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/graph/initializer.h"
#include "core/graph/matmul_add_fusion.h"
#include "core/graph/graph_utils.h"
#include <deque>

using namespace onnx;
using namespace ::onnxruntime::common;
namespace onnxruntime {

Status MatMulAddFusion::Apply(Graph& graph, bool& modified) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
std::deque<onnxruntime::NodeIndex> removed_nodes;

for (auto node_index : node_topology_list) {
auto node = graph.GetNode(node_index);
if (nullptr == node ||
!(utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 1) || utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 9)) ||
node->GetOutputEdgesCount() != 1) {
continue;
}

auto next_node_itr = node->OutputNodesBegin();
if (next_node_itr == node->OutputNodesEnd()) {
continue;
}

const Node& next_node = (*next_node_itr);
if (!utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7)) {
continue;
}

Node* matmul_node = node;
Node& add_node = const_cast<Node&>(next_node);
std::vector<NodeArg> input_args, output_args;
auto matmul_input_defs = matmul_node->MutableInputDefs();
auto add_input_defs = add_node.MutableInputDefs();

// Gemm only support float, so the inputs of MatMul
auto matmul_type = matmul_input_defs[0]->Type();
auto add_type = add_input_defs[0]->Type();
if ((*matmul_type) != "tensor(float)" || (*add_type) != "tensor(float)") {
continue;
}

// Gemm only support Matrix, need to check the shape of MatMul and Add
auto matmul_a_shape = matmul_input_defs[0]->Shape();
auto matmul_b_shape = matmul_input_defs[1]->Shape();
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape ) {
continue;
} else if (1 == matmul_a_shape->dim_size() && 2 == matmul_b_shape->dim_size()) {
// MatMul has shape [K] * [K, N], reset it to [1, K] * [K, N], so that it can work for Gemm
auto mutable_matmul_a_shape = const_cast<onnx::TensorShapeProto*>(matmul_a_shape);
auto dim_0 = mutable_matmul_a_shape->mutable_dim(0);
auto dim_1 = (const_cast<onnx::TensorShapeProto*>(matmul_a_shape))->add_dim();
(*dim_1) = (*dim_0);
dim_0->set_dim_value(1);
} if (2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) {
// Gemm only support Matrix
continue;
}

auto matmul_output_name = matmul_node->OutputDefs()[0]->Name();
auto gemm_input_defs = matmul_input_defs;
if (matmul_output_name == add_input_defs[0]->Name()) {
// matmul output as Add_A, should use Add_B as input C for gemm
// Gemm only support unidirectional broadcast on C
if (add_input_defs[1]->Shape()->dim_size() > 2) {
continue;
}
gemm_input_defs.push_back(add_input_defs[1]);
} else {
// matmul output as Add_B, should use Add_A as input C for gemm
// Gemm only support unidirectional broadcast on C
if (add_input_defs[0]->Shape()->dim_size() > 2) {
continue;
}
gemm_input_defs.push_back(add_input_defs[0]);
}

graph.AddNode(graph.GenerateNodeName("gemm"),
"Gemm",
"fused Matmul and Add " + add_node.OpType(),
gemm_input_defs,
add_node.MutableOutputDefs());

removed_nodes.push_front(matmul_node->Index());
removed_nodes.push_front(add_node.Index());
}

// Have to remove node in reversed order for now to walk around the issue in RemoveNode
for (auto it = removed_nodes.begin(); it != removed_nodes.end(); ++it) {
graph.RemoveNode(*it);
}

if (!removed_nodes.empty()) {
modified = true;
ORT_RETURN_IF_ERROR(graph.Resolve());
}

return Status::OK();
}
} // namespace onnxruntime
Loading

0 comments on commit 647cc2d

Please sign in to comment.