From ce0c2092fc7f726c2d9b2cabc9e65af821711f27 Mon Sep 17 00:00:00 2001 From: Zezhi Shao <864453277@qq.com> Date: Tue, 5 Dec 2023 01:29:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=F0=9F=8E=B8=20add=20DeepAR?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Paper: DeepAR: Probabilistic Forecasting with Autoregressive Recurrent Networks; Link: https://arxiv.org/abs/1704.04110; Ref Code: https://github.com/jingw2/demand_forecast, https://github.com/husnejahan/DeepAR-pytorch, https://github.com/arrigonialberto86/deepar. --- baselines/DeepAR/METR-LA.py | 100 +++++++++++++++++++ baselines/DeepAR/PEMS08.py | 100 +++++++++++++++++++ baselines/DeepAR/arch/__init__.py | 1 + baselines/DeepAR/arch/deepar.py | 106 +++++++++++++++++++++ baselines/DeepAR/arch/distributions.py | 22 +++++ baselines/DeepAR/loss/__init__.py | 1 + baselines/DeepAR/loss/gaussian.py | 29 ++++++ baselines/DeepAR/runner/__init__.py | 1 + baselines/DeepAR/runner/deepar_runner.py | 116 +++++++++++++++++++++++ 9 files changed, 476 insertions(+) create mode 100644 baselines/DeepAR/METR-LA.py create mode 100644 baselines/DeepAR/PEMS08.py create mode 100644 baselines/DeepAR/arch/__init__.py create mode 100644 baselines/DeepAR/arch/deepar.py create mode 100644 baselines/DeepAR/arch/distributions.py create mode 100644 baselines/DeepAR/loss/__init__.py create mode 100644 baselines/DeepAR/loss/gaussian.py create mode 100644 baselines/DeepAR/runner/__init__.py create mode 100644 baselines/DeepAR/runner/deepar_runner.py diff --git a/baselines/DeepAR/METR-LA.py b/baselines/DeepAR/METR-LA.py new file mode 100644 index 00000000..af48e695 --- /dev/null +++ b/baselines/DeepAR/METR-LA.py @@ -0,0 +1,100 @@ +import os +import sys + +# TODO: remove it when basicts can be installed by pip +sys.path.append(os.path.abspath(__file__ + "/../../..")) +from easydict import EasyDict +from basicts.data import TimeSeriesForecastingDataset + +from .arch import DeepAR +from .runner import DeepARRunner +from .loss import gaussian_loss + +CFG = EasyDict() + +# ================= general ================= # +CFG.DESCRIPTION = "DeepAR model configuration" +CFG.RUNNER = DeepARRunner +CFG.DATASET_CLS = TimeSeriesForecastingDataset +CFG.DATASET_NAME = "METR-LA" +CFG.DATASET_TYPE = "Traffic speed" +CFG.DATASET_INPUT_LEN = 12 +CFG.DATASET_OUTPUT_LEN = 12 +CFG.GPU_NUM = 1 +CFG.NULL_VAL = 0.0 + +# ================= environment ================= # +CFG.ENV = EasyDict() +CFG.ENV.SEED = 1 +CFG.ENV.CUDNN = EasyDict() +CFG.ENV.CUDNN.ENABLED = True + +# ================= model ================= # +CFG.MODEL = EasyDict() +CFG.MODEL.NAME = "DeepAR" +CFG.MODEL.ARCH = DeepAR +CFG.MODEL.PARAM = { + "cov_feat_size" : 2, + "embedding_size" : 32, + "hidden_size" : 64, + "num_layers": 3, + "use_ts_id" : True, + "id_feat_size": 32, + "num_nodes": 207 +} +CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] +CFG.MODEL.TARGET_FEATURES = [0] + +# ================= optim ================= # +CFG.TRAIN = EasyDict() +CFG.TRAIN.LOSS = gaussian_loss +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = "Adam" +CFG.TRAIN.OPTIM.PARAM= { + "lr":0.003, +} + +# ================= train ================= # +CFG.TRAIN.NUM_EPOCHS = 100 +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + "checkpoints", + "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) +) +# train data +CFG.TRAIN.DATA = EasyDict() +# read data +CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME +# dataloader args, optional +CFG.TRAIN.DATA.BATCH_SIZE = 64 +CFG.TRAIN.DATA.PREFETCH = False +CFG.TRAIN.DATA.SHUFFLE = True +CFG.TRAIN.DATA.NUM_WORKERS = 2 +CFG.TRAIN.DATA.PIN_MEMORY = False + +# ================= validate ================= # +CFG.VAL = EasyDict() +CFG.VAL.INTERVAL = 1 +# validating data +CFG.VAL.DATA = EasyDict() +# read data +CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME +# dataloader args, optional +CFG.VAL.DATA.BATCH_SIZE = 64 +CFG.VAL.DATA.PREFETCH = False +CFG.VAL.DATA.SHUFFLE = False +CFG.VAL.DATA.NUM_WORKERS = 2 +CFG.VAL.DATA.PIN_MEMORY = False + +# ================= test ================= # +CFG.TEST = EasyDict() +CFG.TEST.INTERVAL = 1 +# test data +CFG.TEST.DATA = EasyDict() +# read data +CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME +# dataloader args, optional +CFG.TEST.DATA.BATCH_SIZE = 64 +CFG.TEST.DATA.PREFETCH = False +CFG.TEST.DATA.SHUFFLE = False +CFG.TEST.DATA.NUM_WORKERS = 2 +CFG.TEST.DATA.PIN_MEMORY = False diff --git a/baselines/DeepAR/PEMS08.py b/baselines/DeepAR/PEMS08.py new file mode 100644 index 00000000..719bb948 --- /dev/null +++ b/baselines/DeepAR/PEMS08.py @@ -0,0 +1,100 @@ +import os +import sys + +# TODO: remove it when basicts can be installed by pip +sys.path.append(os.path.abspath(__file__ + "/../../..")) +from easydict import EasyDict +from basicts.data import TimeSeriesForecastingDataset + +from .arch import DeepAR +from .runner import DeepARRunner +from .loss import gaussian_loss + +CFG = EasyDict() + +# ================= general ================= # +CFG.DESCRIPTION = "DeepAR model configuration" +CFG.RUNNER = DeepARRunner +CFG.DATASET_CLS = TimeSeriesForecastingDataset +CFG.DATASET_NAME = "PEMS08" +CFG.DATASET_TYPE = "Traffic flow" +CFG.DATASET_INPUT_LEN = 12 +CFG.DATASET_OUTPUT_LEN = 12 +CFG.GPU_NUM = 1 +CFG.NULL_VAL = 0.0 + +# ================= environment ================= # +CFG.ENV = EasyDict() +CFG.ENV.SEED = 1 +CFG.ENV.CUDNN = EasyDict() +CFG.ENV.CUDNN.ENABLED = True + +# ================= model ================= # +CFG.MODEL = EasyDict() +CFG.MODEL.NAME = "DeepAR" +CFG.MODEL.ARCH = DeepAR +CFG.MODEL.PARAM = { + "cov_feat_size" : 2, + "embedding_size" : 32, + "hidden_size" : 64, + "num_layers": 3, + "use_ts_id" : True, + "id_feat_size": 32, + "num_nodes": 170 + } +CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] +CFG.MODEL.TARGET_FEATURES = [0] + +# ================= optim ================= # +CFG.TRAIN = EasyDict() +CFG.TRAIN.LOSS = gaussian_loss +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = "Adam" +CFG.TRAIN.OPTIM.PARAM= { + "lr":0.003, +} + +# ================= train ================= # +CFG.TRAIN.NUM_EPOCHS = 100 +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + "checkpoints", + "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) +) +# train data +CFG.TRAIN.DATA = EasyDict() +# read data +CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME +# dataloader args, optional +CFG.TRAIN.DATA.BATCH_SIZE = 64 +CFG.TRAIN.DATA.PREFETCH = False +CFG.TRAIN.DATA.SHUFFLE = True +CFG.TRAIN.DATA.NUM_WORKERS = 2 +CFG.TRAIN.DATA.PIN_MEMORY = False + +# ================= validate ================= # +CFG.VAL = EasyDict() +CFG.VAL.INTERVAL = 1 +# validating data +CFG.VAL.DATA = EasyDict() +# read data +CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME +# dataloader args, optional +CFG.VAL.DATA.BATCH_SIZE = 64 +CFG.VAL.DATA.PREFETCH = False +CFG.VAL.DATA.SHUFFLE = False +CFG.VAL.DATA.NUM_WORKERS = 2 +CFG.VAL.DATA.PIN_MEMORY = False + +# ================= test ================= # +CFG.TEST = EasyDict() +CFG.TEST.INTERVAL = 1 +# test data +CFG.TEST.DATA = EasyDict() +# read data +CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME +# dataloader args, optional +CFG.TEST.DATA.BATCH_SIZE = 64 +CFG.TEST.DATA.PREFETCH = False +CFG.TEST.DATA.SHUFFLE = False +CFG.TEST.DATA.NUM_WORKERS = 2 +CFG.TEST.DATA.PIN_MEMORY = False diff --git a/baselines/DeepAR/arch/__init__.py b/baselines/DeepAR/arch/__init__.py new file mode 100644 index 00000000..6ec10582 --- /dev/null +++ b/baselines/DeepAR/arch/__init__.py @@ -0,0 +1 @@ +from .deepar import DeepAR \ No newline at end of file diff --git a/baselines/DeepAR/arch/deepar.py b/baselines/DeepAR/arch/deepar.py new file mode 100644 index 00000000..72dded63 --- /dev/null +++ b/baselines/DeepAR/arch/deepar.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .distributions import Gaussian + + +class DeepAR(nn.Module): + """ + Paper: DeepAR: Probabilistic Forecasting with Autoregressive Recurrent Networks + Ref Code: + https://github.com/jingw2/demand_forecast + https://github.com/husnejahan/DeepAR-pytorch + https://github.com/arrigonialberto86/deepar + Link: https://arxiv.org/abs/1704.04110 + """ + + def __init__(self, cov_feat_size, embedding_size, hidden_size, num_layers, use_ts_id, id_feat_size=0, num_nodes=0) -> None: + """Init DeepAR. + + Args: + cov_feat_size (int): covariate feature size (e.g. time in day, day in week, etc.). + embedding_size (int): output size of the input embedding layer. + hidden_size (int): hidden size of the LSTM. + num_layers (int): number of LSTM layers. + use_ts_id (bool): whether to use time series id to construct spatial id embedding as additional features. + id_feat_size (int, optional): size of the spatial id embedding. Defaults to 0. + num_nodes (int, optional): number of nodes. Defaults to 0. + """ + super().__init__() + self.use_ts_id = use_ts_id + # input embedding layer + self.input_embed = nn.Linear(1, embedding_size) + # spatial id embedding layer + if use_ts_id: + assert id_feat_size > 0, "id_feat_size must be greater than 0 if use_ts_id is True" + assert num_nodes > 0, "num_nodes must be greater than 0 if use_ts_id is True" + self.id_feat = nn.Parameter(torch.empty(num_nodes, id_feat_size)) + nn.init.xavier_uniform_(self.id_feat) + else: + id_feat_size = 0 + # the LSTM layer + self.encoder = nn.LSTM(embedding_size+cov_feat_size+id_feat_size, hidden_size, num_layers, bias=True, batch_first=True) + # the likelihood function + self.likelihood_layer = Gaussian(hidden_size, 1) + + def gaussian_sample(self, mu, sigma): + """Sampling. + + Args: + mu (torch.Tensor): mean values of distributions. + sigma (torch.Tensor): std values of distributions. + """ + mu = mu.squeeze(1) + sigma = sigma.squeeze(1) + gaussian = torch.distributions.Normal(mu, sigma) + ypred = gaussian.sample([1]).squeeze(0) + return ypred + + def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, train: bool, **kwargs) -> torch.Tensor: + """Feed forward of DeepAR. + Reference code: https://github.com/jingw2/demand_forecast/blob/master/deepar.py + + Args: + history_data (torch.Tensor): history data. [B, L, N, C]. + future_data (torch.Tensor): future data. [B, L, N, C]. + train (bool): is training or not. + """ + history_next = None + preds = [] + mus = [] + sigmas = [] + len_in, len_out = history_data.shape[1], future_data.shape[1] + B, _, N, C = history_data.shape + input_feat_full = torch.cat([history_data[:, :, :, 0:1], future_data[:, :, :, 0:1]], dim=1) # B, L_in+L_out, N, 1 + covar_feat_full = torch.cat([history_data[:, :, :, 1:], future_data[:, :, :, 1:]], dim=1) # B, L_in+L_out, N, C-1 + + for t in range(1, len_in + len_out): + if not (t > len_in and not train): # not in the decoding stage when inferecing + history_next = input_feat_full[:, t-1:t, :, 0:1] + else: + a = 1 + embed_feat = self.input_embed(history_next) + covar_feat = covar_feat_full[:, t:t+1, :, :] + if self.use_ts_id: + id_feat = self.id_feat.unsqueeze(0).expand(history_data.shape[0], -1, -1).unsqueeze(1) + encoder_input = torch.cat([embed_feat, covar_feat, id_feat], dim=-1) + else: + encoder_input = torch.cat([embed_feat, covar_feat], dim=-1) + # lstm + B, _, N, C = encoder_input.shape # _ is 1 + encoder_input = encoder_input.transpose(1, 2).reshape(B * N, -1, C) + _, (h, c) = self.encoder(encoder_input) if t == 1 else self.encoder(encoder_input, (h, c)) + # distribution proj + mu, sigma = self.likelihood_layer(F.relu(h[-1, :, :])) + history_next = self.gaussian_sample(mu, sigma).view(B, N).view(B, 1, N, 1) + mus.append(mu.view(B, N, 1).unsqueeze(1)) + sigmas.append(sigma.view(B, N, 1).unsqueeze(1)) + preds.append(history_next) + assert not torch.isnan(history_next).any() + + preds = torch.concat(preds, dim=1) + mus = torch.concat(mus, dim=1) + sigmas = torch.concat(sigmas, dim=1) + reals = input_feat_full[:, -preds.shape[1]:, :, :] + return preds, reals, mus, sigmas diff --git a/baselines/DeepAR/arch/distributions.py b/baselines/DeepAR/arch/distributions.py new file mode 100644 index 00000000..0c84d512 --- /dev/null +++ b/baselines/DeepAR/arch/distributions.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn + + +class Gaussian(nn.Module): + + def __init__(self, hidden_size, output_size): + """ + Gaussian Likelihood Supports Continuous Data + Args: + input_size (int): hidden h_{i,t} column size + output_size (int): embedding size + """ + super(Gaussian, self).__init__() + self.mu_layer = nn.Linear(hidden_size, output_size) + self.sigma_layer = nn.Linear(hidden_size, output_size) + + def forward(self, h): + sigma_t = torch.log(1 + torch.exp(self.sigma_layer(h))) + 1e-6 + sigma_t = sigma_t.squeeze(0) + mu_t = self.mu_layer(h).squeeze(0) + return mu_t, sigma_t diff --git a/baselines/DeepAR/loss/__init__.py b/baselines/DeepAR/loss/__init__.py new file mode 100644 index 00000000..9b08b8a3 --- /dev/null +++ b/baselines/DeepAR/loss/__init__.py @@ -0,0 +1 @@ +from .gaussian import gaussian_loss \ No newline at end of file diff --git a/baselines/DeepAR/loss/gaussian.py b/baselines/DeepAR/loss/gaussian.py new file mode 100644 index 00000000..c278f4b0 --- /dev/null +++ b/baselines/DeepAR/loss/gaussian.py @@ -0,0 +1,29 @@ +import torch +import numpy as np + + +def gaussian_loss(prediction, real_value, mu, sigma, null_val = np.nan): + """Masked gaussian loss. Kindly note that the gaussian loss is calculated based on mu, sigma, and real_value. The prediction is sampled from N(mu, sigma), and is not used in the loss calculation (it will be used in the metrics calculation). + + Args: + prediction (torch.Tensor): prediction of model. [B, L, N, 1]. + real_value (torch.Tensor): ground truth. [B, L, N, 1]. + mu (torch.Tensor): the mean of gaussian distribution. [B, L, N, 1]. + sigma (torch.Tensor): the std of gaussian distribution. [B, L, N, 1] + null_val (optional): null value. Defaults to np.nan. + """ + # mask + if np.isnan(null_val): + mask = ~torch.isnan(real_value) + else: + eps = 5e-5 + mask = ~torch.isclose(real_value, torch.tensor(null_val).expand_as(real_value).to(real_value.device), atol=eps, rtol=0.) + mask = mask.float() + mask /= torch.mean((mask)) + mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) + + distribution = torch.distributions.Normal(mu, sigma) + likelihood = distribution.log_prob(real_value) + likelihood = likelihood * mask + loss_g = -torch.mean(likelihood) + return loss_g diff --git a/baselines/DeepAR/runner/__init__.py b/baselines/DeepAR/runner/__init__.py new file mode 100644 index 00000000..1e41b855 --- /dev/null +++ b/baselines/DeepAR/runner/__init__.py @@ -0,0 +1 @@ +from .deepar_runner import DeepARRunner diff --git a/baselines/DeepAR/runner/deepar_runner.py b/baselines/DeepAR/runner/deepar_runner.py new file mode 100644 index 00000000..7fc12c4d --- /dev/null +++ b/baselines/DeepAR/runner/deepar_runner.py @@ -0,0 +1,116 @@ +from typing import List +import torch +from basicts.data.registry import SCALER_REGISTRY +from easytorch.utils.dist import master_only + +from basicts.runners import BaseTimeSeriesForecastingRunner + + +class DeepARRunner(BaseTimeSeriesForecastingRunner): + def __init__(self, cfg: dict): + super().__init__(cfg) + self.forward_features = cfg["MODEL"].get("FORWARD_FEATURES", None) + self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None) + self.output_seq_len = cfg["DATASET_OUTPUT_LEN"] + + def select_input_features(self, data: torch.Tensor) -> torch.Tensor: + """Select input features and reshape data to fit the target model. + + Args: + data (torch.Tensor): input history data, shape [B, L, N, C]. + + Returns: + torch.Tensor: reshaped data + """ + + # select feature using self.forward_features + if self.forward_features is not None: + data = data[:, :, :, self.forward_features] + return data + + def select_target_features(self, data: torch.Tensor) -> torch.Tensor: + """Select target features and reshape data back to the BasicTS framework + + Args: + data (torch.Tensor): prediction of the model with arbitrary shape. + + Returns: + torch.Tensor: reshaped data with shape [B, L, N, C] + """ + + # select feature using self.target_features + data = data[:, :, :, self.target_features] + return data + + def rescale_data(self, input_data: List[torch.Tensor]) -> List[torch.Tensor]: + """Rescale data. + + Args: + data (List[torch.Tensor]): list of data to be re-scaled. + + Returns: + List[torch.Tensor]: list of re-scaled data. + """ + prediction, real_value, mus, sigmas = input_data + if self.if_rescale: + prediction = SCALER_REGISTRY.get(self.scaler["func"])(prediction, **self.scaler["args"]) + real_value = SCALER_REGISTRY.get(self.scaler["func"])(real_value, **self.scaler["args"]) + mus = SCALER_REGISTRY.get(self.scaler["func"])(mus, **self.scaler["args"]) + sigmas = SCALER_REGISTRY.get(self.scaler["func"])(sigmas, **self.scaler["args"]) + return [prediction, real_value, mus, sigmas] + + @torch.no_grad() + @master_only + def test(self): + """Evaluate the model. + + Args: + train_epoch (int, optional): current epoch if in training process. + """ + + # test loop + prediction = [] + real_value = [] + for _, data in enumerate(self.test_data_loader): + forward_return = list(self.forward(data, epoch=None, iter_num=None, train=False)) + if not self.if_evaluate_on_gpu: + forward_return[0], forward_return[1] = forward_return[0].detach().cpu(), forward_return[1].detach().cpu() + prediction.append(forward_return[0]) # preds = forward_return[0] + real_value.append(forward_return[1]) # testy = forward_return[1] + prediction = torch.cat(prediction, dim=0) + real_value = torch.cat(real_value, dim=0) + # re-scale data + if self.if_rescale: + prediction = SCALER_REGISTRY.get(self.scaler["func"])(prediction, **self.scaler["args"])[:, -self.output_seq_len:, :, :] + real_value = SCALER_REGISTRY.get(self.scaler["func"])(real_value, **self.scaler["args"])[:, -self.output_seq_len:, :, :] + # evaluate + self.evaluate(prediction, real_value) + + def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:bool = True, **kwargs) -> tuple: + """feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. + + Args: + data (tuple): data (future data, history data). [B, L, N, C] for each of them + epoch (int, optional): epoch number. Defaults to None. + iter_num (int, optional): iteration number. Defaults to None. + train (bool, optional): if in the training process. Defaults to True. + + Returns: + tuple: (prediction, real_value) + """ + + # preprocess + future_data, history_data = data + history_data = self.to_running_device(history_data) # B, L, N, C + future_data = self.to_running_device(future_data) # B, L, N, C + batch_size, length, num_nodes, _ = future_data.shape + + history_data = self.select_input_features(history_data) + future_data_4_dec = self.select_input_features(future_data) + + # feed forward + pred_values, real_values, mus, sigmas = self.model(history_data=history_data, future_data=future_data_4_dec, train=train) + # post process + prediction = self.select_target_features(pred_values) + real_value = self.select_target_features(real_values) + return prediction, real_value, mus, sigmas