forked from yoshall/AirFormer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
helper.py
64 lines (56 loc) · 2.46 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from torch.utils.data import DataLoader, TensorDataset
import torch
from torch import Tensor
import numpy as np
import os
from src.utils.scaler import StandardScaler
def get_dataloader(datapath, batch_size, output_dim, mode='train'):
'''
get data loader from preprocessed data
'''
data = {}
processed = {}
results = {}
for category in ['train', 'val', 'test']:
cat_data = np.load(os.path.join(datapath, category + '.npz'))
data['x_' + category] = cat_data['x']
data['y_' + category] = cat_data['y']
scalers = []
for i in range(output_dim):
scalers.append(StandardScaler(mean=data['x_train'][..., i].mean(),
std=data['x_train'][..., i].std()))
# Data format
for category in ['train', 'val', 'test']:
# normalize the target series (generally, one kind of series)
for i in range(output_dim):
data['x_' + category][..., i] = scalers[i].transform(data['x_' + category][..., i])
data['y_' + category][..., i] = scalers[i].transform(data['y_' + category][..., i])
new_x = Tensor(data['x_' + category])
new_y = Tensor(data['y_' + category])
processed[category] = TensorDataset(new_x, new_y)
results['train_loader'] = DataLoader(processed['train'], batch_size, shuffle=True)
results['val_loader'] = DataLoader(processed['val'], batch_size, shuffle=False)
results['test_loader'] = DataLoader(processed['test'], batch_size, shuffle=False)
print('train: {}\t valid: {}\t test:{}'.format(len(results['train_loader'].dataset),
len(results['val_loader'].dataset),
len(results['test_loader'].dataset)))
results['scalers'] = scalers
return results
def check_device(device=None):
if device is None:
print("`device` is missing, try to train and evaluate the model on default device.")
if torch.cuda.is_available():
print("cuda device is available, place the model on the device.")
return torch.device("cuda")
else:
print("cuda device is not available, place the model on cpu.")
return torch.device("cpu")
else:
if isinstance(device, torch.device):
return device
else:
return torch.device(device)
def get_num_nodes(dataset):
d = {'AIR_TINY': 1085}
assert dataset in d.keys()
return d[dataset]