-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
185 lines (159 loc) · 7.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# -*- coding: utf-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Description: Train the face classification models ('MiniXception' or 'SimpleCNN').
"""
import os
import logging
import numpy as np
import random
import paddle
from tqdm import tqdm
from models.simple_cnn import SimpleCNN
from models.mini_xception import MiniXception
from data.dataset import load_imdb, split_imdb_data
from data.dataset import FaceDataset
from visualdl import LogWriter
from argparse import ArgumentParser
from paddle.io import DataLoader
from paddle import distributed as dist
from config.confg import parse_args
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
def train():
# Load the dataset
logging.info(f"Loading the dataset ...")
data = load_imdb(os.path.join(data_args.imdb_dir, 'imdb.mat'))
train_set, val_set = split_imdb_data(data, args.validation_split)
train_set = FaceDataset(train_set, img_path_prefix=data_args.imdb_dir, grayscale=data_args.grayscale,
do_transform=True,
do_random_crop=data_args.do_rand_crop,
translation_factor=data_args.translation_factor,
vertical_flip_probability=data_args.vertical_flip_probability,
read_img_at_once=data_args.read_img_at_once,
img_size=data_args.img_size,
)
val_set = FaceDataset(val_set, img_path_prefix=data_args.imdb_dir, grayscale=data_args.grayscale,
read_img_at_once=data_args.read_img_at_once,
img_size=data_args.img_size,
)
logging.info(
f"The numbers of samples of training and validation sets are {train_set.__len__()} and {val_set.__len__()}.")
train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
val_loader = DataLoader(val_set, batch_size=args.batch_size, num_workers=args.num_workers)
# For multi-GPUs training
dist.init_parallel_env()
# Build the model, optimizer and loss function
if args.model_name == "MiniXception":
model = MiniXception(args.n_classes, args.in_channels)
else:
model = SimpleCNN(args.n_classes, args.in_channels)
scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=0.001, factor=0.5, patience=50, verbose=True,
epsilon=1e-6)
optimizer = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters())
loss_fn = paddle.nn.CrossEntropyLoss()
# Restore parameters
if args.restore:
model = restore_params(model, args.model_state_dict)
optimizer = restore_params(optimizer, args.opt_state_dict)
epoch = int(args.model_state_dict.split('-')[1]) + 1
else:
epoch = 1
# For multi-GPUs training
model = paddle.DataParallel(model)
# For visualDL
train_logger = LogWriter(logdir=os.path.join(args.logdir, 'train'))
val_logger = LogWriter(logdir=os.path.join(args.logdir, 'val'))
# Start training
saved_params = []
max_val_acc = 0.95
for epoch in range(epoch, args.epochs + 1):
logging.info("=" * 50 + f"Epoch {epoch}" + "=" * 50)
# Training
n_samples, sum_acc, sum_loss = 0, 0., 0.
model.train()
with tqdm(train_loader) as t:
for step, batch in enumerate(t, 1):
inputs, labels = batch[0], batch[1]
pred = model(inputs)
loss = loss_fn(pred, labels)
acc = paddle.metric.accuracy(pred, labels.unsqueeze(1))
n = len(inputs)
sum_loss += loss.item() * n
sum_acc += acc.item() * n
n_samples += n
loss.backward()
optimizer.step()
optimizer.clear_grad()
t.set_postfix(train_loss=loss.item(), train_acc=acc.item())
loss, acc = sum_loss / n_samples, sum_acc / n_samples
train_logger.add_scalar('loss', loss, epoch)
train_logger.add_scalar('acc', acc, epoch)
train_logger.add_scalar('lr', optimizer.get_lr(), epoch)
logging.info(f"Epoch {epoch}, train loss {loss}, train acc {acc}, lr {optimizer.get_lr()}.")
# Validation
n_samples, sum_acc, sum_loss = 0, 0., 0.
model.eval()
for batch in tqdm(val_loader()):
inputs, labels = batch[0], batch[1]
pred = model(inputs)
loss = loss_fn(pred, labels)
acc = paddle.metric.accuracy(pred, labels.unsqueeze(1))
n = len(inputs)
sum_loss += loss.item() * n
sum_acc += acc.item() * n
n_samples += n
loss, acc = sum_loss / n_samples, sum_acc / n_samples
val_logger.add_scalar('loss', loss, epoch)
val_logger.add_scalar('acc', acc, epoch)
logging.info(f"Epoch {epoch}, val loss {loss}, val acc {acc}")
scheduler.step(loss, epoch)
# Save the model state dict
if acc > max_val_acc or epoch >= args.epochs:
logging.info(f"Epoch {epoch}, Output model with val acc {acc} to {args.model_dir}")
model_path = os.path.join(args.model_dir, f"{args.model_name}-{epoch}-{acc}.params")
opt_path = os.path.join(args.model_dir, f"{args.model_name}-{epoch}-{acc}.opt")
paddle.save(model.state_dict(), model_path)
paddle.save(optimizer.state_dict(), opt_path)
if args.just_keep_the_best: # Just keep the best model, remove the others
for p in saved_params:
os.remove(p)
saved_params = [model_path, opt_path]
max_val_acc = acc
if acc >= 0.96: # Stop training when the acc reach 0.96
logging.info("Finished training.")
break
logging.info("Finished training.")
def restore_params(model, model_path):
"""Restore parameters of the model.
Args:
model: paddle.nn.Layer, the model to be restored.
model_path: The file path of the parameters.
"""
logging.info(f"Restoring parameters from {model_path}")
model_state_dict = paddle.load(model_path)
model.set_state_dict(model_state_dict)
return model
def set_seed(seed: int):
"""Set the random seed.
"""
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
if __name__ == '__main__':
parser = ArgumentParser("Training parameters")
parser.add_argument('--conf_path', '-c', type=str, default='config/conf.yaml', help='Path to the config.')
parser.add_argument('--model_name', '-m', type=str, choices=['MiniXception', 'SimpleCNN'],
help='Choose a model to train.')
args = parse_args(parser) # parse arguments
data_args = args.dataset # arguments of the dataset
set_seed(args.seed) # set the random seed
train() # begin training