Skip to content

Commit

Permalink
Added distributed training support for distillation of CNN-2. (#208)
Browse files Browse the repository at this point in the history
Signed-off-by: Xinyu Ye <[email protected]>
  • Loading branch information
XinyuYe-Intel authored Dec 5, 2022
1 parent 5855117 commit ebe9e2a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,14 @@ python train_without_distillation.py --model_type CNN-10 --epochs 200 --lr 0.1 -
# for distillation of the student model CNN-2 with the teacher model CNN-10
python main.py --epochs 200 --lr 0.02 --name CNN-2-distillation --student_type CNN-2 --teacher_type CNN-10 --teacher_model runs/CNN-10/model_best.pth.tar --tensorboard
```

We also supported Distributed Data Parallel training on single node and multi nodes settings for distillation. To use Distributed Data Parallel to speedup training, the bash command needs a small adjustment.
<br>
For example, bash command will look like the following, where *`<MASTER_ADDRESS>`* is the address of the master node, it won't be necessary for single node case, *`<NUM_PROCESSES_PER_NODE>`* is the desired processes to use in current node, for node with GPU, usually set to number of GPUs in this node, for node without GPU and use CPU for training, it's recommended set to 1, *`<NUM_NODES>`* is the number of nodes to use, *`<NODE_RANK>`* is the rank of the current node, rank starts from 0 to *`<NUM_NODES>`*`-1`.
<br>
Also please note that to use CPU for training in each node with multi nodes settings, argument `--no_cuda` is mandatory. In multi nodes setting, following command needs to be launched in each node, and all the commands should be the same except for *`<NODE_RANK>`*, which should be integer from 0 to *`<NUM_NODES>`*`-1` assigned to each node.

```bash
python -m torch.distributed.launch --master_addr=<MASTER_ADDRESS> --nproc_per_node=<NUM_PROCESSES_PER_NODE> --nnodes=<NUM_NODES> --node_rank=<NODE_RANK> \
main.py --epochs 200 --lr 0.02 --name CNN-2-distillation --student_type CNN-2 --teacher_type CNN-10 --teacher_model runs/CNN-10/model_best.pth.tar --tensorboard
```
82 changes: 49 additions & 33 deletions examples/pytorch/image_recognition/CNN-2/distillation/eager/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from accelerate import Accelerator
from plain_cnn_cifar import ConvNetMaker, plane_cifar100_book

# used for logging to TensorBoard
Expand Down Expand Up @@ -60,6 +61,7 @@
help='loss weights of distillation, should be a list of length 2, '
'and sum to 1.0, first for student targets loss weight, '
'second for teacher student loss weight.')
parser.add_argument("--no_cuda", action='store_true', help='use cpu for training.')
parser.set_defaults(augment=True)


Expand All @@ -75,10 +77,13 @@ def set_seed(seed):
def main():
global args, best_prec1
args, _ = parser.parse_known_args()
accelerator = Accelerator(cpu=args.no_cuda)

best_prec1 = 0
if args.seed is not None:
set_seed(args.seed)
if args.tensorboard: configure("runs/%s" % (args.name))
with accelerator.local_main_process_first():
if args.tensorboard: configure("runs/%s"%(args.name))

# Data loading code
normalize = transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761])
Expand Down Expand Up @@ -121,9 +126,9 @@ def main():
raise NotImplementedError('Unsupported student model type')

# get the number of model parameters
print('Number of teacher model parameters: {}'.format(
accelerator.print('Number of teacher model parameters: {}'.format(
sum([p.data.nelement() for p in teacher_model.parameters()])))
print('Number of student model parameters: {}'.format(
accelerator.print('Number of student model parameters: {}'.format(
sum([p.data.nelement() for p in student_model.parameters()])))

kwargs = {'num_workers': 0, 'pin_memory': True}
Expand All @@ -135,10 +140,10 @@ def main():
if args.loss_weights[1] > 0:
from tqdm import tqdm
def get_logits(teacher_model, train_dataset):
print("***** Getting logits of teacher model *****")
print(f" Num examples = {len(train_dataset) }")
accelerator.print("***** Getting logits of teacher model *****")
accelerator.print(f" Num examples = {len(train_dataset) }")
logits_file = os.path.join(os.path.dirname(args.teacher_model), 'teacher_logits.npy')
if not os.path.exists(logits_file):
if not os.path.exists(logits_file) and accelerator.is_local_main_process:
teacher_model.eval()
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, **kwargs)
train_dataloader = tqdm(train_dataloader, desc="Evaluating")
Expand All @@ -147,8 +152,8 @@ def get_logits(teacher_model, train_dataset):
outputs = teacher_model(input)
teacher_logits += [x for x in outputs.numpy()]
np.save(logits_file, np.array(teacher_logits))
else:
teacher_logits = np.load(logits_file)
accelerator.wait_for_everyone()
teacher_logits = np.load(logits_file)
train_dataset.targets = [{'labels':l, 'teacher_logits':tl} \
for l, tl in zip(train_dataset.targets, teacher_logits)]
return train_dataset
Expand All @@ -163,29 +168,34 @@ def get_logits(teacher_model, train_dataset):
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
accelerator.print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
student_model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
accelerator.print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
accelerator.print("=> no checkpoint found at '{}'".format(args.resume))

# define optimizer
optimizer = torch.optim.SGD(student_model.parameters(), args.lr,
momentum=args.momentum, nesterov = args.nesterov,
weight_decay=args.weight_decay)

# cosine learning rate
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*args.epochs)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, len(train_loader) * args.epochs // accelerator.num_processes
)

student_model, teacher_model, train_loader, val_loader, optimizer = \
accelerator.prepare(student_model, teacher_model, train_loader, val_loader, optimizer)

def train_func(model):
return train(train_loader, model, scheduler, distiller, best_prec1)
return train(train_loader, model, scheduler, distiller, best_prec1, accelerator)

def eval_func(model):
return validate(val_loader, model, distiller)
return validate(val_loader, model, distiller, accelerator)

from neural_compressor.experimental import Distillation, common
from neural_compressor.experimental.common.criterion import PyTorchKnowledgeDistillationLoss
Expand All @@ -204,11 +214,12 @@ def eval_func(model):

directory = "runs/%s/"%(args.name)
os.makedirs(directory, exist_ok=True)
model._model = accelerator.unwrap_model(model.model)
model.save(directory)
# change to framework model for further use
model = model.model

def train(train_loader, model, scheduler, distiller, best_prec1):
def train(train_loader, model, scheduler, distiller, best_prec1, accelerator):
distiller.on_train_begin()
for epoch in range(args.start_epoch, args.epochs):
"""Train for one epoch on the training set"""
Expand All @@ -233,13 +244,15 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
loss = distiller.on_after_compute_loss(input, output, loss, teacher_logits)

# measure accuracy and record loss
output = accelerator.gather(output)
target = accelerator.gather(target)
prec1 = accuracy(output.data, target, topk=(1,))[0]
losses.update(loss.data.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
losses.update(accelerator.gather(loss).sum().data.item(), input.size(0)*accelerator.num_processes)
top1.update(prec1.item(), input.size(0)*accelerator.num_processes)

# compute gradient and do SGD step
distiller.optimizer.zero_grad()
loss.backward()
accelerator.backward(loss) # loss.backward()
distiller.optimizer.step()
scheduler.step()

Expand All @@ -248,7 +261,7 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
end = time.time()

if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
accelerator.print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
Expand All @@ -260,19 +273,20 @@ def train(train_loader, model, scheduler, distiller, best_prec1):
# remember best prec@1 and save checkpoint
is_best = distiller.best_score > best_prec1
best_prec1 = max(distiller.best_score, best_prec1)
save_checkpoint({
'epoch': distiller._epoch_runned + 1,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
}, is_best)
# log to TensorBoard
if args.tensorboard:
log_value('train_loss', losses.avg, epoch)
log_value('train_acc', top1.avg, epoch)
log_value('learning_rate', scheduler._last_lr[0], epoch)
if accelerator.is_local_main_process:
save_checkpoint({
'epoch': distiller._epoch_runned + 1,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
}, is_best)
# log to TensorBoard
if args.tensorboard:
log_value('train_loss', losses.avg, epoch)
log_value('train_acc', top1.avg, epoch)
log_value('learning_rate', scheduler._last_lr[0], epoch)


def validate(val_loader, model, distiller):
def validate(val_loader, model, distiller, accelerator):
"""Perform validation on the validation set"""
batch_time = AverageMeter()
top1 = AverageMeter()
Expand All @@ -287,6 +301,8 @@ def validate(val_loader, model, distiller):
output = model(input)

# measure accuracy
output = accelerator.gather(output)
target = accelerator.gather(target)
prec1 = accuracy(output.data, target, topk=(1,))[0]
top1.update(prec1.item(), input.size(0))

Expand All @@ -295,15 +311,15 @@ def validate(val_loader, model, distiller):
end = time.time()

if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
accelerator.print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time,
top1=top1))

print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
accelerator.print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
# log to TensorBoard
if args.tensorboard:
if accelerator.is_local_main_process and args.tensorboard:
log_value('val_acc', top1.avg, distiller._epoch_runned)
return top1.avg

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
torch==1.5.0+cpu
torchvision==0.6.0+cpu
tensorboard_logger
accelerate

0 comments on commit ebe9e2a

Please sign in to comment.