Skip to content

Commit

Permalink
Change optimizer parameters group method (#1239)
Browse files Browse the repository at this point in the history
* Change optimizer parameters group method

* Add torch nn

* Change isinstance method(torch.Tensor to nn.Parameter)

* parameter freeze fix, PEP8 reformat

* freeze bug fix

Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
hoonyyhoon and glenn-jocher authored Nov 1, 2020
1 parent 96fcde4 commit 187f7c2
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import math
import numpy as np
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
Expand Down Expand Up @@ -80,27 +81,26 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
model = Model(opt.cfg, ch=3, nc=nc).to(device) # create

# Freeze
freeze = ['', ] # parameter names to freeze (full or partial)
if any(freeze):
for k, v in model.named_parameters():
if any(x in k for x in freeze):
print('freezing %s' % k)
v.requires_grad = False
freeze = [] # parameter names to freeze (full or partial)
for k, v in model.named_parameters():
v.requires_grad = True # train all layers
if any(x in k for x in freeze):
print('freezing %s' % k)
v.requires_grad = False

# Optimizer
nbs = 64 # nominal batch size
accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay

pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
for k, v in model.named_parameters():
v.requires_grad = True
if '.bias' in k:
pg2.append(v) # biases
elif '.weight' in k and '.bn' not in k:
pg1.append(v) # apply weight decay
else:
pg0.append(v) # all else
for k, v in model.named_modules():
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
pg2.append(v.bias) # biases
if isinstance(v, nn.BatchNorm2d):
pg0.append(v.weight) # no decay
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
pg1.append(v.weight) # apply decay

if opt.adam:
optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
Expand Down

0 comments on commit 187f7c2

Please sign in to comment.