Skip to content

Commit

Permalink
Fixed bug of no saving simplified ONNX file (#1489) (#1490)
Browse files Browse the repository at this point in the history
(cherry picked from commit 7ab603b)
  • Loading branch information
BloodAxe authored Sep 27, 2023
1 parent 29eea5b commit cefaffe
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
14 changes: 11 additions & 3 deletions src/super_gradients/module_interfaces/exportable_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Union, Optional, List, Tuple

import numpy as np
import onnx
import onnxsim
import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -495,9 +496,12 @@ def export(
if onnx_simplify:
# If TRT engine is used, we need to run onnxsim.simplify BEFORE attaching NMS,
# because EfficientNMS_TRT is not supported by onnxsim and would lead to a runtime error.
onnxsim.simplify(output)
model_opt, simplify_successful = onnxsim.simplify(output)
if not simplify_successful:
raise RuntimeError(f"Failed to simplify ONNX model {output} with onnxsim. Please check the logs for details.")
onnx.save(model_opt, output)
logger.debug(f"Ran onnxsim.simplify on model {output}")
# Disable onnx_simplify to avoid running it twice.
# Disable onnx_simplify to avoid running it second time.
onnx_simplify = False

nms_attach_method = attach_tensorrt_nms
Expand Down Expand Up @@ -528,7 +532,11 @@ def export(
)

if onnx_simplify:
onnxsim.simplify(output)
model_opt, simplify_successful = onnxsim.simplify(output)
if not simplify_successful:
raise RuntimeError(f"Failed to simplify ONNX model {output} with onnxsim. Please check the logs for details.")
onnx.save(model_opt, output)

logger.debug(f"Ran onnxsim.simplify on {output}")
finally:
if quantization_mode == ExportQuantizationMode.INT8:
Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/export_detection_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ def test_export_with_fp16_quantization(self):

max_predictions_per_image = 300
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = "."
out_path = os.path.join(tmpdirname, "ppyoloe_s_with_fp16_quantization.onnx")

ppyolo_e: ExportableObjectDetectionModel = models.get(Models.PP_YOLOE_S, pretrained_weights="coco")
Expand Down

0 comments on commit cefaffe

Please sign in to comment.