Add support for ONNX type and tag broadcasting #2919
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Closes #1370
This adds support for type and tag broadcasting for all ONNX models. While we did have type broadcasting for ONNX models before, those broadcasts did not contain much information. Now they contain the model's input/output channels and scale, just like for PyTorch and NCNN models. This finally allows us to compute output types for ONNX's Upscale Image node.
A model's input/output channels and scale(s) are detected using
onnx.shape_inference
[1] [2]. This allows us to statically determine the shape of the output tensor for a given input tensor shape without actually running the model (=no inference session). While this is roughly 10-100x faster than creating anort.InferenceSession
, it's still not free. Loading generic ONNX models will now take roughly 3-4x longer. This might sound bad, but it's not. E.g. an ESRGAN ONNX model previously took 0.11s and now it takes 0.38s. A lot slower, but not a problem.Changes:
OnnxInfo
class.is_rembg_model
function, because it wasn't necessary.create_inference_session
.