-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
32 lines (24 loc) · 869 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import json
import torch
import warnings
warnings.filterwarnings('ignore')
from dataset import Dataset
from FormerTime import FormerTime
from process import Trainer
from args import args
import torch.utils.data as Data
def main():
torch.set_num_threads(6)
train_dataset = Dataset(device=args.device, mode='train')
train_loader = Data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
args.data_shape = train_dataset.shape()
test_dataset = Dataset(device=args.device, mode='test')
test_loader = Data.DataLoader(test_dataset, batch_size=args.test_batch_size)
print(args.data_shape)
print('dataset initial ends')
model = FormerTime(args)
print('model initial ends')
trainer = Trainer(args, model, train_loader, test_loader, verbose=True)
trainer.train()
if __name__ == '__main__':
main()