From 1a14676b1570ab872ac7a7d1d467f19b7a41def6 Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 17 May 2022 16:30:23 +0800 Subject: [PATCH] fix(pytorch/ops/linear.py): bias maybe None --- csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp | 2 +- mmdeploy/pytorch/ops/linear.py | 19 +++++-------------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp b/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp index 0ea4e26d90..d78b883125 100644 --- a/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp +++ b/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp @@ -1304,7 +1304,7 @@ int main(int argc, char** argv) { } fprintf(pp, " 0=%d", axis); } else if (op == "Gelu") { - fprintf(pp, " 0=0"); + fprintf(pp, " 0=1"); } else if (op == "Gemm") { float alpha = get_node_attr_f(node, "alpha", 1.f); float beta = get_node_attr_f(node, "beta", 1.f); diff --git a/mmdeploy/pytorch/ops/linear.py b/mmdeploy/pytorch/ops/linear.py index 25e65b41f8..8cb997b400 100644 --- a/mmdeploy/pytorch/ops/linear.py +++ b/mmdeploy/pytorch/ops/linear.py @@ -13,14 +13,8 @@ def linear_no_bias(g, input, weight): PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'. """ - g.op( - 'mmdeploy::Gemm', - input, - weight, - alpha_f=1.0, - beta_f=1.0, - transA_i=0, - transB_i=1) + return g.op( + 'Gemm', input, weight, alpha_f=1.0, beta_f=1.0, transA_i=0, transB_i=1) @parse_args('v', 'v', 'v', 'f', 'f', 'i', 'i') @@ -29,8 +23,8 @@ def linear_normal(g, input, weight, bias): PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'. """ - g.op( - 'mmdeploy::Gemm', + return g.op( + 'Gemm', input, weight, bias, @@ -41,10 +35,7 @@ def linear_normal(g, input, weight, bias): @SYMBOLIC_REWRITER.register_symbolic( - 'torch.nn.functional.linear', - is_pytorch=True, - # arg_descriptors=['v', 'v', 'v', 'f', 'f', 'i', 'i'], - backend=Backend.NCNN.value) + 'linear', is_pytorch=True, backend=Backend.NCNN.value) def linear__ncnn(ctx, g, input, weight, bias): """Support export linear This rewrite enable export Gemm.""" if bias is None: