forked from open-mmlab/mmdeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
improvement(ViT): use Crop to subtitude Gather (open-mmlab#477)
* improvement(ViT): use Crop to subtitude Gather * fix(CI): code format * fix(pytorch/ops/linear.py): bias maybe None * fix(test/test_pytorch_ops.py): op_type error * fix(test): pytest error * fix(test): torch version 1.8
- Loading branch information
1 parent
ad08e07
commit 85d1dee
Showing
9 changed files
with
147 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
|
||
#include "shape_inference.h" | ||
|
||
#include <algorithm> | ||
|
||
/** | ||
* @brief query output shape of target node | ||
* | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
#pragma once | ||
#include <algorithm> | ||
|
||
#include "utils.h" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# Modified from: | ||
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py | ||
from torch.onnx.symbolic_helper import parse_args | ||
|
||
from mmdeploy.core import SYMBOLIC_REWRITER | ||
from mmdeploy.utils import Backend | ||
|
||
|
||
@parse_args('v', 'v', 'f', 'f', 'i', 'i') | ||
def linear_no_bias(g, input, weight): | ||
"""Symbolic function for `linear` without bias. | ||
PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'. | ||
""" | ||
return g.op( | ||
'Gemm', input, weight, alpha_f=1.0, beta_f=1.0, transA_i=0, transB_i=1) | ||
|
||
|
||
@parse_args('v', 'v', 'v', 'f', 'f', 'i', 'i') | ||
def linear_normal(g, input, weight, bias): | ||
"""Symbolic function for `linear`. | ||
PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'. | ||
""" | ||
return g.op( | ||
'Gemm', | ||
input, | ||
weight, | ||
bias, | ||
alpha_f=1.0, | ||
beta_f=1.0, | ||
transA_i=0, | ||
transB_i=1) | ||
|
||
|
||
@SYMBOLIC_REWRITER.register_symbolic( | ||
'linear', is_pytorch=True, backend=Backend.NCNN.value) | ||
def linear__ncnn(ctx, g, input, weight, bias): | ||
"""Support export linear This rewrite enable export Gemm.""" | ||
if bias is None: | ||
return linear_no_bias(g, input, weight) | ||
else: | ||
return linear_normal(g, input, weight, bias) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters