-
Notifications
You must be signed in to change notification settings - Fork 5
/
mlp_coral.py
321 lines (260 loc) · 9.27 KB
/
mlp_coral.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import argparse
import os
import shutil
import time
import torch
import torch.nn.functional as F
# Import from local helper file
from helper import parse_cmdline_args
from helper import compute_mae_and_rmse
from helper import get_dataloaders_fireman
# Argparse helper
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
args = parse_cmdline_args(parser)
##########################
# Settings and Setup
##########################
NUM_WORKERS = args.numworkers
LEARNING_RATE = args.learningrate
NUM_EPOCHS = args.epochs
BATCH_SIZE = args.batchsize
OUTPUT_DIR = args.output_dir
LOSS_PRINT_INTERVAL = args.loss_print_interval
if os.path.exists(args.output_dir):
if args.overwrite:
shutil.rmtree(args.output_dir)
else:
raise ValueError('Output directory already exists.')
os.makedirs(args.output_dir)
BEST_MODEL_PATH = os.path.join(args.output_dir, 'best_model.pt')
LOGFILE_PATH = os.path.join(args.output_dir, 'training.log')
if args.cuda >= 0 and torch.cuda.is_available():
DEVICE = torch.device(f'cuda:{args.cuda}')
else:
DEVICE = torch.device('cpu')
if args.seed == -1:
RANDOM_SEED = None
else:
RANDOM_SEED = args.seed
############################
# Dataset
############################
NUM_CLASSES = 16
NUM_FEATURES = 10
train_loader, valid_loader, test_loader = get_dataloaders_fireman(
batch_size=BATCH_SIZE,
train_csv_path='./datasets/fireman_example_balanced_train.csv',
valid_csv_path='./datasets/fireman_example_balanced_valid.csv',
test_csv_path='./datasets/fireman_example_balanced_test.csv',
balanced=True,
num_workers=NUM_WORKERS,
num_classes=NUM_CLASSES)
##########################
# MODEL
##########################
class MultilayerPerceptron(torch.nn.Module):
def __init__(self, num_features, num_classes,
num_hidden_1, num_hidden_2):
super().__init__()
self.num_classes = num_classes
self.my_network = torch.nn.Sequential(
# 1st hidden layer
torch.nn.Linear(num_features, num_hidden_1, bias=False),
torch.nn.LeakyReLU(),
torch.nn.Dropout(0.2),
torch.nn.BatchNorm1d(num_hidden_1),
# 2nd hidden layer
torch.nn.Linear(num_hidden_1, num_hidden_2, bias=False),
torch.nn.LeakyReLU(),
torch.nn.Dropout(0.2),
torch.nn.BatchNorm1d(num_hidden_2),
torch.nn.Linear(num_hidden_2, 1, bias=False)
)
self.output_biases = torch.nn.Parameter(
torch.zeros(NUM_CLASSES-1).float())
def forward(self, x):
logits = self.my_network(x)
logits = logits + self.output_biases.view(1, -1)
return logits
if RANDOM_SEED is not None:
torch.manual_seed(RANDOM_SEED)
model = MultilayerPerceptron(num_features=NUM_FEATURES,
num_hidden_1=300,
num_hidden_2=300,
num_classes=16)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
#######################################
# Utility Functions
#######################################
def label_to_levels(label, num_classes, dtype=torch.float32):
"""Converts integer class label to extended binary label vector
Parameters
----------
label : int
Class label to be converted into a extended
binary vector. Should be smaller than num_classes-1.
num_classes : int
The number of class clabels in the dataset. Assumes
class labels start at 0. Determines the size of the
output vector.
dtype : torch data type (default=torch.float32)
Data type of the torch output vector for the
extended binary labels.
Returns
----------
levels : torch.tensor, shape=(num_classes-1,)
Extended binary label vector. Type is determined
by the `dtype` parameter.
Examples
----------
>>> label_to_levels(0, num_classes=5)
tensor([0., 0., 0., 0.])
>>> label_to_levels(1, num_classes=5)
tensor([1., 0., 0., 0.])
>>> label_to_levels(3, num_classes=5)
tensor([1., 1., 1., 0.])
>>> label_to_levels(4, num_classes=5)
tensor([1., 1., 1., 1.])
"""
if not label <= num_classes-1:
raise ValueError('Class label must be smaller or '
'equal to %d (num_classes-1). Got %d.'
% (num_classes-1, label))
if isinstance(label, torch.Tensor):
int_label = label.item()
else:
int_label = label
levels = [1]*int_label + [0]*(num_classes - 1 - int_label)
levels = torch.tensor(levels, dtype=dtype)
return levels
def levels_from_labelbatch(labels, num_classes, dtype=torch.float32):
"""
Converts a list of integer class label to extended binary label vectors
Parameters
----------
labels : list or 1D orch.tensor, shape=(num_labels,)
A list or 1D torch.tensor with integer class labels
to be converted into extended binary label vectors.
num_classes : int
The number of class clabels in the dataset. Assumes
class labels start at 0. Determines the size of the
output vector.
dtype : torch data type (default=torch.float32)
Data type of the torch output vector for the
extended binary labels.
Returns
----------
levels : torch.tensor, shape=(num_labels, num_classes-1)
Examples
----------
>>> levels_from_labelbatch(labels=[2, 1, 4], num_classes=5)
tensor([[1., 1., 0., 0.],
[1., 0., 0., 0.],
[1., 1., 1., 1.]])
"""
levels = []
for label in labels:
levels_from_label = label_to_levels(
label=label, num_classes=num_classes, dtype=dtype)
levels.append(levels_from_label)
levels = torch.stack(levels)
return levels
def loss_coral(logits, levels):
val = (-torch.sum((F.logsigmoid(logits)*levels
+ (F.logsigmoid(logits) - logits)*(1-levels)),
dim=1))
return torch.mean(val)
def label_from_logits(logits):
""" Converts logits to class labels.
This is function is specific to CORAL.
"""
probas = torch.sigmoid(logits)
predict_levels = probas > 0.5
predicted_labels = torch.sum(predict_levels, dim=1)
return predicted_labels
#######################################
# Training
#######################################
best_valid_mae = torch.tensor(float('inf'))
s = (f'Script: {__file__}\n'
f'PyTorch version: {torch.__version__}\n'
f'Device: {DEVICE}\n'
f'Learning rate: {LEARNING_RATE}\n'
f'Batch size: {BATCH_SIZE}\n')
print(s)
with open(LOGFILE_PATH, 'w') as f:
f.write(f'{s}\n')
start_time = time.time()
for epoch in range(1, NUM_EPOCHS+1):
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.to(DEVICE)
targets = targets.to(DEVICE)
# FORWARD AND BACK PROP
logits = model(features)
# CORAL loss
levels = levels_from_labelbatch(
targets,
num_classes=NUM_CLASSES).type_as(logits)
loss = loss_coral(logits, levels)
# ##--------------------------------------------------------------------###
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Logging
if not batch_idx % LOSS_PRINT_INTERVAL:
s = (f'Epoch: {epoch:03d}/{NUM_EPOCHS:03d} | '
f'Batch {batch_idx:04d}/'
f'{len(train_loader):04d} | '
f'Loss: {loss:.4f}')
print(s)
with open(LOGFILE_PATH, 'a') as f:
f.write(f'{s}\n')
# Logging: Evaluate after epoch
model.eval()
with torch.no_grad():
valid_mae, valid_rmse = compute_mae_and_rmse(
model=model,
data_loader=valid_loader,
device=DEVICE,
label_from_logits_func=label_from_logits
)
if valid_mae < best_valid_mae:
best_valid_mae = valid_mae
best_epoch = epoch
torch.save(model.state_dict(), BEST_MODEL_PATH)
s = (f'MAE Current Valid: {valid_mae:.2f} Ep. {epoch}'
f' | Best Valid: {best_valid_mae:.2f} Ep. {best_epoch}')
s += f'\nTime elapsed: {(time.time() - start_time)/60:.2f} min'
print(s)
with open(LOGFILE_PATH, 'a') as f:
f.write('%s\n' % s)
# Final
model.load_state_dict(torch.load(BEST_MODEL_PATH))
model.eval()
with torch.no_grad():
train_mae, train_rmse = compute_mae_and_rmse(
model=model,
data_loader=train_loader,
device=DEVICE,
label_from_logits_func=label_from_logits
)
valid_mae, valid_rmse = compute_mae_and_rmse(
model=model,
data_loader=valid_loader,
device=DEVICE,
label_from_logits_func=label_from_logits
)
test_mae, test_rmse = compute_mae_and_rmse(
model=model,
data_loader=valid_loader,
device=DEVICE,
label_from_logits_func=label_from_logits
)
s = ('\n\n=========================================\n\n'
'Performance of best model based on validation set MAE:'
f'Train MAE / RMSE: {train_mae:.2f} / {train_rmse:.2f}'
f'Valid MAE / RMSE: {valid_mae:.2f} / {valid_rmse:.2f}'
f'Test MAE / RMSE: {test_mae:.2f} / {test_rmse:.2f}')