diff --git a/train.py b/train.py index 879bb2e07968..6a4e2c3a2796 100644 --- a/train.py +++ b/train.py @@ -1,11 +1,13 @@ import argparse +import torch import torch.distributed as dist import torch.nn.functional as F import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel as DDP import test # import test.py to get mAP after each epoch from models.yolo import Model @@ -42,7 +44,7 @@ 'shear': 0.0} # image shear (+/- deg) -def train(hyp): +def train(hyp, tb_writer, opt, device): print(f'Hyperparameters {hyp}') log_dir = tb_writer.log_dir if tb_writer else 'runs/evolution' # run directory wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory @@ -59,8 +61,13 @@ def train(hyp): yaml.dump(vars(opt), f, sort_keys=False) epochs = opt.epochs # 300 - batch_size = opt.batch_size # 64 + batch_size = opt.batch_size # batch size per process. + total_batch_size = opt.total_batch_size weights = opt.weights # initial training weights + local_rank = opt.local_rank + + # TODO: Init DDP logging. Only the first process is allowed to log. + # Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs. # Configure init_seeds(1) @@ -72,8 +79,9 @@ def train(hyp): assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check # Remove previous results - for f in glob.glob('*_batch*.jpg') + glob.glob(results_file): - os.remove(f) + if local_rank in [-1, 0]: + for f in glob.glob('*_batch*.jpg') + glob.glob(results_file): + os.remove(f) # Create model model = Model(opt.cfg, nc=nc).to(device) @@ -84,8 +92,15 @@ def train(hyp): # Optimizer nbs = 64 # nominal batch size - accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing - hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay + # the default DDP implementation is slow for accumulation according to: https://pytorch.org/docs/stable/notes/ddp.html + # all-reduce operation is carried out during loss.backward(). + # Thus, there would be redundant all-reduce communications in a accumulation procedure, + # which means, the result is still right but the training speed gets slower. + # TODO: If acceleration is needed, there is an implementation of allreduce_post_accumulation + # in https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/run_pretraining.py + accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing + hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay + pg0, pg1, pg2 = [], [], [] # optimizer parameter groups for k, v in model.named_parameters(): if v.requires_grad: @@ -106,13 +121,10 @@ def train(hyp): print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) del pg0, pg1, pg2 - # Scheduler https://arxiv.org/pdf/1812.01187.pdf - lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine - scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) - # plot_lr_scheduler(optimizer, scheduler, epochs, save_dir=log_dir) - # Load Model - google_utils.attempt_download(weights) + # Avoid multiple downloads. + with torch_distributed_zero_first(local_rank): + google_utils.attempt_download(weights) start_epoch, best_fitness = 0, 0.0 if weights.endswith('.pt'): # pytorch format ckpt = torch.load(weights, map_location=device) # load checkpoint @@ -125,7 +137,7 @@ def train(hyp): except KeyError as e: s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \ "Please delete or update %s and try again, or use --weights '' to train from scratch." \ - % (opt.weights, opt.cfg, opt.weights, opt.weights) + % (weights, opt.cfg, weights, weights) raise KeyError(s) from e # load optimizer @@ -142,7 +154,7 @@ def train(hyp): start_epoch = ckpt['epoch'] + 1 if epochs < start_epoch: print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % - (opt.weights, ckpt['epoch'], epochs)) + (weights, ckpt['epoch'], epochs)) epochs += ckpt['epoch'] # finetune additional epochs del ckpt @@ -151,25 +163,41 @@ def train(hyp): if mixed_precision: model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0) - # Distributed training - if device.type != 'cpu' and torch.cuda.device_count() > 1 and dist.is_available(): - dist.init_process_group(backend='nccl', # distributed backend - init_method='tcp://127.0.0.1:9999', # init method - world_size=1, # number of nodes - rank=0) # node rank - # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) # requires world_size > 1 - model = torch.nn.parallel.DistributedDataParallel(model) + # Scheduler https://arxiv.org/pdf/1812.01187.pdf + lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) + # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 + # plot_lr_scheduler(optimizer, scheduler, epochs) + + # DP mode + if device.type != 'cpu' and local_rank == -1 and torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + + # Exponential moving average + # From https://github.com/rwightman/pytorch-image-models/blob/master/train.py: + # "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper" + # chenyzsjtu: ema should be placed before after SyncBN. As SyncBN introduces new modules. + if opt.sync_bn and device.type != 'cpu' and local_rank != -1: + print("SyncBN activated!") + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) + ema = torch_utils.ModelEMA(model) if local_rank in [-1, 0] else None + + # DDP mode + if device.type != 'cpu' and local_rank != -1: + model = DDP(model, device_ids=[local_rank], output_device=local_rank) # Trainloader dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, - hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect) + hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, local_rank=local_rank) mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class nb = len(dataloader) # number of batches assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg) # Testloader - testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt, - hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0] + if local_rank in [-1, 0]: + # local_rank is set to -1. Because only the first process is expected to do evaluation. + testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, + hyp=hyp, augment=False, cache=opt.cache_images, rect=True, local_rank=-1)[0] # Model parameters hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset @@ -180,48 +208,63 @@ def train(hyp): model.names = names # Class frequency - labels = np.concatenate(dataset.labels, 0) - c = torch.tensor(labels[:, 0]) # classes - # cf = torch.bincount(c.long(), minlength=nc) + 1. - # model._initialize_biases(cf.to(device)) - plot_labels(labels, save_dir=log_dir) if tb_writer: # tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384 + labels = np.concatenate(dataset.labels, 0) + c = torch.tensor(labels[:, 0]) # classes + # cf = torch.bincount(c.long(), minlength=nc) + 1. + # model._initialize_biases(cf.to(device)) + plot_labels(labels) tb_writer.add_histogram('classes', c, 0) + # Check anchors if not opt.noautoanchor: check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) - # Exponential moving average - ema = torch_utils.ModelEMA(model) - # Start training t0 = time.time() nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' scheduler.last_epoch = start_epoch - 1 # do not move - print('Image sizes %g train, %g test' % (imgsz, imgsz_test)) - print('Using %g dataloader workers' % dataloader.num_workers) - print('Starting training for %g epochs...' % epochs) + if local_rank in [0, -1]: + print('Image sizes %g train, %g test' % (imgsz, imgsz_test)) + print('Using %g dataloader workers' % dataloader.num_workers) + print('Starting training for %g epochs...' % epochs) # torch.autograd.set_detect_anomaly(True) for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ model.train() # Update image weights (optional) + # When in DDP mode, the generated indices will be broadcasted to synchronize dataset. if dataset.image_weights: - w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights - image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) - dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx + # Generate indices. + if local_rank in [-1, 0]: + w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights + image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) + dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx + # Broadcast. + if local_rank != -1: + indices = torch.zeros([dataset.n], dtype=torch.int) + if local_rank == 0: + indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int) + dist.broadcast(indices, 0) + if local_rank != 0: + dataset.indices = indices.cpu().numpy() # Update mosaic border # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) # dataset.mosaic_border = [b - imgsz, -b] # height, width borders mloss = torch.zeros(4, device=device) # mean losses - print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size')) - pbar = tqdm(enumerate(dataloader), total=nb) # progress bar + if local_rank != -1: + dataloader.sampler.set_epoch(epoch) + pbar = enumerate(dataloader) + if local_rank in [-1, 0]: + print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size')) + pbar = tqdm(pbar, total=nb) # progress bar + optimizer.zero_grad() for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- ni = i + nb * epoch # number integrated batches (since train start) imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 @@ -230,7 +273,7 @@ def train(hyp): if ni <= nw: xi = [0, nw] # x interp # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) - accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) + accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) @@ -250,6 +293,9 @@ def train(hyp): # Loss loss, loss_items = compute_loss(pred, targets.to(device), model) + # loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices. + if local_rank != -1: + loss *= opt.world_size if not torch.isfinite(loss): print('WARNING: non-finite loss, ending training ', loss_items) return results @@ -265,106 +311,110 @@ def train(hyp): if ni % accumulate == 0: optimizer.step() optimizer.zero_grad() - ema.update(model) + if ema is not None: + ema.update(model) # Print - mloss = (mloss * i + loss_items) / (i + 1) # update mean losses - mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB) - s = ('%10s' * 2 + '%10.4g' * 6) % ( - '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1]) - pbar.set_description(s) - - # Plot - if ni < 3: - f = str(Path(log_dir) / ('train_batch%g.jpg' % ni)) # filename - result = plot_images(images=imgs, targets=targets, paths=paths, fname=f) - if tb_writer and result is not None: - tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) - # tb_writer.add_graph(model, imgs) # add model to tensorboard + if local_rank in [-1, 0]: + mloss = (mloss * i + loss_items) / (i + 1) # update mean losses + mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB) + s = ('%10s' * 2 + '%10.4g' * 6) % ( + '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1]) + pbar.set_description(s) + + # Plot + if ni < 3: + f = str(Path(log_dir) / ('train_batch%g.jpg' % ni)) # filename + result = plot_images(images=imgs, targets=targets, paths=paths, fname=f) + if tb_writer and result is not None: + tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) + # tb_writer.add_graph(model, imgs) # add model to tensorboard # end batch ------------------------------------------------------------------------------------------------ # Scheduler scheduler.step() - # mAP - ema.update_attr(model, include=['md', 'nc', 'hyp', 'gr', 'names', 'stride']) - final_epoch = epoch + 1 == epochs - if not opt.notest or final_epoch: # Calculate mAP - results, maps, times = test.test(opt.data, - batch_size=batch_size, - imgsz=imgsz_test, - save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'), - model=ema.ema, - single_cls=opt.single_cls, - dataloader=testloader, - save_dir=log_dir) - - # Write - with open(results_file, 'a') as f: - f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls) - if len(opt.name) and opt.bucket: - os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name)) - - # Tensorboard - if tb_writer: - tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', - 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', - 'val/giou_loss', 'val/obj_loss', 'val/cls_loss'] - for x, tag in zip(list(mloss[:-1]) + list(results), tags): - tb_writer.add_scalar(tag, x, epoch) - - # Update best mAP - fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1] - if fi > best_fitness: - best_fitness = fi - - # Save model - save = (not opt.nosave) or (final_epoch and not opt.evolve) - if save: - with open(results_file, 'r') as f: # create checkpoint - ckpt = {'epoch': epoch, - 'best_fitness': best_fitness, - 'training_results': f.read(), - 'model': ema.ema, - 'optimizer': None if final_epoch else optimizer.state_dict()} - - # Save last, best and delete - torch.save(ckpt, last) - if (best_fitness == fi) and not final_epoch: - torch.save(ckpt, best) - del ckpt - + # Only the first process in DDP mode is allowed to log or save checkpoints. + if local_rank in [-1, 0]: + # mAP + if ema is not None: + ema.update_attr(model, include=['md', 'nc', 'hyp', 'gr', 'names', 'stride']) + final_epoch = epoch + 1 == epochs + if not opt.notest or final_epoch: # Calculate mAP + results, maps, times = test.test(opt.data, + batch_size=total_batch_size, + imgsz=imgsz_test, + save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'), + model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema, + single_cls=opt.single_cls, + dataloader=testloader, + save_dir=log_dir) + # Explicitly keep the shape. + # Write + with open(results_file, 'a') as f: + f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls) + if len(opt.name) and opt.bucket: + os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (opt.bucket, opt.name)) + + # Tensorboard + if tb_writer: + tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', + 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/F1', + 'val/giou_loss', 'val/obj_loss', 'val/cls_loss'] + for x, tag in zip(list(mloss[:-1]) + list(results), tags): + tb_writer.add_scalar(tag, x, epoch) + + # Update best mAP + fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1] + if fi > best_fitness: + best_fitness = fi + + # Save model + save = (not opt.nosave) or (final_epoch and not opt.evolve) + if save: + with open(results_file, 'r') as f: # create checkpoint + ckpt = {'epoch': epoch, + 'best_fitness': best_fitness, + 'training_results': f.read(), + 'model': ema.ema.module if hasattr(ema, 'module') else ema.ema, + 'optimizer': None if final_epoch else optimizer.state_dict()} + + # Save last, best and delete + torch.save(ckpt, last) + if (best_fitness == fi) and not final_epoch: + torch.save(ckpt, best) + del ckpt # end epoch ---------------------------------------------------------------------------------------------------- # end training - # Strip optimizers - n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name - fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n - for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]): - if os.path.exists(f1): - os.rename(f1, f2) # rename - ispt = f2.endswith('.pt') # is *.pt - strip_optimizer(f2) if ispt else None # strip optimizer - os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload - - # Finish - if not opt.evolve: - plot_results(save_dir=log_dir) # save as results.png - print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) + if local_rank in [-1, 0]: + # Strip optimizers + n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name + fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n + for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]): + if os.path.exists(f1): + os.rename(f1, f2) # rename + ispt = f2.endswith('.pt') # is *.pt + strip_optimizer(f2) if ispt else None # strip optimizer + os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload + # Finish + if not opt.evolve: + plot_results() # save as results.png + print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) + dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None torch.cuda.empty_cache() return results if __name__ == '__main__': - check_git_status() parser = argparse.ArgumentParser() parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='model.yaml path') parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path') parser.add_argument('--hyp', type=str, default='', help='hyp.yaml path (optional)') parser.add_argument('--epochs', type=int, default=300) - parser.add_argument('--batch-size', type=int, default=16) + parser.add_argument('--batch-size', type=int, default=16, help="batch size for all gpus.") parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') parser.add_argument('--rect', action='store_true', help='rectangular training') parser.add_argument('--resume', nargs='?', const='get_last', default=False, @@ -380,32 +430,54 @@ def train(hyp): parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') + parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.") + # Parameter For DDP. + parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.") opt = parser.parse_args() last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run if last and not opt.weights: print(f'Resuming training from {last}') opt.weights = last if opt.resume and not opt.weights else opt.weights + if opt.local_rank in [-1, 0]: + check_git_status() opt.cfg = check_file(opt.cfg) # check file opt.data = check_file(opt.data) # check file if opt.hyp: # update hyps opt.hyp = check_file(opt.hyp) # check file with open(opt.hyp) as f: hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps - print(opt) opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size) + opt.total_batch_size = opt.batch_size if device.type == 'cpu': mixed_precision = False + opt.world_size = 1 + elif opt.local_rank != -1: + # DDP mode + assert torch.cuda.device_count() > opt.local_rank + torch.cuda.set_device(opt.local_rank) + device = torch.device("cuda", opt.local_rank) + dist.init_process_group(backend='nccl', init_method='env://') # distributed backend + + opt.world_size = dist.get_world_size() + assert opt.batch_size % opt.world_size == 0 + opt.batch_size = opt.total_batch_size // opt.world_size + print(opt) # Train if not opt.evolve: - tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name)) - print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') - train(hyp) + if opt.local_rank in [-1, 0]: + print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') + tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name)) + else: + tb_writer = None + train(hyp, tb_writer, opt, device) # Evolve hyperparameters (optional) else: + assert opt.local_rank == -1, "DDP mode currently not implemented for Evolve!" + tb_writer = None opt.notest, opt.nosave = True, True # only test/save final epoch if opt.bucket: @@ -444,7 +516,7 @@ def train(hyp): hyp[k] = np.clip(hyp[k], v[0], v[1]) # Train mutation - results = train(hyp.copy()) + results = train(hyp.copy(), tb_writer, opt, device) # Write mutation results print_mutation(hyp, results, opt.bucket) diff --git a/utils/datasets.py b/utils/datasets.py index 4d8424c5a208..7da1c372a2f4 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -14,7 +14,7 @@ from torch.utils.data import Dataset from tqdm import tqdm -from utils.utils import xyxy2xywh, xywh2xyxy +from utils.utils import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng'] @@ -46,21 +46,25 @@ def exif_size(img): return s -def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False): - dataset = LoadImagesAndLabels(path, imgsz, batch_size, - augment=augment, # augment images - hyp=hyp, # augmentation hyperparameters - rect=rect, # rectangular training - cache_images=cache, - single_cls=opt.single_cls, - stride=int(stride), - pad=pad) +def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, local_rank=-1): + # Make sure only the first process in DDP process the dataset first, and the following others can use the cache. + with torch_distributed_zero_first(local_rank): + dataset = LoadImagesAndLabels(path, imgsz, batch_size, + augment=augment, # augment images + hyp=hyp, # augmentation hyperparameters + rect=rect, # rectangular training + cache_images=cache, + single_cls=opt.single_cls, + stride=int(stride), + pad=pad) batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if local_rank != -1 else None dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=nw, + sampler=train_sampler, pin_memory=True, collate_fn=LoadImagesAndLabels.collate_fn) return dataloader, dataset diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 06d044779410..71cb73d8f1c6 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -2,6 +2,7 @@ import os import time from copy import deepcopy +import pickle import torch import torch.backends.cudnn as cudnn @@ -208,8 +209,9 @@ def update(self, model): self.updates += 1 d = self.decay(self.updates) - msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict - for k, v in self.ema.state_dict().items(): + msd = model.module.state_dict() if hasattr(model, 'module') else model.state_dict() + esd = self.ema.module.state_dict() if hasattr(self.ema, 'module') else self.ema.state_dict() + for k, v in esd.items(): if v.dtype.is_floating_point: v *= d v += (1. - d) * msd[k].detach() diff --git a/utils/utils.py b/utils/utils.py index ce1d9101a023..b0df9959b0f8 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -8,6 +8,7 @@ from copy import copy from pathlib import Path from sys import platform +from contextlib import contextmanager import cv2 import matplotlib @@ -31,6 +32,19 @@ cv2.setNumThreads(0) +@contextmanager +def torch_distributed_zero_first(local_rank: int): + """ + Decorator to make all processes in distributed training wait for each local_master to do something. + """ + if local_rank not in [-1, 0]: + torch.distributed.barrier() + yield + if local_rank == 0: + torch.distributed.barrier() + + + def init_seeds(seed=0): random.seed(seed) np.random.seed(seed) @@ -424,15 +438,16 @@ def forward(self, pred, true): def compute_loss(p, targets, model): # predictions, targets, model + device = targets.device ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor - lcls, lbox, lobj = ft([0]), ft([0]), ft([0]) + lcls, lbox, lobj = ft([0]).to(device), ft([0]).to(device), ft([0]).to(device) tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets h = model.hyp # hyperparameters red = 'mean' # Loss reduction (sum or mean) # Define criteria - BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red) - BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red) + BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red).to(device) + BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red).to(device) # class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 cp, cn = smooth_BCE(eps=0.0) @@ -448,7 +463,7 @@ def compute_loss(p, targets, model): # predictions, targets, model balance = [1.0, 1.0, 1.0] for i, pi in enumerate(p): # layer index, layer predictions b, a, gj, gi = indices[i] # image, anchor, gridy, gridx - tobj = torch.zeros_like(pi[..., 0]) # target obj + tobj = torch.zeros_like(pi[..., 0]).to(device) # target obj nb = b.shape[0] # number of targets if nb: @@ -458,7 +473,7 @@ def compute_loss(p, targets, model): # predictions, targets, model # GIoU pxy = ps[:, :2].sigmoid() * 2. - 0.5 pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] - pbox = torch.cat((pxy, pwh), 1) # predicted box + pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target) lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss @@ -467,7 +482,7 @@ def compute_loss(p, targets, model): # predictions, targets, model # Class if model.nc > 1: # cls loss (only if multiple classes) - t = torch.full_like(ps[:, 5:], cn) # targets + t = torch.full_like(ps[:, 5:], cn).to(device) # targets t[range(nb), tcls[i]] = cp lcls += BCEcls(ps[:, 5:], t) # BCE @@ -495,8 +510,10 @@ def compute_loss(p, targets, model): # predictions, targets, model def build_targets(p, targets, model): # Build targets for compute_loss(), input targets(image,class,x,y,w,h) - det = model.module.model[-1] if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) \ - else model.model[-1] # Detect() module + if hasattr(model, "module"): + det = model.module.model[-1] + else: + det = model.model[-1] na, nt = det.na, targets.shape[0] # number of anchors, targets tcls, tbox, indices, anch = [], [], [], [] gain = torch.ones(6, device=targets.device) # normalized to gridspace gain