-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathtrain_single_label_from_scratch.py
129 lines (103 loc) · 4.62 KB
/
train_single_label_from_scratch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# --------------------------------------------------------
# ImageNet-21K Pretraining for The Masses
# Copyright 2021 Alibaba MIIL (c)
# Licensed under MIT License [see the LICENSE file for details]
# Written by Tal Ridnik
# --------------------------------------------------------
import argparse
import time
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data.distributed
from torch.optim import lr_scheduler
from src_files.data_loading.data_loader import create_data_loaders
from src_files.helper_functions.distributed import print_at_master, to_ddp, reduce_tensor, num_distrib, setup_distrib
from src_files.helper_functions.general_helper_functions import accuracy, AverageMeter, silence_PIL_warnings
from src_files.models import create_model
from src_files.loss_functions.losses import CrossEntropyLS
from torch.cuda.amp import GradScaler, autocast
from src_files.optimizers.create_optimizer import create_optimizer, create_optimizer_sgd
parser = argparse.ArgumentParser(description='PyTorch ImageNet21K Single-label Training From Random Initialization')
parser.add_argument('--data_path', type=str)
parser.add_argument('--lr', default=1e-2, type=float)
parser.add_argument('--model_name', default='tresnet_m')
parser.add_argument('--model_path', default='', type=str)
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--image_size', default=224, type=int)
parser.add_argument('--num_classes', default=11221, type=int)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--epochs', default=140, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--label_smooth", default=0.2, type=float)
def main():
# arguments
args = parser.parse_args()
# EXIF warning silent
silence_PIL_warnings()
# setup distributed
setup_distrib(args)
# Setup model
model = create_model(args).cuda()
model = to_ddp(model, args)
# create optimizer
optimizer = create_optimizer_sgd(model, args)
# Data loading
train_loader, val_loader = create_data_loaders(args)
# Actuall Training
train_21k(model, train_loader, val_loader, optimizer, args)
def train_21k(model, train_loader, val_loader, optimizer, args):
# set loss
loss_fn = CrossEntropyLS(args.label_smooth)
# set scheduler
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=len(train_loader),
epochs=args.epochs, pct_start=0.1, cycle_momentum=False, div_factor=20)
# set scalaer
scaler = GradScaler()
# training loop
for epoch in range(args.epochs):
if num_distrib() > 1:
train_loader.sampler.set_epoch(epoch)
# train epoch
print_at_master("\nEpoch {}".format(epoch))
epoch_start_time = time.time()
for i, (input, target) in enumerate(train_loader):
with autocast(): # mixed precision
output = model(input)
loss = loss_fn(output, target) # note - loss also in fp16
model.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
epoch_time = time.time() - epoch_start_time
print_at_master(
"\nFinished Epoch, Training Rate: {:.1f} [img/sec]".format(len(train_loader) *
args.batch_size / epoch_time * max(num_distrib(),
1)))
# validation epoch
validate_21k(val_loader, model)
def validate_21k(val_loader, model):
print_at_master("starting validation")
model.eval()
top1 = AverageMeter()
top5 = AverageMeter()
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
# mixed precision
with autocast():
logits = model(input).float()
# measure accuracy and record loss
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
if num_distrib() > 1:
acc1 = reduce_tensor(acc1, num_distrib())
acc5 = reduce_tensor(acc5, num_distrib())
torch.cuda.synchronize()
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))
print_at_master("Validation results:")
print_at_master('Acc_Top1 [%] {:.2f}, Acc_Top5 [%] {:.2f} '.format(top1.avg, top5.avg))
model.train()
if __name__ == '__main__':
main()