-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
365 lines (311 loc) · 12.9 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from tqdm import tqdm
import random
import numpy as np
import os
import matplotlib.pyplot as plt
import argparse
def set_seed(seed_value=42):
"""Set seed for reproducibility."""
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
# If you are using GPU
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def initialize_model(model_name):
"""Initialize a model from torchhub."""
model = torch.hub.load('pytorch/vision:v0.13.0', model_name, weights='IMAGENET1K_V1')
if model_name.startswith('resnet'):
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
elif model_name.startswith('efficientnet'):
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 2)
else:
raise NotImplementedError(f'{model_name} not implemented')
return model
def calculate_metrics(labels, preds):
"""Calculate metrics like accuracy, precision, recall, and F1 score."""
acc = accuracy_score(labels, preds)
precision = precision_score(labels, preds)
recall = recall_score(labels, preds)
f1 = f1_score(labels, preds)
return acc, precision, recall, f1
def train(model, criterion, optimizer, train_loader, device, epoch):
"""
Trains the given model for one epoch using the provided criterion, optimizer, and data loader.
Args:
model (torch.nn.Module): The model to train.
criterion (torch.nn.Module): The loss function to use.
optimizer (torch.optim.Optimizer): The optimizer to use.
train_loader (torch.utils.data.DataLoader): The data loader for the training set.
device (torch.device): The device to use for training.
epoch (int): The current epoch number.
Returns:
tuple: A tuple containing the average loss, accuracy, precision, recall, and F1 score for the epoch.
"""
model.train()
total_loss = 0.0
all_preds = []
all_labels = []
for data, label in tqdm(train_loader, desc=f"Epoch {epoch}"):
data, label = data.to(device), label.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, label)
loss.backward()
optimizer.step()
total_loss += loss.item()
all_preds.extend(output.argmax(1).cpu().numpy())
all_labels.extend(label.cpu().numpy())
acc, precision, recall, f1 = calculate_metrics(all_labels, all_preds)
avg_loss = total_loss/len(train_loader)
print(f"Epoch {epoch} | Train Loss: {avg_loss} | F1: {f1} | Accuracy: {acc} | Precision: {precision} | Recall: {recall}")
return avg_loss, acc, precision, recall, f1
def validate(model, criterion, val_loader, device):
"""
Validates the performance of the model on the validation set.
Args:
model (torch.nn.Module): The model to be validated.
criterion (torch.nn.Module): The loss function used for validation.
val_loader (torch.utils.data.DataLoader): The validation data loader.
device (str): The device to run the validation on.
Returns:
tuple: A tuple containing the average validation loss, accuracy, precision, recall, and F1 score.
"""
model.eval()
val_loss = 0.0
all_preds = []
all_labels = []
for data, label in tqdm(val_loader, desc="Validating"):
data, label = data.to(device), label.to(device)
output = model(data)
loss = criterion(output, label)
val_loss += loss.item()
all_preds.extend(output.argmax(1).cpu().numpy())
all_labels.extend(label.cpu().numpy())
acc, precision, recall, f1 = calculate_metrics(all_labels, all_preds)
avg_loss = val_loss/len(val_loader)
print(f"Validation Loss: {avg_loss} | F1: {f1} | Accuracy: {acc} | Precision: {precision} | Recall: {recall}")
return avg_loss, acc, precision, recall, f1
def train_and_evaluate(model, train_loader, val_loader, device, config, model_name):
"""
Trains and evaluates the given model using the provided train and validation data loaders, on the specified device.
The training process is controlled by the provided configuration dictionary, and the resulting model is saved
under the specified model name. The function returns nothing, but saves the training results and the final model
in the specified directory.
Args:
- model: a PyTorch model to be trained and evaluated
- train_loader: a PyTorch DataLoader object containing the training data
- val_loader: a PyTorch DataLoader object containing the validation data
- device: a PyTorch device object specifying the device to be used for training and evaluation
- config: a dictionary containing the configuration parameters for the training process
- model_name: a string specifying the name of the model to be saved
Returns:
- None
"""
if not os.path.exists(model_name):
os.makedirs(model_name)
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config['WARMUP_LR'], weight_decay=config['WEIGHT_DECAY'])
# Learning Rate Scheduler with Warmup
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['EPOCHS'] - config['WARMUP_EPOCHS'], eta_min=0)
# Training Loop
best_f1 = 0.0
patience_counter = 0
train_losses = []
val_losses = []
train_accs = []
val_accs = []
train_precisions = []
val_precisions = []
train_recalls = []
val_recalls = []
train_f1s = []
val_f1s = []
for epoch in range(1, config['EPOCHS']+1):
if epoch <= config['WARMUP_EPOCHS']:
# Linear warmup
new_lr = config['WARMUP_LR'] + (config['LR0'] - config['WARMUP_LR']) * (epoch / config['WARMUP_EPOCHS'])
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
else:
# After warmup, we apply cosine annealing. The scheduler step function is called after optimizer.step() in your training function.
scheduler.step()
avg_loss, acc, precision, recall, f1 = train(model, criterion, optimizer, train_loader, device, epoch)
train_losses.append(avg_loss)
train_accs.append(acc)
train_precisions.append(precision)
train_recalls.append(recall)
train_f1s.append(f1)
# if using cuda clear cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
with torch.no_grad():
avg_loss, acc, precision, recall, f1 = validate(model, criterion, val_loader, device)
val_losses.append(avg_loss)
val_accs.append(acc)
val_precisions.append(precision)
val_recalls.append(recall)
val_f1s.append(f1)
if f1 > best_f1:
best_f1 = f1
patience_counter = 0
print(f'Saving Best Model at Epoch {epoch}')
torch.save(model.state_dict(), f'{model_name}/{model_name}_best_model.pt')
else:
patience_counter += 1
if patience_counter > config['PATIENCE']:
print(f'Early Stopping at Epoch {epoch}')
break
# save last model
print('Saving Last Model')
torch.save(model.state_dict(), f'{model_name}/{model_name}_last_model.pt')
print('Training Complete for', model_name)
# plot the training and validation graphs
# We have, loss, accuracy, precision, recall, and F1 score
plt.figure(figsize=(20, 10))
epochs = range(1, len(train_losses) + 1)
plt.subplot(2, 3, 1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss vs Epochs')
plt.legend()
plt.subplot(2, 3, 2)
plt.plot(epochs, train_accs, label='Train Accuracy')
plt.plot(epochs, val_accs, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Accuracy vs Epochs')
plt.legend()
plt.subplot(2, 3, 3)
plt.plot(epochs, train_precisions, label='Train Precision')
plt.plot(epochs, val_precisions, label='Validation Precision')
plt.xlabel('Epochs')
plt.ylabel('Precision')
plt.title('Precision vs Epochs')
plt.legend()
plt.subplot(2, 3, 4)
plt.plot(epochs, train_recalls, label='Train Recall')
plt.plot(epochs, val_recalls, label='Validation Recall')
plt.xlabel('Epochs')
plt.ylabel('Recall')
plt.title('Recall vs Epochs')
plt.legend()
plt.subplot(2, 3, 5)
plt.plot(epochs, train_f1s, label='Train F1 Score')
plt.plot(epochs, val_f1s, label='Validation F1 Score')
plt.xlabel('Epochs')
plt.ylabel('F1 Score')
plt.title('F1 Score vs Epochs')
plt.legend()
plt.tight_layout()
plt.savefig(f'{model_name}/{model_name}_training_results.png')
plt.show()
def parse_args():
"""
Parses command line arguments.
"""
parser = argparse.ArgumentParser(description="Training")
parser.add_argument(
'--train_dir',
type=str,
help='Directory with training images.',
required=True
)
parser.add_argument(
'--val_dir',
type=str,
help='Directory with validation images.',
required=True
)
parser.add_argument(
'--batch_size',
type=int,
default=8,
help='Number of images to process at once. Default is 8.'
)
parser.add_argument(
'--epochs',
type=int,
default=18,
help='Number of epochs to train models. Default is 18.'
)
parser.add_argument(
'--input_size',
type=int,
default=640,
help='Dimensionality of images after resizing. Default is 640x640.'
)
return parser.parse_args()
def main():
args = parse_args()
TRAIN_DIR = args.train_dir
VAL_DIR = args.val_dir
BATCH_SIZE = args.batch_size
INPUT_SIZE = args.input_size
EPOCHS = args.epochs
set_seed(42)
# Конфиги для тренировки модели
# Укажите здесь путь к тренировочным и валидационным фотографиям
# TEST_PATH можете оставить пустым...
config = {
"TRAIN_PATH": TRAIN_DIR,
"VAL_PATH": VAL_DIR,
"BATCH_SIZE": BATCH_SIZE,
"INPUT_SIZE": INPUT_SIZE,
"NUM_WORKERS": 2,
"EPOCHS": EPOCHS,
"WARMUP_LR": 0.00001,
"LR0": 0.0001,
"PATIENCE": 6,
"WARMUP_EPOCHS": 3,
"WEIGHT_DECAY": 0.0003
}
# Data Transforms
transform = transforms.Compose([
transforms.Resize((config['INPUT_SIZE'], config['INPUT_SIZE'])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Datasets and DataLoaders
train_dataset = datasets.ImageFolder(config['TRAIN_PATH'], transform=transform)
val_dataset = datasets.ImageFolder(config['VAL_PATH'], transform=transform)
train_loader = DataLoader(train_dataset, batch_size=config['BATCH_SIZE'], shuffle=True, num_workers = config['NUM_WORKERS'])
val_loader = DataLoader(val_dataset, batch_size=config['BATCH_SIZE'], shuffle=False, num_workers = config['NUM_WORKERS'])
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
model_names = ['resnet18', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3']
for model_name in model_names:
model = initialize_model(model_name)
model = model.to(device)
train_and_evaluate(model, train_loader, val_loader, device, config, model_name)
# Весь код ниже нужен для теста. Тест лучше проводить через inference.py
# load best model
# model.load_state_dict(torch.load(f'{model_name}/{model_name}_best_model.pt'))
# model = model.to(device)
#
# Testing on Test
# with torch.no_grad():
# all_preds = []
# all_labels = []
# for data, label in tqdm(test_loader, desc="Testing"):
# data, label = data.to(device), label.to(device)
# output = model(data)
# all_preds.extend(output.argmax(1).cpu().numpy())
# all_labels.extend(label.cpu().numpy())
# acc, precision, recall, f1 = calculate_metrics(all_labels, all_preds)
# print(f"Test Accuracy: {acc} | Precision: {precision} | Recall: {recall} | F1: {f1}")
# print('Testing Complete for', model_name)
if __name__ == '__main__':
main()