diff --git a/mmdeploy/backend/tensorrt/utils.py b/mmdeploy/backend/tensorrt/utils.py index 7ad190428d..0cbee7fb3b 100644 --- a/mmdeploy/backend/tensorrt/utils.py +++ b/mmdeploy/backend/tensorrt/utils.py @@ -165,9 +165,13 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto], parser = trt.OnnxParser(network, logger) if isinstance(onnx_model, str): - onnx_model = onnx.load(onnx_model) + parse_valid = parser.parse_from_file(onnx_model) + elif isinstance(onnx_model, onnx.ModelProto): + parse_valid = parser.parse(onnx_model.SerializeToString()) + else: + raise TypeError('Unsupported onnx model type!') - if not parser.parse(onnx_model.SerializeToString()): + if not parse_valid: error_msgs = '' for error in range(parser.num_errors): error_msgs += f'{parser.get_error(error)}\n'