Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot export MiVOLO model into onnx format using torch.onnx.export #14

Closed
MasterHM-ml opened this issue Aug 10, 2023 · 6 comments
Closed

Comments

@MasterHM-ml
Copy link

I am adding a line here to convert self.model into onnx format. Here is my code snippet

        random_input = torch.randn(1, 6, 224, 224, device=self.device)
        onnx_model_name = "mi_volo.onnx"
        # pytorch to onnx
        torch.onnx.export(self.model, random_input, onnx_model_name, verbose=True, opset_version=18)

but I am getting the following error:

============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last):
  File "/home/master/.local/lib/python3.10/site-packages/torch/onnx/symbolic_opset18.py", line 52, in col2im
    num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
TypeError: 'NoneType' object is not subscriptable

I tried debugging the error, but couldn't understand it due to less familiarity with the conversion process. The actual line that is causing an error is num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]; I tried inspecting the output_size variable and it is
1072 defined in (%1072 : int[] = prim::ListConstruct(%958, %963), scope: mivolo.model.mivolo_model.MiVOLOModel::/torch.nn.modules.container.Sequential::network.0/timm.models.volo.Outlooker::network.0.0/timm.models.volo.OutlookAttention::attn)

MiVOLO - Latest Pull
Pytorch version - 2.0.1
onnx version - 1.14.1
OS - Ubuntu 22.04.3LTS

Any help/direction/discussion will be highly appreciated, thank you.

@MasterHM-ml MasterHM-ml changed the title Cannot export MiVOLO model into using torch.onnx.export Cannot export MiVOLO model into onnx format using torch.onnx.export Aug 10, 2023
@WildChlamydia
Copy link
Owner

WildChlamydia commented Aug 11, 2023

Hello!
It won't be easy.

  1. You need the Timm library version directly from the source by cloning it from GitHub (you can remove it as a module and clone it, then add to PYTHONPATH). This is necessary for the upcoming steps.

  2. Explicitly convert the variables 'H' and 'W' to int() here, it's timm bug

  3. During the ONNX conversion process, you will encounter the following error:

torch.onnx.errors.CheckerError: Unrecognized attribute: axes for operator ReduceMax

==> Context: Bad node spec for node. Name: /ReduceMax OpType: ReduceMax

To address this issue, you need to make a modification to the Torch sources. Open the file torch/onnx/utils.py and locate the _export function. Comment out the line that checks the ONNX proto using the following code:

if (operator_export_type is _C_onnx.OperatorExportTypes.ONNX) and (
      not val_use_external_data_format
):
      try:
-         _C._check_onnx_proto(proto)
+         pass #_C._check_onnx_proto(proto)
      except RuntimeError as e:
         raise errors.CheckerError(e) from e
  1. After that model will be saved but won't work because of torch.onnx bug. You have to rewrite the graph:
onnx_model = onnx.load(output_file)
# Get the graph from the model
graph = onnx_model.graph

# Iterate through all nodes in the graph
for node in graph.node:
    if "ReduceMax" in node.op_type:
        for index in range(len(node.attribute)):
            if node.attribute[index].name == "axes":
                del node.attribute[index]
                axes_input = onnx.helper.make_tensor_value_info("axes", onnx.TensorProto.INT64, [1])
                axes_value = numpy_helper.from_array(np.array([1]), "axes")
                onnx_model.graph.input.extend([axes_input])
                onnx_model.graph.initializer.extend([axes_value])
                node.input.append("axes")
                break

Now, save this model. It will work.

And all of this is simply not worth it: the ONNX model performs poorly with batch processing, and TensorRT is currently not an option due to its lack of support for col2im.
The best way for now is TorchScript.

Good luck and thank you for your star.

@MasterHM-ml
Copy link
Author

And all of this is simply not worth it: the ONNX model performs poorly with batch processing, and TensorRT is currently not an option due to its lack of support for col2im. The best way for now is TorchScript.

Thank you for your detailed reply. Through documentation and GitHub issues, I completed the first 3 steps and was ready to convert the model into OpenVINO IR, and found out that the OpenVINO runtime does not support the col2im operation. I was not aware of the issue that you guided in step 4. Massive thanks for informing providing the guide. I will do that.

and, yeah, as you said

Hello! It won't be easy.

this was really not easy. I've spent 3 days and still see no chance of reaching the finish line today.

@Hab2Verer
Copy link

@MasterHM-ml @WildChlamydia
Hello Guys
Good work.

Did you find the definitive way to convert the model to onnx?

@vishal19217
Copy link

vishal19217 commented Nov 13, 2024

Hi @WildChlamydia
Can you provide me the torchScript method to convert the model. I have to use it in the C++ env.
Also i will be running this model on a system without GPU. So will the model performance affects drastically without GPU or is the GPU mandatory to run the model?

@liuxufenfeiya
Copy link

liuxufenfeiya commented Feb 4, 2025

thanks @WildChlamydia ,change to onnx success
step1 is not nessesary,directly change /usr/local/lib/python3.8/dist-packages/timm/models/volo.py line 126

    x = F.fold(x, output_size=(int(H), int(W)), kernel_size=self.kernel_size, padding=self.padding, stride=self.stride)

@tomchen1000
Copy link

ONNX model performs poorly with batch processing

Hi @WildChlamydia , what exactly do you mean by "ONNX model performs poorly with batch processing"? By "batch processing" Do you refer to the batch size (when it's more than 1?) or the BatchNorm layers?

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants