From 5f6b3a4f47128505d3e5c27aa3ae21530d535eba Mon Sep 17 00:00:00 2001 From: Qiao Zhongzheng Date: Fri, 4 Oct 2024 15:38:19 +0800 Subject: [PATCH] Initial version of Seasonal Naive --- .../seasonal_naive_eval/data/etth1_test.yaml | 7 + .../seasonal_naive_eval/data/etth1_val.yaml | 7 + .../data/gluonts_test.yaml | 4 + .../seasonal_naive_eval/data/gluonts_val.yaml | 4 + .../seasonal_naive_eval/data/lsf_test.yaml | 4 + .../seasonal_naive_eval/data/lsf_val.yaml | 4 + cli/conf/seasonal_naive_eval/data/monash.yaml | 4 + cli/conf/seasonal_naive_eval/default.yaml | 24 + .../model/moirai_lightning_ckpt.yaml | 6 + .../data/electricity.yaml | 2 + .../seasonal_naive_finetune/data/etth1.yaml | 3 + .../seasonal_naive_finetune/data/etth2.yaml | 2 + .../seasonal_naive_finetune/data/ettm1.yaml | 2 + .../seasonal_naive_finetune/data/ettm2.yaml | 2 + .../seasonal_naive_finetune/data/weather.yaml | 2 + cli/conf/seasonal_naive_finetune/default.yaml | 85 ++ .../model/moirai_1.1_R_small.yaml | 43 + .../val_data/electricity.yaml | 13 + .../val_data/etth1.yaml | 9 + .../val_data/etth1_multi.yaml | 16 + .../val_data/etth2.yaml | 13 + .../val_data/ettm1.yaml | 14 + .../val_data/ettm2.yaml | 14 + .../val_data/weather.yaml | 9 + .../model/seasonal_naive_moirai/__init__.py | 25 + .../model/seasonal_naive_moirai/finetune.py | 790 ++++++++++++ .../model/seasonal_naive_moirai/forecast.py | 1056 +++++++++++++++++ .../model/seasonal_naive_moirai/module.py | 178 +++ src/uni2ts/transform/__init__.py | 8 + src/uni2ts/transform/seasonal_naive.py | 193 +++ 30 files changed, 2543 insertions(+) create mode 100644 cli/conf/seasonal_naive_eval/data/etth1_test.yaml create mode 100644 cli/conf/seasonal_naive_eval/data/etth1_val.yaml create mode 100644 cli/conf/seasonal_naive_eval/data/gluonts_test.yaml create mode 100644 cli/conf/seasonal_naive_eval/data/gluonts_val.yaml create mode 100644 cli/conf/seasonal_naive_eval/data/lsf_test.yaml create mode 100644 cli/conf/seasonal_naive_eval/data/lsf_val.yaml create mode 100644 cli/conf/seasonal_naive_eval/data/monash.yaml create mode 100644 cli/conf/seasonal_naive_eval/default.yaml create mode 100644 cli/conf/seasonal_naive_eval/model/moirai_lightning_ckpt.yaml create mode 100644 cli/conf/seasonal_naive_finetune/data/electricity.yaml create mode 100644 cli/conf/seasonal_naive_finetune/data/etth1.yaml create mode 100644 cli/conf/seasonal_naive_finetune/data/etth2.yaml create mode 100644 cli/conf/seasonal_naive_finetune/data/ettm1.yaml create mode 100644 cli/conf/seasonal_naive_finetune/data/ettm2.yaml create mode 100644 cli/conf/seasonal_naive_finetune/data/weather.yaml create mode 100644 cli/conf/seasonal_naive_finetune/default.yaml create mode 100644 cli/conf/seasonal_naive_finetune/model/moirai_1.1_R_small.yaml create mode 100644 cli/conf/seasonal_naive_finetune/val_data/electricity.yaml create mode 100644 cli/conf/seasonal_naive_finetune/val_data/etth1.yaml create mode 100644 cli/conf/seasonal_naive_finetune/val_data/etth1_multi.yaml create mode 100644 cli/conf/seasonal_naive_finetune/val_data/etth2.yaml create mode 100644 cli/conf/seasonal_naive_finetune/val_data/ettm1.yaml create mode 100644 cli/conf/seasonal_naive_finetune/val_data/ettm2.yaml create mode 100644 cli/conf/seasonal_naive_finetune/val_data/weather.yaml create mode 100644 src/uni2ts/model/seasonal_naive_moirai/__init__.py create mode 100644 src/uni2ts/model/seasonal_naive_moirai/finetune.py create mode 100644 src/uni2ts/model/seasonal_naive_moirai/forecast.py create mode 100644 src/uni2ts/model/seasonal_naive_moirai/module.py create mode 100644 src/uni2ts/transform/seasonal_naive.py diff --git a/cli/conf/seasonal_naive_eval/data/etth1_test.yaml b/cli/conf/seasonal_naive_eval/data/etth1_test.yaml new file mode 100644 index 0000000..789aff5 --- /dev/null +++ b/cli/conf/seasonal_naive_eval/data/etth1_test.yaml @@ -0,0 +1,7 @@ +_target_: uni2ts.eval_util.data.get_custom_eval_dataset +dataset_name: ETTh1_eval +offset: 14400 +windows: 2785 +distance: 1 +prediction_length: 96 +mode: null \ No newline at end of file diff --git a/cli/conf/seasonal_naive_eval/data/etth1_val.yaml b/cli/conf/seasonal_naive_eval/data/etth1_val.yaml new file mode 100644 index 0000000..1a32d35 --- /dev/null +++ b/cli/conf/seasonal_naive_eval/data/etth1_val.yaml @@ -0,0 +1,7 @@ +_target_: uni2ts.eval_util.data.get_custom_eval_dataset +dataset_name: ETTh1_eval +offset: 11520 +windows: 2785 +distance: 1 +prediction_length: 96 +mode: null \ No newline at end of file diff --git a/cli/conf/seasonal_naive_eval/data/gluonts_test.yaml b/cli/conf/seasonal_naive_eval/data/gluonts_test.yaml new file mode 100644 index 0000000..4e713bf --- /dev/null +++ b/cli/conf/seasonal_naive_eval/data/gluonts_test.yaml @@ -0,0 +1,4 @@ +_target_: uni2ts.eval_util.data.get_gluonts_test_dataset +dataset_name: ??? +prediction_length: null +mode: S \ No newline at end of file diff --git a/cli/conf/seasonal_naive_eval/data/gluonts_val.yaml b/cli/conf/seasonal_naive_eval/data/gluonts_val.yaml new file mode 100644 index 0000000..d2079ce --- /dev/null +++ b/cli/conf/seasonal_naive_eval/data/gluonts_val.yaml @@ -0,0 +1,4 @@ +_target_: uni2ts.eval_util.data.get_gluonts_val_dataset +dataset_name: ??? +prediction_length: null +mode: S \ No newline at end of file diff --git a/cli/conf/seasonal_naive_eval/data/lsf_test.yaml b/cli/conf/seasonal_naive_eval/data/lsf_test.yaml new file mode 100644 index 0000000..0c7c1dd --- /dev/null +++ b/cli/conf/seasonal_naive_eval/data/lsf_test.yaml @@ -0,0 +1,4 @@ +_target_: uni2ts.eval_util.data.get_lsf_test_dataset +dataset_name: ??? +prediction_length: ??? +mode: S \ No newline at end of file diff --git a/cli/conf/seasonal_naive_eval/data/lsf_val.yaml b/cli/conf/seasonal_naive_eval/data/lsf_val.yaml new file mode 100644 index 0000000..cb5fd5b --- /dev/null +++ b/cli/conf/seasonal_naive_eval/data/lsf_val.yaml @@ -0,0 +1,4 @@ +_target_: uni2ts.eval_util.data.get_lsf_val_dataset +dataset_name: ??? +prediction_length: ??? +mode: S \ No newline at end of file diff --git a/cli/conf/seasonal_naive_eval/data/monash.yaml b/cli/conf/seasonal_naive_eval/data/monash.yaml new file mode 100644 index 0000000..4e713bf --- /dev/null +++ b/cli/conf/seasonal_naive_eval/data/monash.yaml @@ -0,0 +1,4 @@ +_target_: uni2ts.eval_util.data.get_gluonts_test_dataset +dataset_name: ??? +prediction_length: null +mode: S \ No newline at end of file diff --git a/cli/conf/seasonal_naive_eval/default.yaml b/cli/conf/seasonal_naive_eval/default.yaml new file mode 100644 index 0000000..441c274 --- /dev/null +++ b/cli/conf/seasonal_naive_eval/default.yaml @@ -0,0 +1,24 @@ +hydra: + run: + dir: outputs/eval/${hydra:runtime.choices.model}/${exp_name}/${data.dataset_name}/${data.mode}/cl${model.context_length}_pl${data.prediction_length} +defaults: + - model: ??? + - data: ??? + - _self_ +exp_name: ??? +metrics: + - _target_: gluonts.ev.metrics.MSE + - _target_: uni2ts.eval_util.metrics.MedianMSE + - _target_: gluonts.ev.metrics.MAE + - _target_: gluonts.ev.metrics.MASE + - _target_: gluonts.ev.metrics.MAPE + - _target_: gluonts.ev.metrics.SMAPE + - _target_: gluonts.ev.metrics.MSIS + - _target_: gluonts.ev.metrics.RMSE + - _target_: gluonts.ev.metrics.NRMSE + - _target_: gluonts.ev.metrics.ND + - _target_: gluonts.ev.metrics.MeanWeightedSumQuantileLoss + quantile_levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] +batch_size: 512 +min_batch_size: 16 +device: auto \ No newline at end of file diff --git a/cli/conf/seasonal_naive_eval/model/moirai_lightning_ckpt.yaml b/cli/conf/seasonal_naive_eval/model/moirai_lightning_ckpt.yaml new file mode 100644 index 0000000..1211bc4 --- /dev/null +++ b/cli/conf/seasonal_naive_eval/model/moirai_lightning_ckpt.yaml @@ -0,0 +1,6 @@ +_target_: uni2ts.model.seasonal_naive_moirai.MoiraiForecast.load_from_checkpoint +checkpoint_path: ... +pretrained_checkpoint_path: null +num_samples: 100 +patch_size: ??? +context_length: ??? diff --git a/cli/conf/seasonal_naive_finetune/data/electricity.yaml b/cli/conf/seasonal_naive_finetune/data/electricity.yaml new file mode 100644 index 0000000..4fb675b --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/data/electricity.yaml @@ -0,0 +1,2 @@ +_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +dataset: electricity diff --git a/cli/conf/seasonal_naive_finetune/data/etth1.yaml b/cli/conf/seasonal_naive_finetune/data/etth1.yaml new file mode 100644 index 0000000..a5de611 --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/data/etth1.yaml @@ -0,0 +1,3 @@ +_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +dataset: ETTh1 +weight: 1000 \ No newline at end of file diff --git a/cli/conf/seasonal_naive_finetune/data/etth2.yaml b/cli/conf/seasonal_naive_finetune/data/etth2.yaml new file mode 100644 index 0000000..13d29d6 --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/data/etth2.yaml @@ -0,0 +1,2 @@ +_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +dataset: ETTh2 diff --git a/cli/conf/seasonal_naive_finetune/data/ettm1.yaml b/cli/conf/seasonal_naive_finetune/data/ettm1.yaml new file mode 100644 index 0000000..df066af --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/data/ettm1.yaml @@ -0,0 +1,2 @@ +_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +dataset: ETTm1 diff --git a/cli/conf/seasonal_naive_finetune/data/ettm2.yaml b/cli/conf/seasonal_naive_finetune/data/ettm2.yaml new file mode 100644 index 0000000..5ffbcc5 --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/data/ettm2.yaml @@ -0,0 +1,2 @@ +_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +dataset: ETTm2 diff --git a/cli/conf/seasonal_naive_finetune/data/weather.yaml b/cli/conf/seasonal_naive_finetune/data/weather.yaml new file mode 100644 index 0000000..41d5b06 --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/data/weather.yaml @@ -0,0 +1,2 @@ +_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +dataset: weather diff --git a/cli/conf/seasonal_naive_finetune/default.yaml b/cli/conf/seasonal_naive_finetune/default.yaml new file mode 100644 index 0000000..c5083f2 --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/default.yaml @@ -0,0 +1,85 @@ +hydra: + run: + dir: outputs/sn_finetune/${hydra:runtime.choices.model}/${exp_name}/${model.finetune_pattern}/${hydra:runtime.choices.data}/${run_name} +defaults: + - model: ??? + - data: ??? + - val_data: null + - _self_ +exp_name: ??? +run_name: ??? +seed: 0 +tf32: true +compile: false # set to mode: default, reduce-overhead, max-autotune +ckpt_path: null +trainer: + _target_: lightning.Trainer + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 32 + logger: + _target_: lightning.pytorch.loggers.TensorBoardLogger + save_dir: ${hydra:runtime.output_dir} + name: logs + callbacks: + - _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: epoch + - _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${hydra:runtime.output_dir}/checkpoints + monitor: val/PackedNLLLoss + save_weights_only: true + mode: min + save_top_k: 1 + every_n_epochs: 1 + - _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: val/PackedNLLLoss + min_delta: 0.0 + patience: 30 + mode: min + strict: false + verbose: true + max_epochs: 1000 + enable_progress_bar: true + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + gradient_clip_algorithm: norm +train_dataloader: + _target_: uni2ts.data.loader.DataLoader + batch_size: 32 + batch_size_factor: 2.0 + cycle: true + num_batches_per_epoch: 10 + shuffle: true + num_workers: 11 + collate_fn: + _target_: uni2ts.data.loader.PackCollate + max_length: ${model.module_kwargs.max_seq_len} + seq_fields: ${cls_getattr:${model._target_},seq_fields} + pad_func_map: ${cls_getattr:${model._target_},pad_func_map} + pin_memory: true + drop_last: false + fill_last: false + worker_init_fn: null + prefetch_factor: 2 + persistent_workers: true +val_dataloader: + _target_: uni2ts.data.loader.DataLoader + batch_size: 32 + batch_size_factor: 2.0 + cycle: false + num_batches_per_epoch: null + shuffle: false + num_workers: 11 + collate_fn: + _target_: uni2ts.data.loader.PackCollate + max_length: ${model.module_kwargs.max_seq_len} + seq_fields: ${cls_getattr:${model._target_},seq_fields} + pad_func_map: ${cls_getattr:${model._target_},pad_func_map} + pin_memory: false + drop_last: false + fill_last: true + worker_init_fn: null + prefetch_factor: 2 + persistent_workers: true \ No newline at end of file diff --git a/cli/conf/seasonal_naive_finetune/model/moirai_1.1_R_small.yaml b/cli/conf/seasonal_naive_finetune/model/moirai_1.1_R_small.yaml new file mode 100644 index 0000000..1d3a55a --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/model/moirai_1.1_R_small.yaml @@ -0,0 +1,43 @@ +# load a pretrained checkpoint from huggingface hub +_target_: uni2ts.model.seasonal_naive_moirai.MoiraiFinetune +module: + _target_: uni2ts.model.seasonal_naive_moirai.MoiraiModule.from_pretrained + pretrained_model_name_or_path: Salesforce/moirai-1.1-R-small +module_kwargs: + _target_: builtins.dict + distr_output: + _target_: uni2ts.distribution.MixtureOutput + components: + - _target_: uni2ts.distribution.StudentTOutput + - _target_: uni2ts.distribution.NormalFixedScaleOutput + - _target_: uni2ts.distribution.NegativeBinomialOutput + - _target_: uni2ts.distribution.LogNormalOutput + d_model: 384 + num_layers: 6 + patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} + max_seq_len: 512 + attn_dropout_p: 0.0 + dropout_p: 0.0 + scaling: true +min_patches: 2 +min_mask_ratio: 0.15 +max_mask_ratio: 0.5 +max_dim: 128 +loss_func: + _target_: uni2ts.loss.packed.PackedNLLLoss +val_metric: + - _target_: uni2ts.loss.packed.PackedMSELoss + - _target_: uni2ts.loss.packed.PackedNRMSELoss + normalize: absolute_target_squared +lr: 1e-5 +weight_decay: 1e-1 +beta1: 0.9 +beta2: 0.98 +num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} +num_warmup_steps: 0 +patch_size: null +context_length: null +prediction_length: null +finetune_pattern: full +replace_distr_output: False +apply_seasonal_naive: True \ No newline at end of file diff --git a/cli/conf/seasonal_naive_finetune/val_data/electricity.yaml b/cli/conf/seasonal_naive_finetune/val_data/electricity.yaml new file mode 100644 index 0000000..61e8c7e --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/val_data/electricity.yaml @@ -0,0 +1,13 @@ +_target_: uni2ts.data.builder.ConcatDatasetBuilder +_args_: + _target_: uni2ts.data.builder.simple.generate_eval_builders + dataset: electricity_eval + offset: 18412 # Same as _lsf_dataset.py + eval_length: 2630 # Same as _lsf_dataset.py, test_length=5260 + prediction_lengths: ??? + context_lengths: ??? + patch_sizes: ??? + +# prediction_lengths: [96, 192, 336, 720] +# context_lengths: [3000] +# patch_sizes: [32, 64] # freq='h' \ No newline at end of file diff --git a/cli/conf/seasonal_naive_finetune/val_data/etth1.yaml b/cli/conf/seasonal_naive_finetune/val_data/etth1.yaml new file mode 100644 index 0000000..00c462a --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/val_data/etth1.yaml @@ -0,0 +1,9 @@ +_target_: uni2ts.data.builder.ConcatDatasetBuilder +_args_: + _target_: uni2ts.data.builder.simple.generate_eval_builders + dataset: ETTh1_eval + offset: 11520 + eval_length: 2880 + prediction_lengths: [96, 192, 336, 720] + context_lengths: [1000, 2000, 3000, 4000, 5000] + patch_sizes: [32, 64] \ No newline at end of file diff --git a/cli/conf/seasonal_naive_finetune/val_data/etth1_multi.yaml b/cli/conf/seasonal_naive_finetune/val_data/etth1_multi.yaml new file mode 100644 index 0000000..56e19ae --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/val_data/etth1_multi.yaml @@ -0,0 +1,16 @@ +- _target_: uni2ts.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh1_eval + offset: 11520 + windows: 10 + distance: 96 + prediction_length: 96 + context_length: 1000 + patch_size: 32 +- _target_: uni2ts.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh1_eval + offset: 11520 + windows: 10 + distance: 192 + prediction_length: 192 + context_length: 1000 + patch_size: 32 \ No newline at end of file diff --git a/cli/conf/seasonal_naive_finetune/val_data/etth2.yaml b/cli/conf/seasonal_naive_finetune/val_data/etth2.yaml new file mode 100644 index 0000000..5fc653c --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/val_data/etth2.yaml @@ -0,0 +1,13 @@ +_target_: uni2ts.data.builder.ConcatDatasetBuilder +_args_: + _target_: uni2ts.data.builder.simple.generate_eval_builders + dataset: ETTh2_eval + offset: 8640 # Same as _lsf_dataset.py + eval_length: 2880 # Same as _lsf_dataset.py + prediction_lengths: ??? + context_lengths: ??? + patch_sizes: ??? + +# prediction_lengths: [ 96, 192, 336, 720 ] +# context_lengths: [ 3000 ] +# patch_sizes: [ 32, 64 ] diff --git a/cli/conf/seasonal_naive_finetune/val_data/ettm1.yaml b/cli/conf/seasonal_naive_finetune/val_data/ettm1.yaml new file mode 100644 index 0000000..e6a15b4 --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/val_data/ettm1.yaml @@ -0,0 +1,14 @@ +_target_: uni2ts.data.builder.ConcatDatasetBuilder +_args_: + _target_: uni2ts.data.builder.simple.generate_eval_builders + dataset: ETTm1_eval + offset: 34560 # Same as _lsf_dataset.py + eval_length: 11520 # Same as _lsf_dataset.py + prediction_lengths: ??? + context_lengths: ??? + patch_sizes: ??? + + +# prediction_lengths: [96, 192, 336, 720] +# context_lengths: [ 3000 ] +# patch_sizes: [ 32, 64, 128 ] # freq="15T" \ No newline at end of file diff --git a/cli/conf/seasonal_naive_finetune/val_data/ettm2.yaml b/cli/conf/seasonal_naive_finetune/val_data/ettm2.yaml new file mode 100644 index 0000000..cb070fd --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/val_data/ettm2.yaml @@ -0,0 +1,14 @@ +_target_: uni2ts.data.builder.ConcatDatasetBuilder +_args_: + _target_: uni2ts.data.builder.simple.generate_eval_builders + dataset: ETTm2_eval + offset: 34560 # Same as _lsf_dataset.py + eval_length: 11520 # Same as _lsf_dataset.py + prediction_lengths: ??? + context_lengths: ??? + patch_sizes: ??? + + +# prediction_lengths: [96, 192, 336, 720] +# context_lengths: [3000] +# patch_sizes: [32, 64, 128] # "freq=15T" diff --git a/cli/conf/seasonal_naive_finetune/val_data/weather.yaml b/cli/conf/seasonal_naive_finetune/val_data/weather.yaml new file mode 100644 index 0000000..8f1973e --- /dev/null +++ b/cli/conf/seasonal_naive_finetune/val_data/weather.yaml @@ -0,0 +1,9 @@ +_target_: uni2ts.data.builder.ConcatDatasetBuilder +_args_: + _target_: uni2ts.data.builder.simple.generate_eval_builders + dataset: weather_eval + offset: 36887 # Same as _lsf_dataset.py + eval_length: 5269 # Same as _lsf_dataset.py; test_length=10539 + prediction_lengths: ??? + context_lengths: ??? + patch_sizes: ??? \ No newline at end of file diff --git a/src/uni2ts/model/seasonal_naive_moirai/__init__.py b/src/uni2ts/model/seasonal_naive_moirai/__init__.py new file mode 100644 index 0000000..10bbe5b --- /dev/null +++ b/src/uni2ts/model/seasonal_naive_moirai/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .finetune import MoiraiFinetune +from .forecast import MoiraiForecast +from .module import MoiraiModule + + +__all__ = [ + "MoiraiFinetune", + "MoiraiForecast", + "MoiraiModule", +] diff --git a/src/uni2ts/model/seasonal_naive_moirai/finetune.py b/src/uni2ts/model/seasonal_naive_moirai/finetune.py new file mode 100644 index 0000000..4f702bb --- /dev/null +++ b/src/uni2ts/model/seasonal_naive_moirai/finetune.py @@ -0,0 +1,790 @@ +# Copyright (c) 2024, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import defaultdict +from collections.abc import Callable, Sequence +from typing import Any, Optional + +import lightning as L +import numpy as np +import torch +from jaxtyping import Bool, Float, Int +from torch import nn +from torch.distributions import Distribution + +from src.uni2ts.model.moirai.module import MoiraiModule +from uni2ts.distribution import StudentTOutput +from uni2ts.loss.packed import ( + PackedDistributionLoss, + PackedLoss, + PackedNLLLoss, + PackedPointLoss, +) +from uni2ts.module.norm import RMSNorm +from uni2ts.module.position import ( + BinaryAttentionBias, + LearnedEmbedding, + LearnedProjection, +) +from uni2ts.module.ts_embed import MultiInSizeLinear, MultiOutSizeLinear +from uni2ts.optim import SchedulerType, get_scheduler +from uni2ts.transform import ( + AddObservedMask, + AddSeasonalNaiveTarget, + AddTimeIndex, + AddVariateIndex, + DefaultPatchSizeConstraints, + DummyValueImputation, + EvalMaskedPrediction, + EvalPad, + ExtendMask, + FixedPatchSizeConstraints, + FlatPackCollection, + FlatPackFields, + GetPatchSize, + GetSeasonalNaivePrediction, + Identity, + ImputeTimeSeries, + MaskedPrediction, + MaskedPredictionGivenFixedConfig, + MaskOutRangePaddedTokens, + PackFields, + PatchCrop, + PatchCropGivenFixedConfig, + Patchify, + SeasonalNaiveEvalCrop, + SelectFields, + SequencifyField, + Transformation, +) + + +class MoiraiFinetune(L.LightningModule): + seq_fields: tuple[str, ...] = ( + "target", + "observed_mask", + "time_id", + "variate_id", + "prediction_mask", + "patch_size", + "naive_target", + ) + pad_func_map: dict[str, Callable[[Sequence[int], np.dtype], np.ndarray]] = { + "target": np.zeros, + "observed_mask": np.zeros, + "time_id": np.zeros, + "variate_id": np.zeros, + "prediction_mask": np.zeros, + "patch_size": np.zeros, + "naive_target": np.zeros, + } + + def __init__( + self, + min_patches: int, + min_mask_ratio: float, + max_mask_ratio: float, + max_dim: int, + num_training_steps: int, + num_warmup_steps: int, + module_kwargs: Optional[dict[str, Any]] = None, + module: Optional[MoiraiModule] = None, + num_samples: int = 100, + beta1: float = 0.9, + beta2: float = 0.98, + loss_func: PackedDistributionLoss = PackedNLLLoss(), + val_metric: Optional[PackedLoss | list[PackedLoss]] = None, + lr: float = 1e-3, + weight_decay: float = 1e-2, + log_on_step: bool = False, + context_length: Optional[int | list[int]] = None, + prediction_length: Optional[int | list[int]] = None, + patch_size: Optional[int] = None, + finetune_pattern: str | list[str] = "full", + replace_distr_output: bool = False, + apply_seasonal_naive: bool = False, + # full + # in_proj + # param_proj + # norm: norm1 norm2 + # mask_encoding + # self_attn: q_proj, k_proj, v_proj, q_norm, k_norm, var_attn_bias, out_proj + # ffn: 2 * fc + 1 fc_gating. + # No PE, implicitly included in q_proj & k_proj as RoPE + # Except in_proj & param_poj, other params only have weight, without bias. + ): + assert (module is not None) or ( + module_kwargs is not None + ), "if module is not provided, module_kwargs is required" + assert ( + num_warmup_steps <= num_training_steps + ), f"num_warmup_steps ({num_warmup_steps}) should be <= num_training_steps ({num_training_steps})." + super().__init__() + self.save_hyperparameters(ignore=["module"]) + self.module = ( + MoiraiModule(**module_kwargs) if module is None else module + ) # ToDo: revise masking in module + + self.context_length = context_length + self.prediction_length = prediction_length + self.patch_size = patch_size + self.finetune_pattern = finetune_pattern + self.apply_seasonal_naive = apply_seasonal_naive + + def replace_distr_output(self): + assert ( + "full" in self.finetune_pattern or "param_proj" in self.finetune_pattern + ), "Must finetune param_proj if replace distr_output" + pretraiend_param_proj = self.module.param_proj + pretraiend_param_proj_student_t = pretraiend_param_proj.proj["components"][0] + self.module.distr_output = StudentTOutput() + self.module.param_proj = self.module.distr_output.get_param_proj( + self.module.d_model, self.module.patch_sizes + ) + self.module.param_proj.proj = pretraiend_param_proj_student_t + + def forward( + self, + target: Float[torch.Tensor, "*batch seq_len max_patch"], + observed_mask: Bool[torch.Tensor, "*batch seq_len max_patch"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + time_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + patch_size: Int[torch.Tensor, "*batch seq_len"], + ) -> Distribution: + distr = self.module( + target=target, + observed_mask=observed_mask, + sample_id=sample_id, + time_id=time_id, + variate_id=variate_id, + prediction_mask=prediction_mask, + patch_size=patch_size, + ) + return distr + + def training_step( + self, batch: dict[str, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + + if self.apply_seasonal_naive: + distr = self( + **{ + field: batch[field if field != "target" else "naive_target"] + for field in list(self.seq_fields) + ["sample_id"] + if field != "naive_target" + } + ) + else: + distr = self( + **{ + field: batch[field] + for field in list(self.seq_fields) + ["sample_id"] + } + ) + loss = self.hparams.loss_func( + pred=distr, + **{ + field: batch[field] + for field in [ + "target", + "prediction_mask", + "observed_mask", + "sample_id", + "variate_id", + ] + }, + ) + batch_size = ( + batch["sample_id"].max(dim=1).values.sum() if "sample_id" in batch else None + ) + self.log( + f"train/{self.hparams.loss_func.__class__.__name__}", + loss, + on_step=self.hparams.log_on_step, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + batch_size=batch_size, + rank_zero_only=True, + ) + return loss + + def validation_step( + self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int = 0 + ) -> torch.Tensor: + if self.apply_seasonal_naive: + distr = self( + **{ + field: batch[field if field != "target" else "naive_target"] + for field in list(self.seq_fields) + ["sample_id"] + if field != "naive_target" + } + ) + else: + distr = self( + **{ + field: batch[field] + for field in list(self.seq_fields) + ["sample_id"] + } + ) + val_loss = self.hparams.loss_func( + pred=distr, + **{ + field: batch[field] + for field in [ + "target", + "prediction_mask", + "observed_mask", + "sample_id", + "variate_id", + ] + }, + ) + batch_size = ( + batch["sample_id"].max(dim=1).values.sum() if "sample_id" in batch else None + ) + self.log( + f"val/{self.hparams.loss_func.__class__.__name__}", + val_loss, + on_step=self.hparams.log_on_step, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + batch_size=batch_size, + rank_zero_only=True, + ) + + if self.hparams.val_metric is not None: + val_metrics = ( + self.hparams.val_metric + if isinstance(self.hparams.val_metric, list) + else [self.hparams.val_metric] + ) + for metric_func in val_metrics: + if isinstance(metric_func, PackedPointLoss): + pred = distr.sample(torch.Size((self.hparams.num_samples,))) + pred = torch.median(pred, dim=0).values + elif isinstance(metric_func, PackedDistributionLoss): + pred = distr + else: + raise ValueError(f"Unsupported loss function: {metric_func}") + + metric = metric_func( + pred=pred, + **{ + field: batch[field] + for field in [ + "target", + "prediction_mask", + "observed_mask", + "sample_id", + "variate_id", + ] + }, + ) + + self.log( + f"val/{metric_func.__class__.__name__}", + metric, + on_step=self.hparams.log_on_step, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + batch_size=batch_size, + rank_zero_only=True, + ) + + return val_loss + + def configure_optimizers(self) -> dict: + decay = set() + no_decay = set() + + if self.finetune_pattern == "full": + pass + else: + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze the corresponding params + if "param_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "param_proj" in pn: + p.requires_grad = True + + if "in_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "in_proj" in pn: + p.requires_grad = True + + if "norm" in self.finetune_pattern: # + for pn, p in self.named_parameters(): + if "norm1" in pn or "norm2" in pn: + p.requires_grad = True + + if "mask" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "mask_encoding" in pn: + p.requires_grad = True + + if "ffn" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "ffn" in pn: + p.requires_grad = True + + if "q_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "q_proj" in pn: + p.requires_grad = True + + if "k_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "k_proj" in pn: + p.requires_grad = True + + if "v_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "v_proj" in pn: + p.requires_grad = True + + if "attn_norm" in self.finetune_pattern: # + for pn, p in self.named_parameters(): + if "self_attn.q_norm" in pn or "self_attn.k_norm" in pn: + p.requires_grad = True + + if "var_attn_bias" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "var_attn_bias" in pn: + p.requires_grad = True + + if "out_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "out_proj" in pn: + p.requires_grad = True + + if "studentT" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if ( + "param_proj.proj.components.0" in pn + or "param_proj.proj.weights_logits" in pn + ): + p.requires_grad = True + + whitelist_params = ( + LearnedProjection, + MultiInSizeLinear, + MultiOutSizeLinear, + nn.Linear, + ) + blacklist_params = ( + BinaryAttentionBias, + LearnedEmbedding, + RMSNorm, + nn.Embedding, + nn.LayerNorm, + ) + + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + if not p.requires_grad: + continue + + fpn = f"{mn}.{pn}" if mn else pn + if pn.endswith("bias"): + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_params): + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_params): + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} + self.trainable_params = param_dict + + inter_params = decay & no_decay + union_params = decay | no_decay + assert ( + len(inter_params) == 0 + ), f"parameters {str(inter_params)} made it into both decay/no_decay sets!" + assert ( + len(param_dict.keys() - union_params) == 0 + ), f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" + + optim_groups = [ + { + "params": filter( + lambda p: p.requires_grad, + [param_dict[pn] for pn in sorted(list(decay))], + ), + "weight_decay": self.hparams.weight_decay, + }, + { + "params": filter( + lambda p: p.requires_grad, + [param_dict[pn] for pn in sorted(list(no_decay))], + ), + "weight_decay": 0.0, + }, + ] + + optimizer = torch.optim.AdamW( + optim_groups, + lr=self.hparams.lr, + betas=(self.hparams.beta1, self.hparams.beta2), + eps=1e-6, + ) + scheduler = get_scheduler( + SchedulerType.COSINE_WITH_RESTARTS, + optimizer, + num_warmup_steps=self.hparams.num_warmup_steps, + num_training_steps=self.hparams.num_training_steps, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "train_loss", + "interval": "step", + }, + } + + @property + def train_transform_map( + self, + ) -> dict[str | type, Callable[..., Transformation]]: + def default_train_transform(): + return ( + GetPatchSize( + min_time_patches=self.hparams.min_patches, + target_field="target", + patch_sizes=self.module.patch_sizes, + patch_size_constraints=( + DefaultPatchSizeConstraints() + if self.patch_size is None + else FixedPatchSizeConstraints(self.patch_size) + ), + offset=True, + ) + + ( + PatchCrop( + min_time_patches=self.hparams.min_patches, + max_patches=self.module.max_seq_len, + will_flatten=True, + offset=True, + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + ) + if self.context_length is None or self.prediction_length is None + else PatchCropGivenFixedConfig( + context_length=self.context_length, + prediction_length=self.prediction_length, + will_flatten=True, + offset=True, + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + ) + ) + + PackFields( + output_field="target", + fields=("target",), + ) + + PackFields( + output_field="past_feat_dynamic_real", + fields=tuple(), + optional_fields=("past_feat_dynamic_real",), + ) + # Set the padded tokens in context (1st patch) and in prediction (the last patch) as Nan. + + ( + Identity() + if self.context_length is None or self.prediction_length is None + else MaskOutRangePaddedTokens( + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + ) + ) + + AddObservedMask( + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + observed_mask_field="observed_mask", + collection_type=dict, + ) + + ImputeTimeSeries( + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + imputation_method=DummyValueImputation(value=0.0), + ) + + ( + GetSeasonalNaivePrediction( + naive_prediction_field="naive_prediction" + ) + if self.apply_seasonal_naive + else Identity() + ) + + Patchify( + max_patch_size=max(self.module.patch_sizes), + fields=("target", "observed_mask"), + optional_fields=("past_feat_dynamic_real",), + ) + + ( + AddSeasonalNaiveTarget( + max_patch_size=max(self.module.patch_sizes), + naive_target_field="naive_target", + naive_prediction_field="naive_prediction", + ) + if self.apply_seasonal_naive + else Identity() + ) + + AddVariateIndex( + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + variate_id_field="variate_id", + expected_ndim=3, + max_dim=self.hparams.max_dim, + randomize=True, + collection_type=dict, + ) + + AddTimeIndex( + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + time_id_field="time_id", + expected_ndim=3, + collection_type=dict, + ) + + ( + MaskedPrediction( + min_mask_ratio=self.hparams.min_mask_ratio, + max_mask_ratio=self.hparams.max_mask_ratio, + target_field="target", + truncate_fields=("variate_id", "time_id", "observed_mask"), + optional_truncate_fields=("past_feat_dynamic_real",), + prediction_mask_field="prediction_mask", + expected_ndim=3, + ) + if self.context_length is None or self.prediction_length is None + else MaskedPredictionGivenFixedConfig( + target_field="target", + truncate_fields=("variate_id", "time_id", "observed_mask"), + optional_truncate_fields=("past_feat_dynamic_real",), + prediction_mask_field="prediction_mask", + expected_ndim=3, + ) + ) + + ExtendMask( + fields=tuple(), + optional_fields=("past_feat_dynamic_real",), + mask_field="prediction_mask", + expected_ndim=3, + ) + + FlatPackCollection( + field="variate_id", + feat=False, + ) + + FlatPackCollection( + field="time_id", + feat=False, + ) + + FlatPackCollection( + field="prediction_mask", + feat=False, + ) + + FlatPackCollection( + field="observed_mask", + feat=True, + ) + + FlatPackFields( + output_field="target", + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + feat=True, + ) + + ( + FlatPackFields( + output_field="naive_target", + fields=("naive_target",), + feat=True, + ) + if self.apply_seasonal_naive + else Identity() + ) + + SequencifyField(field="patch_size", target_field="target") + + SelectFields(fields=list(self.seq_fields)) + ) + + return defaultdict(lambda: default_train_transform) + + @property + def val_transform_map( + self, + ) -> dict[str | type, Callable[..., Transformation]]: + def default_val_transform( + offset: int, + distance: int, + prediction_length: int, + context_length: int, + patch_size: int, + ): + return ( + GetPatchSize( + min_time_patches=2, + target_field="target", + patch_sizes=self.module.patch_sizes, + patch_size_constraints=FixedPatchSizeConstraints(patch_size), + offset=True, + ) + + SeasonalNaiveEvalCrop( + offset, + distance, + prediction_length, + context_length, + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + ) + + PackFields( + output_field="target", + fields=("target",), + ) + + PackFields( + output_field="past_feat_dynamic_real", + fields=tuple(), + optional_fields=("past_feat_dynamic_real",), + ) + + EvalPad( + prediction_pad=-prediction_length % patch_size, + context_pad=-context_length % patch_size, + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + ) + + AddObservedMask( + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + observed_mask_field="observed_mask", + collection_type=dict, + ) + + ImputeTimeSeries( + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + imputation_method=DummyValueImputation(value=0.0), + ) + + ( + GetSeasonalNaivePrediction( + naive_prediction_field="naive_prediction" + ) + if self.apply_seasonal_naive + else Identity() + ) + + Patchify( + max_patch_size=max(self.module.patch_sizes), + fields=("target", "observed_mask"), + optional_fields=("past_feat_dynamic_real",), + ) + + ( + AddSeasonalNaiveTarget( + max_patch_size=max(self.module.patch_sizes), + naive_target_field="naive_target", + naive_prediction_field="naive_prediction", + ) + if self.apply_seasonal_naive + else Identity() + ) + + AddVariateIndex( + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + variate_id_field="variate_id", + expected_ndim=3, + max_dim=self.hparams.max_dim, + randomize=True, + collection_type=dict, + ) + + AddTimeIndex( + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + time_id_field="time_id", + expected_ndim=3, + collection_type=dict, + ) + + EvalMaskedPrediction( + mask_length=math.ceil(prediction_length / patch_size), + target_field="target", + truncate_fields=("variate_id", "time_id", "observed_mask"), + optional_truncate_fields=("past_feat_dynamic_real",), + prediction_mask_field="prediction_mask", + expected_ndim=3, + ) + + ExtendMask( + fields=tuple(), + optional_fields=("past_feat_dynamic_real",), + mask_field="prediction_mask", + expected_ndim=3, + ) + + FlatPackCollection( + field="variate_id", + feat=False, + ) + + FlatPackCollection( + field="time_id", + feat=False, + ) + + FlatPackCollection( + field="prediction_mask", + feat=False, + ) + + FlatPackCollection( + field="observed_mask", + feat=True, + ) + + FlatPackFields( + output_field="target", + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + feat=True, + ) + + ( + FlatPackFields( + output_field="naive_target", + fields=("naive_target",), + feat=True, + ) + if self.apply_seasonal_naive + else Identity() + ) + + SequencifyField(field="patch_size", target_field="target") + + SelectFields(fields=list(self.seq_fields)) + ) + + return defaultdict(lambda: default_val_transform) + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + """ + Modify state_dict to only save trainable params. + Note the default state_dict saved by PL converts all params to require_grads=False + """ + state = super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + filtered_state = { + name: tensor + for name, tensor in state.items() + if name in self.trainable_params + } + return filtered_state + + +class MoiraiLinearProbe(MoiraiFinetune): ... diff --git a/src/uni2ts/model/seasonal_naive_moirai/forecast.py b/src/uni2ts/model/seasonal_naive_moirai/forecast.py new file mode 100644 index 0000000..a7c235d --- /dev/null +++ b/src/uni2ts/model/seasonal_naive_moirai/forecast.py @@ -0,0 +1,1056 @@ +# Copyright (c) 2024, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from contextlib import contextmanager +from copy import deepcopy +from typing import Any, Generator, Optional + +import lightning as L +import numpy as np +import torch +from einops import rearrange, reduce, repeat +from gluonts.model import Input, InputSpec +from gluonts.torch import PyTorchPredictor +from gluonts.transform import ( + AddObservedValuesIndicator, + AsNumpyArray, + ExpandDimArray, + TestSplitSampler, + Transformation, +) +from gluonts.transform.split import TFTInstanceSplitter +from jaxtyping import Bool, Float, Int +from torch.distributions import Distribution + +from src.uni2ts.model.moirai.module import MoiraiModule +from uni2ts.common.torch_util import safe_div +from uni2ts.distribution import StudentTOutput +from uni2ts.loss.packed import PackedNLLLoss as _PackedNLLLoss + + +def seasonal_naive_predict_torch( + past_target: torch.Tensor, prediction_length: int +) -> torch.Tensor: + """ + Apply seasonal naive prediction using the past target time series (Torch version). + + Args: + past_target (torch.Tensor): A tensor of shape (batch, past_time, tgt) containing the historical data. + prediction_length (int): The length of the prediction horizon. + + Returns: + torch.Tensor: Forecasted values of shape (batch, prediction_length, tgt) using seasonal naive method. + """ + batch_size, past_time, tgt = ( + past_target.shape + ) # Extract batch shape and target dimensions + + # Initialize future_target tensor + future_target = torch.zeros( + (batch_size, prediction_length, tgt), + dtype=past_target.dtype, + device=past_target.device, + ) + + # Iterate over the batch + for b in range(batch_size): + for i in range(tgt): + # Convert the past time series into numpy array for FFT processing + past_series_np = past_target[b, :, i].cpu().numpy() + + # Compute FFT on the past series + fft_vals = np.fft.fft(past_series_np) + freqs = np.fft.fftfreq(past_time) + + # Discard freq=0 (the mean component) and find the dominant frequency + fft_vals = fft_vals[1:] + freqs = freqs[1:] + dominant_freq = freqs[np.argmax(np.abs(fft_vals))] + + # Compute the period length from the dominant frequency + period = int(np.abs(1 / dominant_freq)) + + # If no periodicity in context, use the last time points for forecasting. + # ToDo: For now, we only consider the case that context is longer than predicton + # Generate the future_target using seasonal naive prediction + if period == past_time: + # No clear seasonality, copy the last `prediction_length` points + future_target[b, :, i] = past_target[b, -prediction_length:, i] + else: + # Apply seasonal naive prediction logic for `prediction_length` time steps + for t in range(prediction_length): + future_target[b, t, i] = past_target[ + b, (past_time - period + (t % period)) % past_time, i + ] + + return future_target + + +class SampleNLLLoss(_PackedNLLLoss): + def reduce_loss( + self, + loss: Float[torch.Tensor, "batch seq_len #dim"], + prediction_mask: Optional[Bool[torch.Tensor, "batch seq_len"]], + observed_mask: Optional[Bool[torch.Tensor, "batch seq_len #dim"]], + sample_id: Optional[Int[torch.Tensor, "batch seq_len"]], + variate_id: Optional[Int[torch.Tensor, "batch seq_len"]], + ) -> Float[torch.Tensor, "batch"]: + id_mask = torch.logical_and( + torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)), + torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)), + ) + mask = prediction_mask.unsqueeze(-1) * observed_mask + tobs = reduce( + id_mask + * reduce( + mask, + "... seq dim -> ... 1 seq", + "sum", + ), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + loss = safe_div(loss, tobs) + return (loss * mask).sum(dim=(-1, -2)) + + +class MoiraiForecast(L.LightningModule): + def __init__( + self, + prediction_length: int, + target_dim: int, + feat_dynamic_real_dim: int, + past_feat_dynamic_real_dim: int, + context_length: int, + module_kwargs: Optional[dict[str, Any]] = None, + module: Optional[MoiraiModule] = None, + patch_size: int | str = "auto", + num_samples: int = 100, + pretrained_checkpoint_path: str = None, + replace_distr_output: bool = False, + ): + assert (module is not None) or ( + module_kwargs is not None + ), "if module is not provided, module_kwargs is required" + super().__init__() + self.save_hyperparameters(ignore=["module"]) + self.module = MoiraiModule(**module_kwargs) if module is None else module + self.per_sample_loss_func = SampleNLLLoss() + self.strict_loading = False + + if replace_distr_output: + self.replace_distr_output() + + def replace_distr_output(self): + pretraiend_param_proj = self.module.param_proj + pretraiend_param_proj_student_t = pretraiend_param_proj.proj["components"][0] + self.module.distr_output = StudentTOutput() + self.module.param_proj = self.module.distr_output.get_param_proj( + self.module.d_model, self.module.patch_sizes + ) + self.module.param_proj.proj = pretraiend_param_proj_student_t + + @contextmanager + def hparams_context( + self, + prediction_length: Optional[int] = None, + target_dim: Optional[int] = None, + feat_dynamic_real_dim: Optional[int] = None, + past_feat_dynamic_real_dim: Optional[int] = None, + context_length: Optional[int] = None, + patch_size: Optional[int | str] = None, + num_samples: Optional[int] = None, + ) -> Generator["MoiraiForecast", None, None]: + kwargs = { + "prediction_length": prediction_length, + "target_dim": target_dim, + "feat_dynamic_real_dim": feat_dynamic_real_dim, + "past_feat_dynamic_real_dim": past_feat_dynamic_real_dim, + "context_length": context_length, + "patch_size": patch_size, + "num_samples": num_samples, + } + old_hparams = deepcopy(self.hparams) + for kw, arg in kwargs.items(): + if arg is not None: + self.hparams[kw] = arg + + yield self + + for kw in kwargs: + self.hparams[kw] = old_hparams[kw] + + def create_predictor( + self, + batch_size: int, + device: str = "auto", + ) -> PyTorchPredictor: + ts_fields = [] + if self.hparams.feat_dynamic_real_dim > 0: + ts_fields.append("feat_dynamic_real") + ts_fields.append("observed_feat_dynamic_real") + past_ts_fields = [] + if self.hparams.past_feat_dynamic_real_dim > 0: + past_ts_fields.append("past_feat_dynamic_real") + past_ts_fields.append("past_observed_feat_dynamic_real") + instance_splitter = TFTInstanceSplitter( + instance_sampler=TestSplitSampler(), + past_length=self.past_length, + future_length=self.hparams.prediction_length, + observed_value_field="observed_target", + time_series_fields=ts_fields, + past_time_series_fields=past_ts_fields, + ) + return PyTorchPredictor( + input_names=self.prediction_input_names, + prediction_net=self, + batch_size=batch_size, + prediction_length=self.hparams.prediction_length, + input_transform=self.get_default_transform() + instance_splitter, + device=device, + ) + + def describe_inputs(self, batch_size: int = 1) -> InputSpec: + data = { + "past_target": Input( + shape=( + batch_size, + self.past_length, + self.hparams.target_dim, + ), + dtype=torch.float, + ), + "past_observed_target": Input( + shape=( + batch_size, + self.past_length, + self.hparams.target_dim, + ), + dtype=torch.bool, + ), + "past_is_pad": Input( + shape=(batch_size, self.past_length), + dtype=torch.bool, + ), + } + if self.hparams.feat_dynamic_real_dim > 0: + data["feat_dynamic_real"] = Input( + shape=( + batch_size, + self.past_length + self.hparams.prediction_length, + self.hparams.feat_dynamic_real_dim, + ), + dtype=torch.float, + ) + data["observed_feat_dynamic_real"] = Input( + shape=( + batch_size, + self.past_length + self.hparams.prediction_length, + self.hparams.feat_dynamic_real_dim, + ), + dtype=torch.bool, + ) + if self.hparams.past_feat_dynamic_real_dim > 0: + data["past_feat_dynamic_real"] = Input( + shape=( + batch_size, + self.past_length, + self.hparams.past_feat_dynamic_real_dim, + ), + dtype=torch.float, + ) + data["past_observed_feat_dynamic_real"] = Input( + shape=( + batch_size, + self.past_length, + self.hparams.past_feat_dynamic_real_dim, + ), + dtype=torch.bool, + ) + return InputSpec(data=data, zeros_fn=torch.zeros) + + @property + def prediction_input_names(self) -> list[str]: + return list(self.describe_inputs()) + + @property + def training_input_names(self): + return self.prediction_input_names + ["future_target", "future_observed_values"] + + @property + def past_length(self) -> int: + return ( + self.hparams.context_length + self.hparams.prediction_length + if self.hparams.patch_size == "auto" + else self.hparams.context_length + ) + + def context_token_length(self, patch_size: int) -> int: + return math.ceil(self.hparams.context_length / patch_size) + + def prediction_token_length(self, patch_size) -> int: + return math.ceil(self.hparams.prediction_length / patch_size) + + @property + def max_patch_size(self) -> int: + return max(self.module.patch_sizes) + + def forward( + self, + past_target: Float[torch.Tensor, "batch past_time tgt"], + past_observed_target: Bool[torch.Tensor, "batch past_time tgt"], + past_is_pad: Bool[torch.Tensor, "batch past_time"], + feat_dynamic_real: Optional[Float[torch.Tensor, "batch time feat"]] = None, + observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch time feat"] + ] = None, + past_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + past_observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + num_samples: Optional[int] = None, + ) -> Float[torch.Tensor, "batch sample future_time *tgt"]: + if self.hparams.patch_size == "auto": + val_loss = [] + preds = [] + for patch_size in self.module.patch_sizes: + val_loss.append( + self._val_loss( + patch_size=patch_size, + target=past_target[..., : self.past_length, :], + observed_target=past_observed_target[ + ..., : self.past_length, : + ], + is_pad=past_is_pad[..., : self.past_length], + feat_dynamic_real=( + feat_dynamic_real[..., : self.past_length, :] + if feat_dynamic_real is not None + else None + ), + observed_feat_dynamic_real=( + observed_feat_dynamic_real[..., : self.past_length, :] + if observed_feat_dynamic_real is not None + else None + ), + past_feat_dynamic_real=( + past_feat_dynamic_real[ + ..., : self.hparams.context_length, : + ] + if past_feat_dynamic_real is not None + else None + ), + past_observed_feat_dynamic_real=( + past_observed_feat_dynamic_real[ + ..., : self.hparams.context_length, : + ] + if past_observed_feat_dynamic_real is not None + else None + ), + ) + ) + distr = self._get_distr( + patch_size, + past_target[..., -self.hparams.context_length :, :], + past_observed_target[..., -self.hparams.context_length :, :], + past_is_pad[..., -self.hparams.context_length :], + ( + feat_dynamic_real[..., -self.past_length :, :] + if feat_dynamic_real is not None + else None + ), + ( + observed_feat_dynamic_real[..., -self.past_length :, :] + if observed_feat_dynamic_real is not None + else None + ), + ( + past_feat_dynamic_real[..., -self.hparams.context_length :, :] + if past_feat_dynamic_real is not None + else None + ), + ( + past_observed_feat_dynamic_real[ + ..., -self.hparams.context_length :, : + ] + if past_observed_feat_dynamic_real is not None + else None + ), + ) + preds.append( + self._format_preds( + patch_size, + distr.sample( + torch.Size((num_samples or self.hparams.num_samples,)) + ), + past_target.shape[-1], + ) + ) + val_loss = torch.stack(val_loss) + preds = torch.stack(preds) + idx = val_loss.argmin(dim=0) + return preds[idx, torch.arange(len(idx), device=idx.device)] + else: + distr = self._get_distr( + self.hparams.patch_size, + past_target, + past_observed_target, + past_is_pad, + feat_dynamic_real, + observed_feat_dynamic_real, + past_feat_dynamic_real, + past_observed_feat_dynamic_real, + ) + preds = distr.sample(torch.Size((num_samples or self.hparams.num_samples,))) + return self._format_preds( + self.hparams.patch_size, preds, past_target.shape[-1] + ) + + # ToDo: Need to revise _val_loss in the future. + def _val_loss( + self, + patch_size: int, + target: Float[torch.Tensor, "batch time tgt"], + observed_target: Bool[torch.Tensor, "batch time tgt"], + is_pad: Bool[torch.Tensor, "batch time"], + feat_dynamic_real: Optional[Float[torch.Tensor, "batch time feat"]] = None, + observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch time feat"] + ] = None, + past_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + past_observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + ) -> Float[torch.Tensor, "batch"]: + # convert format + ( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + ) = self._convert( + patch_size, + past_target=target[..., : self.hparams.context_length, :], + past_observed_target=observed_target[..., : self.hparams.context_length, :], + past_is_pad=is_pad[..., : self.hparams.context_length], + future_target=target[..., self.hparams.context_length :, :], + future_observed_target=observed_target[ + ..., self.hparams.context_length :, : + ], + future_is_pad=is_pad[..., self.hparams.context_length :], + feat_dynamic_real=feat_dynamic_real, + observed_feat_dynamic_real=observed_feat_dynamic_real, + past_feat_dynamic_real=past_feat_dynamic_real, + past_observed_feat_dynamic_real=past_observed_feat_dynamic_real, + ) + # get predictions + distr = self.module( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + torch.ones_like(time_id, dtype=torch.long) * patch_size, + ) + val_loss = self.per_sample_loss_func( + pred=distr, + target=target, + prediction_mask=prediction_mask, + observed_mask=observed_mask, + sample_id=sample_id, + variate_id=variate_id, + ) + return val_loss + + def _get_distr( + self, + patch_size: int, + past_target: Float[torch.Tensor, "batch past_time tgt"], + past_observed_target: Bool[torch.Tensor, "batch past_time tgt"], + past_is_pad: Bool[torch.Tensor, "batch past_time"], + feat_dynamic_real: Optional[Float[torch.Tensor, "batch time feat"]] = None, + observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch time feat"] + ] = None, + past_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + past_observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + ) -> Distribution: + # convert format. QZ: Target here is the seasonal naive target + ( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + ) = self._convert( + patch_size, + past_target, + past_observed_target, + past_is_pad, + feat_dynamic_real=feat_dynamic_real, + observed_feat_dynamic_real=observed_feat_dynamic_real, + past_feat_dynamic_real=past_feat_dynamic_real, + past_observed_feat_dynamic_real=past_observed_feat_dynamic_real, + ) + + # get predictions + distr = self.module( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + torch.ones_like(time_id, dtype=torch.long) * patch_size, + ) + return distr + + @staticmethod + def _patched_seq_pad( + patch_size: int, + x: torch.Tensor, + dim: int, + left: bool = True, + value: Optional[float] = None, + ) -> torch.Tensor: + if dim >= 0: + dim = -x.ndim + dim + pad_length = -x.size(dim) % patch_size + if left: + pad = (pad_length, 0) + else: + pad = (0, pad_length) + pad = (0, 0) * (abs(dim) - 1) + pad + return torch.nn.functional.pad(x, pad, value=value) + + def _generate_time_id( + self, + patch_size: int, + past_observed_target: Bool[torch.Tensor, "batch past_seq tgt"], + ) -> tuple[ + Int[torch.Tensor, "batch past_token"], Int[torch.Tensor, "batch future_token"] + ]: + past_seq_id = reduce( + self._patched_seq_pad(patch_size, past_observed_target, -2, left=True), + "... (seq patch) dim -> ... seq", + "max", + patch=patch_size, + ) + past_seq_id = torch.clamp(past_seq_id.cumsum(dim=-1) - 1, min=0) + batch_shape = " ".join(map(str, past_observed_target.shape[:-2])) + future_seq_id = ( + repeat( + torch.arange( + self.prediction_token_length(patch_size), + device=past_observed_target.device, + ), + f"prediction -> {batch_shape} prediction", + ) + + past_seq_id.max(dim=-1, keepdim=True).values + + 1 + ) + return past_seq_id, future_seq_id + + def _convert( + self, + patch_size: int, + past_target: Float[torch.Tensor, "batch past_time tgt"], + past_observed_target: Bool[torch.Tensor, "batch past_time tgt"], + past_is_pad: Bool[torch.Tensor, "batch past_time"], + future_target: Optional[Float[torch.Tensor, "batch future_time tgt"]] = None, + future_observed_target: Optional[ + Bool[torch.Tensor, "batch future_time tgt"] + ] = None, + future_is_pad: Optional[Bool[torch.Tensor, "batch future_time"]] = None, + feat_dynamic_real: Optional[Float[torch.Tensor, "batch time feat"]] = None, + observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch time feat"] + ] = None, + past_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + past_observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + ) -> tuple[ + Float[torch.Tensor, "batch combine_seq patch"], # target + Bool[torch.Tensor, "batch combine_seq patch"], # observed_mask + Int[torch.Tensor, "batch combine_seq"], # sample_id + Int[torch.Tensor, "batch combine_seq"], # time_id + Int[torch.Tensor, "batch combine_seq"], # variate_id + Bool[torch.Tensor, "batch combine_seq"], # prediction_mask + ]: + batch_shape = past_target.shape[:-2] + device = past_target.device + + target = [] + observed_mask = [] + sample_id = [] + time_id = [] + variate_id = [] + prediction_mask = [] + dim_count = 0 + + past_seq_id, future_seq_id = self._generate_time_id( + patch_size, past_observed_target + ) + + if ( + future_target is None + ): # QZ: Revise here to change future_target as seasonal naive prediction + # future_target = torch.zeros( + # batch_shape + # + ( + # self.hparams.prediction_length, + # past_target.shape[-1], + # ), + # dtype=past_target.dtype, + # device=device, + # ) + future_target = seasonal_naive_predict_torch( + past_target, self.hparams.prediction_length + ) + + target.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad(patch_size, past_target, -2, left=True), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, future_target, -2, left=False + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + ] + ) + if future_observed_target is None: + future_observed_target = torch.ones( + batch_shape + + ( + self.hparams.prediction_length, + past_observed_target.shape[-1], + ), + dtype=torch.bool, + device=device, + ) + observed_mask.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, past_observed_target, -2, left=True + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, future_observed_target, -2, left=False + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + ] + ) + if future_is_pad is None: + future_is_pad = torch.zeros( + batch_shape + (self.hparams.prediction_length,), + dtype=torch.long, + device=device, + ) + sample_id.extend( + [ + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, past_is_pad, -1, left=True, value=1 + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=past_target.shape[-1], + ), + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, future_is_pad, -1, left=False, value=1 + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=past_target.shape[-1], + ), + ] + ) + time_id.extend( + [past_seq_id] * past_target.shape[-1] + + [future_seq_id] * past_target.shape[-1] + ) + variate_id.extend( + [ + repeat( + torch.arange(past_target.shape[-1], device=device) + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim past)", + past=self.context_token_length(patch_size), + ), + repeat( + torch.arange(past_target.shape[-1], device=device) + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim future)", + future=self.prediction_token_length(patch_size), + ), + ] + ) + dim_count += past_target.shape[-1] + prediction_mask.extend( + [ + torch.zeros( + batch_shape + + (self.context_token_length(patch_size) * past_target.shape[-1],), + dtype=torch.bool, + device=device, + ), + torch.ones( + batch_shape + + ( + self.prediction_token_length(patch_size) + * past_target.shape[-1], + ), + dtype=torch.bool, + device=device, + ), + ] + ) + + if feat_dynamic_real is not None: + if observed_feat_dynamic_real is None: + raise ValueError( + "observed_feat_dynamic_real must be provided if feat_dynamic_real is provided" + ) + + target.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + feat_dynamic_real[ + ..., : self.hparams.context_length, : + ], + -2, + left=True, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + feat_dynamic_real[ + ..., self.hparams.context_length :, : + ], + -2, + left=False, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + ] + ) + observed_mask.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + observed_feat_dynamic_real[ + ..., : self.hparams.context_length, : + ], + -2, + left=True, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + observed_feat_dynamic_real[ + ..., self.hparams.context_length :, : + ], + -2, + left=False, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ), + ] + ) + sample_id.extend( + [ + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, past_is_pad, -1, left=True + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=feat_dynamic_real.shape[-1], + ), + torch.ones( + batch_shape + + ( + self.prediction_token_length(patch_size) + * feat_dynamic_real.shape[-1], + ), + dtype=torch.long, + device=device, + ), + ] + ) + time_id.extend( + [past_seq_id] * feat_dynamic_real.shape[-1] + + [future_seq_id] * feat_dynamic_real.shape[-1] + ) + variate_id.extend( + [ + repeat( + torch.arange(feat_dynamic_real.shape[-1], device=device) + + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim past)", + past=self.context_token_length(patch_size), + ), + repeat( + torch.arange(feat_dynamic_real.shape[-1], device=device) + + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim future)", + future=self.prediction_token_length(patch_size), + ), + ] + ) + dim_count += feat_dynamic_real.shape[-1] + prediction_mask.extend( + [ + torch.zeros( + batch_shape + + ( + self.context_token_length(patch_size) + * feat_dynamic_real.shape[-1], + ), + dtype=torch.bool, + device=device, + ), + torch.zeros( + batch_shape + + ( + self.prediction_token_length(patch_size) + * feat_dynamic_real.shape[-1], + ), + dtype=torch.bool, + device=device, + ), + ] + ) + + if past_feat_dynamic_real is not None: + if past_observed_feat_dynamic_real is None: + raise ValueError( + "past_observed_feat_dynamic_real must be provided if past_feat_dynamic_real is provided" + ) + target.append( + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, past_feat_dynamic_real, -2, left=True + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ) + ) + observed_mask.append( + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, past_observed_feat_dynamic_real, -2, left=True + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, self.max_patch_size - patch_size), + ) + ) + sample_id.append( + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, past_is_pad, -1, left=True + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=past_feat_dynamic_real.shape[-1], + ) + ) + time_id.extend([past_seq_id] * past_feat_dynamic_real.shape[-1]) + + variate_id.append( + repeat( + torch.arange(past_feat_dynamic_real.shape[-1], device=device) + + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim past)", + past=self.context_token_length(patch_size), + ) + ) + dim_count += past_feat_dynamic_real.shape[-1] + prediction_mask.append( + torch.zeros( + batch_shape + + ( + self.context_token_length(patch_size) + * past_feat_dynamic_real.shape[-1], + ), + dtype=torch.bool, + device=device, + ) + ) + + target = torch.cat(target, dim=-2) + observed_mask = torch.cat(observed_mask, dim=-2) + sample_id = torch.cat(sample_id, dim=-1) + time_id = torch.cat(time_id, dim=-1) + variate_id = torch.cat(variate_id, dim=-1) + prediction_mask = torch.cat(prediction_mask, dim=-1) + return ( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + ) + + def _format_preds( + self, + patch_size: int, + preds: Float[torch.Tensor, "sample batch combine_seq patch"], + target_dim: int, + ) -> Float[torch.Tensor, "batch sample future_time *tgt"]: + start = target_dim * self.context_token_length(patch_size) + end = start + target_dim * self.prediction_token_length(patch_size) + preds = preds[..., start:end, :patch_size] + preds = rearrange( + preds, + "sample ... (dim seq) patch -> ... sample (seq patch) dim", + dim=target_dim, + )[..., : self.hparams.prediction_length, :] + return preds.squeeze(-1) + + def get_default_transform(self) -> Transformation: + transform = AsNumpyArray( + field="target", + expected_ndim=1 if self.hparams.target_dim == 1 else 2, + dtype=np.float32, + ) + if self.hparams.target_dim == 1: + transform += ExpandDimArray(field="target", axis=0) + transform += AddObservedValuesIndicator( + target_field="target", + output_field="observed_target", + dtype=bool, + ) + + if self.hparams.feat_dynamic_real_dim > 0: + transform += AsNumpyArray( + field="feat_dynamic_real", + expected_ndim=2, + dtype=np.float32, + ) + transform += AddObservedValuesIndicator( + target_field="feat_dynamic_real", + output_field="observed_feat_dynamic_real", + dtype=bool, + ) + + if self.hparams.past_feat_dynamic_real_dim > 0: + transform += AsNumpyArray( + field="past_feat_dynamic_real", + expected_ndim=2, + dtype=np.float32, + ) + transform += AddObservedValuesIndicator( + target_field="past_feat_dynamic_real", + output_field="past_observed_feat_dynamic_real", + dtype=bool, + ) + return transform diff --git a/src/uni2ts/model/seasonal_naive_moirai/module.py b/src/uni2ts/model/seasonal_naive_moirai/module.py new file mode 100644 index 0000000..34f170a --- /dev/null +++ b/src/uni2ts/model/seasonal_naive_moirai/module.py @@ -0,0 +1,178 @@ +# Copyright (c) 2024, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch +import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin +from hydra.utils import instantiate +from jaxtyping import Bool, Float, Int +from torch import nn +from torch.distributions import Distribution +from torch.utils._pytree import tree_map + +from uni2ts.common.torch_util import mask_fill, packed_attention_mask +from uni2ts.distribution import DistributionOutput +from uni2ts.module.norm import RMSNorm +from uni2ts.module.packed_scaler import PackedNOPScaler, PackedStdScaler +from uni2ts.module.position import ( + BinaryAttentionBias, + QueryKeyProjection, + RotaryProjection, +) +from uni2ts.module.transformer import TransformerEncoder +from uni2ts.module.ts_embed import MultiInSizeLinear + + +def encode_distr_output( + distr_output: DistributionOutput, +) -> dict[str, str | float | int]: + """Serialization function for DistributionOutput""" + + def _encode(val): + if not isinstance(val, DistributionOutput): + return val + + return { + "_target_": f"{val.__class__.__module__}.{val.__class__.__name__}", + **tree_map(_encode, val.__dict__), + } + + return _encode(distr_output) + + +def decode_distr_output(config: dict[str, str | float | int]) -> DistributionOutput: + """Deserialization function for DistributionOutput""" + return instantiate(config, _convert_="all") + + +class MoiraiModule( + nn.Module, + PyTorchModelHubMixin, + coders={DistributionOutput: (encode_distr_output, decode_distr_output)}, +): + """ + Contains components of Moirai, to ensure implementation is identical across models. + Subclasses huggingface_hub.PyTorchModelHubMixin to support loading from HuggingFace Hub. + """ + + def __init__( + self, + distr_output: DistributionOutput, + d_model: int, + num_layers: int, + patch_sizes: tuple[int, ...], # tuple[int, ...] | list[int] + max_seq_len: int, + attn_dropout_p: float, + dropout_p: float, + scaling: bool = True, + ): + """ + :param distr_output: distribution output object + :param d_model: model hidden dimensions + :param num_layers: number of transformer layers + :param patch_sizes: sequence of patch sizes + :param max_seq_len: maximum sequence length for inputs + :param attn_dropout_p: dropout probability for attention layers + :param dropout_p: dropout probability for all other layers + :param scaling: whether to apply scaling (standardization) + """ + super().__init__() + self.d_model = d_model + self.num_layers = num_layers + self.patch_sizes = patch_sizes + self.max_seq_len = max_seq_len + self.scaling = scaling + + self.mask_encoding = nn.Embedding(num_embeddings=1, embedding_dim=d_model) + self.scaler = PackedStdScaler() if scaling else PackedNOPScaler() + self.in_proj = MultiInSizeLinear( + in_features_ls=patch_sizes, + out_features=d_model, + ) + self.encoder = TransformerEncoder( + d_model, + num_layers, + num_heads=None, + pre_norm=True, + attn_dropout_p=attn_dropout_p, + dropout_p=dropout_p, + norm_layer=RMSNorm, + activation=F.silu, + use_glu=True, + use_qk_norm=True, + var_attn_bias_layer=partial(BinaryAttentionBias), + time_qk_proj_layer=partial( + QueryKeyProjection, + proj_layer=RotaryProjection, + kwargs=dict(max_len=max_seq_len), + partial_factor=(0.0, 0.5), + ), + shared_var_attn_bias=False, + shared_time_qk_proj=True, + d_ff=None, + ) + self.distr_output = distr_output + self.param_proj = self.distr_output.get_param_proj(d_model, patch_sizes) + + def forward( + self, + target: Float[torch.Tensor, "*batch seq_len max_patch"], + observed_mask: Bool[torch.Tensor, "*batch seq_len max_patch"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + time_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + patch_size: Int[torch.Tensor, "*batch seq_len"], + ) -> Distribution: + """ + Defines the forward pass of MoiraiModule. + This method expects processed inputs. + + 1. Apply scaling to observations + 2. Project from observations to representations + 3. Replace prediction window with learnable mask + 4. Apply transformer layers + 5. Project from representations to distribution parameters + 6. Return distribution object + + :param target: input data + :param observed_mask: binary mask for missing values, 1 if observed, 0 otherwise + :param sample_id: indices indicating the sample index (for packing) + :param time_id: indices indicating the time index + :param variate_id: indices indicating the variate index + :param prediction_mask: binary mask for prediction horizon, 1 if part of the horizon, 0 otherwise + :param patch_size: patch size for each token + :return: predictive distribution + """ + loc, scale = self.scaler( + target, + observed_mask * ~prediction_mask.unsqueeze(-1), + sample_id, + variate_id, + ) + scaled_target = (target - loc) / scale + reprs = self.in_proj(scaled_target, patch_size) + masked_reprs = reprs + reprs = self.encoder( + masked_reprs, + packed_attention_mask(sample_id), + time_id=time_id, + var_id=variate_id, + ) + distr_param = self.param_proj(reprs, patch_size) + distr = self.distr_output.distribution(distr_param, loc=loc, scale=scale) + return distr diff --git a/src/uni2ts/transform/__init__.py b/src/uni2ts/transform/__init__.py index 349c201..0f92966 100644 --- a/src/uni2ts/transform/__init__.py +++ b/src/uni2ts/transform/__init__.py @@ -41,6 +41,11 @@ SequencifyField, Transpose, ) +from .seasonal_naive import ( + AddSeasonalNaiveTarget, + GetSeasonalNaivePrediction, + SeasonalNaiveEvalCrop, +) from .task import ( EvalMaskedPrediction, ExtendMask, @@ -89,4 +94,7 @@ "SetValue", "Transformation", "Transpose", + "GetSeasonalNaivePrediction", + "AddSeasonalNaiveTarget", + "SeasonalNaiveEvalCrop", ] diff --git a/src/uni2ts/transform/seasonal_naive.py b/src/uni2ts/transform/seasonal_naive.py new file mode 100644 index 0000000..26291ad --- /dev/null +++ b/src/uni2ts/transform/seasonal_naive.py @@ -0,0 +1,193 @@ +import math +import re +from collections.abc import Sequence +from dataclasses import dataclass +from functools import partial +from typing import Any + +import numpy as np +from einops import pack, rearrange +from jaxtyping import Bool, Float, Num + +from uni2ts.common.typing import UnivarTimeSeries + +from ._base import Transformation +from ._mixin import ( + AddNewArrMixin, + ApplyFuncMixin, + CheckArrNDimMixin, + CollectFuncMixin, + MapFuncMixin, +) + + +def seasonal_naive_predict(context: np.ndarray, prediction: np.ndarray) -> np.ndarray: + """ + Apply seasonal naive prediction to forecast the prediction time series based on the context. + + Args: + context (np.ndarray): Time series of shape (feat, context_len) that provides the historical data. + prediction (np.ndarray): Time series of shape (feat, prediction_len) to be predicted using seasonal naive method. + + Returns: + np.ndarray: Forecasted time series of shape (feat, prediction_len) using seasonal naive method. + """ + + # Get the number of features and lengths for both context and prediction + feat, context_len = context.shape + _, prediction_len = prediction.shape + + # Initialize the forecast array with the same shape as the prediction array + forecast = np.zeros_like(prediction) + + # Iterate through each feature separately + for i in range(feat): + # Compute FFT on the context to find the dominant period + fft_vals = np.fft.fft(context[i]) + freqs = np.fft.fftfreq(context_len) + + # Discard the freq=0 component by starting from index 1 + fft_vals = fft_vals[1:] + freqs = freqs[1:] + + # Identify the period by finding the frequency with the highest power + dominant_freq = freqs[np.argmax(np.abs(fft_vals))] + + # Compute the period length from the dominant frequency + period = int(np.abs(1 / dominant_freq)) + + # ToDo: For now, we only consider the case that context is longer than prediction + # If no periodicity in context, use the last time points for forecasting. + if period == context_len: + forecast = context[i, -prediction_len:] + else: + # Apply the seasonal naive method to forecast + for t in range(prediction_len): + # Forecast based on repeating the seasonal pattern + forecast[i, t] = context[ + i, (context_len - period + (t % period)) % context_len + ] + + return forecast + + +@dataclass +class GetSeasonalNaivePrediction(Transformation): + """ + Forecast the prediction range with SeasonalityNaive + """ + + naive_prediction_field: str + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + """ + target: ndarray of shape (feat, time). It has been padded to make sure it can be divided by patch_size. + """ + target = data_entry["target"] + context_length = data_entry["context_length"] + prediction_length = data_entry["prediction_length"] + patch_size = data_entry["patch_size"] + + context_pad = -context_length % patch_size + prediction_pad = -prediction_length % patch_size + + context = target[:, context_pad : context_pad + context_length] + + if prediction_pad == 0: + prediction = target[:, -prediction_length:] + else: + prediction = target[ + :, -(prediction_pad + prediction_length) : -prediction_pad + ] + + season_naive_prediction = seasonal_naive_predict(context, prediction) + + data_entry[self.naive_prediction_field] = season_naive_prediction + + return data_entry + + +@dataclass +class AddSeasonalNaiveTarget(Transformation): + max_patch_size: int + naive_target_field: str + naive_prediction_field: str + pad_value: int | float = 0 + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + target = data_entry["target"] + patch_size = data_entry["patch_size"] + num_pred_patches = data_entry["num_pred_patches"] + prediction_length = data_entry["prediction_length"] + season_naive_prediction = data_entry[self.naive_prediction_field] + + prediction_pad = -prediction_length % patch_size + + # Pad with zeros to get prediction patches + pad_width = [(0, 0) for _ in range(season_naive_prediction.ndim)] + pad_width[-1] = (0, prediction_pad) + season_naive_prediction = np.pad( + season_naive_prediction, pad_width, mode="constant", constant_values=0 + ) + + season_naive_prediction_patches = self._patchify_arr( + season_naive_prediction, patch_size + ) + target[:, -num_pred_patches:, :] = season_naive_prediction_patches + data_entry[self.naive_target_field] = target + + return data_entry + + def _patchify_arr( + self, arr: Num[np.ndarray, "var time*patch"], patch_size: int + ) -> Num[np.ndarray, "var time max_patch"]: + assert arr.shape[-1] % patch_size == 0 + arr = rearrange(arr, "... (time patch) -> ... time patch", patch=patch_size) + pad_width = [(0, 0) for _ in range(arr.ndim)] + pad_width[-1] = (0, self.max_patch_size - patch_size) + arr = np.pad(arr, pad_width, mode="constant", constant_values=self.pad_value) + return arr + + +@dataclass +class SeasonalNaiveEvalCrop(MapFuncMixin, Transformation): + offset: int + distance: int + prediction_length: int + context_length: int + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + a, b = self._get_boundaries(data_entry) + self.map_func( + partial(self._crop, a=a, b=b), # noqa + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + data_entry["context_length"] = self.context_length + data_entry["prediction_length"] = self.prediction_length + data_entry["num_pred_patches"] = math.ceil( + self.prediction_length / data_entry["patch_size"] + ) + return data_entry + + @staticmethod + def _crop(data_entry: dict[str, Any], field: str, a: int, b: int) -> Sequence: + return [ts[a : b or None] for ts in data_entry[field]] + + def _get_boundaries(self, data_entry: dict[str, Any]) -> tuple[int, int]: + field: list[UnivarTimeSeries] = data_entry[self.fields[0]] + time = field[0].shape[0] + window = data_entry["window"] + fcst_start = self.offset + window * self.distance + a = fcst_start - self.context_length + b = fcst_start + self.prediction_length + + if self.offset >= 0: + assert time >= b > a >= 0 + else: + assert 0 >= b > a >= -time + + return a, b