Skip to content

Commit

Permalink
YOLOv5 Export Benchmarks (ultralytics#6613)
Browse files Browse the repository at this point in the history
* Add benchmarks.py

* Update

* Add requirements

* Updates

* Updates

* Updates

* Updates

* Updates

* Updates

* dataset autodownload from root

* Update

* Redirect to /dev/null

* sudo --help

* Cleanup

* Add exports pd df

* Updates

* Updates

* Updates

* Cleanup

* dir handling fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

* Cleanup2

* Cleanup3

* Cleanup model_type

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and eladco committed Mar 10, 2022
1 parent 4cee2fb commit 1d89d21
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 39 deletions.
17 changes: 17 additions & 0 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import time
from pathlib import Path

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile
Expand All @@ -67,6 +68,22 @@
from utils.torch_utils import select_device


def export_formats():
# YOLOv5 export formats
x = [['PyTorch', '-', '.pt'],
['TorchScript', 'torchscript', '.torchscript'],
['ONNX', 'onnx', '.onnx'],
['OpenVINO', 'openvino', '_openvino_model'],
['TensorRT', 'engine', '.engine'],
['CoreML', 'coreml', '.mlmodel'],
['TensorFlow SavedModel', 'saved_model', '_saved_model'],
['TensorFlow GraphDef', 'pb', '.pb'],
['TensorFlow Lite', 'tflite', '.tflite'],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite'],
['TensorFlow.js', 'tfjs', '_web_model']]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix'])


def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
# YOLOv5 TorchScript model export
try:
Expand Down
53 changes: 16 additions & 37 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,30 +274,6 @@ def __init__(self, dimension=1):
def forward(self, x):
return torch.cat(x, self.d)

anchors = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]]
strides = [8,16,32]
def box_decoding(outputs, anchors, strides, nc=80, imgsz=(640, 640)):
import tensorflow as tf
no = nc + 5 # number of outputs per anchor
nl = len(anchors) # number of detection layers
na = len(anchors[0]) // 2 # number of anchors
anchor_grid = tf.reshape(np.array(anchors), [nl, 1, -1, 1, 2])
anchor_grid = tf.cast(anchor_grid, tf.float32)
z = []
for i in range(nl):
y = outputs[i]
ny, nx = imgsz[0] // strides[i], imgsz[1] // strides[i]
xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
grid = tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
xy = (y[..., 0:2] * 2 - 0.5 + grid) * strides[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * anchor_grid[i]
# Normalize xywh to 0-1 to reduce calibration error
#xy /= tf.constant([[imgsz[1], imgsz[0]]], dtype=tf.float32)
#wh /= tf.constant([[imgsz[1], imgsz[0]]], dtype=tf.float32)
y = tf.concat([xy, wh, y[..., 4:]], -1)
z.append(tf.reshape(y, [-1, na * ny * nx, no]))

return torch.tensor(tf.concat(z, 1).numpy())

class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
Expand All @@ -318,10 +294,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):

super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
suffix = Path(w).suffix.lower()
suffixes = ['.pt', '.torchscript', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel', '.xml', '.h5']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, jit, onnx, engine, tflite, pb, saved_model, coreml, xml, h5 = (suffix == x for x in suffixes) # backends
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
w = attempt_download(w) # download if not local
if data: # data.yaml path (optional)
Expand Down Expand Up @@ -356,6 +329,8 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.inference_engine as ie
core = ie.IECore()
if not Path(w).is_file(): # if not *.xml
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
network = core.read_network(model=w, weights=Path(w).with_suffix('.bin')) # *.xml, *.bin paths
executable_network = core.load_network(network, device_name='CPU', num_requests=1)
elif engine: # TensorRT
Expand Down Expand Up @@ -385,10 +360,6 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
import tensorflow as tf
model = tf.keras.models.load_model(w)
elif h5:
LOGGER.info(f'Loading {w} for Keras inference...')
import tensorflow as tf
model = tf.keras.models.load_model(w)
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
import tensorflow as tf
Expand Down Expand Up @@ -457,16 +428,12 @@ def forward(self, im, augment=False, visualize=False, val=False):
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
else:
y = y[sorted(y)[-1]] # last output
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU, Keras)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.saved_model: # SavedModel
y = self.model(im, training=False).numpy()
elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im)).numpy()
elif self.h5:
model_out = self.model(im, training=False)
y = box_decoding(model_out, anchors, strides)
return y
elif self.tflite: # Lite
input, output = self.input_details[0], self.output_details[0]
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
Expand All @@ -491,6 +458,18 @@ def warmup(self, imgsz=(1, 3, 640, 640), half=False):
im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
self.forward(im) # warmup

@staticmethod
def model_type(p='path/to/model.pt'):
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
from export import export_formats
suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
check_suffix(p, suffixes) # checks
p = Path(p).name # eliminate trailing separators
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes)
xml |= xml2 # *_openvino_model or *.xml
tflite &= not edgetpu # *.tflite
return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs


class AutoShape(nn.Module):
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
Expand Down
92 changes: 92 additions & 0 deletions utils/benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Run YOLOv5 benchmarks on all supported export formats
Format | `export.py --include` | Model
--- | --- | ---
PyTorch | - | yolov5s.pt
TorchScript | `torchscript` | yolov5s.torchscript
ONNX | `onnx` | yolov5s.onnx
OpenVINO | `openvino` | yolov5s_openvino_model/
TensorRT | `engine` | yolov5s.engine
CoreML | `coreml` | yolov5s.mlmodel
TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/
TensorFlow GraphDef | `pb` | yolov5s.pb
TensorFlow Lite | `tflite` | yolov5s.tflite
TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
TensorFlow.js | `tfjs` | yolov5s_web_model/
Requirements:
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
Usage:
$ python utils/benchmarks.py --weights yolov5s.pt --img 640
"""

import argparse
import sys
import time
from pathlib import Path

import pandas as pd

FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
# ROOT = ROOT.relative_to(Path.cwd()) # relative

import export
import val
from utils import notebook_init
from utils.general import LOGGER, print_args


def run(weights=ROOT / 'yolov5s.pt', # weights path
imgsz=640, # inference size (pixels)
batch_size=1, # batch size
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
):
y, t = [], time.time()
formats = export.export_formats()
for i, (name, f, suffix) in formats.iterrows(): # index, (name, file, suffix)
try:
w = weights if f == '-' else export.run(weights=weights, imgsz=[imgsz], include=[f], device='cpu')[-1]
assert suffix in str(w), 'export failed'
result = val.run(data, w, batch_size, imgsz=imgsz, plots=False, device='cpu', task='benchmark')
metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls))
speeds = result[2] # times (preprocess, inference, postprocess)
y.append([name, metrics[3], speeds[1]]) # mAP, t_inference
except Exception as e:
LOGGER.warning(f'WARNING: Benchmark failure for {name}: {e}')
y.append([name, None, None]) # mAP, t_inference

# Print results
LOGGER.info('\n')
parse_opt()
notebook_init() # print system info
py = pd.DataFrame(y, columns=['Format', '[email protected]:0.95', 'Inference time (ms)'])
LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)')
LOGGER.info(str(py))
return py


def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
opt = parser.parse_args()
print_args(FILE.stem, opt)
return opt


def main(opt):
run(**vars(opt))


if __name__ == "__main__":
opt = parse_opt()
main(opt)
5 changes: 3 additions & 2 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,10 @@ def run(data,
# Dataloader
if not training:
model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz), half=half) # warmup
pad = 0.0 if task == 'speed' else 0.5
pad = 0.0 if task in ('speed', 'benchmark') else 0.5
rect = False if task == 'benchmark' else pt # square inference for benchmarks
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,
dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=rect,
workers=workers, prefix=colorstr(f'{task}: '))[0]

seen = 0
Expand Down

0 comments on commit 1d89d21

Please sign in to comment.