Skip to content

Commit

Permalink
Added distributed training support for distillation of MobileNetV2. (#…
Browse files Browse the repository at this point in the history
…166)

Signed-off-by: Xinyu Ye <[email protected]>
  • Loading branch information
XinyuYe-Intel authored Dec 5, 2022
1 parent ebe9e2a commit d33ebe6
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,15 @@ pip install -r requirements.txt
python train_without_distillation.py --epochs 200 --lr 0.1 --layers 40 --widen-factor 2 --name WideResNet-40-2 --tensorboard
# for distillation of the teacher model WideResNet40-2 to the student model MobileNetV2-0.35
python main.py --epochs 200 --lr 0.02 --name MobileNetV2-0.35-distillation --teacher_model runs/WideResNet-40-2/model_best.pth.tar --tensorboard --seed 9
```

We also supported Distributed Data Parallel training on single node and multi nodes settings for distillation. To use Distributed Data Parallel to speedup training, the bash command needs a small adjustment.
<br>
For example, bash command will look like the following, where *`<MASTER_ADDRESS>`* is the address of the master node, it won't be necessary for single node case, *`<NUM_PROCESSES_PER_NODE>`* is the desired processes to use in current node, for node with GPU, usually set to number of GPUs in this node, for node without GPU and use CPU for training, it's recommended set to 1, *`<NUM_NODES>`* is the number of nodes to use, *`<NODE_RANK>`* is the rank of the current node, rank starts from 0 to *`<NUM_NODES>`*`-1`.
<br>
Also please note that to use CPU for training in each node with multi nodes settings, argument `--no_cuda` is mandatory. In multi nodes setting, following command needs to be launched in each node, and all the commands should be the same except for *`<NODE_RANK>`*, which should be integer from 0 to *`<NUM_NODES>`*`-1` assigned to each node.

```bash
python -m torch.distributed.launch --master_addr=<MASTER_ADDRESS> --nproc_per_node=<NUM_PROCESSES_PER_NODE> --nnodes=<NUM_NODES> --node_rank=<NODE_RANK> \
main.py --epochs 200 --lr 0.02 --name MobileNetV2-0.35-distillation --teacher_model runs/WideResNet-40-2/model_best.pth.tar --tensorboard --seed 9
```
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from accelerate import Accelerator
from wideresnet import WideResNet

# used for logging to TensorBoard
Expand Down Expand Up @@ -60,6 +61,7 @@
help='loss weights of distillation, should be a list of length 2, '
'and sum to 1.0, first for student targets loss weight, '
'second for teacher student loss weight.')
parser.add_argument("--no_cuda", action='store_true', help='use cpu for training.')
parser.set_defaults(augment=True)

def set_seed(seed):
Expand All @@ -73,10 +75,13 @@ def set_seed(seed):
def main():
global args, best_prec1
args, _ = parser.parse_known_args()
accelerator = Accelerator(cpu=args.no_cuda)

best_prec1 = 0
if args.seed is not None:
set_seed(args.seed)
if args.tensorboard: configure("runs/%s"%(args.name))
with accelerator.local_main_process_first():
if args.tensorboard: configure("runs/%s"%(args.name))

# Data loading code
normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
Expand Down Expand Up @@ -111,9 +116,9 @@ def main():
student_model = mobilenet.MobileNetV2(num_classes=10, width_mult=0.35)

# get the number of model parameters
print('Number of teacher model parameters: {}'.format(
accelerator.print('Number of teacher model parameters: {}'.format(
sum([p.data.nelement() for p in teacher_model.parameters()])))
print('Number of student model parameters: {}'.format(
accelerator.print('Number of student model parameters: {}'.format(
sum([p.data.nelement() for p in student_model.parameters()])))

kwargs = {'num_workers': 0, 'pin_memory': True}
Expand All @@ -125,10 +130,10 @@ def main():
if args.loss_weights[1] > 0:
from tqdm import tqdm
def get_logits(teacher_model, train_dataset):
print("***** Getting logits of teacher model *****")
print(f" Num examples = {len(train_dataset) }")
accelerator.print("***** Getting logits of teacher model *****")
accelerator.print(f" Num examples = {len(train_dataset) }")
logits_file = os.path.join(os.path.dirname(args.teacher_model), 'teacher_logits.npy')
if not os.path.exists(logits_file):
if not os.path.exists(logits_file) and accelerator.is_local_main_process:
teacher_model.eval()
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, **kwargs)
train_dataloader = tqdm(train_dataloader, desc="Evaluating")
Expand All @@ -137,8 +142,8 @@ def get_logits(teacher_model, train_dataset):
outputs = teacher_model(input)
teacher_logits += [x for x in outputs.numpy()]
np.save(logits_file, np.array(teacher_logits))
else:
teacher_logits = np.load(logits_file)
accelerator.wait_for_everyone()
teacher_logits = np.load(logits_file)
train_dataset.targets = [{'labels':l, 'teacher_logits':tl} \
for l, tl in zip(train_dataset.targets, teacher_logits)]
return train_dataset
Expand All @@ -153,29 +158,34 @@ def get_logits(teacher_model, train_dataset):
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
accelerator.print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
student_model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
accelerator.print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
accelerator.print("=> no checkpoint found at '{}'".format(args.resume))

# define optimizer
optimizer = torch.optim.SGD(student_model.parameters(), args.lr,
momentum=args.momentum, nesterov = args.nesterov,
weight_decay=args.weight_decay)

# cosine learning rate
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*args.epochs)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, len(train_loader) * args.epochs // accelerator.num_processes
)

student_model, teacher_model, train_loader, val_loader, optimizer = \
accelerator.prepare(student_model, teacher_model, train_loader, val_loader, optimizer)

def train_func(model):
return train(train_loader, model, scheduler, distiller, best_prec1)
return train(train_loader, model, scheduler, distiller, best_prec1, accelerator)

def eval_func(model):
return validate(val_loader, model, distiller)
return validate(val_loader, model, distiller, accelerator)

from neural_compressor.experimental import Distillation, common
from neural_compressor.experimental.common.criterion import PyTorchKnowledgeDistillationLoss
Expand All @@ -194,11 +204,12 @@ def eval_func(model):

directory = "runs/%s/"%(args.name)
os.makedirs(directory, exist_ok=True)
model._model = accelerator.unwrap_model(model.model)
model.save(directory)
# change to framework model for further use
model = model.model

def train(train_loader, model, scheduler, distiller, best_prec1):
def train(train_loader, model, scheduler, distiller, best_prec1, accelerator):
distiller.on_train_begin()
for epoch in range(args.start_epoch, args.epochs):
"""Train for one epoch on the training set"""
Expand All @@ -222,13 +233,15 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
loss = distiller.on_after_compute_loss(input, output, loss, teacher_logits)

# measure accuracy and record loss
output = accelerator.gather(output)
target = accelerator.gather(target)
prec1 = accuracy(output.data, target, topk=(1,))[0]
losses.update(loss.data.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
losses.update(accelerator.gather(loss).sum().data.item(), input.size(0)*accelerator.num_processes)
top1.update(prec1.item(), input.size(0)*accelerator.num_processes)

# compute gradient and do SGD step
distiller.optimizer.zero_grad()
loss.backward()
accelerator.backward(loss) # loss.backward()
distiller.optimizer.step()
scheduler.step()

Expand All @@ -237,7 +250,7 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
end = time.time()

if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
accelerator.print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
Expand All @@ -249,19 +262,20 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
# remember best prec@1 and save checkpoint
is_best = distiller.best_score > best_prec1
best_prec1 = max(distiller.best_score, best_prec1)
save_checkpoint({
'epoch': distiller._epoch_runned + 1,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
}, is_best)
# log to TensorBoard
if args.tensorboard:
log_value('train_loss', losses.avg, epoch)
log_value('train_acc', top1.avg, epoch)
log_value('learning_rate', scheduler._last_lr[0], epoch)
if accelerator.is_local_main_process:
save_checkpoint({
'epoch': distiller._epoch_runned + 1,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
}, is_best)
# log to TensorBoard
if args.tensorboard:
log_value('train_loss', losses.avg, epoch)
log_value('train_acc', top1.avg, epoch)
log_value('learning_rate', scheduler._last_lr[0], epoch)


def validate(val_loader, model, distiller):
def validate(val_loader, model, distiller, accelerator):
"""Perform validation on the validation set"""
batch_time = AverageMeter()
top1 = AverageMeter()
Expand All @@ -276,6 +290,8 @@ def validate(val_loader, model, distiller):
output = model(input)

# measure accuracy
output = accelerator.gather(output)
target = accelerator.gather(target)
prec1 = accuracy(output.data, target, topk=(1,))[0]
top1.update(prec1.item(), input.size(0))

Expand All @@ -284,15 +300,15 @@ def validate(val_loader, model, distiller):
end = time.time()

if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
accelerator.print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time,
top1=top1))

print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
accelerator.print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
# log to TensorBoard
if args.tensorboard:
if accelerator.is_local_main_process and args.tensorboard:
log_value('val_acc', top1.avg, distiller._epoch_runned)
return top1.avg

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
torch==1.5.0+cpu
torchvision==0.6.0+cpu
tensorboard_logger
accelerate

0 comments on commit d33ebe6

Please sign in to comment.