Skip to content

Commit

Permalink
PyTorch 1.6.0 update with native AMP (#573)
Browse files Browse the repository at this point in the history
* PyTorch have Automatic Mixed Precision (AMP) Training.

* Fixed the problem of inconsistent code length indentation

* Fixed the problem of inconsistent code length indentation

* Mixed precision training is turned on by default
  • Loading branch information
Lornatang authored Jul 31, 2020
1 parent 48e15be commit c020875
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 46 deletions.
80 changes: 36 additions & 44 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

Expand All @@ -14,13 +15,6 @@
from utils.datasets import *
from utils.utils import *

mixed_precision = True
try: # Mixed precision training https://github.com/NVIDIA/apex
from apex import amp
except:
print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
mixed_precision = False # not installed

# Hyperparameters
hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
Expand Down Expand Up @@ -63,6 +57,7 @@ def train(hyp, tb_writer, opt, device):
yaml.dump(vars(opt), f, sort_keys=False)

# Configure
cuda = device.type != 'cpu'
init_seeds(2 + rank)
with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
Expand Down Expand Up @@ -113,7 +108,7 @@ def train(hyp, tb_writer, opt, device):
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
del pg0, pg1, pg2

# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
Expand Down Expand Up @@ -160,24 +155,20 @@ def train(hyp, tb_writer, opt, device):

del ckpt

# Mixed precision training https://github.com/NVIDIA/apex
if mixed_precision:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

# DP mode
if device.type != 'cpu' and rank == -1 and torch.cuda.device_count() > 1:
if cuda and rank == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)

# SyncBatchNorm
if opt.sync_bn and device.type != 'cpu' and rank != -1:
if opt.sync_bn and cuda and rank != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
print('Using SyncBatchNorm()')

# Exponential moving average
ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None

# DDP mode
if device.type != 'cpu' and rank != -1:
if cuda and rank != -1:
model = DDP(model, device_ids=[rank], output_device=rank)

# Trainloader
Expand Down Expand Up @@ -223,6 +214,7 @@ def train(hyp, tb_writer, opt, device):
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
if rank in [0, -1]:
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
print('Using %g dataloader workers' % dataloader.num_workers)
Expand All @@ -232,15 +224,14 @@ def train(hyp, tb_writer, opt, device):
model.train()

# Update image weights (optional)
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
if dataset.image_weights:
# Generate indices.
# Generate indices
if rank in [-1, 0]:
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
dataset.indices = random.choices(range(dataset.n), weights=image_weights,
k=dataset.n) # rand weighted idx
# Broadcast.
# Broadcast if DDP
if rank != -1:
indices = torch.zeros([dataset.n], dtype=torch.int)
if rank == 0:
Expand All @@ -263,7 +254,7 @@ def train(hyp, tb_writer, opt, device):
optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
ni = i + nb * epoch # number integrated batches (since train start)
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0

# Warmup
if ni <= nw:
Expand All @@ -284,35 +275,34 @@ def train(hyp, tb_writer, opt, device):
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

# Forward
pred = model(imgs)
# Autocast
with amp.autocast():
# Forward
pred = model(imgs)

# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items)
return results
# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
# if not torch.isfinite(loss):
# print('WARNING: non-finite loss, ending training ', loss_items)
# return results

# Backward
if mixed_precision:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
scaler.scale(loss).backward()

# Optimize
if ni % accumulate == 0:
optimizer.step()
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad()
if ema is not None:
ema.update(model)

# Print
if rank in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
pbar.set_description(s)
Expand All @@ -330,7 +320,7 @@ def train(hyp, tb_writer, opt, device):
# Scheduler
scheduler.step()

# Only the first process in DDP mode is allowed to log or save checkpoints.
# DDP process 0 or single-GPU
if rank in [-1, 0]:
# mAP
if ema is not None:
Expand Down Expand Up @@ -377,7 +367,7 @@ def train(hyp, tb_writer, opt, device):

# Save last, best and delete
torch.save(ckpt, last)
if best_fitness == fi:
if best_fitness == fi:
torch.save(ckpt, best)
del ckpt
# end epoch ----------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -429,10 +419,12 @@ def train(hyp, tb_writer, opt, device):
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
opt = parser.parse_args()

# Resume
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
if last and not opt.weights:
print(f'Resuming training from {last}')
opt.weights = last if opt.resume and not opt.weights else opt.weights

if opt.local_rank in [-1, 0]:
check_git_status()
opt.cfg = check_file(opt.cfg) # check file
Expand All @@ -442,21 +434,20 @@ def train(hyp, tb_writer, opt, device):
with open(opt.hyp) as f:
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
device = torch_utils.select_device(opt.device, batch_size=opt.batch_size)
opt.total_batch_size = opt.batch_size
opt.world_size = 1
if device.type == 'cpu':
mixed_precision = False
elif opt.local_rank != -1:
# DDP mode

# DDP mode
if opt.local_rank != -1:
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
device = torch.device("cuda", opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend

opt.world_size = dist.get_world_size()
assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"
opt.batch_size = opt.total_batch_size // opt.world_size

print(opt)

# Train
Expand All @@ -466,11 +457,12 @@ def train(hyp, tb_writer, opt, device):
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
else:
tb_writer = None

train(hyp, tb_writer, opt, device)

# Evolve hyperparameters (optional)
else:
assert opt.local_rank == -1, "DDP mode currently not implemented for Evolve!"
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'

tb_writer = None
opt.notest, opt.nosave = True, True # only test/save final epoch
Expand Down
4 changes: 2 additions & 2 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def init_seeds(seed=0):
cudnn.benchmark = True


def select_device(device='', apex=False, batch_size=None):
def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
cpu_request = device.lower() == 'cpu'
if device and not cpu_request: # if device requested other than 'cpu'
Expand All @@ -36,7 +36,7 @@ def select_device(device='', apex=False, batch_size=None):
if ng > 1 and batch_size: # check that batch_size is compatible with device_count
assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
x = [torch.cuda.get_device_properties(i) for i in range(ng)]
s = 'Using CUDA ' + ('Apex ' if apex else '') # apex for mixed precision https://github.com/NVIDIA/apex
s = 'Using CUDA '
for i in range(0, ng):
if i == 1:
s = ' ' * len(s)
Expand Down

0 comments on commit c020875

Please sign in to comment.