-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix ONNX exports for Optimum compatible models #31311
Conversation
def safe_int(x): | ||
return x.to(torch.int64) if torch.jit.is_tracing() else int(x) | ||
old_grid_size = safe_int(posemb_grid.size(0) ** 0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new_height = (torch.ceil(orig_height / patch_height) * patch_height).to(torch.int64) | ||
new_width = (torch.ceil(orig_width / patch_width) * patch_width).to(torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above - doesn't interpolate require (int, int)
when not tracing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll check tracing, thanks for the heads up
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Thanks for fixing this for all these models ❤️
Just a few small comments
src/transformers/utils/generic.py
Outdated
@@ -750,3 +750,15 @@ def infer_framework(model_class): | |||
return "flax" | |||
else: | |||
raise TypeError(f"Could not infer framework from class {model_class}.") | |||
|
|||
|
|||
def safe_int(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstrings would be helpful here e.g. for inspecting in IDEs: what does it mean for an int to be safe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed a better name is probably a good idea 😅 I called it safe_int
in a way to "safely cast some value (which could be a python number or tensor) to an integer in a way that respects tracing"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll swap with torch_int and torch_float
new_width = int(math.ceil(orig_width / patch_width) * patch_width) | ||
new_height = ( | ||
safe_float(torch.ceil(orig_height / patch_height) * patch_height) | ||
if torch.jit.is_tracing() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the conditional here? This is already handled in the safe_float
and safe_int
functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's required for torch.ceil
no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbh, I don't know, is there a reason we couldn't usetorch.ceil
directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if I'm passing an int or float, torch.ceil will be called first and it will fail because torch.ceil can only be called with tensors AFAIK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only other Q here then is why do we use a float
when tracing and int otherwise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I think I was mistaken with that one, you're right, I fixed it :)
Co-authored-by: amyeroberts <[email protected]>
@amyeroberts the failing tests seem irrelevant to this PR, I can't re-run them, can you re-run? |
@merveenoyan si si - done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing all of these!
Just for my own understanding - it there any reason to not use the torch compatible float/int when were not tracing?
@amyeroberts to my understanding, torch ONNX export internally calls |
@amyeroberts can you merge if you think it's ok? |
Right, I see why we need to do it for the onnx export, but for day-to-day use could we just use torch primitives instead of a python |
@amyeroberts I guess if it's just torch modelling code then yes. Would you like me to swap everything? |
also asking the same question to @xenova |
@merveenoyan Yes please! This will be cleaner and easier to follow in the code :) |
I agree with @amyeroberts - if there is a way to "do everything in torch land", that's the best solution! However, there are cases where I'm not entirely sure how to do this. For example, with
See here for example code (DinoV2 backbone): if torch.jit.is_tracing():
sqrt_N = N ** 0.5
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, (sqrt_N).to(torch.int64), (sqrt_N).to(torch.int64), dim).permute(0, 3, 1, 2),
size=(w0, h0),
mode="bicubic",
antialias=self.interpolate_antialias,
)
else:
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy),
mode="bicubic",
antialias=self.interpolate_antialias,
) Very ugly... I know :/ |
@xenova sounds good, very glad to work with you tbh I didn't know that it would be required in inference. |
@merveenoyan My understanding from above was that the PR would be updated to remove all the if/else structures wherever possible (but as @xenova points out isn't everywhere unfortunately) |
@amyeroberts from what I understood we should still keep them in if/else not to break the inference (I'm also scared of edge cases if there is etc) so I'd rather keep them. what I can do is to test all of them to see if they break or not when all are tensors and remove where it doesn't have to be a python type |
@merveenoyan OK. Let's just merge then and we can follow up in future PRs 👍 |
@amyeroberts as discussed and also pinging @xenova for review :') (who also fixed DPT)
I prioritized Optimum compatible ones because I'm launching a project where there's Optimum examples for vision models. I will have a separate PR for the models that aren't compatible with Optimum. Rest of the Optimum compatible models export well without a problem.