-
Notifications
You must be signed in to change notification settings - Fork 123
/
Copy pathsimple_tsf_runner.py
161 lines (119 loc) · 6.01 KB
/
simple_tsf_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from typing import Dict
import torch
from ..base_tsf_runner import BaseTimeSeriesForecastingRunner
class SimpleTimeSeriesForecastingRunner(BaseTimeSeriesForecastingRunner):
"""
A Simple Runner for Time Series Forecasting:
Selects forward and target features. This runner is designed to handle most cases.
Args:
cfg (Dict): Configuration dictionary.
"""
def __init__(self, cfg: Dict):
super().__init__(cfg)
self.forward_features = cfg['MODEL'].get('FORWARD_FEATURES', None)
self.target_features = cfg['MODEL'].get('TARGET_FEATURES', None)
self.target_time_series = cfg['MODEL'].get('TARGET_TIME_SERIES', None)
def preprocessing(self, input_data: Dict) -> Dict:
"""Preprocess data.
Args:
input_data (Dict): Dictionary containing data to be processed.
Returns:
Dict: Processed data.
"""
if self.scaler is not None:
input_data['target'] = self.scaler.transform(input_data['target'])
input_data['inputs'] = self.scaler.transform(input_data['inputs'])
# TODO: add more preprocessing steps as needed.
return input_data
def postprocessing(self, input_data: Dict) -> Dict:
"""Postprocess data.
Args:
input_data (Dict): Dictionary containing data to be processed.
Returns:
Dict: Processed data.
"""
# rescale data
if self.scaler is not None and self.scaler.rescale:
input_data['prediction'] = self.scaler.inverse_transform(input_data['prediction'])
input_data['target'] = self.scaler.inverse_transform(input_data['target'])
input_data['inputs'] = self.scaler.inverse_transform(input_data['inputs'])
# subset forecasting
if self.target_time_series is not None:
input_data['target'] = input_data['target'][:, :, self.target_time_series, :]
input_data['prediction'] = input_data['prediction'][:, :, self.target_time_series, :]
# TODO: add more postprocessing steps as needed.
return input_data
def forward(self, data: Dict, epoch: int = None, iter_num: int = None, train: bool = True, **kwargs) -> Dict:
"""
Performs the forward pass for training, validation, and testing.
Args:
data (Dict): A dictionary containing 'target' (future data) and 'inputs' (history data) (normalized by self.scaler).
epoch (int, optional): Current epoch number. Defaults to None.
iter_num (int, optional): Current iteration number. Defaults to None.
train (bool, optional): Indicates whether the forward pass is for training. Defaults to True.
Returns:
Dict: A dictionary containing the keys:
- 'inputs': Selected input features.
- 'prediction': Model predictions.
- 'target': Selected target features.
Raises:
AssertionError: If the shape of the model output does not match [B, L, N].
"""
data = self.preprocessing(data)
# Preprocess input data
future_data, history_data = data['target'], data['inputs']
history_data = self.to_running_device(history_data) # Shape: [B, L, N, C]
future_data = self.to_running_device(future_data) # Shape: [B, L, N, C]
batch_size, length, num_nodes, _ = future_data.shape
# Select input features
history_data = self.select_input_features(history_data)
future_data_4_dec = self.select_input_features(future_data)
if not train:
# For non-training phases, use only temporal features
future_data_4_dec[..., 0] = torch.empty_like(future_data_4_dec[..., 0])
# Forward pass through the model
model_return = self.model(history_data=history_data, future_data=future_data_4_dec,
batch_seen=iter_num, epoch=epoch, train=train)
# Parse model return
if isinstance(model_return, torch.Tensor):
model_return = {'prediction': model_return}
if 'inputs' not in model_return:
model_return['inputs'] = self.select_target_features(history_data)
if 'target' not in model_return:
model_return['target'] = self.select_target_features(future_data)
# Ensure the output shape is correct
assert list(model_return['prediction'].shape)[:3] == [batch_size, length, num_nodes], \
"The shape of the output is incorrect. Ensure it matches [B, L, N, C]."
model_return = self.postprocessing(model_return)
return model_return
def select_input_features(self, data: torch.Tensor) -> torch.Tensor:
"""
Selects input features based on the forward features specified in the configuration.
Args:
data (torch.Tensor): Input history data with shape [B, L, N, C1].
Returns:
torch.Tensor: Data with selected features with shape [B, L, N, C2].
"""
if self.forward_features is not None:
data = data[:, :, :, self.forward_features]
return data
def select_target_features(self, data: torch.Tensor) -> torch.Tensor:
"""
Selects target features based on the target features specified in the configuration.
Args:
data (torch.Tensor): Model prediction data with shape [B, L, N, C1].
Returns:
torch.Tensor: Data with selected target features and shape [B, L, N, C2].
"""
data = data[:, :, :, self.target_features]
return data
def select_target_time_series(self, data: torch.Tensor) -> torch.Tensor:
"""
Select target time series based on the target time series specified in the configuration.
Args:
data (torch.Tensor): Model prediction data with shape [B, L, N1, C].
Returns:
torch.Tensor: Data with selected target time series and shape [B, L, N2, C].
"""
data = data[:, :, self.target_time_series, :]
return data