Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved ONNX support with dynamic shapes #117

Closed
xenova opened this issue Jun 20, 2024 · 9 comments
Closed

Improved ONNX support with dynamic shapes #117

xenova opened this issue Jun 20, 2024 · 9 comments

Comments

@xenova
Copy link

xenova commented Jun 20, 2024

Hi there! 👋 Following the conversation in #103, I wanted to export the models so they (1) support dynamic shapes and (2) returned the normal information, mainly to run the models with Transformers.js. I got them working, and I've uploaded them to the Hugging Face Hub:

(you can find the .onnx weights - both fp32 and fp16) in the onnx subfolder)

Feel free to use them yourself or add the links to the README for increased visibility! 🤗 PS: I'd also recommend uploading your original pytorch checkpoints to separate repos (instead of a single repo). Let me know if I can help with any of this!

Regarding the export, there were a few things to consider, mainly fixing the modelling code to avoid python type casts (ensuring the dynamic shapes work during tracing). I also made a few modifications to support CPU exports. Here's my conversion code:

import torch
import math
import torch.nn as nn

class NullContext:
  def __init__(self, *args, **kwargs):
    pass

  def __enter__(self):
    pass

  def __exit__(self, exc_type, exc_value, traceback):
    pass

# Do not autocast to bf16 or cuda
torch.autocast = NullContext

class Metric3DExportModel(nn.Module):
    """
    The model for exporting to ONNX format. Add custom preprocessing and postprocessing here.
    """

    def __init__(self, meta_arch):
        super().__init__()
        self.meta_arch = meta_arch
        self.register_buffer(
            "rgb_mean", torch.tensor([123.675, 116.28, 103.53]).view(1, 3, 1, 1)
        )
        self.register_buffer(
            "rgb_std", torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1)
        )

    def normalize_image(self, image):
        image = image - self.rgb_mean
        image = image / self.rgb_std
        return image

    def forward(self, image):
        image = self.normalize_image(image)
        with torch.no_grad():
            pred_depth, confidence, output_dict = self.meta_arch.inference(
                {"input": image}
            )

        pred_depth = pred_depth.squeeze(1)
        pred_normal = output_dict['prediction_normal'][:, :3, :, :] # only available for Metric3Dv2 i.e., ViT models
        normal_confidence = output_dict['prediction_normal'][:, 3, :, :] # see https://arxiv.org/abs/2109.09881 for details

        return pred_depth, pred_normal, normal_confidence


def patch_model(model):

    def interpolate_pos_encoding(self, x, w, h):
        previous_dtype = x.dtype
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        # Comment out this code (so we always interpolate)
        # if npatch == N and w == h:
        #     return self.pos_embed
        pos_embed = self.pos_embed.float()
        class_pos_embed = pos_embed[:, 0]
        patch_pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_size
        h0 = h // self.patch_size
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset

        if torch.jit.is_tracing():
          sqrt_N = N ** 0.5
          patch_pos_embed = nn.functional.interpolate(
              patch_pos_embed.reshape(1, (sqrt_N).to(torch.int64), (sqrt_N).to(torch.int64), dim).permute(0, 3, 1, 2),
              size=(w0, h0),
              mode="bicubic",
              antialias=self.interpolate_antialias,
          )
        else:
          sqrt_N = math.sqrt(N)
          sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
          patch_pos_embed = nn.functional.interpolate(
              patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
              scale_factor=(sx, sy),
              mode="bicubic",
              antialias=self.interpolate_antialias,
          )

        assert int(w0) == patch_pos_embed.shape[-2]
        assert int(h0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

    model.depth_model.encoder.interpolate_pos_encoding = (
        interpolate_pos_encoding.__get__(
            model.depth_model.encoder, model.depth_model.encoder.__class__
        )
    )

    def get_bins(self, bins_num):
        depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num)
        depth_bins_vec = torch.exp(depth_bins_vec)
        return depth_bins_vec

    model.depth_model.decoder.get_bins = (
        get_bins.__get__(
            model.depth_model.decoder, model.depth_model.decoder.__class__
        )
    )

    return model

# Load model
model_name = "metric3d_vit_small" # or "metric3d_vit_large" or "metric3d_vit_giant2"
model = torch.hub.load("yvanyin/metric3d", model_name, pretrain=True)
model.eval()

# Patch model so we can export to ONNX
model = patch_model(model)
export_model = Metric3DExportModel(model)
export_model.eval()

# Export the model
dummy_image = torch.randn([2, 3, 280, 420])
onnx_output = f"{model_name}.onnx"
torch.onnx.export(
    export_model,
    (dummy_image, ),
    onnx_output,
    input_names=["pixel_values"],
    output_names=["predicted_depth", "predicted_normal", "normal_confidence"],
    opset_version=11,

    dynamic_axes= {
      "pixel_values": {0: "batch_size", 2: "height", 3: "width"},
      "predicted_depth": {0: "batch_size", 1: "height", 2: "width"},
      "predicted_normal": {0: "batch_size", 2: "height", 3: "width"},
      "normal_confidence": {0: "batch_size", 1: "height", 2: "width"},
    }
)
@xenova
Copy link
Author

xenova commented Jun 20, 2024

There are minor differences in output, but this can be attributed to (1) implementation differences between ORT and pytorch, and (2) default dtypes.

Diff between normalized images:

PyTorch ONNX Diff
image image image

@xenova
Copy link
Author

xenova commented Jun 20, 2024

Example usage in python:

  1. Download the model:
wget https://huggingface.co/onnx-community/metric3d-vit-small/resolve/main/onnx/model.onnx
  1. Run model
import onnxruntime as ort
import requests
import numpy as np
from PIL import Image

# Load session
ort_session = ort.InferenceSession("./model.onnx", providers=['CPUExecutionProvider'])

# Load image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

# Predict depth
input = np.array(image).transpose(2, 0, 1)
input = np.expand_dims(input, 0) # Add batch dim
onnxruntime_input = {'pixel_values': input.astype(np.float32)}
pred_depth, pred_normal, normal_confidence = ort_session.run(None, onnxruntime_input)
  1. Visualize results
min_val = pred_depth.min()
max_val = pred_depth.max()
normalized = 255 * ((pred_depth - min_val)/(max_val-min_val))
Image.fromarray(normalized[0].astype(np.uint8)).save('depth.png')

image

@YvanYin
Copy link
Owner

YvanYin commented Jun 21, 2024

Hi @xenova , thx for your support. Do you mind joining this project and updating your efforts to our README?

@xenova
Copy link
Author

xenova commented Jun 21, 2024

@YvanYin you're welcome! :) Do you mean submitting a PR? If so, then sure!

@YvanYin
Copy link
Owner

YvanYin commented Jun 23, 2024

@xenova I invited you.

@YvanYin YvanYin closed this as completed Jun 23, 2024
@bananajoe182
Copy link

@xenova Great work!
I don't know if I'm missing something but the onnx-community/metric3d-vit-giant2 model seems to be incomplete looking at the size. (model.onnx is 1,64MB)
If I try to load the model_fp16 it throws an error:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from weight/giant/model_fp16.onnx failed:Type Error: Type parameter (T) of Optype (Add) bound to different types (tensor(int64) and tensor(float16) in node (/depth_model/decoder/Add_6).

Using the large model it works great!

@stepstefan
Copy link

@xenova Great work! I don't know if I'm missing something but the onnx-community/metric3d-vit-giant2 model seems to be incomplete looking at the size. (model.onnx is 1,64MB) If I try to load the model_fp16 it throws an error: onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from weight/giant/model_fp16.onnx failed:Type Error: Type parameter (T) of Optype (Add) bound to different types (tensor(int64) and tensor(float16) in node (/depth_model/decoder/Add_6).

Using the large model it works great!

Hi! Great work on both models and export. Smaller ones work great.
I'm hitting the same issue for giant2 model for both versions (fp16 and fp32). Any update on this?
Tagging in case these comments are overlooked in a closed issue @xenova @YvanYin

@haofengsiji
Copy link

same problem here, can anyone show the right inference script for onnx giant2 ? really appricate

@adricostas
Copy link

adricostas commented Oct 8, 2024

Hello,

I'm using the code that you provided (here) with my own image. I see that the output size is not the same as the input size. Is this ok?

Input shape:  (3, 960, 1280)
Pred depth shape:  (1, 952, 1272)

Thank you in advance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants