Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ONNX type and tag broadcasting #2919

Merged
merged 4 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions backend/src/nodes/impl/onnx/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import onnx
import onnx.inliner
import re2

from .model import OnnxGeneric, OnnxInfo, OnnxModel, OnnxRemBg
from .utils import (
get_opset,
get_tensor_fp_datatype,
image_to_image_shape_inference,
)

re2_options = re2.Options()
re2_options.dot_nl = True
re2_options.encoding = re2.Options.Encoding.LATIN1

U2NET_STANDARD = re2.compile(b"1959.+1960.+1961.+1962.+1963.+1964.+1965", re2_options)
U2NET_CLOTH = re2.compile(
b"output.+d1.+Concat_1876.+Concat_1896.+Concat_1916.+Concat_1936.+Concat_1956",
re2_options,
)
U2NET_SILUETA = re2.compile(b"1808.+1827.+1828.+2296.+1831.+1850.+1958", re2_options)
U2NET_ISNET = re2.compile(
b"/stage1/rebnconvin/conv_s1/Conv.+/stage1/rebnconvin/relu_s1/Relu", re2_options
)


def load_onnx_model(model_or_bytes: onnx.ModelProto | bytes) -> OnnxModel:
if isinstance(model_or_bytes, onnx.ModelProto):
model = model_or_bytes
model_as_bytes = model.SerializeToString()
else:
model_as_bytes = model_or_bytes
model = onnx.load_model_from_string(model_or_bytes)

info = OnnxInfo(
opset=get_opset(model),
dtype=get_tensor_fp_datatype(model),
)

if (
U2NET_STANDARD.search(model_as_bytes[-1000:]) is not None
or U2NET_SILUETA.search(model_as_bytes[-600:]) is not None
or U2NET_ISNET.search(model_as_bytes[:10000]) is not None
):
info.scale_width = 1
info.scale_height = 1
return OnnxRemBg(model_as_bytes, info)
elif U2NET_CLOTH.search(model_as_bytes[-1000:]) is not None:
info.scale_width = 1
info.scale_height = 3
return OnnxRemBg(model_as_bytes, info)
else:
try:
i_hwc, o_hwc = image_to_image_shape_inference(model, (512, 512))
i_h, i_w, i_c = i_hwc
o_h, o_w, o_c = o_hwc

def get_scale(i: int | None, o: int | None) -> int | None:
if i is None or o is None:
return None
if o % i != 0:
return None
return o // i

info.scale_width = get_scale(i_w, o_w)
info.scale_height = get_scale(i_h, o_h)

info.input_channels = i_c
info.output_channels = o_c
except Exception:
pass

return OnnxGeneric(model_as_bytes, info)
76 changes: 28 additions & 48 deletions backend/src/nodes/impl/onnx/model.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,45 @@
# This class defines an interface.
# It is important that is does not contain types that depend on ONNX.
from typing import Union
from __future__ import annotations

import re2
from dataclasses import dataclass
from typing import Final, Literal, Union

re2_options = re2.Options()
re2_options.dot_nl = True
re2_options.encoding = re2.Options.Encoding.LATIN1
OnnxSubType = Literal["Generic", "RemBg"]

U2NET_STANDARD = re2.compile(b"1959.+1960.+1961.+1962.+1963.+1964.+1965", re2_options)
U2NET_CLOTH = re2.compile(
b"output.+d1.+Concat_1876.+Concat_1896.+Concat_1916.+Concat_1936.+Concat_1956",
re2_options,
)
U2NET_SILUETA = re2.compile(b"1808.+1827.+1828.+2296.+1831.+1850.+1958", re2_options)
U2NET_ISNET = re2.compile(
b"/stage1/rebnconvin/conv_s1/Conv.+/stage1/rebnconvin/relu_s1/Relu", re2_options
)

@dataclass
class OnnxInfo:
opset: int
dtype: str

scale_width: int | None = None
scale_height: int | None = None

fixed_input_width: int | None = None
fixed_input_height: int | None = None

input_channels: int | None = None
output_channels: int | None = None


class OnnxGeneric:
def __init__(self, model_as_bytes: bytes):
def __init__(self, model_as_bytes: bytes, info: OnnxInfo):
self.bytes: bytes = model_as_bytes
self.sub_type = "Generic"
self.scale_height = None
self.scale_width = None
self.sub_type: Final[Literal["Generic"]] = "Generic"
self.info: OnnxInfo = info


class OnnxRemBg:
def __init__(self, model_as_bytes: bytes, scale_height: int = 1):
def __init__(
self,
model_as_bytes: bytes,
info: OnnxInfo,
):
self.bytes: bytes = model_as_bytes
self.sub_type = "RemBg"
self.scale_height = scale_height
self.scale_width = 1
self.sub_type: Final[Literal["RemBg"]] = "RemBg"
self.info: OnnxInfo = info


OnnxModels = (OnnxGeneric, OnnxRemBg)
OnnxModel = Union[OnnxGeneric, OnnxRemBg]


def is_rembg_model(model_as_bytes: bytes) -> bool:
if (
U2NET_STANDARD.search(model_as_bytes[-600:]) is not None
or U2NET_CLOTH.search(model_as_bytes[-1000:]) is not None
or U2NET_SILUETA.search(model_as_bytes[-600:]) is not None
or U2NET_ISNET.search(model_as_bytes[:10000]) is not None
):
return True
return False


def load_onnx_model(model_as_bytes: bytes) -> OnnxModel:
if (
U2NET_STANDARD.search(model_as_bytes[-1000:]) is not None
or U2NET_SILUETA.search(model_as_bytes[-600:]) is not None
or U2NET_ISNET.search(model_as_bytes[:10000]) is not None
):
model = OnnxRemBg(model_as_bytes)
elif U2NET_CLOTH.search(model_as_bytes[-1000:]) is not None:
model = OnnxRemBg(model_as_bytes, scale_height=3)
else:
model = OnnxGeneric(model_as_bytes)

return model
70 changes: 41 additions & 29 deletions backend/src/nodes/impl/onnx/session.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

from typing import Any, Dict, Tuple, Union
from weakref import WeakKeyDictionary

import onnxruntime as ort

from .model import OnnxModel
from .utils import OnnxParsedTensorShape, parse_onnx_shape

ProviderDesc = Union[str, Tuple[str, Dict[Any, Any]]]


def create_inference_session(
Expand All @@ -14,41 +18,33 @@ def create_inference_session(
should_tensorrt_fp16: bool = False,
tensorrt_cache_path: str | None = None,
) -> ort.InferenceSession:
tensorrt: ProviderDesc = (
"TensorrtExecutionProvider",
{
"device_id": gpu_index,
"trt_engine_cache_enable": tensorrt_cache_path is not None,
"trt_engine_cache_path": tensorrt_cache_path,
"trt_fp16_enable": should_tensorrt_fp16,
},
)
cuda: ProviderDesc = (
"CUDAExecutionProvider",
{
"device_id": gpu_index,
},
)
cpu: ProviderDesc = "CPUExecutionProvider"

if execution_provider == "TensorrtExecutionProvider":
providers = [
(
"TensorrtExecutionProvider",
{
"device_id": gpu_index,
"trt_engine_cache_enable": tensorrt_cache_path is not None,
"trt_engine_cache_path": tensorrt_cache_path,
"trt_fp16_enable": should_tensorrt_fp16,
},
),
(
"CUDAExecutionProvider",
{
"device_id": gpu_index,
},
),
"CPUExecutionProvider",
]
providers = [tensorrt, cuda, cpu]
elif execution_provider == "CUDAExecutionProvider":
providers = [
(
"CUDAExecutionProvider",
{
"device_id": gpu_index,
},
),
"CPUExecutionProvider",
]
providers = [cuda, cpu]
else:
providers = [execution_provider, "CPUExecutionProvider"]
providers = [execution_provider, cpu]

session = ort.InferenceSession(
model.bytes,
providers=providers, # type: ignore
providers=providers,
)
return session

Expand Down Expand Up @@ -76,3 +72,19 @@ def get_onnx_session(
)
__session_cache[model] = cached
return cached


def get_input_shape(session: ort.InferenceSession) -> OnnxParsedTensorShape:
"""
Returns the input shape, input channels, input width (optional), and input height (optional).
"""

return parse_onnx_shape(session.get_inputs()[0].shape)


def get_output_shape(session: ort.InferenceSession) -> OnnxParsedTensorShape:
"""
Returns the output shape, output channels, output width (optional), and output height (optional).
"""

return parse_onnx_shape(session.get_outputs()[0].shape)
Loading
Loading