Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use gemm to replace matmul + add #234

Merged
merged 23 commits into from
Jan 22, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
48e7468
matmul add fusion
HectorSVC Dec 20, 2018
df38220
add shape check on Gemm input C
HectorSVC Dec 20, 2018
f57dd20
walk around the issue with RemoveNode
HectorSVC Dec 21, 2018
167c03d
update the version support
HectorSVC Dec 21, 2018
748bdf4
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
HectorSVC Dec 22, 2018
d47eb0f
Merge branch 'master' into hecli/gemm
HectorSVC Jan 3, 2019
d576ce9
If MatMul has shape [K] * [K, N], update it to [1, K] * [K, N], so th…
HectorSVC Jan 3, 2019
e92ff53
Merge branch 'master' into hecli/gemm
HectorSVC Jan 3, 2019
eef69d7
Fuse Gemm+Activation into FusedGemm
HectorSVC Jan 4, 2019
c073a06
Merge branch 'master' into hecli/gemm
HectorSVC Jan 4, 2019
750b2c7
test
HectorSVC Jan 5, 2019
10379f3
revert the change which fuse the matmul with shape [K]*[K, N] to Gemm…
HectorSVC Jan 7, 2019
033cb46
revert the change which change the shape for Matmul from [K]*[K, N] t…
HectorSVC Jan 8, 2019
361785f
1. Fix build issue for CUDA
HectorSVC Jan 17, 2019
b0779f9
Merge branch 'master' into hecli/gemm_fusion
HectorSVC Jan 17, 2019
8d1162e
revert the hack in C API
HectorSVC Jan 17, 2019
3a20760
Merge branch 'hecli/gemm' of https://github.com/Microsoft/onnxruntime…
HectorSVC Jan 17, 2019
8bc9c95
Fix build issue
HectorSVC Jan 17, 2019
15161ee
Fuse the activation node even it connects the output
HectorSVC Jan 18, 2019
0c4d6e9
Merge branch 'master' into hecli/gemm
HectorSVC Jan 18, 2019
ee56b50
resolve the merge conflicts
HectorSVC Jan 19, 2019
7f616e4
Merge branch 'master' into hecli/gemm
HectorSVC Jan 22, 2019
4a34c7f
Add test model for Gemm+Activation fusion
HectorSVC Jan 22, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions onnxruntime/core/graph/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/graph/initializer.h"
#include "core/graph/matmul_add_fusion.h"
#include "core/graph/graph_utils.h"

using namespace onnx;
using namespace ::onnxruntime::common;
namespace onnxruntime {

Status MatMulAddFusion::Apply(Graph& graph, bool& modified) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
std::vector<onnxruntime::NodeIndex> removed_nodes;

for (auto node_index : node_topology_list) {
HectorSVC marked this conversation as resolved.
Show resolved Hide resolved
auto node = graph.GetNode(node_index);
if (nullptr == node ||
!utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 9) ||
HectorSVC marked this conversation as resolved.
Show resolved Hide resolved
node->GetOutputEdgesCount() != 1) {
continue;
}

auto next_node_itr = node->OutputNodesBegin();
if (next_node_itr == node->OutputNodesEnd()) {
Copy link
Contributor

@pranavsharma pranavsharma Dec 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: do we've to check for this condition if we've already checked for node->GetOutputEdgesCount() != 1? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, will remove this


In reply to: 243680666 [](ancestors = 243680666)

continue;
}

const Node& next_node = (*next_node_itr);
if (!utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7)) {
HectorSVC marked this conversation as resolved.
Show resolved Hide resolved
continue;
}

Node* matmul_node = node;
Node& add_node = const_cast<Node&>(next_node);
std::vector<NodeArg> input_args, output_args;
auto matmul_input_defs = matmul_node->MutableInputDefs();
auto add_input_defs = add_node.MutableInputDefs();

// Gemm only support float, so the inputs of MatMul
auto matmul_type = matmul_input_defs[0]->Type();
auto add_type = add_input_defs[0]->Type();
if ((*matmul_type) != "tensor(float)" || (*add_type) != "tensor(float)") {
continue;
}

// Gemm only support Matrix, need to check the shape of MatMul and Add
Copy link
Contributor Author

@HectorSVC HectorSVC Dec 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemm only support Matrix, need to check the shape of MatMul and Add [](start = 7, length = 67)

if mat_mul is [K] * [K, N], should be able to update the shape as [1, K] * [K, N], and make it work for gemm. will update this. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to revert this change as the data is not aware of this.


In reply to: 243648189 [](ancestors = 243648189)

auto matmul_a_shape = matmul_input_defs[0]->Shape();
auto matmul_b_shape = matmul_input_defs[1]->Shape();
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape ||
2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) {
continue;
}

auto matmul_output_name = matmul_node->OutputDefs()[0]->Name();
auto gemm_input_defs = matmul_input_defs;
if (matmul_output_name == add_input_defs[0]->Name()) {
// matmul output as Add_A, should used Add_B as input C for gemm
// Gemm only support unidirectional broadcast on C
if (add_input_defs[1]->Shape()->dim_size() > 2) {
continue;
}
gemm_input_defs.push_back(add_input_defs[1]);
} else {
// matmul output as Add_B, should used Add_A as input C for gemm
// Gemm only support unidirectional broadcast on C
if (add_input_defs[0]->Shape()->dim_size() > 2) {
continue;
}
gemm_input_defs.push_back(add_input_defs[0]);
}

graph.AddNode(graph.GenerateNodeName("gemm"),
"Gemm",
"fused Matmul and Add " + add_node.OpType(),
gemm_input_defs,
add_node.MutableOutputDefs());

removed_nodes.push_back(add_node.Index());
removed_nodes.push_back(matmul_node->Index());
}

for (auto i : removed_nodes) {
graph.RemoveNode(i);
}

if (!removed_nodes.empty()) {
modified = true;
ORT_RETURN_IF_ERROR(graph.Resolve());
}
return Status::OK();
}
} // namespace onnxruntime
16 changes: 16 additions & 0 deletions onnxruntime/core/graph/matmul_add_fusion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/graph/graph_transformer.h"

namespace onnxruntime {

class MatMulAddFusion : public onnxruntime::GraphTransformer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try to implement this as a rewrite rule?
It is a pity to traverse the whole graph for these local transformations.
You can check the identity elimination as well as the the PR here. It is a bit outdated but I will rebase it most probably today.

public:
MatMulAddFusion() noexcept : onnxruntime::GraphTransformer("MatMulAddFusion", "Fusing MatMul and Add into Gemm") {}
Status Apply(onnxruntime::Graph& graph, bool& modified) const override;
};

} // namespace onnxruntime
38 changes: 38 additions & 0 deletions onnxruntime/test/ir/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "core/graph/conv_mul_fusion.h"
#include "core/graph/conv_add_fusion.h"
#include "core/graph/conv_activation_fusion.h"
#include "core/graph/matmul_add_fusion.h"
#include "core/platform/env.h"

#include "test/capturing_sink.h"
Expand Down Expand Up @@ -194,5 +195,42 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) {
ASSERT_TRUE(session_object.Initialize().IsOK());
}

TEST(GraphTransformationTests, MatMulAddFusion_two_input) {
string model_uri = MODEL_FOLDER + "matmul_add_fusion/2Input/model.onnx";

SessionOptions so;
so.session_logid = "GraphTransformationTests.LoadModelToTransform";
InferenceSession session_object{so, &DefaultLoggingManager()};
ASSERT_TRUE(session_object.Load(model_uri).IsOK());

std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());

std::unique_ptr<MatMulAddFusion> matmul_add_fusion_transformer = std::make_unique<MatMulAddFusion>();

session_object.RegisterGraphTransformer(std::move(matmul_add_fusion_transformer));

ASSERT_TRUE(session_object.Initialize().IsOK());
}

TEST(GraphTransformationTests, MatMulAddFusion_three_input) {
string model_uri = MODEL_FOLDER + "matmul_add_fusion/3Input/model.onnx";

SessionOptions so;
so.session_logid = "GraphTransformationTests.LoadModelToTransform";
InferenceSession session_object{so, &DefaultLoggingManager()};
ASSERT_TRUE(session_object.Load(model_uri).IsOK());

std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());

std::unique_ptr<MatMulAddFusion> matmul_add_fusion_transformer = std::make_unique<MatMulAddFusion>();

session_object.RegisterGraphTransformer(std::move(matmul_add_fusion_transformer));

ASSERT_TRUE(session_object.Initialize().IsOK());
}


} // namespace test
} // namespace onnxruntime
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.