Skip to content

Commit

Permalink
Make SyncBN a choice
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhi.chen committed Jul 14, 2020
1 parent e90d400 commit e742dd9
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def train(hyp, tb_writer, opt, device):
local_rank = opt.local_rank

# TODO: Init DDP logging. Only the first process is allowed to log.
# Since I see lots of print here, the logging is skipped here.
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.

# Configure
init_seeds(1)
Expand Down Expand Up @@ -177,7 +177,8 @@ def train(hyp, tb_writer, opt, device):
# From https://github.com/rwightman/pytorch-image-models/blob/master/train.py:
# "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper"
# chenyzsjtu: ema should be placed before after SyncBN. As SyncBN introduces new modules.
if device.type != 'cpu' and local_rank != -1:
if opt.sync_bn and device.type != 'cpu' and local_rank != -1:
print("SyncBN activated!")
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
ema = torch_utils.ModelEMA(model) if local_rank in [-1, 0] else None

Expand Down Expand Up @@ -258,11 +259,10 @@ def train(hyp, tb_writer, opt, device):
mloss = torch.zeros(4, device=device) # mean losses
if local_rank != -1:
dataloader.sampler.set_epoch(epoch)
pbar = enumerate(dataloader)
if local_rank in [-1, 0]:
print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
else:
pbar = enumerate(dataloader)
pbar = tqdm(pbar, total=nb) # progress bar
optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
ni = i + nb * epoch # number integrated batches (since train start)
Expand Down Expand Up @@ -429,6 +429,7 @@ def train(hyp, tb_writer, opt, device):
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.")
# Parameter For DDP.
parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.")
opt = parser.parse_args()
Expand All @@ -437,7 +438,7 @@ def train(hyp, tb_writer, opt, device):
if last and not opt.weights:
print(f'Resuming training from {last}')
opt.weights = last if opt.resume and not opt.weights else opt.weights
with torch_distributed_zero_first(opt.local_rank):
if opt.local_rank in [-1, 0]:
check_git_status()
opt.cfg = check_file(opt.cfg) # check file
opt.data = check_file(opt.data) # check file
Expand Down

0 comments on commit e742dd9

Please sign in to comment.