Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1106 Minor fixes in export api #1412

Merged
merged 5 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions documentation/source/models_export.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ export_result

import onnxruntime
import numpy as np
session = onnxruntime.InferenceSession("yolo_nas_s.onnx")
session = onnxruntime.InferenceSession("yolo_nas_s.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
example_input_image = np.zeros(1, 3, 640, 640).astype(np.uint8)
example_input_image = np.zeros((1, 3, 640, 640)).astype(np.uint8)
predictions = session.run(outputs, {inputs[0]: example_input_image})

Exported model has predictions in batch format:
Expand Down Expand Up @@ -117,7 +117,7 @@ image = load_image("https://deci-pretrained-models.s3.amazonaws.com/sample_image
image = cv2.resize(image, (export_result.input_image_shape[1], export_result.input_image_shape[0]))
image_bchw = np.transpose(np.expand_dims(image, 0), (0, 3, 1, 2))

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -337,10 +337,10 @@ export_result

import onnxruntime
import numpy as np
session = onnxruntime.InferenceSession("yolo_nas_s.onnx")
session = onnxruntime.InferenceSession("yolo_nas_s.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
example_input_image = np.zeros(1, 3, 640, 640).astype(np.uint8)
example_input_image = np.zeros((1, 3, 640, 640)).astype(np.uint8)
predictions = session.run(outputs, {inputs[0]: example_input_image})

Exported model has predictions in flat format:
Expand All @@ -359,7 +359,7 @@ Now we exported a model that produces predictions in `flat` format. Let's run th


```python
session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -437,7 +437,7 @@ export_result = model.export(
output_predictions_format = DetectionOutputFormatMode.FLAT_FORMAT
)

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -471,7 +471,7 @@ export_result = model.export(
quantization_mode=ExportQuantizationMode.INT8 # or ExportQuantizationMode.FP16
)

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})
Expand Down Expand Up @@ -514,15 +514,15 @@ export_result = model.export(
calibration_loader=dummy_calibration_loader
)

session = onnxruntime.InferenceSession(export_result.output)
session = onnxruntime.InferenceSession(export_result.output, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]
result = session.run(outputs, {inputs[0]: image_bchw})

show_predictions_from_flat_format(image, result)
```

25%|█████████████████████████████████████████████████ | 4/16 [00:11<00:34, 2.87s/it]
25%|█████████████████████████████████████████████████ | 4/16 [00:11<00:34, 2.90s/it]



Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified documentation/source/models_export_files/models_export_30_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 5 additions & 5 deletions src/super_gradients/examples/model_export/models_export.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
"image = cv2.resize(image, (export_result.input_image_shape[1], export_result.input_image_shape[0]))\n",
"image_bchw = np.transpose(np.expand_dims(image, 0), (0, 3, 1, 2))\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -486,7 +486,7 @@
}
],
"source": [
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -605,7 +605,7 @@
" output_predictions_format = DetectionOutputFormatMode.FLAT_FORMAT\n",
")\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -659,7 +659,7 @@
" quantization_mode=ExportQuantizationMode.INT8 # or ExportQuantizationMode.FP16\n",
")\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down Expand Up @@ -729,7 +729,7 @@
" calibration_loader=dummy_calibration_loader\n",
")\n",
"\n",
"session = onnxruntime.InferenceSession(export_result.output)\n",
"session = onnxruntime.InferenceSession(export_result.output, providers=[\"CUDAExecutionProvider\", \"CPUExecutionProvider\"])\n",
"inputs = [o.name for o in session.get_inputs()]\n",
"outputs = [o.name for o in session.get_outputs()]\n",
"result = session.run(outputs, {inputs[0]: image_bchw})\n",
Expand Down
59 changes: 41 additions & 18 deletions src/super_gradients/module_interfaces/exportable_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from super_gradients.conversion.gs_utils import import_onnx_graphsurgeon_or_install
from super_gradients.training.utils.export_utils import infer_format_from_file_name, infer_image_shape_from_model, infer_image_input_channels
from super_gradients.training.utils.quantization.fix_pytorch_quantization_modules import patch_pytorch_quantization_modules_if_needed
from super_gradients.training.utils.utils import infer_model_device, check_model_contains_quantized_modules

from super_gradients.training.utils.utils import infer_model_device, check_model_contains_quantized_modules, infer_model_dtype

logger = get_logger(__name__)

Expand Down Expand Up @@ -50,6 +49,19 @@ def forward(self, predictions: Any) -> Tuple[Tensor, Tensor]:
"""
raise NotImplementedError(f"forward() method is not implemented for class {self.__class__.__name__}. ")

@torch.jit.ignore
def infer_total_number_of_predictions(self, predictions: Any) -> int:
"""
This method is used to infer the total number of predictions for a given input resolution.
The function takes raw predictions from the model and returns the total number of predictions.
It is needed to check whether max_predictions_per_image and num_pre_nms_predictions are not greater than
the total number of predictions for a given resolution.

:param predictions: Predictions from the model itself.
:return: A total number of predictions for a given resolution
"""
raise NotImplementedError(f"forward() method is not implemented for class {self.__class__.__name__}. ")

def get_output_names(self) -> List[str]:
"""
Returns the names of the outputs of the module.
Expand Down Expand Up @@ -122,7 +134,7 @@ def export(
confidence_threshold: Optional[float] = None,
nms_threshold: Optional[float] = None,
engine: Optional[ExportTargetBackend] = None,
quantization_mode: ExportQuantizationMode = Optional[None],
quantization_mode: Optional[ExportQuantizationMode] = None,
selective_quantizer: Optional["SelectiveQuantizer"] = None, # noqa
calibration_loader: Optional[DataLoader] = None,
calibration_method: str = "percentile",
Expand Down Expand Up @@ -325,6 +337,27 @@ def export(
num_pre_nms_predictions = postprocessing_module.num_pre_nms_predictions
max_predictions_per_image = max_predictions_per_image or num_pre_nms_predictions

dummy_input = torch.randn(input_shape).to(device=infer_model_device(model), dtype=infer_model_dtype(model))
with torch.no_grad():
number_of_predictions = postprocessing_module.infer_total_number_of_predictions(model.eval()(dummy_input))

if num_pre_nms_predictions > number_of_predictions:
logger.warning(
f"num_pre_nms_predictions ({num_pre_nms_predictions}) is greater than the total number of predictions ({number_of_predictions}) for input"
f"shape {input_shape}. Setting num_pre_nms_predictions to {number_of_predictions}"
)
num_pre_nms_predictions = number_of_predictions
# We have to re-created the postprocessing_module with the new value of num_pre_nms_predictions
postprocessing_kwargs["num_pre_nms_predictions"] = num_pre_nms_predictions
postprocessing_module: AbstractObjectDetectionDecodingModule = model.get_decoding_module(**postprocessing_kwargs)

if max_predictions_per_image > num_pre_nms_predictions:
logger.warning(
f"max_predictions_per_image ({max_predictions_per_image}) is greater than num_pre_nms_predictions ({num_pre_nms_predictions}). "
f"Setting max_predictions_per_image to {num_pre_nms_predictions}"
)
max_predictions_per_image = num_pre_nms_predictions

nms_threshold = nms_threshold or getattr(model, "_default_nms_iou", None)
if nms_threshold is None:
raise ValueError(
Expand All @@ -339,12 +372,6 @@ def export(
"Please specify the confidence_threshold explicitly: model.export(..., confidence_threshold=0.5)"
)

if max_predictions_per_image > num_pre_nms_predictions:
raise ValueError(
f"max_predictions_per_image={max_predictions_per_image} is greater than "
f"num_pre_nms_predictions={num_pre_nms_predictions}. "
f"Please specify max_predictions_per_image <= {num_pre_nms_predictions}."
)
else:
attach_nms_postprocessing = False
postprocessing_module = None
Expand Down Expand Up @@ -523,19 +550,15 @@ def export(
usage_instructions.append("")
usage_instructions.append(" import onnxruntime")
usage_instructions.append(" import numpy as np")
usage_instructions.append(f' session = onnxruntime.InferenceSession("{output}")')
usage_instructions.append(f' session = onnxruntime.InferenceSession("{output}", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])')
usage_instructions.append(" inputs = [o.name for o in session.get_inputs()]")
usage_instructions.append(" outputs = [o.name for o in session.get_outputs()]")

dtype_name = np.dtype(torch_dtype_to_numpy_dtype(input_image_dtype)).name
if preprocessing:
usage_instructions.append(
f" example_input_image = np.zeros({batch_size}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]}).astype(np.{dtype_name})" # noqa
)
else:
usage_instructions.append(
f" example_input_image = np.zeros({batch_size}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]}).astype(np.{dtype_name})" # noqa
)
usage_instructions.append(
f" example_input_image = np.zeros(({batch_size}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]})).astype(np.{dtype_name})" # noqa
)

usage_instructions.append(" predictions = session.run(outputs, {inputs[0]: example_input_image})")
usage_instructions.append("")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Union, Optional, List, Tuple
from typing import Union, Optional, List, Tuple, Any

import torch
from torch import Tensor
Expand Down Expand Up @@ -81,6 +81,20 @@ def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]) -> T

return output_pred_bboxes, output_pred_scores

@torch.jit.ignore
def infer_total_number_of_predictions(self, predictions: Any) -> int:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"""

:param inputs:
:return:
"""
if torch.jit.is_tracing():
pred_bboxes, pred_scores = predictions
else:
pred_bboxes, pred_scores = predictions[0]

return pred_bboxes.size(1)


class PPYoloE(SgModule, ExportableObjectDetectionModel, HasPredict):
def __init__(self, arch_params):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def _generate_anchors(self, feats=None, dtype=None, device=None):
else:
h = int(self.eval_size[0] / stride)
w = int(self.eval_size[1] / stride)
shift_x = torch.arange(end=w) + self.grid_cell_offset
shift_y = torch.arange(end=h) + self.grid_cell_offset
shift_x = torch.arange(end=w, dtype=dtype) + self.grid_cell_offset
shift_y = torch.arange(end=h, dtype=dtype) + self.grid_cell_offset
if torch_version_is_greater_or_equal(1, 10):
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij")
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import math
import warnings
from typing import Union, Type, List, Tuple, Optional
from typing import Union, Type, List, Tuple, Optional, Any
from functools import lru_cache

import numpy as np
Expand Down Expand Up @@ -281,9 +281,9 @@ def forward(self, inputs):
def _make_grid(nx: int, ny: int, dtype: torch.dtype):
if torch_version_is_greater_or_equal(1, 10):
# https://github.com/pytorch/pytorch/issues/50276
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij")
yv, xv = torch.meshgrid([torch.arange(ny, dtype=dtype), torch.arange(nx, dtype=dtype)], indexing="ij")
else:
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
yv, xv = torch.meshgrid([torch.arange(ny, dtype=dtype), torch.arange(nx, dtype=dtype)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).to(dtype)


Expand Down Expand Up @@ -745,3 +745,18 @@ def forward(self, predictions):
output_pred_scores = pred_scores.reshape(-1, pred_scores.size(2))[flat_indices, :].reshape(pred_scores.size(0), nms_top_k, pred_scores.size(2))

return output_pred_bboxes, output_pred_scores

def get_num_pre_nms_predictions(self) -> int:
return self.num_pre_nms_predictions

@torch.jit.ignore
def infer_total_number_of_predictions(self, predictions: Any) -> int:
"""

:param inputs:
:return:
"""
if isinstance(predictions, (tuple, list)):
predictions = predictions[0]

return predictions.size(1)
Original file line number Diff line number Diff line change
Expand Up @@ -281,14 +281,16 @@ def _generate_anchors(self, feats=None, dtype=None, device=None):
else:
h = int(self.eval_size[0] / stride)
w = int(self.eval_size[1] / stride)
shift_x = torch.arange(end=w) + self.grid_cell_offset
shift_y = torch.arange(end=h) + self.grid_cell_offset

shift_x = torch.arange(end=w, dtype=dtype) + self.grid_cell_offset
shift_y = torch.arange(end=h, dtype=dtype) + self.grid_cell_offset

if torch_version_is_greater_or_equal(1, 10):
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij")
else:
shift_y, shift_x = torch.meshgrid(shift_y, shift_x)

anchor_point = torch.stack([shift_x, shift_y], dim=-1).to(dtype=dtype)
anchor_point = torch.stack([shift_x, shift_y], dim=-1)
anchor_points.append(anchor_point.reshape([-1, 2]))
stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype))
anchor_points = torch.cat(anchor_points)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Union, Tuple, Optional
from typing import Union, Tuple, Optional, Any

import torch
from omegaconf import DictConfig
Expand Down Expand Up @@ -28,6 +28,23 @@ def __init__(
super().__init__()
self.num_pre_nms_predictions = num_pre_nms_predictions

@torch.jit.ignore
def infer_total_number_of_predictions(self, predictions: Any) -> int:
"""

:param inputs:
:return:
"""
if torch.jit.is_tracing():
pred_bboxes, pred_scores = predictions
else:
pred_bboxes, pred_scores = predictions[0]

return pred_bboxes.size(1)

def get_num_pre_nms_predictions(self) -> int:
return self.num_pre_nms_predictions

def forward(self, inputs: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]):
if torch.jit.is_tracing():
pred_bboxes, pred_scores = inputs
Expand Down
19 changes: 19 additions & 0 deletions tests/unit_tests/export_detection_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,25 @@ def setUp(self) -> None:
this_dir = os.path.dirname(__file__)
self.test_image_path = os.path.join(this_dir, "../data/tinycoco/images/val2017/000000444010.jpg")

def test_export_model_on_small_size(self):
with tempfile.TemporaryDirectory() as tmpdirname:
for model_type in [
Models.YOLO_NAS_S,
Models.PP_YOLOE_S,
Models.YOLOX_S,
]:
out_path = os.path.join(tmpdirname, model_type + ".onnx")
ppyolo_e: ExportableObjectDetectionModel = models.get(model_type, pretrained_weights="coco")
result = ppyolo_e.export(
out_path,
input_image_shape=(64, 64),
num_pre_nms_predictions=2000,
max_predictions_per_image=1000,
output_predictions_format=DetectionOutputFormatMode.BATCH_FORMAT,
)
assert result.input_image_dtype == torch.uint8
assert result.input_image_shape == (64, 64)

def test_the_most_common_export_use_case(self):
"""
Test the most common export use case - export to ONNX with all default parameters
Expand Down