Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] [Relay] [ONNX] Incorrect shape inference of Squeeze in DynamicToStatic #17050

Closed
shaoyuyoung opened this issue May 31, 2024 · 1 comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@shaoyuyoung
Copy link

shaoyuyoung commented May 31, 2024

Description

This torch model has only two ops: ReflectionPad3d and squeeze

Firstly, I try to export the torch model to onnx model.
Then I get the below.
model onnx

Onnx does its unique operation on the model.
We can find that this is a dynamic graph which contains if branch structure because of the squeeze operator.

ONNX thinks this model is valid.
However, When I used relay to convert the model, I met shape mismatch error. The correct shape should be Tensor[(13, 1, 1, 1), float32] but TVM got Tensor[(13, 13, 1, 1), float32].

(I think maybe) TVM has some bugs in the DynamicToStatic :(

Code

import onnx
import torch
import torch.nn as nn
import torch.onnx
from tvm import relay, relax


def get_onnx_shape(onnx_model):
    input_shapes = {}
    for input in onnx_model.graph.input:
        shape = []
        for dim in input.type.tensor_type.shape.dim:
            if dim.dim_value > 0:
                shape.append(dim.dim_value)
            else:
                shape.append(1)

        input_shapes[input.name] = shape
    return input_shapes


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.pad = nn.ReflectionPad3d((0, 0, -43, 0, 0, -46))

    def forward(self, x):
        x = self.pad(x)
        x = torch.squeeze(x, dim=1)
        return x



model = Model()

input_tensor = torch.randn(13, 47, 44, 1)

onnx_file_path = "model.onnx"
torch.onnx.export(model,
                  input_tensor,
                  onnx_file_path,
                  export_params=True,
                  opset_version=14,
                  do_constant_folding=False,
                  input_names=['input'],
                  output_names=['output']
                  )

onnx_model = onnx.load("model.onnx")
shape_dict = get_onnx_shape(onnx_model)

mod, params = relay.frontend.from_onnx(
    onnx_model, shape_dict, freeze_params=True
)

Error Log

click to see the error log
TVMError: Traceback (most recent call last):
  20: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  19: tvm::transform::Pass::operator()(tvm::IRModule) const
  18: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  17: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  16: _ZN3tvm7runtime13PackedFun
  15: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  14: tvm::relay::DynamicToStatic(tvm::relay::Function, tvm::IRModule)
  13: tvm::relay::DynamicToStaticMutator::PrepareInput(tvm::RelayExpr const&)
  12: tvm::transform::Pass::operator()(tvm::IRModule) const
  11: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  10: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  9: tvm::transform::Pass::operator()(tvm::IRModule) const
  8: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  6: _ZN3tvm7runtime13PackedFun
  5: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  4: tvm::DiagnosticContext::Render()
  3: tvm::DiagnosticRenderer::Render(tvm::DiagnosticContext const&)
  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::DiagnosticContext)>::AssignTypedLambda<tvm::TerminalRenderer(std::ostream&)::{lambda(tvm::DiagnosticContext const&)#1}>(tvm::TerminalRenderer(std::ostream&)::{lambda(tvm::DiagnosticContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: tvm::ReportAt(tvm::DiagnosticContext const&, std::ostream&, tvm::Span const&, tvm::Diagnostic const&)
  0: _ZN3tvm7runtime6deta
  File "/workspace/tvm/src/ir/diagnostic.cc", line 264
TVMError: The source maps are not populated for this module. Please use `tvm.relay.transform.AnnotateSpans` to attach source maps for error reporting.
Error: The Relay type checker is unable to show the following types match:
  Tensor[(13, 13, 1, 1), float32]
  Tensor[(13, 1, 1, 1), float32]
In particular:
  dimension 1 conflicts: 13 does not match 1.

Environment

TVM d1ac1c0
ubuntu 20

cc @KJlaccHoeUM9l @shingjan

@shaoyuyoung shaoyuyoung added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels May 31, 2024
@xhmelon
Copy link
Contributor

xhmelon commented Sep 18, 2024

This issue has been fixed by #17383 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

2 participants