Skip to content

Commit

Permalink
TorchScript single-output fix (#7261)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Apr 3, 2022
1 parent 05cf0d1 commit 8bc839e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
18 changes: 12 additions & 6 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,18 @@

def export_formats():
# YOLOv5 export formats
x = [['PyTorch', '-', '.pt', True], ['TorchScript', 'torchscript', '.torchscript', True],
['ONNX', 'onnx', '.onnx', True], ['OpenVINO', 'openvino', '_openvino_model', False],
['TensorRT', 'engine', '.engine', True], ['CoreML', 'coreml', '.mlmodel', False],
['TensorFlow SavedModel', 'saved_model', '_saved_model', True], ['TensorFlow GraphDef', 'pb', '.pb', True],
['TensorFlow Lite', 'tflite', '.tflite', False], ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
['TensorFlow.js', 'tfjs', '_web_model', False]]
x = [
['PyTorch', '-', '.pt', True],
['TorchScript', 'torchscript', '.torchscript', True],
['ONNX', 'onnx', '.onnx', True],
['OpenVINO', 'openvino', '_openvino_model', False],
['TensorRT', 'engine', '.engine', True],
['CoreML', 'coreml', '.mlmodel', False],
['TensorFlow SavedModel', 'saved_model', '_saved_model', True],
['TensorFlow GraphDef', 'pb', '.pb', True],
['TensorFlow Lite', 'tflite', '.tflite', False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
['TensorFlow.js', 'tfjs', '_web_model', False],]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU'])


Expand Down
7 changes: 4 additions & 3 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,10 @@ def wrap_frozen_graph(gd, inputs, outputs):
def forward(self, im, augment=False, visualize=False, val=False):
# YOLOv5 MultiBackend inference
b, ch, h, w = im.shape # batch, channel, height, width
if self.pt or self.jit: # PyTorch
y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
return y if val else y[0]
if self.pt: # PyTorch
y = self.model(im, augment=augment, visualize=visualize)[0]
elif self.jit: # TorchScript
y = self.model(im)[0]
elif self.dnn: # ONNX OpenCV DNN
im = im.cpu().numpy() # torch to numpy
self.net.setInput(im)
Expand Down

0 comments on commit 8bc839e

Please sign in to comment.