diff --git a/.dockerignore b/.dockerignore
index a68626df5f2e..42f241f28c7b 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -14,8 +14,10 @@ data/samples/*
# Neural Network weights -----------------------------------------------------------------------------------------------
**/*.weights
**/*.pt
+**/*.pth
**/*.onnx
**/*.mlmodel
+**/*.torchscript
# Below Copied From .gitignore -----------------------------------------------------------------------------------------
diff --git a/.gitignore b/.gitignore
index 5a95798f0f61..07993ab27f15 100755
--- a/.gitignore
+++ b/.gitignore
@@ -50,6 +50,7 @@ gcp_test*.sh
*.pt
*.onnx
*.mlmodel
+*.torchscript
darknet53.conv.74
yolov3-tiny.conv.15
diff --git a/Dockerfile b/Dockerfile
index 01551a0e49e4..357c6dbc4cb9 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,5 +1,5 @@
# Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
-FROM nvcr.io/nvidia/pytorch:20.06-py3
+FROM nvcr.io/nvidia/pytorch:20.03-py3
RUN pip install -U gsutil
# Create working directory
@@ -47,4 +47,4 @@ COPY . /usr/src/app
# sudo docker commit 6d525e299258 user/test_image && sudo docker run -it --gpus all --ipc=host -v "$(pwd)"/coco:/usr/src/coco --entrypoint=sh user/test_image
# Clean up
-# docker system prune -a --volumes
\ No newline at end of file
+# docker system prune -a --volumes
diff --git a/README.md b/README.md
index 1e29d1835196..6306e55ec866 100755
--- a/README.md
+++ b/README.md
@@ -25,8 +25,8 @@ This repository represents Ultralytics open-source research into future object d
** APtest denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results in the table denote val2017 accuracy.
-** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --img 736 --conf 0.001`
-** SpeedGPU measures end-to-end time per image averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) instance with one V100 GPU, and includes image preprocessing, PyTorch FP16 image inference at --batch-size 32 --img-size 640, postprocessing and NMS. Average NMS time included in this chart is 1-2ms/img. Reproduce by `python test.py --img 640 --conf 0.1`
+** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --data data/coco.yaml --img 736 --conf 0.001`
+** SpeedGPU measures end-to-end time per image averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) instance with one V100 GPU, and includes image preprocessing, PyTorch FP16 image inference at --batch-size 32 --img-size 640, postprocessing and NMS. Average NMS time included in this chart is 1-2ms/img. Reproduce by `python test.py --data data/coco.yaml --img 640 --conf 0.1`
** All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation).
diff --git a/detect.py b/detect.py
index bb84a0df0c2c..2650c202d49d 100644
--- a/detect.py
+++ b/detect.py
@@ -158,7 +158,7 @@ def detect(save_img=False):
with torch.no_grad():
detect()
- # Update all models
+ # # Update all models
# for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
- # detect()
- # create_pretrained(opt.weights, opt.weights)
+ # detect()
+ # create_pretrained(opt.weights, opt.weights)
diff --git a/models/common.py b/models/common.py
index 3c4a0d729210..2c2d600394c1 100644
--- a/models/common.py
+++ b/models/common.py
@@ -1,9 +1,15 @@
# This file contains modules common to various models
-
from utils.utils import *
+def autopad(k, p=None): # kernel, padding
+ # Pad to 'same'
+ if p is None:
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
+ return p
+
+
def DWConv(c1, c2, k=1, s=1, act=True):
# Depthwise convolution
return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
@@ -11,10 +17,9 @@ def DWConv(c1, c2, k=1, s=1, act=True):
class Conv(nn.Module):
# Standard convolution
- def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).__init__()
- p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # padding
- self.conv = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False)
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.LeakyReLU(0.1, inplace=True) if act else nn.Identity()
@@ -46,7 +51,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
- self.cv4 = Conv(c2, c2, 1, 1)
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
self.act = nn.LeakyReLU(0.1, inplace=True)
self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
@@ -79,9 +84,9 @@ def forward(self, x):
class Focus(nn.Module):
# Focus wh information into c-space
- def __init__(self, c1, c2, k=1):
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Focus, self).__init__()
- self.conv = Conv(c1 * 4, c2, k, 1)
+ self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
diff --git a/models/experimental.py b/models/experimental.py
index 60cb7aa14cd5..cff9d141446d 100644
--- a/models/experimental.py
+++ b/models/experimental.py
@@ -1,6 +1,40 @@
+# This file contains experimental modules
+
from models.common import *
+class CrossConv(nn.Module):
+ # Cross Convolution
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
+ super(CrossConv, self).__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, (1, 3), 1)
+ self.cv2 = Conv(c_, c2, (3, 1), 1, g=g)
+ self.add = shortcut and c1 == c2
+
+ def forward(self, x):
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class C3(nn.Module):
+ # Cross Convolution CSP
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super(C3, self).__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
+ self.act = nn.LeakyReLU(0.1, inplace=True)
+ self.m = nn.Sequential(*[CrossConv(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
+
+ def forward(self, x):
+ y1 = self.cv3(self.m(self.cv1(x)))
+ y2 = self.cv2(x)
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
+
+
class Sum(nn.Module):
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
def __init__(self, n, weight=False): # n: number of inputs
diff --git a/models/export.py b/models/export.py
index 2aa6ce403ac6..bb310f3f89a0 100644
--- a/models/export.py
+++ b/models/export.py
@@ -1,4 +1,4 @@
-"""Exports a YOLOv5 *.pt model to *.onnx and *.torchscript formats
+"""Exports a YOLOv5 *.pt model to ONNX and TorchScript formats
Usage:
$ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
@@ -6,8 +6,6 @@
import argparse
-import onnx
-
from models.common import *
from utils import google_utils
@@ -21,7 +19,7 @@
print(opt)
# Input
- img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection
+ img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection
# Load PyTorch model
google_utils.attempt_download(opt.weights)
@@ -30,20 +28,22 @@
model.model[-1].export = True # set Detect() layer export=True
_ = model(img) # dry run
- # Export to torchscript
+ # TorchScript export
try:
f = opt.weights.replace('.pt', '.torchscript') # filename
ts = torch.jit.trace(model, img)
ts.save(f)
- print('Torchscript export success, saved as %s' % f)
- except:
- print('Torchscript export failed.')
+ print('TorchScript export success, saved as %s' % f)
+ except Exception as e:
+ print('TorchScript export failed: %s' % e)
- # Export to ONNX
+ # ONNX export
try:
+ import onnx
+
f = opt.weights.replace('.pt', '.onnx') # filename
model.fuse() # only for ONNX
- torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'],
+ torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=['output']) # output_names=['classes', 'boxes']
# Checks
@@ -51,5 +51,5 @@
onnx.checker.check_model(onnx_model) # check onnx model
print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable representation of the graph
print('ONNX export success, saved as %s\nView with https://github.com/lutzroeder/netron' % f)
- except:
- print('ONNX export failed.')
+ except Exception as e:
+ print('ONNX export failed: %s' % e)
diff --git a/train.py b/train.py
index 0c63b80ae27b..b7d202d2fbe4 100644
--- a/train.py
+++ b/train.py
@@ -67,13 +67,8 @@ def train(hyp, tb_writer, opt, device):
total_batch_size = opt.batch_size if opt.local_rank == -1 else opt.batch_size * torch.distributed.get_world_size() # 64
weights = opt.weights # initial training weights
- if opt.local_rank in [-1, 0]:
- # TODO: Init DDP logging. Only the first process is allowed to log.
- # Since I see lots of print here, the logging is skipped here.
- pass
- else:
- tb_writer = None
-
+ # TODO: Init DDP logging. Only the first process is allowed to log.
+ # Since I see lots of print here, the logging is skipped here.
# Configure
init_seeds(1)
@@ -84,13 +79,13 @@ def train(hyp, tb_writer, opt, device):
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
# Remove previous results
- for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
- os.remove(f)
+ if opt.local_rank in [-1, 0]:
+ for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
+ os.remove(f)
# Create model
model = Model(opt.cfg).to(device)
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
- model.names = data_dict['names']
# Image sizes
gs = int(max(model.stride)) # grid size (max stride)
@@ -138,7 +133,7 @@ def train(hyp, tb_writer, opt, device):
model.load_state_dict(ckpt['model'], strict=False)
except KeyError as e:
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
- "Please delete or update %s and try again, or use --weights '' to train from scatch." \
+ "Please delete or update %s and try again, or use --weights '' to train from scratch." \
% (opt.weights, opt.cfg, opt.weights, opt.weights)
raise KeyError(s) from e
@@ -205,6 +200,7 @@ def train(hyp, tb_writer, opt, device):
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
+ model.names = data_dict['names']
# Class frequency
if tb_writer:
@@ -326,10 +322,9 @@ def train(hyp, tb_writer, opt, device):
batch_size=total_batch_size,
imgsz=imgsz_test,
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
- model=ema.ema,
+ model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema,
single_cls=opt.single_cls,
dataloader=testloader)
-
# Write
with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
@@ -368,22 +363,21 @@ def train(hyp, tb_writer, opt, device):
# end epoch ----------------------------------------------------------------------------------------------------
# end training
- results = None
if opt.local_rank in [-1, 0]:
- n = opt.name
- if len(n):
- n = '_' + n if not n.isnumeric() else n
- fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
- for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
- if os.path.exists(f1):
- os.rename(f1, f2) # rename
- ispt = f2.endswith('.pt') # is *.pt
- strip_optimizer(f2) if ispt else None # strip optimizer
- os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
+ # Strip optimizers
+ n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
+ fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
+ for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
+ if os.path.exists(f1):
+ os.rename(f1, f2) # rename
+ ispt = f2.endswith('.pt') # is *.pt
+ strip_optimizer(f2) if ispt else None # strip optimizer
+ os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
+ # Finish
if not opt.evolve:
plot_results() # save as results.png
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
- if opt.local_rank == -1:
+ if opt.local_rank == 0:
dist.destroy_process_group()
torch.cuda.empty_cache()
return results
@@ -414,16 +408,16 @@ def train(hyp, tb_writer, opt, device):
# Parameter For DDP.
parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.")
opt = parser.parse_args()
- opt.weights = last if opt.resume else opt.weights
+ opt.weights = last if opt.resume and not opt.weights else opt.weights
opt.cfg = check_file(opt.cfg) # check file
opt.data = check_file(opt.data) # check file
print(opt)
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
- # If local_rank is not -1, the DDP mode is triggered. Use local_rank to overwrite the opt.device config.
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
if device.type == 'cpu':
mixed_precision = False
elif opt.local_rank != -1:
+ # DDP mode
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
device = torch.device("cuda")
@@ -435,10 +429,10 @@ def train(hyp, tb_writer, opt, device):
# Train
if not opt.evolve:
if opt.local_rank in [-1, 0]:
+ print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
tb_writer = SummaryWriter(comment=opt.name)
else:
tb_writer = None
- print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
train(hyp, tb_writer, opt, device)
# Evolve hyperparameters (optional)
diff --git a/utils/torch_utils.py b/utils/torch_utils.py
index a2f69c1a92cb..b9c1ad6155c5 100644
--- a/utils/torch_utils.py
+++ b/utils/torch_utils.py
@@ -54,6 +54,11 @@ def time_synchronized():
return time.time()
+def is_parallel(model):
+ # is model is parallel with DP or DDP
+ return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
+
+
def initialize_weights(model):
for m in model.modules():
t = type(m)
@@ -111,8 +116,8 @@ def model_info(model, verbose=False):
try: # FLOPS
from thop import profile
- macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
- fs = ', %.1f GFLOPS' % (macs / 1E9 * 2)
+ flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
+ fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
except:
fs = ''
@@ -187,7 +192,6 @@ def update(self, model):
with torch.no_grad():
msd = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
esd = self.ema.module.state_dict() if hasattr(self.ema, 'module') else self.ema.state_dict()
-
for k, v in esd.items():
if v.dtype.is_floating_point:
v *= d
@@ -196,6 +200,6 @@ def update(self, model):
def update_attr(self, model):
# Assign attributes (which may change during training)
for k in model.__dict__.keys():
- if not k.startswith('_') and not isinstance(getattr(model, k),
- (torch.distributed.ProcessGroupNCCL, torch.distributed.Reducer)):
+ if not k.startswith('_') and (k != 'module' or not isinstance(getattr(model, k),
+ (torch.distributed.ProcessGroupNCCL, torch.distributed.Reducer))):
setattr(self.ema, k, getattr(model, k))