-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
116 lines (105 loc) · 4.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import re
import os
import sys
import torch
import logging
import argparse
from io import StringIO
from pathlib import Path
from pprint import pprint
from utils.data_utils import get_dataloader_and_adj_mx
from model.trainer import model_train
def masked_mae_loss(y_pred, y_true):
mask = (y_true != 0).float()
mask /= mask.mean()
loss = torch.abs(y_pred - y_true)
loss[loss != loss] = 0.
return loss.mean()
def get_run_num(log_dir="outputs/tensorboard"):
p = re.compile(r'run_\d+$')
files = (int(file.split('_')[1]) for file in os.listdir(log_dir) if p.match(file))
run_num = 1
try:
run_num = 1 + max(files)
except ValueError as e:
pass
return run_num
def config_logging(run_num, log_dir='logs'):
path = Path(log_dir, f'run_{run_num}')
path.mkdir(exist_ok=True)
logger = logging.getLogger()
file = Path(path, 'log.txt')
fh = logging.FileHandler(file)
fh.setLevel(logging.INFO)
logger.setLevel(logging.INFO)
logger.addHandler(fh)
sys.stderr = open(Path(log_dir, f'run_{run_num}', 'stderr.txt'), 'w')
return logger
def main():
# ---------------- parser setup -------------------
parser = argparse.ArgumentParser(description='Train the model')
parser.add_argument('--traffic_path', type=str, required=True,
help='path to traffic data (pkl gz format)')
parser.add_argument('--lipschitz_path', type=str, required=True,
help='path to lipschitz data (npz)')
parser.add_argument('--adj_path', type=str, required=True,
help='path to adjacency data (pickle)')
parser.add_argument('--seen_path', type=str, required=True,
help='path to seen nodes index (npy)')
parser.add_argument('--keep_tod', default=False, action='store_true',
help='whether to keep time of day (boolean flag)')
parser.add_argument('--future', type=int, default=12,
help='how far in the future to predict')
parser.add_argument('--past', type=int, default=12,
help='how far in the past to look')
parser.add_argument('--nepochs', type=int, required=True,
help='number of epochs')
parser.add_argument('--nlayers', type=int, default=10,
help='number of layers used in the GNN')
parser.add_argument('--gnn_input_dim', type=int, required=True,
help='number of input dimensions taken by gnn')
parser.add_argument('--gnn_hidden_dim', type=int, required=True,
help='number of hidden dimensions of gnn')
parser.add_argument('--enc_input_dim', type=int, required=True,
help='number of input dimensions taken by lstm\'s encoder')
parser.add_argument('--enc_hidden_dim', type=int, required=True,
help='number of hidden dimensions of lstm encoder')
parser.add_argument('--dec_hidden_dim', type=int, required=True,
help='number of hidden dimensions of lstm decoder')
parser.add_argument('--output_dim', type=int, required=True,
help='number of output dimensions')
parser.print_usage = parser.print_help
pargs = parser.parse_args()
# ---------------------------------------------------
model_args = {
'gnn_input_dim':pargs.gnn_input_dim,
'gnn_hidden_dim':pargs.gnn_hidden_dim,
'enc_input_dim':pargs.enc_input_dim,
'enc_hidden_dim':pargs.enc_hidden_dim,
'dec_hidden_dim':pargs.dec_hidden_dim,
'output_dim':pargs.output_dim,
'nlayers':pargs.nlayers,
}
run_num = get_run_num()
logger = config_logging(run_num)
print(f"This is run number: {run_num}\n Logs will be saved in logs/run_{run_num}")
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
dataloaders, edge_weight = get_dataloader_and_adj_mx(
pargs.traffic_path,
pargs.lipschitz_path,
pargs.adj_path,
pargs.seen_path,
keep_tod=pargs.keep_tod,
f=pargs.future,
p=pargs.past,
nlayers=pargs.nlayers,
)
with StringIO() as s:
pprint(vars(pargs), stream=s, indent=4)
logger.info(s.getvalue())
edge_weight = edge_weight.to(device).to(torch.float32)
model_train(model_args, device, pargs.nepochs, dataloaders, edge_weight,
#masked_mae_loss, run_num, logger) # doesn't help
torch.nn.L1Loss(), run_num, logger)
if __name__=="__main__":
main()