Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to convert the onnx model to fp16 model? #489

Closed
yuananf opened this issue Sep 13, 2022 · 27 comments
Closed

Is it possible to convert the onnx model to fp16 model? #489

yuananf opened this issue Sep 13, 2022 · 27 comments
Assignees
Labels
stale Issues that haven't received updates

Comments

@yuananf
Copy link

yuananf commented Sep 13, 2022

The torch example gives parameter revision="fp16", can onnx model do the same optimization? Current onnx inference(using CUDAExecutionProvider) is slower than torch version, and used more gpu memory than torch version(12G vs 4G).

@anton-l
Copy link
Member

anton-l commented Sep 13, 2022

Hi @yuananf! At the moment the onnx pipeline is less optimized than its pytorch counterpart, so all computation happens in float32 and there's overhead due to cpu-gpu tensor copies in the inference sampling loop.
For now only the CPU runtime offers a significant speedup over pytorch, but we're working with the onnxruntime team on a GPU revamp. If you see something obvious that could improve performace on GPU: feel free to open a PR and we'll integrate it! :)

@yuananf
Copy link
Author

yuananf commented Sep 14, 2022

Thank you for your response!

@yuananf
Copy link
Author

yuananf commented Sep 14, 2022

Hi @yuananf! At the moment the onnx pipeline is less optimized than its pytorch counterpart, so all computation happens in float32 and there's overhead due to cpu-gpu tensor copies in the inference sampling loop. For now only the CPU runtime offers a significant speedup over pytorch, but we're working with the onnxruntime team on a GPU revamp. If you see something obvious that could improve performace on GPU: feel free to open a PR and we'll integrate it! :)

To avoid data copy between cpu and gpu, onnxruntime provided IOBinding feature. https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device

@wareya
Copy link

wareya commented Sep 15, 2022

It's possible to convert individual .onnx things to fp16 with this script from onnxconverter: https://github.com/microsoft/onnxconverter-common/blob/master/onnxconverter_common/auto_mixed_precision_model_path.py. I need to do this to get 512x512 stable diffusion to fit in dedicated VRAM on my GPU (an 8gb rx 480), giving me 2 seconds per iteration instead of 16. However the process for doing this is very cumbersome, because you need to feed each stage the exact inputs that it typically expects (I ended up copy/pasting the onnx pipeline file and added in code to copy the inputs to each stage), and it's very slow, because it essentially brute forces which parts it can convert to fp16 without causing the model's output to significantly change. Also, at least on my system, I had to edit the auto mixed precision script to deal with the case of stale file handles by force deleting said files instead of trying to delete them normally (it would crash with a file access violation because they were still open in stale handles).

@tdeboissiere
Copy link

@wareya could you share the details of your scripts / conversion process ?

@wareya
Copy link

wareya commented Sep 16, 2022

First, get the full-precision onnx model locally from the onnx exporter (convert_stable_diffusion_checkpoint_to_onnx.py). For example:

python convert_stable_diffusion_checkpoint_to_onnx.py --model_path="CompVis/stable-diffusion-v1-4" --output_path="./stable_diffusion_onnx" --height=512 --width=512

Then modify pipeline_stable_diffusion_onnx.py to call out to auto_convert_mixed_precision_model_path with the relevant paths and inputs at the right points in the pipeline. I modified it like so: https://gist.github.com/wareya/0d5d111b1e2448a3e99e8be2b39fbcf1 (I've modified this since the last time I ran it, so it might be slightly broken. YMMV. Also, this is extremely hacky and bad and you shouldn't encourage other people to do it this way or do it this way in any release pipelines, but it was the fastest way for me to get it done locally.)

However, on my system, this crashes inside of auto_mixed_precision_model_path.py because it tries to delete files that it still has open. Might be a bug in the exact version of the onnx runtime that I'm running (I'm running a nightly version). To work around it, I modify _clean_output_folder in auto_mixed_precision_model_path.py like this (this is evil and might only work on Windows 10, not other versions of windows, and will not work on non-windows OSs):

    if os.path.exists(tmp_tensor_path):
        try:
            os.remove(tmp_tensor_path)
        except:
            try:
                tmp_tensor_path = tmp_tensor_path.replace("/", "\\")
                print(f"force deleting {tmp_tensor_path}")
                os.system(f'cmd /c "del /f {tmp_tensor_path}"')
            except:
                print("no idea what broke here but something did!")

@tianleiwu
Copy link
Contributor

tianleiwu commented Sep 28, 2022

Latest script can be found here: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md

Example script to convert FP32 to FP16:

 # You can clone the source code of onnxruntime to run this script as the following:
 #    git clone https://github.com/microsoft/onnxruntime
 #    cd onnxruntime/onnxruntime/python/tools/transformers 
 #    save this script to the directory as sd_fp16.py. Modify the root_dir if needed.
 #    python sd_fp16.py
    
import os
import shutil
import onnx
from onnxruntime.transformers.optimizer import optimize_model

# root directory of the onnx pipeline data files
root_dir = "./sd_onnx"

for name in ["unet", "vae_encoder", "vae_decoder", "text_encoder", "safety_checker"]:
    onnx_model_path = f"{root_dir}/{name}/model.onnx"

    # The following will fuse LayerNormalization and Gelu. Do it before fp16 conversion, otherwise they cannot be fused later.
    # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead.
    m = optimize_model(
        onnx_model_path,
        model_type="bert",
        num_heads=0,
        hidden_size=0,
        opt_level=0,
        optimization_options=None,
        use_gpu=False,
    )

    # Use op_bloack_list to force some operators to compute in FP32.
    # TODO: might need some tuning to add more operators to op_bloack_list to reduce accuracy loss.
    if name == "safety_checker":
        m.convert_float_to_float16(op_block_list=["Where"])
    else:
        m.convert_float_to_float16(op_block_list=["RandomNormalLike"])

    # Overwrite existing models. You can change it to another directory but need copy other files like tokenizer manually.
    optimized_model_path = f"{root_dir}/{name}/model.onnx"
    output_dir = os.path.dirname(optimized_model_path)
    shutil.rmtree(output_dir)
    os.mkdir(output_dir)

    onnx.save_model(m.model, optimized_model_path)

To get best performance, please set providers like the following:

   providers = [("CUDAExecutionProvider", {"cudnn_conv_use_max_workspace": '1'})]

See https://onnxruntime.ai/docs/performance/tune-performance.html#convolution-heavy-models-and-the-cuda-ep for more info.

Latency (Seconds per Query) for GPU

Stable Diffusion Pipeline (text to 512x512 image) T4 V100 A100
PyTorch FP16 12.8 5.1 3.1
Onnx FP32 26.2 8.3 4.9
Onnx FP16 9.6 3.8 2.4

@patrickvonplaten
Copy link
Contributor

Nice! @anton-l do you want to take a look here?

@cprivitere
Copy link

I tried @tianleiwu 's script and all my results, while twice as fast, are very washed out. I'm also getting a lot of all black box results. This was on an AMD card using the DML provider on windows. The non converted-to-fp16 files work fine.

@tianleiwu
Copy link
Contributor

@cprivitere, thanks for the feedback. To improve accuracy, onnx model need convert to mixed precision by adding some operators (like LayerNormalization, Gelu etc) to op_block_list in the script. The list could be tuned by some A/B testing (Comparing to fp32 results) using a test set.

@cprivitere
Copy link

I've had some time since my last post to actually install and set all this up on Linux, none of the speed increases here even come close to the speed of running the fp16 models on ROCm, sadly. On a 6750XT we're talking 4.5it/s using LMS on Linux fp16 ROCm models versus 1.4it/s using LMS with the broken fp16 onnx models.

@patrickvonplaten
Copy link
Contributor

Any updates here @anton-l ?

@anton-l
Copy link
Member

anton-l commented Oct 27, 2022

FP16 models are now supported when tracing on GPU, thanks to @SkyTNT: #932
AFAIK @tianleiwu is working on a working CPU-only conversion script :)

@cprivitere
Copy link

Just a note for folks that this fp16 conversion of the ONNX models does NOT support AMD GPUs. It only works on NVIDIA.

@wareya
Copy link

wareya commented Oct 27, 2022

The one I posted works on AMD GPUs, at least.

@averad
Copy link

averad commented Nov 2, 2022

@anton-l I ran the FP32 to FP16 @tianleiwu provided and was able to convert a Onnx FP32 Model to Onnx FP16 Model.

Windows 11
AMD RX580 8GB
Python 3.10
Diffusers 0.6.0
DmlExecutionProvider (onnxruntime-directml)

When attempting to load the FP16 Model the following error is received when using the OnnxStableDiffusionPipeline.from_pretrained:

onnxruntime_inference_collection.py", line 384, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from ./models/model\vae_encoder\model.onnx failed:Type Error: Type (tensor(float16)) of output arg (onnx::Cast_882) of node (RandomNormalLike_496) does not match expected type (tensor(float)).

@kamalkraj
Copy link
Contributor

FP16 models are now supported when tracing on GPU, thanks to @SkyTNT: #932 AFAIK @tianleiwu is working on a working CPU-only conversion script :)

Hi,
The onnx fp16 model is taking ~17s per image, and the PyTorch fp16 is ~9s on T4 GPU.

@anton-l @patrickvonplaten

@anton-l
Copy link
Member

anton-l commented Nov 28, 2022

@kamalkraj the current ONNX pipeline design hurts GPU latency, so for now its main use cases are CPU inference and supporting environments which torch doesn't support (e.g. some AMD GPUs).
Opened an issue here if you'd like to improve GPU support: #1452

Other cause could be that our conversion to ONNX is not taking advantage of all of the optimization features in ONNX and ONNXRuntime. We're working with the Optimum team to improve that.

@tianleiwu
Copy link
Contributor

tianleiwu commented Dec 2, 2022

@averad, you can add "RandomNormalLike" to op_block_list to avoid the error. The latest script is here:
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md

@kamalkraj, you can run like the following to reproduce ~10s on T4.

git clone https://github.com/tianleiwu/diffusers
cd diffusers
git checkout tlwu/benchmark
pip install -e .
pip install onnxruntime-gpu
python scripts/benchmark.py --pipeline ./sd_onnx

@holgerpieta
Copy link

In case someone is still interested: I used a script inspired by @tianleiwu's to convert just the UNET to fp16, leaving everything else on fp32.
That reduced time per iteration by about 50 to 70 % with no notable loss of quality. Results looked different here and there, but not that one was significantly better or worse than the other.
Converting everything else except the VAE-decoder changed nothing, neither speed nor quality.
Converting the VAE-decoder greatly reduced quality without any impact on speed.

Though I had to add 'only_onnxruntime=True' to the arguments of 'optimize_model' to make it work, otherwise it crashed with some tensor dimension problems.

@kamalkraj
Copy link
Contributor

@averad, you can add "RandomNormalLike" to op_block_list to avoid the error. The latest script is here: https://github.com/tianleiwu/diffusers/blob/tlwu/benchmark/scripts/convert_sd_onnx_to_fp16.py

@kamalkraj, you can run like the following to reproduce ~10s on T4.

git clone https://github.com/tianleiwu/diffusers
cd diffusers
git checkout tlwu/benchmark
pip install -e .
pip install onnxruntime-gpu
python scripts/benchmark.py --pipeline ./sd_onnx

Hi @tianleiwu ,

Will this also work with Stable Diffusion 2/2.1?

@tianleiwu
Copy link
Contributor

@kamalkraj, I will try SD 2/2.1 and get back to you later.

@saikrishna2893
Copy link

Latest script can be downloaded here: https://github.com/tianleiwu/diffusers/blob/tlwu/benchmark/scripts/convert_sd_onnx_to_fp16.py

Example script to convert FP32 to FP16:

 # You can clone the source code of onnxruntime to run this script as the following:
 #    git clone https://github.com/microsoft/onnxruntime
 #    cd onnxruntime/onnxruntime/python/tools/transformers 
 #    save this script to the directory as sd_fp16.py. Modify the root_dir if needed.
 #    python sd_fp16.py
    
import os
import shutil
import onnx
from onnxruntime.transformers.optimizer import optimize_model

# root directory of the onnx pipeline data files
root_dir = "./sd_onnx"

for name in ["unet", "vae_encoder", "vae_decoder", "text_encoder", "safety_checker"]:
    onnx_model_path = f"{root_dir}/{name}/model.onnx"

    # The following will fuse LayerNormalization and Gelu. Do it before fp16 conversion, otherwise they cannot be fused later.
    # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead.
    m = optimize_model(
        onnx_model_path,
        model_type="bert",
        num_heads=0,
        hidden_size=0,
        opt_level=0,
        optimization_options=None,
        use_gpu=False,
    )

    # Use op_bloack_list to force some operators to compute in FP32.
    # TODO: might need some tuning to add more operators to op_bloack_list to reduce accuracy loss.
    if name == "safety_checker":
        m.convert_float_to_float16(op_block_list=["Where"])
    else:
        m.convert_float_to_float16(op_block_list=["RandomNormalLike"])

    # Overwrite existing models. You can change it to another directory but need copy other files like tokenizer manually.
    optimized_model_path = f"{root_dir}/{name}/model.onnx"
    output_dir = os.path.dirname(optimized_model_path)
    shutil.rmtree(output_dir)
    os.mkdir(output_dir)

    onnx.save_model(m.model, optimized_model_path)

To get best performance, please set providers like the following:

   providers = [("CUDAExecutionProvider", {"cudnn_conv_use_max_workspace": '1'})]

See https://onnxruntime.ai/docs/performance/tune-performance.html#convolution-heavy-models-and-the-cuda-ep for more info.

Latency (Seconds per Query) for GPU

Stable Diffusion Pipeline (text to 512x512 image) T4 V100 A100
PyTorch FP16 12.8 5.1 3.1
Onnx FP32 26.2 8.3 4.9
Onnx FP16 9.6 3.8 2.4

@tianleiwu When converted the stable-diffusion v1-4 onnx model from fp32 using the script provided, Observing that the converted model size is reduced but when loaded the model in netron, observed that still outputs and inputs are shown to be FP32 precision. Is this expected? Cant we generate a complete fp16 model using the available scripts? Because while running the inference, with CPU_FP16 flag with Openvino execution provider support, the device is shown as CPU_OPENVINO_CPU_FP32 instead of CPU_FP16. Reason for that might be that, since so called fp16 model is still having inputs and outputs with fp32, the inference device is fp32. Any thoughts on this?

@tianleiwu
Copy link
Contributor

@saikrishna2893,
to change inputs and outputs, you can add a parameters like the following:
https://github.com/microsoft/onnxruntime/blob/b1abb8c656c597bf221bd85682ae3d9e350d9aba/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py#L154
to convert the inputs and outputs to FP16.

@kamalkraj, you can try out our optimizations for SD 2 or SD 2.1. For SD 2.1, you will need add Attention (or MultiHeadAttention) to op_block_list to run it in float32 instead of float16. Otherwise, you will see black image.

Note that the script contains optimizations for CUDA EP only since some optimized operators might not be available to other EP. See the comments at the beginning of script for usage. The python environment used in test is like the following: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt

@wchao1115
Copy link

The main conversion script in the main branch supports the --fp16 option.
https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py

@tianleiwu Do we still need your special script to convert to FP16 or just using --fp16 is good enough? What's the difference between the two? I also notice that you also need to install torch with cuda to run that --fp16 option.

@tianleiwu
Copy link
Contributor

tianleiwu commented Feb 22, 2023

@wchao1115,

The main conversion script could export checkpoint to FP16 model. The model is composed of official ONNX operators, so it could be supported by different execution providers in inference engines (like ONNX Runtime, TensorRT etc)

However, inference engines still need process the model and optimize the graph. For example, fuse some part of graph to be custom operators (like MultiHeadAttention, which does not exist in ONNX spec), then dispatch it to optimized CUDA kernel (like Flash Attention). Such optimization is slightly different in different inference engines, and even different execution providers in ONNX Runtime.

My script has optimizations of SD for CUDA execution provider of ONNX Runtime only. There is also benchmark to compare the speed with PyTorch+xFormers and PyTorch 2.0.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 27, 2023
@github-actions github-actions bot closed this as completed Apr 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests