-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fb5231b
commit 8a99b40
Showing
7 changed files
with
245 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,4 +128,5 @@ dmypy.json | |
# Pyre type checker | ||
.pyre/ | ||
|
||
.idea/ | ||
.idea/ | ||
lightning_logs/ |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch | ||
import pandas as pd | ||
|
||
from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss | ||
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet | ||
from pytorch_forecasting.data import GroupNormalizer | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
data = pd.read_csv("data/MERCHANT_NUMBER_OF_TRX.csv") | ||
data = data[["MERCHANT_1_NUMBER_OF_TRX", "date"]] | ||
data["id"] = "M1" | ||
|
||
# add time index | ||
data["time_idx"] = pd.to_datetime(data.date).astype(int) | ||
data["time_idx"] -= data["time_idx"].min() | ||
data["time_idx"] = (data.time_idx / 3600000000000) + 1 | ||
data["time_idx"] = data["time_idx"].astype(int) | ||
|
||
# add datetime variables | ||
data["month"] = pd.to_datetime(data.date).dt.month\ | ||
.astype(str)\ | ||
.astype("category") | ||
data["day_of_week"] = pd.to_datetime(data.date).dt.dayofweek\ | ||
.astype(str)\ | ||
.astype("category") | ||
data["hour"] = pd.to_datetime(data.date).dt.hour\ | ||
.astype(str)\ | ||
.astype("category") | ||
|
||
# cut atypical values at the end of the sample | ||
train_data = data[:3200] | ||
max_prediction_length = 24 | ||
max_encoder_length = 72 | ||
training_cutoff = train_data["time_idx"].max() - max_prediction_length | ||
|
||
training = TimeSeriesDataSet( | ||
train_data[lambda x: x.time_idx <= training_cutoff], | ||
time_idx="time_idx", | ||
target="MERCHANT_1_NUMBER_OF_TRX", | ||
group_ids=["id"], | ||
min_encoder_length=max_encoder_length // 2, # keep encoder length long (as it is in the validation set) | ||
max_encoder_length=max_encoder_length, | ||
min_prediction_length=1, | ||
max_prediction_length=max_prediction_length, | ||
static_categoricals=["id"], | ||
time_varying_known_reals=["time_idx"], | ||
time_varying_known_categoricals=["hour", "month", "day_of_week"], | ||
time_varying_unknown_categoricals=[], | ||
time_varying_unknown_reals=["MERCHANT_1_NUMBER_OF_TRX"], | ||
target_normalizer=GroupNormalizer( | ||
groups=["id"], transformation="softplus" | ||
), | ||
add_relative_time_idx=True, | ||
add_target_scales=True, | ||
add_encoder_length=True, | ||
|
||
) | ||
|
||
model = TemporalFusionTransformer.from_dataset( | ||
training, | ||
learning_rate=0.001, | ||
hidden_size=16, | ||
attention_head_size=1, | ||
dropout=0.1, | ||
hidden_continuous_size=8, | ||
output_size=7, # 7 quantiles by default | ||
loss=QuantileLoss(), | ||
log_interval=10, # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches | ||
reduce_on_plateau_patience=4, | ||
) | ||
|
||
model.load_state_dict(torch.load("model/tft_regressor.pt")) | ||
|
||
for start in range(3200, 3700, 72): | ||
test_data = data[start:(start + max_encoder_length)] | ||
y_obs = data[(start + max_encoder_length): (start + max_encoder_length + max_prediction_length)] | ||
|
||
y_hat = model.predict( | ||
test_data, | ||
mode="prediction", | ||
return_x=True | ||
)[0][0].tolist() | ||
|
||
fig, ax = plt.subplots() | ||
|
||
ax.plot( | ||
pd.Series(data=y_hat, | ||
index=pd.to_datetime(y_obs.date)), | ||
label='forecast' | ||
) | ||
|
||
ax.plot( | ||
y_obs.set_index(pd.to_datetime(y_obs.date))["MERCHANT_1_NUMBER_OF_TRX"], | ||
color='orange', | ||
alpha=0.8, | ||
label='observed', | ||
linestyle='--', | ||
linewidth=1.2 | ||
) | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ pyarrow==3.0.0 | |
pandas==1.2.5 | ||
numpy==1.21.3 | ||
torch==1.10.0 | ||
matplotlib~=3.4.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import copy | ||
from pathlib import Path | ||
import os | ||
import warnings | ||
import pickle | ||
|
||
warnings.filterwarnings("ignore") # avoid printing out absolute paths | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytorch_lightning as pl | ||
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor | ||
from pytorch_lightning.loggers import TensorBoardLogger | ||
import torch | ||
|
||
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet | ||
from pytorch_forecasting.data import GroupNormalizer | ||
from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss | ||
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters | ||
from pytorch_forecasting.data.examples import get_stallion_data | ||
|
||
data = pd.read_csv("data/MERCHANT_NUMBER_OF_TRX.csv") | ||
data = data[["MERCHANT_1_NUMBER_OF_TRX", "date"]] | ||
data["id"] = "M1" | ||
|
||
# add time index | ||
data["time_idx"] = pd.to_datetime(data.date).astype(int) | ||
data["time_idx"] -= data["time_idx"].min() | ||
data["time_idx"] = (data.time_idx / 3600000000000) + 1 | ||
data["time_idx"] = data["time_idx"].astype(int) | ||
|
||
# add datetime variables | ||
data["month"] = pd.to_datetime(data.date).dt.month\ | ||
.astype(str)\ | ||
.astype("category") | ||
data["day_of_week"] = pd.to_datetime(data.date).dt.dayofweek\ | ||
.astype(str)\ | ||
.astype("category") | ||
data["hour"] = pd.to_datetime(data.date).dt.hour\ | ||
.astype(str)\ | ||
.astype("category") | ||
|
||
# cut atypical values at the end of the sample | ||
# data = data[:3840] | ||
data = data[:3200] | ||
|
||
max_prediction_length = 24 | ||
max_encoder_length = 72 | ||
training_cutoff = data["time_idx"].max() - max_prediction_length | ||
|
||
test_data = data[lambda x: x.time_idx > x.time_idx.max() - max_encoder_length] | ||
|
||
training = TimeSeriesDataSet( | ||
data[lambda x: x.time_idx <= training_cutoff], | ||
time_idx="time_idx", | ||
target="MERCHANT_1_NUMBER_OF_TRX", | ||
group_ids=["id"], | ||
min_encoder_length=max_encoder_length // 2, # keep encoder length long (as it is in the validation set) | ||
max_encoder_length=max_encoder_length, | ||
min_prediction_length=1, | ||
max_prediction_length=max_prediction_length, | ||
static_categoricals=["id"], | ||
time_varying_known_reals=["time_idx"], | ||
time_varying_known_categoricals=["hour", "month", "day_of_week"], | ||
time_varying_unknown_categoricals=[], | ||
time_varying_unknown_reals=["MERCHANT_1_NUMBER_OF_TRX"], | ||
target_normalizer=GroupNormalizer( | ||
groups=["id"], transformation="softplus" | ||
), | ||
add_relative_time_idx=True, | ||
add_target_scales=True, | ||
add_encoder_length=True, | ||
|
||
) | ||
|
||
# create validation set (predict=True) which means to predict the last max_prediction_length points in time | ||
# for each series | ||
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True) | ||
|
||
# create dataloaders for model | ||
batch_size = 128 # set this between 32 to 128 | ||
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) | ||
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0) | ||
|
||
# calculate baseline mean absolute error, i.e. predict next value as the last available value from the history | ||
actuals = torch.cat([y for x, (y, weight) in iter(val_dataloader)]) | ||
baseline_predictions = Baseline().predict(val_dataloader) | ||
(actuals - baseline_predictions).abs().mean().item() | ||
|
||
# configure network and trainer | ||
pl.seed_everything(42) | ||
|
||
# configure network and trainer | ||
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min") | ||
lr_logger = LearningRateMonitor() # log the learning rate | ||
logger = TensorBoardLogger("lightning_logs") # logging results to a tensorboard | ||
|
||
trainer = pl.Trainer( | ||
max_epochs=150, | ||
gpus=0, | ||
weights_summary="top", | ||
gradient_clip_val=0.1, | ||
# limit_train_batches=30, # comment in for training, running valiation every 30 batches | ||
# fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs | ||
callbacks=[lr_logger, early_stop_callback], | ||
logger=logger, | ||
) | ||
|
||
|
||
tft = TemporalFusionTransformer.from_dataset( | ||
training, | ||
learning_rate=0.001, | ||
hidden_size=16, | ||
attention_head_size=1, | ||
dropout=0.1, | ||
hidden_continuous_size=8, | ||
output_size=7, # 7 quantiles by default | ||
loss=QuantileLoss(), | ||
log_interval=10, # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches | ||
reduce_on_plateau_patience=4, | ||
) | ||
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k") | ||
|
||
# fit network | ||
trainer.fit( | ||
tft, | ||
train_dataloader=train_dataloader, | ||
val_dataloaders=val_dataloader, | ||
) | ||
|
||
torch.save(tft.state_dict(), "model/tft_regressor.pt") | ||
|
||
|