Skip to content

Commit

Permalink
refactor: 💡 update the rescale_data function
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Dec 4, 2023
1 parent 88e0bf2 commit 9fed131
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 32 deletions.
6 changes: 3 additions & 3 deletions basicts/losses/losses.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import torch
import numpy as np
import torch
import torch.nn.functional as F

from ..utils import check_nan_inf


def l1_loss(input_data, target_data, **kwargs):
def l1_loss(input_data, target_data):
"""unmasked mae."""

return F.l1_loss(input_data, target_data)


def l2_loss(input_data, target_data, **kwargs):
def l2_loss(input_data, target_data):
"""unmasked mse"""

check_nan_inf(input_data)
Expand Down
85 changes: 57 additions & 28 deletions basicts/runners/base_tsf_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import functools
from typing import Tuple, Union, Optional
from typing import Tuple, Union, Optional, List

import torch
import numpy as np
Expand Down Expand Up @@ -37,7 +37,11 @@ def __init__(self, cfg: dict):
self.need_setup_graph = cfg["MODEL"].get("SETUP_GRAPH", False)

# read scaler for re-normalization
self.scaler = load_pkl("{0}/scaler_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["TRAIN"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True)))
self.scaler = load_pkl("{0}/scaler_in_{1}_out_{2}_rescale_{3}.pkl".format(
cfg["TRAIN"]["DATA"]["DIR"],
cfg["DATASET_INPUT_LEN"],
cfg["DATASET_OUTPUT_LEN"],
cfg.get("RESCALE", True)))
# define loss
self.loss = cfg["TRAIN"]["LOSS"]
# define metric
Expand Down Expand Up @@ -130,7 +134,7 @@ def build_train_dataset(self, cfg: dict):
2. Normalize on EACH channel (i.e., calculate the mean and std of each channel).
The reason why there are two different preprocessing methods is that each channel of the dataset may have a different value range.
1. Normalizing the WHOLE data set will preserve the relative size relationship between channels.
1. Normalizing the WHOLE data set will preserve the relative size relationship between channels.
Larger channels usually produce larger loss values, so more attention will be paid to these channels when optimizing the model.
Therefore, this approach will achieve better performance when we evaluate on the rescaled dataset.
For example, when evaluating rescaled data for two channels with values in the range [0, 1], [9000, 10000], the prediction on channel [0,1] is trivial.
Expand All @@ -143,7 +147,8 @@ def build_train_dataset(self, cfg: dict):
For example, the first approach is often adopted in the field of Spatial-Temporal Forecasting (STF).
The second approach is often adopted in the field of Long-term Time Series Forecasting (LTSF).
To avoid confusion for users and facilitate them to obtain results comparable to existing studies, we automatically select data based on the cfg.get("RESCALE") flag (default to True).
To avoid confusion for users and facilitate them to obtain results comparable to existing studies, we
automatically select data based on the cfg.get("RESCALE") flag (default to True).
if_rescale == True: use the data that is normalized across the WHOLE dataset
if_rescale == False: use the data that is normalized on EACH channel
Expand All @@ -153,8 +158,16 @@ def build_train_dataset(self, cfg: dict):
Returns:
train dataset (Dataset)
"""
data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["TRAIN"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True))
index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["TRAIN"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True))
data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format(
cfg["TRAIN"]["DATA"]["DIR"],
cfg["DATASET_INPUT_LEN"],
cfg["DATASET_OUTPUT_LEN"],
cfg.get("RESCALE", True))
index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format(
cfg["TRAIN"]["DATA"]["DIR"],
cfg["DATASET_INPUT_LEN"],
cfg["DATASET_OUTPUT_LEN"],
cfg.get("RESCALE", True))

# build dataset args
dataset_args = cfg.get("DATASET_ARGS", {})
Expand Down Expand Up @@ -182,8 +195,16 @@ def build_val_dataset(cfg: dict):
validation dataset (Dataset)
"""
# see build_train_dataset for details
data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["VAL"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True))
index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["VAL"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True))
data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format(
cfg["VAL"]["DATA"]["DIR"],
cfg["DATASET_INPUT_LEN"],
cfg["DATASET_OUTPUT_LEN"],
cfg.get("RESCALE", True))
index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format(
cfg["VAL"]["DATA"]["DIR"],
cfg["DATASET_INPUT_LEN"],
cfg["DATASET_OUTPUT_LEN"],
cfg.get("RESCALE", True))

# build dataset args
dataset_args = cfg.get("DATASET_ARGS", {})
Expand All @@ -207,8 +228,16 @@ def build_test_dataset(cfg: dict):
Returns:
train dataset (Dataset)
"""
data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["TEST"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True))
index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format(cfg["TEST"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"], cfg.get("RESCALE", True))
data_file_path = "{0}/data_in_{1}_out_{2}_rescale_{3}.pkl".format(
cfg["TEST"]["DATA"]["DIR"],
cfg["DATASET_INPUT_LEN"],
cfg["DATASET_OUTPUT_LEN"],
cfg.get("RESCALE", True))
index_file_path = "{0}/index_in_{1}_out_{2}_rescale_{3}.pkl".format(
cfg["TEST"]["DATA"]["DIR"],
cfg["DATASET_INPUT_LEN"],
cfg["DATASET_OUTPUT_LEN"],
cfg.get("RESCALE", True))

# build dataset args
dataset_args = cfg.get("DATASET_ARGS", {})
Expand Down Expand Up @@ -277,17 +306,21 @@ def metric_forward(self, metric_func, args):
raise TypeError("Unknown metric type: {0}".format(type(metric_func)))
return metric_item

def rescale_data(self, data: torch.Tensor) -> torch.Tensor:
def rescale_data(self, input_data: List[torch.Tensor]) -> List[torch.Tensor]:
"""Rescale data.
Args:
data (torch.Tensor): data to be re-scaled.
data (List[torch.Tensor]): list of data to be re-scaled.
Returns:
torch.Tensor: re-scaled data.
List[torch.Tensor]: list of re-scaled data.
"""

return SCALER_REGISTRY.get(self.scaler["func"])(data, **self.scaler["args"])
# prediction, real_value = input_data[:2]
if self.if_rescale:
input_data[0] = SCALER_REGISTRY.get(self.scaler["func"])(input_data[0], **self.scaler["args"])
input_data[1] = SCALER_REGISTRY.get(self.scaler["func"])(input_data[1], **self.scaler["args"])
return input_data

def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor:
"""Training details.
Expand All @@ -304,20 +337,16 @@ def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tup
iter_num = (epoch-1) * self.iter_per_epoch + iter_index
forward_return = list(self.forward(data=data, epoch=epoch, iter_num=iter_num, train=True))
# re-scale data
prediction = self.rescale_data(forward_return[0]) if self.if_rescale else forward_return[0]
real_value = self.rescale_data(forward_return[1]) if self.if_rescale else forward_return[1]
forward_return = self.rescale_data(forward_return)
# loss
if self.cl_param:
cl_length = self.curriculum_learning(epoch=epoch)
forward_return[0] = prediction[:, :cl_length, :, :]
forward_return[1] = real_value[:, :cl_length, :, :]
else:
forward_return[0] = prediction
forward_return[1] = real_value
forward_return[0] = forward_return[0][:, :cl_length, :, :] # prediction
forward_return[1] = forward_return[1][:, :cl_length, :, :] # real_value
loss = self.metric_forward(self.loss, forward_return)
# metrics
for metric_name, metric_func in self.metrics.items():
metric_item = self.metric_forward(metric_func, [prediction, real_value])
metric_item = self.metric_forward(metric_func, forward_return[:2])
self.update_epoch_meter("train_"+metric_name, metric_item.item())
return loss

Expand All @@ -329,13 +358,12 @@ def val_iters(self, iter_index: int, data: Union[torch.Tensor, Tuple]):
data (Union[torch.Tensor, Tuple]): Data provided by DataLoader
"""

forward_return = self.forward(data=data, epoch=None, iter_num=iter_index, train=False)
forward_return = list(self.forward(data=data, epoch=None, iter_num=iter_index, train=False))
# re-scale data
prediction = self.rescale_data(forward_return[0]) if self.if_rescale else forward_return[0]
real_value = self.rescale_data(forward_return[1]) if self.if_rescale else forward_return[1]
forward_return = self.rescale_data(forward_return)
# metrics
for metric_name, metric_func in self.metrics.items():
metric_item = self.metric_forward(metric_func, [prediction, real_value])
metric_item = self.metric_forward(metric_func, forward_return[:2])
self.update_epoch_meter("val_"+metric_name, metric_item.item())

def evaluate(self, prediction, real_value):
Expand Down Expand Up @@ -385,8 +413,9 @@ def test(self):
prediction = torch.cat(prediction, dim=0)
real_value = torch.cat(real_value, dim=0)
# re-scale data
prediction = self.rescale_data(prediction) if self.if_rescale else prediction
real_value = self.rescale_data(real_value) if self.if_rescale else real_value
if self.if_rescale:
prediction = SCALER_REGISTRY.get(self.scaler["func"])(prediction, **self.scaler["args"])
real_value = SCALER_REGISTRY.get(self.scaler["func"])(real_value, **self.scaler["args"])
# evaluate
self.evaluate(prediction, real_value)

Expand Down
2 changes: 1 addition & 1 deletion basicts/utils/xformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def data_transformation_4_xformer(history_data: torch.Tensor, future_data: torch
Args:
history_data (torch.Tensor): history data with shape: [B, L1, N, C].
future_data (torch.Tensor): future data with shape: [B, L2, N, C].
future_data (torch.Tensor): future data with shape: [B, L2, N, C].
L1 and L2 are input sequence length and output sequence length, respectively.
start_token_length (int): length of the decoder start token. Ref: Informer paper.
Expand Down

0 comments on commit 9fed131

Please sign in to comment.