-
Notifications
You must be signed in to change notification settings - Fork 120
/
base_tsf_runner.py
457 lines (361 loc) · 19.3 KB
/
base_tsf_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
import functools
import inspect
import json
import math
import os
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
from easydict import EasyDict
from easytorch.utils import master_only
from tqdm import tqdm
from ..metrics import (masked_mae, masked_mape, masked_mse, masked_rmse,
masked_wape)
from .base_epoch_runner import BaseEpochRunner
class BaseTimeSeriesForecastingRunner(BaseEpochRunner):
"""
Runner for multivariate time series forecasting tasks.
Features:
- Supports evaluation at pre-defined horizons (optional) and overall performance assessment.
- Metrics: MAE, RMSE, MAPE, WAPE, and MSE. Customizable. The best model is selected based on the smallest MAE on the validation set.
- Supports `setup_graph` for models that operate similarly to TensorFlow.
- Default loss function is MAE (masked_mae), but it can be customized.
- Supports curriculum learning.
- Users only need to implement the `forward` function.
Customization:
- Model:
- Args:
- history_data (torch.Tensor): Historical data with shape [B, L, N, C],
where B is the batch size, L is the sequence length, N is the number of nodes,
and C is the number of features.
- future_data (torch.Tensor or None): Future data with shape [B, L, N, C].
Can be None if there is no future data available.
- batch_seen (int): The number of batches seen so far.
- epoch (int): The current epoch number.
- train (bool): Indicates whether the model is in training mode.
- Return:
- Dict or torch.Tensor:
- If returning a Dict, it must contain the 'prediction' key. Other keys are optional and will be passed to the loss and metric functions.
- If returning a torch.Tensor, it should represent the model's predictions, with shape [B, L, N, C].
- Loss & Metrics (optional):
- Args:
- prediction (torch.Tensor): Model's predictions, with shape [B, L, N, C].
- target (torch.Tensor): Ground truth data, with shape [B, L, N, C].
- null_val (float): The value representing missing data in the dataset.
- Other args (optional): Additional arguments will be matched with keys in the model's return dictionary, if applicable.
- Return:
- torch.Tensor: The computed loss or metric value.
- Dataset (optional):
- Return: The returned data will be passed to the `forward` function as the `data` argument.
"""
def __init__(self, cfg: Dict):
super().__init__(cfg)
# setup graph flag
self.need_setup_graph = cfg['MODEL'].get('SETUP_GRAPH', False)
# initialize scaler
self.scaler = self.build_scaler(cfg)
# define loss function
self.loss = cfg['TRAIN']['LOSS']
# define metrics
self.metrics = cfg.get('METRICS', {}).get('FUNCS', {
'MAE': masked_mae,
'RMSE': masked_rmse,
'MAPE': masked_mape,
'WAPE': masked_wape,
'MSE': masked_mse
})
self.target_metrics = cfg.get('METRICS', {}).get('TARGET', 'loss')
self.metrics_best = cfg.get('METRICS', {}).get('BEST', 'min')
assert self.target_metrics in self.metrics or self.target_metrics == 'loss', f'Target metric {self.target_metrics} not found in metrics.'
assert self.metrics_best in ['min', 'max'], f'Invalid best metric {self.metrics_best}.'
# handle null values in datasets, e.g., 0.0 or np.nan.
self.null_val = cfg.get('METRICS', {}).get('NULL_VAL', np.nan)
# curriculum learning setup
self.cl_param = cfg['TRAIN'].get('CL', None)
if self.cl_param is not None:
self.warm_up_epochs = cfg['TRAIN'].CL.get('WARM_EPOCHS', 0)
self.cl_epochs = cfg['TRAIN'].CL.get('CL_EPOCHS')
self.prediction_length = cfg['TRAIN'].CL.get('PREDICTION_LENGTH')
self.cl_step_size = cfg['TRAIN'].CL.get('STEP_SIZE', 1)
# Eealuation settings
self.if_evaluate_on_gpu = cfg.get('EVAL', EasyDict()).get('USE_GPU', True)
self.evaluation_horizons = [_ - 1 for _ in cfg.get('EVAL', EasyDict()).get('HORIZONS', [])]
assert len(self.evaluation_horizons) == 0 or min(self.evaluation_horizons) >= 0, 'The horizon should start counting from 1.'
def build_scaler(self, cfg: Dict):
"""Build scaler.
Args:
cfg (Dict): Configuration.
Returns:
Scaler instance or None if no scaler is declared.
"""
if 'SCALER' in cfg:
return cfg['SCALER']['TYPE'](**cfg['SCALER']['PARAM'])
return None
def setup_graph(self, cfg: Dict, train: bool):
"""Setup all parameters and the computation graph.
Some models (e.g., DCRNN, GTS) require creating parameters during the first forward pass, similar to TensorFlow.
Args:
cfg (Dict): Configuration.
train (bool): Whether the setup is for training or inference.
"""
dataloader = self.build_test_data_loader(cfg=cfg) if not train else self.build_train_data_loader(cfg=cfg)
data = next(iter(dataloader)) # get the first batch
self.forward(data=data, epoch=1, iter_num=0, train=train)
def count_parameters(self):
"""Count the number of parameters in the model."""
num_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f'Number of parameters: {num_parameters}')
def init_training(self, cfg: Dict):
"""Initialize training components, including loss, meters, etc.
Args:
cfg (Dict): Configuration.
"""
if self.need_setup_graph:
self.setup_graph(cfg=cfg, train=True)
self.need_setup_graph = False
super().init_training(cfg)
self.count_parameters()
self.register_epoch_meter('train/loss', 'train', '{:.4f}')
for key in self.metrics:
self.register_epoch_meter(f'train/{key}', 'train', '{:.4f}')
def init_validation(self, cfg: Dict):
"""Initialize validation components, including meters.
Args:
cfg (Dict): Configuration.
"""
super().init_validation(cfg)
self.register_epoch_meter('val/loss', 'val', '{:.4f}')
for key in self.metrics:
self.register_epoch_meter(f'val/{key}', 'val', '{:.4f}')
def init_test(self, cfg: Dict):
"""Initialize test components, including meters.
Args:
cfg (Dict): Configuration.
"""
if self.need_setup_graph:
self.setup_graph(cfg=cfg, train=False)
self.need_setup_graph = False
super().init_test(cfg)
self.register_epoch_meter('test/loss', 'test', '{:.4f}')
for key in self.metrics:
self.register_epoch_meter(f'test/{key}', 'test', '{:.4f}')
def build_train_dataset(self, cfg: Dict):
"""Build the training dataset.
Args:
cfg (Dict): Configuration.
Returns:
Dataset: The constructed training dataset.
"""
if 'DATASET' not in cfg:
# TODO: support building different datasets for training, validation, and test.
if 'logger' in inspect.signature(cfg['TRAIN']['DATA']['DATASET']['TYPE'].__init__).parameters:
cfg['TRAIN']['DATA']['DATASET']['PARAM']['logger'] = self.logger
if 'mode' in inspect.signature(cfg['TRAIN']['DATA']['DATASET']['TYPE'].__init__).parameters:
cfg['TRAIN']['DATA']['DATASET']['PARAM']['mode'] = 'train'
dataset = cfg['TRAIN']['DATA']['DATASET']['TYPE'](**cfg['TRAIN']['DATA']['DATASET']['PARAM'])
self.logger.info(f'Train dataset length: {len(dataset)}')
batch_size = cfg['TRAIN']['DATA']['BATCH_SIZE']
self.iter_per_epoch = math.ceil(len(dataset) / batch_size)
else:
dataset = cfg['DATASET']['TYPE'](mode='train', logger=self.logger, **cfg['DATASET']['PARAM'])
self.logger.info(f'Train dataset length: {len(dataset)}')
batch_size = cfg['TRAIN']['DATA']['BATCH_SIZE']
self.iter_per_epoch = math.ceil(len(dataset) / batch_size)
return dataset
def build_val_dataset(self, cfg: Dict):
"""Build the validation dataset.
Args:
cfg (Dict): Configuration.
Returns:
Dataset: The constructed validation dataset.
"""
if 'DATASET' not in cfg:
# TODO: support building different datasets for training, validation, and test.
if 'logger' in inspect.signature(cfg['VAL']['DATA']['DATASET']['TYPE'].__init__).parameters:
cfg['VAL']['DATA']['DATASET']['PARAM']['logger'] = self.logger
if 'mode' in inspect.signature(cfg['VAL']['DATA']['DATASET']['TYPE'].__init__).parameters:
cfg['VAL']['DATA']['DATASET']['PARAM']['mode'] = 'valid'
dataset = cfg['VAL']['DATA']['DATASET']['TYPE'](**cfg['VAL']['DATA']['DATASET']['PARAM'])
self.logger.info(f'Validation dataset length: {len(dataset)}')
else:
dataset = cfg['DATASET']['TYPE'](mode='valid', logger=self.logger, **cfg['DATASET']['PARAM'])
self.logger.info(f'Validation dataset length: {len(dataset)}')
return dataset
def build_test_dataset(self, cfg: Dict):
"""Build the test dataset.
Args:
cfg (Dict): Configuration.
Returns:
Dataset: The constructed test dataset.
"""
if 'DATASET' not in cfg:
# TODO: support building different datasets for training, validation, and test.
if 'logger' in inspect.signature(cfg['TEST']['DATA']['DATASET']['TYPE'].__init__).parameters:
cfg['TEST']['DATA']['DATASET']['PARAM']['logger'] = self.logger
if 'mode' in inspect.signature(cfg['TEST']['DATA']['DATASET']['TYPE'].__init__).parameters:
cfg['TEST']['DATA']['DATASET']['PARAM']['mode'] = 'test'
dataset = cfg['TEST']['DATA']['DATASET']['TYPE'](**cfg['TEST']['DATA']['DATASET']['PARAM'])
self.logger.info(f'Test dataset length: {len(dataset)}')
else:
dataset = cfg['DATASET']['TYPE'](mode='test', logger=self.logger, **cfg['DATASET']['PARAM'])
self.logger.info(f'Test dataset length: {len(dataset)}')
return dataset
def curriculum_learning(self, epoch: int = None) -> int:
"""Calculate task level for curriculum learning.
Args:
epoch (int, optional): Current epoch if in training process; None otherwise. Defaults to None.
Returns:
int: Task level for the current epoch.
"""
if epoch is None:
return self.prediction_length
epoch -= 1
# generate curriculum length
if epoch < self.warm_up_epochs:
# still in warm-up phase
cl_length = self.prediction_length
else:
progress = ((epoch - self.warm_up_epochs) // self.cl_epochs + 1) * self.cl_step_size
cl_length = min(progress, self.prediction_length)
return cl_length
def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: bool = True, **kwargs) -> Dict:
"""
Performs the forward pass for training, validation, and testing.
Note: The outputs are not re-scaled.
Args:
data (Dict): A dictionary containing 'target' (future data) and 'inputs' (history data) (normalized by self.scaler).
epoch (int, optional): Current epoch number. Defaults to None.
iter_num (int, optional): Current iteration number. Defaults to None.
train (bool, optional): Indicates whether the forward pass is for training. Defaults to True.
Returns:
Dict: A dictionary containing the keys:
- 'inputs': Selected input features.
- 'prediction': Model predictions.
- 'target': Selected target features.
Raises:
AssertionError: If the shape of the model output does not match [B, L, N].
"""
raise NotImplementedError()
def metric_forward(self, metric_func, args: Dict) -> torch.Tensor:
"""Compute metrics using the given metric function.
Args:
metric_func (function or functools.partial): Metric function.
args (Dict): Arguments for metrics computation.
Returns:
torch.Tensor: Computed metric value.
"""
covariate_names = inspect.signature(metric_func).parameters.keys()
args = {k: v for k, v in args.items() if k in covariate_names}
if isinstance(metric_func, functools.partial):
if 'null_val' not in metric_func.keywords and 'null_val' in covariate_names: # null_val is required but not provided
args['null_val'] = self.null_val
metric_item = metric_func(**args)
elif callable(metric_func):
if 'null_val' in covariate_names: # null_val is required
args['null_val'] = self.null_val
metric_item = metric_func(**args)
else:
raise TypeError(f'Unknown metric type: {type(metric_func)}')
return metric_item
def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor:
"""Training iteration process.
Args:
epoch (int): Current epoch.
iter_index (int): Current iteration index.
data (Union[torch.Tensor, Tuple]): Data provided by DataLoader.
Returns:
torch.Tensor: Loss value.
"""
iter_num = (epoch - 1) * self.iter_per_epoch + iter_index
forward_return = self.forward(data=data, epoch=epoch, iter_num=iter_num, train=True)
if self.cl_param:
cl_length = self.curriculum_learning(epoch=epoch)
forward_return['prediction'] = forward_return['prediction'][:, :cl_length, :, :]
forward_return['target'] = forward_return['target'][:, :cl_length, :, :]
loss = self.metric_forward(self.loss, forward_return)
self.update_epoch_meter('train/loss', loss.item())
for metric_name, metric_func in self.metrics.items():
metric_item = self.metric_forward(metric_func, forward_return)
self.update_epoch_meter(f'train/{metric_name}', metric_item.item())
return loss
def val_iters(self, iter_index: int, data: Union[torch.Tensor, Tuple]):
"""Validation iteration process.
Args:
iter_index (int): Current iteration index.
data (Union[torch.Tensor, Tuple]): Data provided by DataLoader.
"""
forward_return = self.forward(data=data, epoch=None, iter_num=iter_index, train=False)
loss = self.metric_forward(self.loss, forward_return)
self.update_epoch_meter('val/loss', loss.item())
for metric_name, metric_func in self.metrics.items():
metric_item = self.metric_forward(metric_func, forward_return)
self.update_epoch_meter(f'val/{metric_name}', metric_item.item())
def compute_evaluation_metrics(self, returns_all: Dict):
"""Compute metrics for evaluating model performance during the test process.
Args:
returns_all (Dict): Must contain keys: inputs, prediction, target.
"""
metrics_results = {}
for i in self.evaluation_horizons:
pred = returns_all['prediction'][:, i, :, :]
real = returns_all['target'][:, i, :, :]
metrics_results[f'horizon_{i + 1}'] = {}
metric_repr = ''
for metric_name, metric_func in self.metrics.items():
if metric_name.lower() == 'mase':
continue # MASE needs to be calculated after all horizons
metric_item = self.metric_forward(metric_func, {'prediction': pred, 'target': real})
metric_repr += f', Test {metric_name}: {metric_item.item():.4f}'
metrics_results[f'horizon_{i + 1}'][metric_name] = metric_item.item()
self.logger.info(f'Evaluate best model on test data for horizon {i + 1}{metric_repr}')
metrics_results['overall'] = {}
for metric_name, metric_func in self.metrics.items():
metric_item = self.metric_forward(metric_func, returns_all)
self.update_epoch_meter(f'test/{metric_name}', metric_item.item())
metrics_results['overall'][metric_name] = metric_item.item()
return metrics_results
@torch.no_grad()
@master_only
def test(self, train_epoch: Optional[int] = None, save_metrics: bool = False, save_results: bool = False) -> Dict:
"""Test process.
Args:
train_epoch (Optional[int]): Current epoch if in training process.
save_metrics (bool): Save the test metrics. Defaults to False.
save_results (bool): Save the test results. Defaults to False.
"""
prediction, target, inputs = [], [], []
for data in tqdm(self.test_data_loader):
forward_return = self.forward(data, epoch=None, iter_num=None, train=False)
loss = self.metric_forward(self.loss, forward_return)
self.update_epoch_meter('test/loss', loss.item())
if not self.if_evaluate_on_gpu:
forward_return['prediction'] = forward_return['prediction'].detach().cpu()
forward_return['target'] = forward_return['target'].detach().cpu()
forward_return['inputs'] = forward_return['inputs'].detach().cpu()
prediction.append(forward_return['prediction'])
target.append(forward_return['target'])
inputs.append(forward_return['inputs'])
prediction = torch.cat(prediction, dim=0)
target = torch.cat(target, dim=0)
inputs = torch.cat(inputs, dim=0)
returns_all = {'prediction': prediction, 'target': target, 'inputs': inputs}
metrics_results = self.compute_evaluation_metrics(returns_all)
# save
if save_results:
# save returns_all to self.ckpt_save_dir/test_results.npz
test_results = {k: v.cpu().numpy() for k, v in returns_all.items()}
np.savez(os.path.join(self.ckpt_save_dir, 'test_results.npz'), **test_results)
if save_metrics:
# save metrics_results to self.ckpt_save_dir/test_metrics.json
with open(os.path.join(self.ckpt_save_dir, 'test_metrics.json'), 'w') as f:
json.dump(metrics_results, f, indent=4)
return returns_all
@master_only
def on_validating_end(self, train_epoch: Optional[int]):
"""Callback at the end of the validation process.
Args:
train_epoch (Optional[int]): Current epoch if in training process.
"""
greater_best = not self.metrics_best == 'min'
if train_epoch is not None:
self.save_best_model(train_epoch, 'val/' + self.target_metrics, greater_best=greater_best)