-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
61 lines (57 loc) · 2.91 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from model.pytorch.supervisor import LSCGFSupervisor
parser = argparse.ArgumentParser()
parser.add_argument('--config_filename', default='data/model/para-bay', type=str)
parser.add_argument('--temperature', default=0.5, type=float, help='temperature value for gumbel-softmax.')
# basic settings
parser.add_argument('--device',default='cuda:1',type=str)
parser.add_argument('--log_dir',default='data/model',type=str,help='')
parser.add_argument('--log_level',default='INFO',type=str)
parser.add_argument('--log_every',default=1,type=int)
parser.add_argument('--save_model',default=0,type=int)
#data settings
parser.add_argument('--batch_size',default=64,type=int)
parser.add_argument('--dataset_dir',default='data/METR-LA',type=str)
parser.add_argument('--test_batch_size',default=64,type=int)
parser.add_argument('--valid_batch_size',default=64,type=int)
# model settings
parser.add_argument('--cl_decay_steps',default=2000,type=int)
parser.add_argument('--filter_type',default='dual_random_walk',type=str)
parser.add_argument('--horizon',default=12,type=int)
parser.add_argument('--input_dim',default=2,type=int)
parser.add_argument('--ll_decay',default=0,type=int)
parser.add_argument('--max_diffusion_step',default=2,type=int)
parser.add_argument('--num_rnn_layers',default=1,type=int)
parser.add_argument('--output_dim',default=1,type=int)
parser.add_argument('--rnn_units',default=64,type=int)
parser.add_argument('--seq_len',default=12,type=int)
parser.add_argument('--use_curriculum_learning',default=True,type=bool)
parser.add_argument('--embedding_size',default=256,type=int)
parser.add_argument('--kernel_size',default=12,type=int)
parser.add_argument('--freq',default=288,type=int)
parser.add_argument('--requires_graph',default=2,type=int)
# train settings0
parser.add_argument('--base_lr',default=0.003,type=float)
parser.add_argument('--dropout',default=0.3,type=float)
parser.add_argument('--epoch',default=0,type=int)
parser.add_argument('--epochs',default=200,type=int)
parser.add_argument('--epsilon',default=1.0e-3,type=float)
parser.add_argument('--global_step',default=0,type=int)
parser.add_argument('--lr_decay_ratio',default=0.1,type=float)
parser.add_argument('--max_grad_norm',default=5,type=int)
parser.add_argument('--max_to_keep',default=100,type=int)
parser.add_argument('--min_learning_rate',default=2.0e-05,type=float)
parser.add_argument('--optimizer',default='adam',type=str)
parser.add_argument('--patience',default=50,type=int)
parser.add_argument('--steps',default=[20, 30, 40],type=list)
parser.add_argument('--test_every_n_epochs', default=5, type=int)
parser.add_argument('--num_sample', default=10, type=int)
args = parser.parse_args()
if __name__ == '__main__':
print(args)
save_adj_name = args.config_filename[11:-5]
supervisor = LSCGFSupervisor(save_adj_name, args=args)
supervisor.train(args)