Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 2235646909 #15349

Open
toobashahid210 opened this issue Apr 4, 2023 · 7 comments

Comments

@toobashahid210
Copy link

Hi
I successfully converted a XLM-Roberta-Large model to the ONNX format using torch.onnx.export(). The converted model size is around 2.1 GB.

However, when I try to optimize the converted model using onnxruntime.transformers.optimizer, I get the following error:

python3 -m onnxruntime.transformers.optimizer --input {onnx_model_path} --output {onnx_optimized_model_path} --hidden_size 1024 --num_heads 16 --opt_level 99 --float16
[W:onnxruntime:, inference_session.cc:1256 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED. The generated model may contain hardware and execution provider specific optimizations, and should only be used in the same environment the model was optimized for.
adjust_reshape_and_expand: Removed Reshape and Expand count: 0
               apply: Fused SkipLayerNormalization count: 48
                fuse: Position embedding path is not found. Embed layer cannot be fused.
         prune_graph: Graph pruned: 0 inputs, 0 outputs and 0 nodes are removed
               apply: Fused SkipLayerNormalization(add bias) count: 48
            optimize: opset verion: 11
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/usr/lib/python3.8/site-packages/onnxruntime/transformers/optimizer.py", line 360, in <module>
    main()
  File "/usr/lib/python3.8/site-packages/onnxruntime/transformers/optimizer.py", line 346, in main
    optimizer.convert_model_float32_to_float16()
  File "/usr/lib/python3.8/site-packages/onnxruntime/transformers/onnx_model.py", line 416, in convert_model_float32_to_float16
    self.model = oc.float16.convert_float_to_float16(self.model, keep_io_types=cast_input_output)
  File "/usr/lib/python3.8/site-packages/onnxconverter_common/float16.py", line 140, in convert_float_to_float16
    model = func_infer_shape(model)
  File "/usr/lib/python3.8/site-packages/onnx/shape_inference.py", line 36, in infer_shapes
    model_str = model.SerializeToString()
ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 2235646909

Here are the library versions I am using:

  • onnx==1.9.0
  • onnxruntime==1.7.0
  • torch==1.13.1
  • protobuf==3.20.2

Can anyone suggest a solution to this problem or provide any guidance on how to optimize this model without running into this error?

Would appreciate a quick help on it.

@baijumeswani
Copy link
Contributor

protobuf has a file size limit of 2GB.

You could consider saving to external data using: https://github.com/onnx/onnx/blob/main/docs/ExternalData.md#converting-an-onnx-model-to-external-data.

And then pass in the --use_external_data_format flag when running the

python3 -m onnxruntime.transformers.optimizer ...

command.

Let me know if that works for you.

@tianleiwu
Copy link
Contributor

tianleiwu commented Apr 4, 2023

The cause is it uses ONNX shape inference which cannot handle model larger than 2GB. The solution is to upgrade to latest onnxruntime-gpu, which uses symbolic shape inference, and could avoid the limitation.

Try install latest onnxruntime-gpu (fp16 need CUDA EP):

pip3 uninstall onnxruntime
pip3 install -U onnxruntime-gpu
pip3 install -U onnx
python3 -m onnxruntime.transformers.optimizer --input {onnx_model_path} --output {onnx_optimized_model_path} --hidden_size 1024 --num_heads 16 --opt_level 0 --float16 --use_external_data_format --use_gpu

Note that --opt_level 0 --float16 --use_external_data_format --use_gpu is used in the command.

I think --use_external_data_format is optional since fp16 model size will be reduced to 50% so it could be within 2GB.

@toobashahid210
Copy link
Author

toobashahid210 commented Apr 6, 2023

@baijumeswani

I tried the solution you provided, but it didn't work. I encountered the same error again.

Here's the code I used

import onnx
onnx_model = onnx.load("/onnx_model/model.onnx")

onnx.save_model(onnx_model, "/optmized_model/model.onnx", save_as_external_data=True, all_tensors_to_one_file=True, 
                size_threshold=1024, convert_attribute=False, location='external_data')

onnx_model is the model converted using torch.onnx.export
Let me know if i am missing something

@puyuanOT
Copy link

got the same error when running command

optimum-cli export onnx --model EleutherAI/pythia-2.8b /local_disk0/optimized_pythia_2.8b_fp16 --optimize O4 --task text-generation-with-past --framework pt --device cuda

tianleiwu added a commit that referenced this issue Sep 6, 2023
…ta (#17427)

Some initializers are added without raw=True flag. That causes those
tensors cannot be saved to external data. If those tensors exceed 2GB
in total, optimized model cannot be saved due to protobuf limit.

This change will save attention weights and bias in raw data.

Note: it is optional to use raw data for shape tensor since they are
tiny.

### Motivation and Context
#17212
#15349
@Tord-Zhang
Copy link

@baijumeswani

I tried the solution you provided, but it didn't work. I encountered the same error again.

Here's the code I used

import onnx
onnx_model = onnx.load("/onnx_model/model.onnx")

onnx.save_model(onnx_model, "/optmized_model/model.onnx", save_as_external_data=True, all_tensors_to_one_file=True, 
                size_threshold=1024, convert_attribute=False, location='external_data')

onnx_model is the model converted using torch.onnx.export Let me know if i am missing something

@toobashahid210 Hi, I ran into the same problem. Any good solution?

@tianleiwu
Copy link
Contributor

tianleiwu commented Oct 6, 2023

If anyone encounter similar error during using onnxruntime.transformers.optimizer for onnxruntime version < 1.17, try install ort-nightly or ort-nightly-gpu package:

python -m pip install coloredlogs flatbuffers numpy packaging protobuf sympy

python -m pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ ort_nightly

@toobashahid210
Copy link
Author

toobashahid210 commented Oct 27, 2023

@baijumeswani

I tried the solution you provided, but it didn't work. I encountered the same error again.
Here's the code I used

import onnx
onnx_model = onnx.load("/onnx_model/model.onnx")

onnx.save_model(onnx_model, "/optmized_model/model.onnx", save_as_external_data=True, all_tensors_to_one_file=True, 
                size_threshold=1024, convert_attribute=False, location='external_data')

onnx_model is the model converted using torch.onnx.export Let me know if i am missing something

@toobashahid210 Hi, I ran into the same problem. Any good solution?

there might be some libraries compatibility issues on my side ig. I resolved it using following library versions

torch==1.11.0
onnx==1.11.0
onnxconverter-common==1.9.0
onnxruntime==1.7.0

tianleiwu added a commit that referenced this issue Oct 31, 2023
…ta (#17427)

Some initializers are added without raw=True flag. That causes those
tensors cannot be saved to external data. If those tensors exceed 2GB
in total, optimized model cannot be saved due to protobuf limit.

This change will save attention weights and bias in raw data.

Note: it is optional to use raw data for shape tensor since they are
tiny.

### Motivation and Context
#17212
#15349
kleiti pushed a commit to kleiti/onnxruntime that referenced this issue Mar 22, 2024
…ta (microsoft#17427)

Some initializers are added without raw=True flag. That causes those
tensors cannot be saved to external data. If those tensors exceed 2GB
in total, optimized model cannot be saved due to protobuf limit.

This change will save attention weights and bias in raw data.

Note: it is optional to use raw data for shape tensor since they are
tiny.

### Motivation and Context
microsoft#17212
microsoft#15349
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants