forked from nnzhan/Graph-WaveNet
-
Notifications
You must be signed in to change notification settings - Fork 24
/
train.py
107 lines (98 loc) · 4.59 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import numpy as np
import pandas as pd
import time
import util
from engine import Trainer
import os
from durbango import pickle_save
from fastprogress import progress_bar
from model import GWNet
from util import calc_tstep_metrics
from exp_results import summary
def main(args, **model_kwargs):
device = torch.device(args.device)
data = util.load_dataset(args.data, args.batch_size, args.batch_size, args.batch_size, n_obs=args.n_obs, fill_zeroes=args.fill_zeroes)
scaler = data['scaler']
aptinit, supports = util.make_graph_inputs(args, device)
model = GWNet.from_args(args, device, supports, aptinit, **model_kwargs)
if args.checkpoint:
model.load_checkpoint(torch.load(args.checkpoint))
model.to(device)
engine = Trainer.from_args(model, scaler, args)
metrics = []
best_model_save_path = os.path.join(args.save, 'best_model.pth')
lowest_mae_yet = 100 # high value, will get overwritten
mb = progress_bar(list(range(1, args.epochs + 1)))
epochs_since_best_mae = 0
for _ in mb:
train_loss, train_mape, train_rmse = [], [], []
data['train_loader'].shuffle()
for iter, (x, y) in enumerate(data['train_loader'].get_iterator()):
trainx = torch.Tensor(x).to(device).transpose(1, 3)
trainy = torch.Tensor(y).to(device).transpose(1, 3)
yspeed = trainy[:, 0, :, :]
if yspeed.max() == 0: continue
mae, mape, rmse = engine.train(trainx, yspeed)
train_loss.append(mae)
train_mape.append(mape)
train_rmse.append(rmse)
if args.n_iters is not None and iter >= args.n_iters:
break
engine.scheduler.step()
_, valid_loss, valid_mape, valid_rmse = eval_(data['val_loader'], device, engine)
m = dict(train_loss=np.mean(train_loss), train_mape=np.mean(train_mape),
train_rmse=np.mean(train_rmse), valid_loss=np.mean(valid_loss),
valid_mape=np.mean(valid_mape), valid_rmse=np.mean(valid_rmse))
m = pd.Series(m)
metrics.append(m)
if m.valid_loss < lowest_mae_yet:
torch.save(engine.model.state_dict(), best_model_save_path)
lowest_mae_yet = m.valid_loss
epochs_since_best_mae = 0
else:
epochs_since_best_mae += 1
met_df = pd.DataFrame(metrics)
mb.comment = f'best val_loss: {met_df.valid_loss.min(): .3f}, current val_loss: {m.valid_loss:.3f}, current train loss: {m.train_loss: .3f}'
met_df.round(6).to_csv(f'{args.save}/metrics.csv')
if epochs_since_best_mae >= args.es_patience: break
# Metrics on test data
engine.model.load_state_dict(torch.load(best_model_save_path))
realy = torch.Tensor(data['y_test']).transpose(1, 3)[:, 0, :, :].to(device)
test_met_df, yhat = calc_tstep_metrics(engine.model, device, data['test_loader'], scaler, realy, args.seq_length)
test_met_df.round(6).to_csv(os.path.join(args.save, 'test_metrics.csv'))
print(summary(args.save))
def eval_(ds, device, engine):
"""Run validation."""
valid_loss = []
valid_mape = []
valid_rmse = []
s1 = time.time()
for (x, y) in ds.get_iterator():
testx = torch.Tensor(x).to(device).transpose(1, 3)
testy = torch.Tensor(y).to(device).transpose(1, 3)
metrics = engine.eval(testx, testy[:, 0, :, :])
valid_loss.append(metrics[0])
valid_mape.append(metrics[1])
valid_rmse.append(metrics[2])
total_time = time.time() - s1
return total_time, valid_loss, valid_mape, valid_rmse
if __name__ == "__main__":
parser = util.get_shared_arg_parser()
parser.add_argument('--epochs', type=int, default=100, help='')
parser.add_argument('--clip', type=int, default=3, help='Gradient Clipping')
parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay rate')
parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate')
parser.add_argument('--lr_decay_rate', type=float, default=0.97, help='learning rate')
parser.add_argument('--save', type=str, default='experiment', help='save path')
parser.add_argument('--n_iters', default=None, help='quit after this many iterations')
parser.add_argument('--es_patience', type=int, default=20, help='quit if no improvement after this many iterations')
args = parser.parse_args()
t1 = time.time()
if not os.path.exists(args.save):
os.mkdir(args.save)
pickle_save(args, f'{args.save}/args.pkl')
main(args)
t2 = time.time()
mins = (t2 - t1) / 60
print(f"Total time spent: {mins:.2f} seconds")