Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug][ONNX][EdgeCase] Crash on Scalar Reduce #10738

Closed
ganler opened this issue Mar 23, 2022 · 2 comments · Fixed by #10780
Closed

[Bug][ONNX][EdgeCase] Crash on Scalar Reduce #10738

ganler opened this issue Mar 23, 2022 · 2 comments · Fixed by #10780

Comments

@ganler
Copy link
Contributor

ganler commented Mar 23, 2022

import torch

class Net(torch.nn.Module):
    def forward(self, x):
        return x.sum()

net = Net().eval()

i = torch.ones((), dtype=torch.float)

with torch.no_grad():
    torch.onnx.export(net, (i), "output.onnx", verbose=True, opset_version=14)

Scalar reduce like the example above is allowed in ONNX and can be executed by ONNXRuntime. But it will crash at tvm::relay::GetReduceAxes in TVM. Same things happens for min/max/mean/sum.

If TVM does not support scalar reduce (understandable as it is unnecessary in most cases and can be fixed in user code), I guess this should be an ONNX conversion issue.

In ONNX, if axis is not provided, it simply means flattening a tensor (including sclar tensor) into a scalar. https://github.com/onnx/onnx/blob/main/docs/Changelog.md#inputs-1---2-3

The default is to reduce over all the dimensions of the input tensor if 'noop_with_empty_axes' is false, else act as an Identity op when 'noop_with_empty_axes' is true.

cc: @masahi @AndrewZhaoLuo

@ganler
Copy link
Contributor Author

ganler commented Mar 23, 2022

def @main(%v0: float32) {
  sum(%v0, axis=[])
}

This relay IR crashes when doing type infer.

@ganler
Copy link
Contributor Author

ganler commented Mar 23, 2022

One quick way to fix it is to check rank of tensor being reduced during conversion:

class Reduce(OnnxOpConverter):
# ...
    @classmethod
    def _impl_v1(cls, inputs, attr, params):
+        if not infer_shape(inputs[0]):
+            return inputs[0]
+
        if "axes" in attr:
            axis = attr.get("axes", 0)
        else:
            axis_len = len(infer_shape(inputs[0]))
            axis = list(range(axis_len))

        return cls.run_calculation(inputs, axis, attr.get("keepdims", True))

But I think it won't be compatible with other reduce operators like ReduceL2 etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant