-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Paper: DeepAR: Probabilistic Forecasting with Autoregressive Recurrent Networks; Link: https://arxiv.org/abs/1704.04110; Ref Code: https://github.com/jingw2/demand_forecast, https://github.com/husnejahan/DeepAR-pytorch, https://github.com/arrigonialberto86/deepar.
- Loading branch information
Showing
9 changed files
with
476 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import os | ||
import sys | ||
|
||
# TODO: remove it when basicts can be installed by pip | ||
sys.path.append(os.path.abspath(__file__ + "/../../..")) | ||
from easydict import EasyDict | ||
from basicts.data import TimeSeriesForecastingDataset | ||
|
||
from .arch import DeepAR | ||
from .runner import DeepARRunner | ||
from .loss import gaussian_loss | ||
|
||
CFG = EasyDict() | ||
|
||
# ================= general ================= # | ||
CFG.DESCRIPTION = "DeepAR model configuration" | ||
CFG.RUNNER = DeepARRunner | ||
CFG.DATASET_CLS = TimeSeriesForecastingDataset | ||
CFG.DATASET_NAME = "METR-LA" | ||
CFG.DATASET_TYPE = "Traffic speed" | ||
CFG.DATASET_INPUT_LEN = 12 | ||
CFG.DATASET_OUTPUT_LEN = 12 | ||
CFG.GPU_NUM = 1 | ||
CFG.NULL_VAL = 0.0 | ||
|
||
# ================= environment ================= # | ||
CFG.ENV = EasyDict() | ||
CFG.ENV.SEED = 1 | ||
CFG.ENV.CUDNN = EasyDict() | ||
CFG.ENV.CUDNN.ENABLED = True | ||
|
||
# ================= model ================= # | ||
CFG.MODEL = EasyDict() | ||
CFG.MODEL.NAME = "DeepAR" | ||
CFG.MODEL.ARCH = DeepAR | ||
CFG.MODEL.PARAM = { | ||
"cov_feat_size" : 2, | ||
"embedding_size" : 32, | ||
"hidden_size" : 64, | ||
"num_layers": 3, | ||
"use_ts_id" : True, | ||
"id_feat_size": 32, | ||
"num_nodes": 207 | ||
} | ||
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] | ||
CFG.MODEL.TARGET_FEATURES = [0] | ||
|
||
# ================= optim ================= # | ||
CFG.TRAIN = EasyDict() | ||
CFG.TRAIN.LOSS = gaussian_loss | ||
CFG.TRAIN.OPTIM = EasyDict() | ||
CFG.TRAIN.OPTIM.TYPE = "Adam" | ||
CFG.TRAIN.OPTIM.PARAM= { | ||
"lr":0.003, | ||
} | ||
|
||
# ================= train ================= # | ||
CFG.TRAIN.NUM_EPOCHS = 100 | ||
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( | ||
"checkpoints", | ||
"_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) | ||
) | ||
# train data | ||
CFG.TRAIN.DATA = EasyDict() | ||
# read data | ||
CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME | ||
# dataloader args, optional | ||
CFG.TRAIN.DATA.BATCH_SIZE = 64 | ||
CFG.TRAIN.DATA.PREFETCH = False | ||
CFG.TRAIN.DATA.SHUFFLE = True | ||
CFG.TRAIN.DATA.NUM_WORKERS = 2 | ||
CFG.TRAIN.DATA.PIN_MEMORY = False | ||
|
||
# ================= validate ================= # | ||
CFG.VAL = EasyDict() | ||
CFG.VAL.INTERVAL = 1 | ||
# validating data | ||
CFG.VAL.DATA = EasyDict() | ||
# read data | ||
CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME | ||
# dataloader args, optional | ||
CFG.VAL.DATA.BATCH_SIZE = 64 | ||
CFG.VAL.DATA.PREFETCH = False | ||
CFG.VAL.DATA.SHUFFLE = False | ||
CFG.VAL.DATA.NUM_WORKERS = 2 | ||
CFG.VAL.DATA.PIN_MEMORY = False | ||
|
||
# ================= test ================= # | ||
CFG.TEST = EasyDict() | ||
CFG.TEST.INTERVAL = 1 | ||
# test data | ||
CFG.TEST.DATA = EasyDict() | ||
# read data | ||
CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME | ||
# dataloader args, optional | ||
CFG.TEST.DATA.BATCH_SIZE = 64 | ||
CFG.TEST.DATA.PREFETCH = False | ||
CFG.TEST.DATA.SHUFFLE = False | ||
CFG.TEST.DATA.NUM_WORKERS = 2 | ||
CFG.TEST.DATA.PIN_MEMORY = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import os | ||
import sys | ||
|
||
# TODO: remove it when basicts can be installed by pip | ||
sys.path.append(os.path.abspath(__file__ + "/../../..")) | ||
from easydict import EasyDict | ||
from basicts.data import TimeSeriesForecastingDataset | ||
|
||
from .arch import DeepAR | ||
from .runner import DeepARRunner | ||
from .loss import gaussian_loss | ||
|
||
CFG = EasyDict() | ||
|
||
# ================= general ================= # | ||
CFG.DESCRIPTION = "DeepAR model configuration" | ||
CFG.RUNNER = DeepARRunner | ||
CFG.DATASET_CLS = TimeSeriesForecastingDataset | ||
CFG.DATASET_NAME = "PEMS08" | ||
CFG.DATASET_TYPE = "Traffic flow" | ||
CFG.DATASET_INPUT_LEN = 12 | ||
CFG.DATASET_OUTPUT_LEN = 12 | ||
CFG.GPU_NUM = 1 | ||
CFG.NULL_VAL = 0.0 | ||
|
||
# ================= environment ================= # | ||
CFG.ENV = EasyDict() | ||
CFG.ENV.SEED = 1 | ||
CFG.ENV.CUDNN = EasyDict() | ||
CFG.ENV.CUDNN.ENABLED = True | ||
|
||
# ================= model ================= # | ||
CFG.MODEL = EasyDict() | ||
CFG.MODEL.NAME = "DeepAR" | ||
CFG.MODEL.ARCH = DeepAR | ||
CFG.MODEL.PARAM = { | ||
"cov_feat_size" : 2, | ||
"embedding_size" : 32, | ||
"hidden_size" : 64, | ||
"num_layers": 3, | ||
"use_ts_id" : True, | ||
"id_feat_size": 32, | ||
"num_nodes": 170 | ||
} | ||
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] | ||
CFG.MODEL.TARGET_FEATURES = [0] | ||
|
||
# ================= optim ================= # | ||
CFG.TRAIN = EasyDict() | ||
CFG.TRAIN.LOSS = gaussian_loss | ||
CFG.TRAIN.OPTIM = EasyDict() | ||
CFG.TRAIN.OPTIM.TYPE = "Adam" | ||
CFG.TRAIN.OPTIM.PARAM= { | ||
"lr":0.003, | ||
} | ||
|
||
# ================= train ================= # | ||
CFG.TRAIN.NUM_EPOCHS = 100 | ||
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( | ||
"checkpoints", | ||
"_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) | ||
) | ||
# train data | ||
CFG.TRAIN.DATA = EasyDict() | ||
# read data | ||
CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME | ||
# dataloader args, optional | ||
CFG.TRAIN.DATA.BATCH_SIZE = 64 | ||
CFG.TRAIN.DATA.PREFETCH = False | ||
CFG.TRAIN.DATA.SHUFFLE = True | ||
CFG.TRAIN.DATA.NUM_WORKERS = 2 | ||
CFG.TRAIN.DATA.PIN_MEMORY = False | ||
|
||
# ================= validate ================= # | ||
CFG.VAL = EasyDict() | ||
CFG.VAL.INTERVAL = 1 | ||
# validating data | ||
CFG.VAL.DATA = EasyDict() | ||
# read data | ||
CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME | ||
# dataloader args, optional | ||
CFG.VAL.DATA.BATCH_SIZE = 64 | ||
CFG.VAL.DATA.PREFETCH = False | ||
CFG.VAL.DATA.SHUFFLE = False | ||
CFG.VAL.DATA.NUM_WORKERS = 2 | ||
CFG.VAL.DATA.PIN_MEMORY = False | ||
|
||
# ================= test ================= # | ||
CFG.TEST = EasyDict() | ||
CFG.TEST.INTERVAL = 1 | ||
# test data | ||
CFG.TEST.DATA = EasyDict() | ||
# read data | ||
CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME | ||
# dataloader args, optional | ||
CFG.TEST.DATA.BATCH_SIZE = 64 | ||
CFG.TEST.DATA.PREFETCH = False | ||
CFG.TEST.DATA.SHUFFLE = False | ||
CFG.TEST.DATA.NUM_WORKERS = 2 | ||
CFG.TEST.DATA.PIN_MEMORY = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .deepar import DeepAR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from .distributions import Gaussian | ||
|
||
|
||
class DeepAR(nn.Module): | ||
""" | ||
Paper: DeepAR: Probabilistic Forecasting with Autoregressive Recurrent Networks | ||
Ref Code: | ||
https://github.com/jingw2/demand_forecast | ||
https://github.com/husnejahan/DeepAR-pytorch | ||
https://github.com/arrigonialberto86/deepar | ||
Link: https://arxiv.org/abs/1704.04110 | ||
""" | ||
|
||
def __init__(self, cov_feat_size, embedding_size, hidden_size, num_layers, use_ts_id, id_feat_size=0, num_nodes=0) -> None: | ||
"""Init DeepAR. | ||
Args: | ||
cov_feat_size (int): covariate feature size (e.g. time in day, day in week, etc.). | ||
embedding_size (int): output size of the input embedding layer. | ||
hidden_size (int): hidden size of the LSTM. | ||
num_layers (int): number of LSTM layers. | ||
use_ts_id (bool): whether to use time series id to construct spatial id embedding as additional features. | ||
id_feat_size (int, optional): size of the spatial id embedding. Defaults to 0. | ||
num_nodes (int, optional): number of nodes. Defaults to 0. | ||
""" | ||
super().__init__() | ||
self.use_ts_id = use_ts_id | ||
# input embedding layer | ||
self.input_embed = nn.Linear(1, embedding_size) | ||
# spatial id embedding layer | ||
if use_ts_id: | ||
assert id_feat_size > 0, "id_feat_size must be greater than 0 if use_ts_id is True" | ||
assert num_nodes > 0, "num_nodes must be greater than 0 if use_ts_id is True" | ||
self.id_feat = nn.Parameter(torch.empty(num_nodes, id_feat_size)) | ||
nn.init.xavier_uniform_(self.id_feat) | ||
else: | ||
id_feat_size = 0 | ||
# the LSTM layer | ||
self.encoder = nn.LSTM(embedding_size+cov_feat_size+id_feat_size, hidden_size, num_layers, bias=True, batch_first=True) | ||
# the likelihood function | ||
self.likelihood_layer = Gaussian(hidden_size, 1) | ||
|
||
def gaussian_sample(self, mu, sigma): | ||
"""Sampling. | ||
Args: | ||
mu (torch.Tensor): mean values of distributions. | ||
sigma (torch.Tensor): std values of distributions. | ||
""" | ||
mu = mu.squeeze(1) | ||
sigma = sigma.squeeze(1) | ||
gaussian = torch.distributions.Normal(mu, sigma) | ||
ypred = gaussian.sample([1]).squeeze(0) | ||
return ypred | ||
|
||
def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, train: bool, **kwargs) -> torch.Tensor: | ||
"""Feed forward of DeepAR. | ||
Reference code: https://github.com/jingw2/demand_forecast/blob/master/deepar.py | ||
Args: | ||
history_data (torch.Tensor): history data. [B, L, N, C]. | ||
future_data (torch.Tensor): future data. [B, L, N, C]. | ||
train (bool): is training or not. | ||
""" | ||
history_next = None | ||
preds = [] | ||
mus = [] | ||
sigmas = [] | ||
len_in, len_out = history_data.shape[1], future_data.shape[1] | ||
B, _, N, C = history_data.shape | ||
input_feat_full = torch.cat([history_data[:, :, :, 0:1], future_data[:, :, :, 0:1]], dim=1) # B, L_in+L_out, N, 1 | ||
covar_feat_full = torch.cat([history_data[:, :, :, 1:], future_data[:, :, :, 1:]], dim=1) # B, L_in+L_out, N, C-1 | ||
|
||
for t in range(1, len_in + len_out): | ||
if not (t > len_in and not train): # not in the decoding stage when inferecing | ||
history_next = input_feat_full[:, t-1:t, :, 0:1] | ||
else: | ||
a = 1 | ||
embed_feat = self.input_embed(history_next) | ||
covar_feat = covar_feat_full[:, t:t+1, :, :] | ||
if self.use_ts_id: | ||
id_feat = self.id_feat.unsqueeze(0).expand(history_data.shape[0], -1, -1).unsqueeze(1) | ||
encoder_input = torch.cat([embed_feat, covar_feat, id_feat], dim=-1) | ||
else: | ||
encoder_input = torch.cat([embed_feat, covar_feat], dim=-1) | ||
# lstm | ||
B, _, N, C = encoder_input.shape # _ is 1 | ||
encoder_input = encoder_input.transpose(1, 2).reshape(B * N, -1, C) | ||
_, (h, c) = self.encoder(encoder_input) if t == 1 else self.encoder(encoder_input, (h, c)) | ||
# distribution proj | ||
mu, sigma = self.likelihood_layer(F.relu(h[-1, :, :])) | ||
history_next = self.gaussian_sample(mu, sigma).view(B, N).view(B, 1, N, 1) | ||
mus.append(mu.view(B, N, 1).unsqueeze(1)) | ||
sigmas.append(sigma.view(B, N, 1).unsqueeze(1)) | ||
preds.append(history_next) | ||
assert not torch.isnan(history_next).any() | ||
|
||
preds = torch.concat(preds, dim=1) | ||
mus = torch.concat(mus, dim=1) | ||
sigmas = torch.concat(sigmas, dim=1) | ||
reals = input_feat_full[:, -preds.shape[1]:, :, :] | ||
return preds, reals, mus, sigmas |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class Gaussian(nn.Module): | ||
|
||
def __init__(self, hidden_size, output_size): | ||
""" | ||
Gaussian Likelihood Supports Continuous Data | ||
Args: | ||
input_size (int): hidden h_{i,t} column size | ||
output_size (int): embedding size | ||
""" | ||
super(Gaussian, self).__init__() | ||
self.mu_layer = nn.Linear(hidden_size, output_size) | ||
self.sigma_layer = nn.Linear(hidden_size, output_size) | ||
|
||
def forward(self, h): | ||
sigma_t = torch.log(1 + torch.exp(self.sigma_layer(h))) + 1e-6 | ||
sigma_t = sigma_t.squeeze(0) | ||
mu_t = self.mu_layer(h).squeeze(0) | ||
return mu_t, sigma_t |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .gaussian import gaussian_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import torch | ||
import numpy as np | ||
|
||
|
||
def gaussian_loss(prediction, real_value, mu, sigma, null_val = np.nan): | ||
"""Masked gaussian loss. Kindly note that the gaussian loss is calculated based on mu, sigma, and real_value. The prediction is sampled from N(mu, sigma), and is not used in the loss calculation (it will be used in the metrics calculation). | ||
Args: | ||
prediction (torch.Tensor): prediction of model. [B, L, N, 1]. | ||
real_value (torch.Tensor): ground truth. [B, L, N, 1]. | ||
mu (torch.Tensor): the mean of gaussian distribution. [B, L, N, 1]. | ||
sigma (torch.Tensor): the std of gaussian distribution. [B, L, N, 1] | ||
null_val (optional): null value. Defaults to np.nan. | ||
""" | ||
# mask | ||
if np.isnan(null_val): | ||
mask = ~torch.isnan(real_value) | ||
else: | ||
eps = 5e-5 | ||
mask = ~torch.isclose(real_value, torch.tensor(null_val).expand_as(real_value).to(real_value.device), atol=eps, rtol=0.) | ||
mask = mask.float() | ||
mask /= torch.mean((mask)) | ||
mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) | ||
|
||
distribution = torch.distributions.Normal(mu, sigma) | ||
likelihood = distribution.log_prob(real_value) | ||
likelihood = likelihood * mask | ||
loss_g = -torch.mean(likelihood) | ||
return loss_g |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .deepar_runner import DeepARRunner |
Oops, something went wrong.