Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improvement(ViT): use Crop to subtitude Gather #477

Merged
merged 7 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,44 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "fuse_pass.h"

void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
const int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; ++i) {
onnx::NodeProto* gather = mutable_graph->mutable_node(i);
if (gather->op_type() != "Gather") {
continue;
}
auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]);
if (indices.size() != 1) {
continue;
}

{
// reconstruct node connections
node_reference[gather->input(1)] -= 1;
std::string origin_inp = gather->input(0);
gather->clear_input();
gather->add_input(origin_inp);
}

{
// update axis, starts and ends
int axis = get_node_attr_i(*gather, "axis", 1) - 1;

gather->set_op_type("Crop");
gather->clear_attribute();

int indice = indices[0];
set_node_attr_ai(*gather, "starts", std::vector<int>{indice});
set_node_attr_ai(*gather, "ends", std::vector<int>{indice + 1});
set_node_attr_ai(*gather, "axis", std::vector<int>{axis});
}
}
}

void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
Expand Down
5 changes: 5 additions & 0 deletions csrc/backend_ops/ncnn/onnx2ncnn/fuse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
#include "shape_inference.h"
#include "utils.h"

void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);

void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
Expand Down
21 changes: 20 additions & 1 deletion csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ int main(int argc, char** argv) {
fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names,
reduced_node_count);
fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
}

// reduce common const weight node_reference
Expand Down Expand Up @@ -622,6 +623,8 @@ int main(int argc, char** argv) {
}
} else if (op == "Cos") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Crop") {
fprintf(pp, "%-16s", "Crop");
} else if (op == "DepthToSpace") {
fprintf(pp, "%-16s", "PixelShuffle");
} else if (op == "DetectionOutput") {
Expand Down Expand Up @@ -1194,6 +1197,22 @@ int main(int argc, char** argv) {
} else if (op == "Cos") {
int op_type = 10;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Crop") {
auto starts = get_node_attr_ai(node, "starts");
fprintf(pp, " -23309=%zu", starts.size());
for (size_t j = 0; j < starts.size(); ++j) {
fprintf(pp, ",%i", starts[j]);
}
auto ends = get_node_attr_ai(node, "ends");
fprintf(pp, " -23310=%zu", ends.size());
for (size_t j = 0; j < ends.size(); ++j) {
fprintf(pp, ",%i", ends[j]);
}
auto axis = get_node_attr_ai(node, "axis");
fprintf(pp, " -23311=%zu", axis.size());
for (size_t j = 0; j < axis.size(); ++j) {
fprintf(pp, ",%i", axis[j]);
}
} else if (op == "DepthToSpace") {
// pixelshuffle
int scale_factor = get_node_attr_i(node, "blocksize", 1);
Expand Down Expand Up @@ -1285,7 +1304,7 @@ int main(int argc, char** argv) {
}
fprintf(pp, " 0=%d", axis);
} else if (op == "Gelu") {
fprintf(pp, " 0=0");
fprintf(pp, " 0=1");
} else if (op == "Gemm") {
float alpha = get_node_attr_f(node, "alpha", 1.f);
float beta = get_node_attr_f(node, "beta", 1.f);
Expand Down
2 changes: 2 additions & 0 deletions csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include "shape_inference.h"

#include <algorithm>

/**
* @brief query output shape of target node
*
Expand Down
1 change: 0 additions & 1 deletion csrc/backend_ops/ncnn/onnx2ncnn/shape_inference.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once
#include <algorithm>

#include "utils.h"

Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/pytorch/functions/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def linear__ncnn(

dim = input.dim()

if dim == 2:
if dim == 2 or dim == 3 and input.shape[0] == 1:
return origin_func(input, weight, bias)
else:
out = origin_func(input, weight)
Expand Down
3 changes: 2 additions & 1 deletion mmdeploy/pytorch/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .hardsigmoid import hardsigmoid__default
from .instance_norm import instance_norm__tensorrt
from .layer_norm import layer_norm__ncnn
from .linear import linear__ncnn
from .lstm import generic_rnn__ncnn
from .squeeze import squeeze__default

Expand All @@ -16,5 +17,5 @@
'adaptive_avg_pool3d__default', 'grid_sampler__default',
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn',
'squeeze__default', 'adaptive_avg_pool2d__ncnn', 'gelu__ncnn',
'layer_norm__ncnn'
'layer_norm__ncnn', 'linear__ncnn'
]
44 changes: 44 additions & 0 deletions mmdeploy/pytorch/ops/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from:
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
from torch.onnx.symbolic_helper import parse_args

from mmdeploy.core import SYMBOLIC_REWRITER
from mmdeploy.utils import Backend


@parse_args('v', 'v', 'f', 'f', 'i', 'i')
def linear_no_bias(g, input, weight):
tpoisonooo marked this conversation as resolved.
Show resolved Hide resolved
"""Symbolic function for `linear` without bias.

PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'.
"""
return g.op(
'Gemm', input, weight, alpha_f=1.0, beta_f=1.0, transA_i=0, transB_i=1)


@parse_args('v', 'v', 'v', 'f', 'f', 'i', 'i')
def linear_normal(g, input, weight, bias):
"""Symbolic function for `linear`.

PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'.
"""
return g.op(
'Gemm',
input,
weight,
bias,
alpha_f=1.0,
beta_f=1.0,
transA_i=0,
transB_i=1)


@SYMBOLIC_REWRITER.register_symbolic(
'linear', is_pytorch=True, backend=Backend.NCNN.value)
def linear__ncnn(ctx, g, input, weight, bias):
tpoisonooo marked this conversation as resolved.
Show resolved Hide resolved
"""Support export linear This rewrite enable export Gemm."""
if bias is None:
return linear_no_bias(g, input, weight)
else:
return linear_normal(g, input, weight, bias)
35 changes: 35 additions & 0 deletions tests/test_pytorch/test_pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,41 @@ def test_instance_norm():
assert nodes[4].domain == 'mmdeploy'


@pytest.mark.usefixtures('prepare_symbolics_ncnn')
class TestLinear:

def check(self, nodes):
print(nodes)

from packaging.version import parse as version_parse
version = version_parse(torch.__version__)
target = 'Gemm'
if version.major <= 1 and version.minor <= 8:
target = 'MatMul'
exist = False
for node in nodes:
if node.op_type == target:
exist = True
break

assert exist is True

def test_normal(self):
x = torch.rand(1, 2, 3)
w = torch.rand(2, 3)
bias = torch.rand(2)
model = OpModel(torch.nn.functional.linear, w, bias).eval()
nodes = get_model_onnx_nodes(model, x)
self.check(nodes)

def test_no_bias(self):
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
x = torch.rand(1, 2, 3)
w = torch.rand(2, 3)
model = OpModel(torch.nn.functional.linear, w).eval()
nodes = get_model_onnx_nodes(model, x)
self.check(nodes)


@pytest.mark.usefixtures('prepare_symbolics')
class TestSqueeze:

Expand Down