From e0a8eed531f5c068786519a3678a1daadba71a2c Mon Sep 17 00:00:00 2001 From: Qiao Zhongzheng Date: Tue, 3 Dec 2024 10:27:51 +0800 Subject: [PATCH] Revise learnable time id. Add 2 stage multiscale finetune. --- .../finetune_two_stage/data/electricity.yaml | 7 + .../finetune_two_stage/data/etth1.yaml | 7 + .../finetune_two_stage/data/etth2.yaml | 7 + .../finetune_two_stage/data/ettm1.yaml | 7 + .../finetune_two_stage/data/ettm2.yaml | 7 + .../finetune_two_stage/data/weather.yaml | 7 + .../finetune_two_stage/default.yaml | 115 +++ .../model/moirai_1.0_R_base.yaml | 51 ++ .../model/moirai_1.0_R_small.yaml | 51 ++ .../model/moirai_1.1_R_small.yaml | 51 ++ .../val_data/electricity.yaml | 8 + .../finetune_two_stage/val_data/etth1.yaml | 8 + .../finetune_two_stage/val_data/etth2.yaml | 8 + .../finetune_two_stage/val_data/ettm1.yaml | 8 + .../finetune_two_stage/val_data/ettm2.yaml | 8 + .../finetune_two_stage/val_data/weather.yaml | 8 + cli/train_two_stage.py | 204 +++++ .../lsf-setup/lsf/finetune/base/weather.sh | 4 +- .../lsf-setup/lsf/finetune/small/weather.sh | 4 +- .../lsf-setup/multi_scale/eval/small/ettm1.sh | 8 +- .../lsf-setup/multi_scale/eval/small/ettm2.sh | 10 +- .../multi_scale/eval/small/weather.sh | 17 +- .../multi_scale/finetune/base/weather.sh | 4 +- .../multi_scale/finetune/small/ettm2.sh | 2 +- .../multi_scale/finetune/small/weather.sh | 6 +- .../finetune_two_stage/base/electricity.sh | 36 + .../finetune_two_stage/base/etth1.sh | 35 + .../finetune_two_stage/base/etth2.sh | 35 + .../finetune_two_stage/base/ettm1.sh | 35 + .../finetune_two_stage/base/ettm2.sh | 35 + .../finetune_two_stage/base/run_multi.sh | 7 + .../finetune_two_stage/base/weather.sh | 35 + .../finetune_two_stage/small/electricity.sh | 36 + .../finetune_two_stage/small/etth1.sh | 35 + .../finetune_two_stage/small/etth2.sh | 35 + .../finetune_two_stage/small/ettm1.sh | 35 + .../finetune_two_stage/small/ettm2.sh | 35 + .../finetune_two_stage/small/run_multi.sh | 7 + .../finetune_two_stage/small/weather.sh | 37 + .../model/multi_scale_moirai/__init__.py | 2 + .../model/multi_scale_moirai/finetune.py | 18 +- .../multi_scale_moirai/finetune_two_stage.py | 823 ++++++++++++++++++ .../model/multi_scale_moirai/forecast.py | 2 + src/uni2ts/model/multi_scale_moirai/module.py | 75 +- src/uni2ts/module/multi_scale/attention.py | 1 + src/uni2ts/module/position/attn_projection.py | 68 +- 46 files changed, 1941 insertions(+), 103 deletions(-) create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/electricity.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/etth1.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/etth2.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/ettm1.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/ettm2.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/weather.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_base.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_small.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.1_R_small.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/electricity.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/etth1.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/etth2.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/ettm1.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/ettm2.yaml create mode 100644 cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/weather.yaml create mode 100644 cli/train_two_stage.py create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/base/electricity.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/base/etth1.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/base/etth2.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/base/ettm1.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/base/ettm2.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/base/run_multi.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/base/weather.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/small/electricity.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/small/etth1.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/small/etth2.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/small/run_multi.sh create mode 100644 project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh create mode 100644 src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/electricity.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/electricity.yaml new file mode 100644 index 0000000..73e7350 --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/electricity.yaml @@ -0,0 +1,7 @@ +_target_: uni2ts.data.builder.simple.generate_finetune_builder +dataset: electricity +train_length: 18412 +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/etth1.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/etth1.yaml new file mode 100644 index 0000000..bd54733 --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/etth1.yaml @@ -0,0 +1,7 @@ +_target_: uni2ts.data.builder.simple.generate_finetune_builder +dataset: ETTh1 +train_length: 8640 +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/etth2.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/etth2.yaml new file mode 100644 index 0000000..6e3eede --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/etth2.yaml @@ -0,0 +1,7 @@ +_target_: uni2ts.data.builder.simple.generate_finetune_builder +dataset: ETTh2 +train_length: 8640 +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/ettm1.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/ettm1.yaml new file mode 100644 index 0000000..2f84768 --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/ettm1.yaml @@ -0,0 +1,7 @@ +_target_: uni2ts.data.builder.simple.generate_finetune_builder +dataset: ETTm1 +train_length: 34560 +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/ettm2.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/ettm2.yaml new file mode 100644 index 0000000..dee2561 --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/ettm2.yaml @@ -0,0 +1,7 @@ +_target_: uni2ts.data.builder.simple.generate_finetune_builder +dataset: ETTm2 +train_length: 34560 +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/weather.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/weather.yaml new file mode 100644 index 0000000..86b5bcc --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/data/weather.yaml @@ -0,0 +1,7 @@ +_target_: uni2ts.data.builder.simple.generate_finetune_builder +dataset: weather +train_length: 36887 +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml new file mode 100644 index 0000000..f35ddd8 --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml @@ -0,0 +1,115 @@ +hydra: + run: + dir: outputs/lsf-setup/multi_scale/finetune_two_stage/${hydra:runtime.choices.model}/${exp_name}/${model.finetune_pattern}/${hydra:runtime.choices.data}/${data.mode}/${run_name} +defaults: + - model: ??? + - data: ??? + - val_data: null + - _self_ +exp_name: ??? +run_name: ??? +seed: 0 +tf32: true +compile: false # set to mode: default, reduce-overhead, max-autotune +ckpt_path: null + +trainer_warmup: + _target_: lightning.Trainer + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 32 + logger: + _target_: lightning.pytorch.loggers.TensorBoardLogger + save_dir: ${hydra:runtime.output_dir} + name: logs + callbacks: + - _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: epoch + - _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${hydra:runtime.output_dir}/checkpoints_warmup + monitor: val/PackedNLLLoss + save_weights_only: true + mode: min + save_top_k: 1 + every_n_epochs: 1 + - _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: val/PackedNLLLoss + min_delta: 0.0 + patience: 3 + mode: min + strict: false + verbose: true + max_epochs: 30 + enable_progress_bar: true + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + gradient_clip_algorithm: norm + + +trainer: + _target_: lightning.Trainer + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 32 + logger: + _target_: lightning.pytorch.loggers.TensorBoardLogger + save_dir: ${hydra:runtime.output_dir} + name: logs + callbacks: + - _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: epoch + - _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${hydra:runtime.output_dir}/checkpoints + monitor: val/PackedNLLLoss + save_weights_only: true + mode: min + save_top_k: 1 # Qz: Sometimes the 1st validation gets anomalous results. Discard that ckpt, and use the 2nd one. + every_n_epochs: 1 + - _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${hydra:runtime.output_dir}/checkpoints + save_weights_only: true + - _target_: lightning.pytorch.callbacks.EarlyStopping # uni2ts.callbacks.earlystop.WarmupEarlyStopping + monitor: val/PackedNLLLoss + min_delta: 0.0 + patience: 3 # Set to a small value as now each epoch has many batches. + mode: min + strict: false + verbose: true + # warmup_steps: 1 + max_epochs: 1000 + enable_progress_bar: true + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + gradient_clip_algorithm: norm +train_dataloader: + _target_: uni2ts.data.loader.DataLoader + batch_size: 512 # Can use a large batch size after disabling sequence packing. + batch_size_factor: 2.0 + cycle: false # Set it as false to loop over all batches per epoch + num_batches_per_epoch: null + shuffle: true + num_workers: 11 + pin_memory: true + drop_last: false + fill_last: false + worker_init_fn: null + prefetch_factor: 2 + persistent_workers: true +val_dataloader: + _target_: uni2ts.data.loader.DataLoader + batch_size: 32 + batch_size_factor: 2.0 + cycle: false + num_batches_per_epoch: null + shuffle: false + num_workers: 11 + pin_memory: false + drop_last: false + fill_last: false + worker_init_fn: null + prefetch_factor: 2 + persistent_workers: true \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_base.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_base.yaml new file mode 100644 index 0000000..75cef7c --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_base.yaml @@ -0,0 +1,51 @@ +# load a pretrained checkpoint from huggingface hub +_target_: uni2ts.model.multi_scale_moirai.TwoStageMoiraiFinetune +module: + _target_: uni2ts.model.multi_scale_moirai.MoiraiModule.from_pretrained + pretrained_model_name_or_path: Salesforce/moirai-1.0-R-base +module_kwargs: + _target_: builtins.dict + distr_output: + _target_: uni2ts.distribution.MixtureOutput + components: + - _target_: uni2ts.distribution.StudentTOutput + - _target_: uni2ts.distribution.NormalFixedScaleOutput + - _target_: uni2ts.distribution.NegativeBinomialOutput + - _target_: uni2ts.distribution.LogNormalOutput + d_model: 768 + num_layers: 12 + patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} + max_seq_len: 512 + attn_dropout_p: 0.0 + dropout_p: 0.0 + scaling: true +min_patches: 2 +min_mask_ratio: 0.15 +max_mask_ratio: 0.5 +max_dim: 128 +loss_func: + _target_: uni2ts.loss.packed.PackedNLLLoss +val_metric: + - _target_: uni2ts.loss.packed.PackedMSELoss + - _target_: uni2ts.loss.packed.PackedNRMSELoss + normalize: absolute_target_squared +lr: 5e-7 # On ETT dataset, using 1e-6/5e-7 converge within 1-2 epochs. 1e-7 converge in tens of epochs +weight_decay: 1e-1 +beta1: 0.9 +beta2: 0.98 +num_training_steps: null +num_warmup_steps: 0 +patch_size: null +context_length: null +prediction_length: null +finetune_pattern: full +num_new_scales: 3 +ds_factor: 2 + +use_lora: True +lora_kwargs: + _target_: builtins.dict + r: 16 + target_modules: ["q_proj", "k_proj", "v_proj"] + lora_alpha: 32 + lora_dropout: 0.05 \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_small.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_small.yaml new file mode 100644 index 0000000..39ae9d7 --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_small.yaml @@ -0,0 +1,51 @@ +# load a pretrained checkpoint from huggingface hub +_target_: uni2ts.model.multi_scale_moirai.TwoStageMoiraiFinetune +module: + _target_: uni2ts.model.multi_scale_moirai.MoiraiModule.from_pretrained + pretrained_model_name_or_path: Salesforce/moirai-1.0-R-small +module_kwargs: + _target_: builtins.dict + distr_output: + _target_: uni2ts.distribution.MixtureOutput + components: + - _target_: uni2ts.distribution.StudentTOutput + - _target_: uni2ts.distribution.NormalFixedScaleOutput + - _target_: uni2ts.distribution.NegativeBinomialOutput + - _target_: uni2ts.distribution.LogNormalOutput + d_model: 384 + num_layers: 6 + patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} + max_seq_len: 512 + attn_dropout_p: 0.0 + dropout_p: 0.0 + scaling: true +min_patches: 2 +min_mask_ratio: 0.15 +max_mask_ratio: 0.5 +max_dim: 128 +loss_func: + _target_: uni2ts.loss.packed.PackedNLLLoss +val_metric: + - _target_: uni2ts.loss.packed.PackedMSELoss + - _target_: uni2ts.loss.packed.PackedNRMSELoss + normalize: absolute_target_squared +lr: 5e-7 +weight_decay: 1e-1 +beta1: 0.9 +beta2: 0.98 +num_training_steps: null +num_warmup_steps: 0 +patch_size: null +context_length: null +prediction_length: null +finetune_pattern: full +num_new_scales: 3 +ds_factor: 2 + +use_lora: False +lora_kwargs: + _target_: builtins.dict + r: 16 + target_modules: ["q_proj", "k_proj", "v_proj"] + lora_alpha: 32 + lora_dropout: 0.05 \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.1_R_small.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.1_R_small.yaml new file mode 100644 index 0000000..f4ddcc6 --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.1_R_small.yaml @@ -0,0 +1,51 @@ +# load a pretrained checkpoint from huggingface hub +_target_: uni2ts.model.multi_scale_moirai.MoiraiFinetune +module: + _target_: uni2ts.model.multi_scale_moirai.MoiraiModule.from_pretrained + pretrained_model_name_or_path: Salesforce/moirai-1.1-R-small +module_kwargs: + _target_: builtins.dict + distr_output: + _target_: uni2ts.distribution.MixtureOutput + components: + - _target_: uni2ts.distribution.StudentTOutput + - _target_: uni2ts.distribution.NormalFixedScaleOutput + - _target_: uni2ts.distribution.NegativeBinomialOutput + - _target_: uni2ts.distribution.LogNormalOutput + d_model: 384 + num_layers: 6 + patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} + max_seq_len: 512 + attn_dropout_p: 0.0 + dropout_p: 0.0 + scaling: true +min_patches: 2 +min_mask_ratio: 0.15 +max_mask_ratio: 0.5 +max_dim: 128 +loss_func: + _target_: uni2ts.loss.packed.PackedNLLLoss +val_metric: + - _target_: uni2ts.loss.packed.PackedMSELoss + - _target_: uni2ts.loss.packed.PackedNRMSELoss + normalize: absolute_target_squared +lr: 5e-7 # On ETT dataset, using 1e-6/5e-7 converge within 1-2 epochs. 1e-7 converge in tens of epochs +weight_decay: 1e-1 +beta1: 0.9 +beta2: 0.98 +num_training_steps: null +num_warmup_steps: 0 +patch_size: null +context_length: null +prediction_length: null +finetune_pattern: full +num_new_scales: 3 +ds_factor: 2 + +use_lora: False +lora_kwargs: + _target_: builtins.dict + r: 16 + target_modules: ["q_proj", "k_proj", "v_proj"] + lora_alpha: 32 + lora_dropout: 0.05 \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/electricity.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/electricity.yaml new file mode 100644 index 0000000..a20c574 --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/electricity.yaml @@ -0,0 +1,8 @@ +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: electricity_eval +offset: 18412 # Same as _lsf_dataset.py +eval_length: 2630 # Same as _lsf_dataset.py, test_length=5260 +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/etth1.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/etth1.yaml new file mode 100644 index 0000000..2a379ab --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/etth1.yaml @@ -0,0 +1,8 @@ +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: ETTh1_eval +offset: 8640 +eval_length: 2880 +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/etth2.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/etth2.yaml new file mode 100644 index 0000000..90e8296 --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/etth2.yaml @@ -0,0 +1,8 @@ +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: ETTh2_eval +offset: 8640 # Same as _lsf_dataset.py +eval_length: 2880 # Same as _lsf_dataset.py +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/ettm1.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/ettm1.yaml new file mode 100644 index 0000000..3cdf94b --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/ettm1.yaml @@ -0,0 +1,8 @@ +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: ETTm1_eval +offset: 34560 # Same as _lsf_dataset.py +eval_length: 11520 # Same as _lsf_dataset.py +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/ettm2.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/ettm2.yaml new file mode 100644 index 0000000..74ae64c --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/ettm2.yaml @@ -0,0 +1,8 @@ +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: ETTm2_eval +offset: 34560 # Same as _lsf_dataset.py +eval_length: 11520 # Same as _lsf_dataset.py +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? \ No newline at end of file diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/weather.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/weather.yaml new file mode 100644 index 0000000..1d4e331 --- /dev/null +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/val_data/weather.yaml @@ -0,0 +1,8 @@ +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: weather_eval +offset: 36887 # Same as _lsf_dataset.py +eval_length: 5269 # Same as _lsf_dataset.py; test_length=10539 +prediction_length: ??? +context_length: ??? +patch_size: ??? +mode: ??? \ No newline at end of file diff --git a/cli/train_two_stage.py b/cli/train_two_stage.py new file mode 100644 index 0000000..d0e94e9 --- /dev/null +++ b/cli/train_two_stage.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Callable, Optional + +import hydra +import lightning as L +import torch +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch.utils._pytree import tree_map +from torch.utils.data import Dataset, DistributedSampler + +from uni2ts.common import hydra_util # noqa: hydra resolvers +from uni2ts.data.loader import DataLoader + + +class DataModule(L.LightningDataModule): + def __init__( + self, + cfg: DictConfig, + train_dataset: Dataset, + val_dataset: Optional[Dataset | list[Dataset]], + ): + super().__init__() + self.cfg = cfg + self.train_dataset = train_dataset + + if val_dataset is not None: + self.val_dataset = val_dataset + self.val_dataloader = self._val_dataloader + + @staticmethod + def get_dataloader( + dataset: Dataset, + dataloader_func: Callable[..., DataLoader], + shuffle: bool, + world_size: int, + batch_size: int, + num_batches_per_epoch: Optional[int] = None, + ) -> DataLoader: + sampler = ( + DistributedSampler( + dataset, + num_replicas=None, + rank=None, + shuffle=shuffle, + seed=0, + drop_last=False, + ) + if world_size > 1 + else None + ) + return dataloader_func( + dataset=dataset, + shuffle=shuffle if sampler is None else None, + sampler=sampler, + batch_size=batch_size, + num_batches_per_epoch=num_batches_per_epoch, + ) + + def train_dataloader(self) -> DataLoader: + return self.get_dataloader( + self.train_dataset, + instantiate(self.cfg.train_dataloader, _partial_=True), + self.cfg.train_dataloader.shuffle, + self.trainer.world_size, + self.train_batch_size, + num_batches_per_epoch=self.train_num_batches_per_epoch, + ) + + def _val_dataloader(self) -> DataLoader | list[DataLoader]: + return tree_map( + partial( + self.get_dataloader, + dataloader_func=instantiate(self.cfg.val_dataloader, _partial_=True), + shuffle=self.cfg.val_dataloader.shuffle, + world_size=self.trainer.world_size, + batch_size=self.val_batch_size, + num_batches_per_epoch=None, + ), + self.val_dataset, + ) + + @property + def train_batch_size(self) -> int: + return self.cfg.train_dataloader.batch_size // ( + self.trainer.world_size * self.trainer.accumulate_grad_batches + ) + + @property + def val_batch_size(self) -> int: + return self.cfg.val_dataloader.batch_size // ( + self.trainer.world_size * self.trainer.accumulate_grad_batches + ) + + @property + def train_num_batches_per_epoch(self) -> int: + if self.cfg.train_dataloader.num_batches_per_epoch is not None: + return ( + self.cfg.train_dataloader.num_batches_per_epoch + * self.trainer.accumulate_grad_batches + ) + + else: + return None + + +@hydra.main(version_base="1.3", config_name="default.yaml") +def main(cfg: DictConfig): + if cfg.tf32: + assert cfg.trainer.precision == 32 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + model: L.LightningModule = instantiate(cfg.model, _convert_="all") + + if hasattr(model, 'post_init') and callable(getattr(model, 'post_init')): + model.post_init() + + if "collate_fn" not in cfg.train_dataloader: + model.seq_fields = model.seq_fields + ("sample_id",) + + if cfg.compile: + model.module.compile(mode=cfg.compile) + + # ToDo: 写training_warmup的config + trainer_warmup: L.Trainer = instantiate(cfg.trainer_warmup) + + + trainer: L.Trainer = instantiate(cfg.trainer) + + # '=' in ckpt name make it cannot be directly loaded with hydra. Change it to '_'. + trainer.callbacks[-1].CHECKPOINT_EQUALS_CHAR = "_" + trainer.callbacks[-2].CHECKPOINT_EQUALS_CHAR = "_" + + train_dataset: Dataset = instantiate(cfg.data).load_dataset( + model.train_transform_map + ) + val_dataset: Optional[Dataset | list[Dataset]] = ( + tree_map( + lambda ds: ds.load_dataset(model.val_transform_map), + instantiate(cfg.val_data, _convert_="all"), + ) + if "val_data" in cfg + else None + ) + L.seed_everything(cfg.seed + trainer.logger.version, workers=True) + + print( + "Number of windows in finetune: ", + train_dataset.dataset_weight * train_dataset.num_ts, + ) + print("Batch size for finetune: ", cfg.train_dataloader.batch_size) + print( + "Number of batches in a epoch: ", + train_dataset.dataset_weight + * train_dataset.num_ts + // cfg.train_dataloader.batch_size, + ) + + print("Number of windows in val: ", val_dataset.dataset_weight * val_dataset.num_ts) + print("Batch size for val: ", cfg.val_dataloader.batch_size) + print( + "Number of batches in a epoch: ", + val_dataset.dataset_weight + * val_dataset.num_ts + // cfg.val_dataloader.batch_size, + ) + + # Validate before training, check the performance of original pretrained model. + # trainer.validate(model, datamodule=DataModule(cfg, train_dataset, val_dataset)) + + trainer_warmup.fit( + model, + datamodule=DataModule(cfg, train_dataset, val_dataset), + ckpt_path=cfg.ckpt_path, + ) + + print("Finished warmup stage. Now finetuning the whole model...") + model.current_stage = 2 + + trainer.fit( + model, + datamodule=DataModule(cfg, train_dataset, val_dataset), + ckpt_path=cfg.ckpt_path, + ) + + +if __name__ == "__main__": + main() diff --git a/project/lsf-setup/lsf/finetune/base/weather.sh b/project/lsf-setup/lsf/finetune/base/weather.sh index 39ef864..e3f24c0 100644 --- a/project/lsf-setup/lsf/finetune/base/weather.sh +++ b/project/lsf-setup/lsf/finetune/base/weather.sh @@ -32,5 +32,7 @@ for pl in 96 192 336 720; do val_data.context_length=$cl \ val_data.prediction_length=$pl \ train_dataloader.batch_size=256 \ - val_data.mode=${mode} + val_data.mode=${mode} \ + trainer.callbacks."1".monitor=val/PackedMSELoss \ + trainer.callbacks."3".monitor=val/PackedMSELoss done \ No newline at end of file diff --git a/project/lsf-setup/lsf/finetune/small/weather.sh b/project/lsf-setup/lsf/finetune/small/weather.sh index 6d681d9..2468ddb 100644 --- a/project/lsf-setup/lsf/finetune/small/weather.sh +++ b/project/lsf-setup/lsf/finetune/small/weather.sh @@ -31,5 +31,7 @@ for pl in 720; do # 96 192 336 val_data.patch_size=${ps} \ val_data.context_length=$cl \ val_data.prediction_length=$pl \ - val_data.mode=${mode} + val_data.mode=${mode} \ + trainer.callbacks."1".monitor=val/PackedMSELoss \ + trainer.callbacks."3".monitor=val/PackedMSELoss done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/eval/small/ettm1.sh b/project/lsf-setup/multi_scale/eval/small/ettm1.sh index 69a7bcd..4f25b21 100644 --- a/project/lsf-setup/multi_scale/eval/small/ettm1.sh +++ b/project/lsf-setup/multi_scale/eval/small/ettm1.sh @@ -9,10 +9,10 @@ exp_name=lsf cl=4000 model=moirai_lightning_ckpt -cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm1/S/cl4000_pl96/checkpoints/epoch_3-step_1668.ckpt' -cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm1/S/cl4000_pl192/checkpoints/epoch_1-step_832.ckpt' -cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt' -cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt' +cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl96/checkpoints/epoch_4-step_2085.ckpt' +cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl192/checkpoints/epoch_1-step_832.ckpt' +cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt' +cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt' index=1 for pl in 96 192 336 720; do diff --git a/project/lsf-setup/multi_scale/eval/small/ettm2.sh b/project/lsf-setup/multi_scale/eval/small/ettm2.sh index f41b129..4dbbc44 100644 --- a/project/lsf-setup/multi_scale/eval/small/ettm2.sh +++ b/project/lsf-setup/multi_scale/eval/small/ettm2.sh @@ -1,7 +1,7 @@ #!/bin/bash export HYDRA_FULL_ERROR=1 -export CUDA_VISIBLE_DEVICES=1 +export CUDA_VISIBLE_DEVICES=3 mode=S cp=conf/lsf-setup/multi_scale/eval @@ -9,10 +9,10 @@ exp_name=lsf cl=3000 model=moirai_lightning_ckpt -cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm2/S/cl3000_pl96/checkpoints/epoch_6-step_3017.ckpt' -cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm2/S/cl3000_pl192/checkpoints/epoch_1-step_858.ckpt' -cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm2/S/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt' -cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt' +cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl96/checkpoints/epoch_13-step_6034.ckpt' +cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl192/checkpoints/epoch_4-step_2145.ckpt' +cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl336/checkpoints/epoch_1-step_854.ckpt' +cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt' index=1 for pl in 96 192 336 720; do diff --git a/project/lsf-setup/multi_scale/eval/small/weather.sh b/project/lsf-setup/multi_scale/eval/small/weather.sh index a0895af..a192c5d 100644 --- a/project/lsf-setup/multi_scale/eval/small/weather.sh +++ b/project/lsf-setup/multi_scale/eval/small/weather.sh @@ -1,7 +1,7 @@ #!/bin/bash export HYDRA_FULL_ERROR=1 -export CUDA_VISIBLE_DEVICES=3 +export CUDA_VISIBLE_DEVICES=1 mode=S cp=conf/lsf-setup/multi_scale/eval @@ -9,16 +9,17 @@ exp_name=lsf cl=2000 model=moirai_lightning_ckpt -#cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/weather/S/cl2000_pl96/checkpoints/epoch_8-step_12852.ckpt' -#cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/weather/S/cl2000_pl192/checkpoints/epoch_6-step_9968.ckpt' -cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/weather/S/cl2000_pl336/checkpoints/epoch_5-step_8508.ckpt' -cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/weather/S/cl2000_pl720/checkpoints/epoch_3-step_5608.ckpt' +cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl96/checkpoints/epoch_7-step_11424.ckpt' +cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl192/checkpoints/epoch_6-step_9968.ckpt' +cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl336/checkpoints/epoch_2-step_4254.ckpt' +cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl720/checkpoints/epoch_1-step_2804.ckpt' index=1 -for pl in 336 720; do # 96 192 +for pl in 192 336 720; do # 96 case $index in - 1) cpp=$cpp3 ;; - 2) cpp=$cpp4 ;; + 1) cpp=$cpp2 ;; + 2) cpp=$cpp3 ;; + 3) cpp=$cpp4 ;; # 1) cpp=$cpp1 ;; # 2) cpp=$cpp2 ;; diff --git a/project/lsf-setup/multi_scale/finetune/base/weather.sh b/project/lsf-setup/multi_scale/finetune/base/weather.sh index 7694b2d..62ec562 100644 --- a/project/lsf-setup/multi_scale/finetune/base/weather.sh +++ b/project/lsf-setup/multi_scale/finetune/base/weather.sh @@ -31,5 +31,7 @@ for pl in 96 192 336 720; do val_data.patch_size=${ps} \ val_data.context_length=$cl \ val_data.prediction_length=$pl \ - val_data.mode=${mode} + val_data.mode=${mode} \ + trainer.callbacks."1".monitor=val/PackedMSELoss \ + trainer.callbacks."3".monitor=val/PackedMSELoss done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune/small/ettm2.sh b/project/lsf-setup/multi_scale/finetune/small/ettm2.sh index c21fdaf..97f05d6 100644 --- a/project/lsf-setup/multi_scale/finetune/small/ettm2.sh +++ b/project/lsf-setup/multi_scale/finetune/small/ettm2.sh @@ -1,6 +1,6 @@ #!/bin/bash -export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3; +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; model=moirai_1.0_R_small cp=conf/lsf-setup/multi_scale/finetune diff --git a/project/lsf-setup/multi_scale/finetune/small/weather.sh b/project/lsf-setup/multi_scale/finetune/small/weather.sh index 102064c..ce25056 100644 --- a/project/lsf-setup/multi_scale/finetune/small/weather.sh +++ b/project/lsf-setup/multi_scale/finetune/small/weather.sh @@ -4,7 +4,7 @@ export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1; model=moirai_1.0_R_small cp=conf/lsf-setup/multi_scale/finetune -exp_name=learned_time_id +exp_name=learned_time_id_valMSE data=weather cl=2000 ps=128 @@ -31,5 +31,7 @@ for pl in 96 192 336 720; do val_data.patch_size=${ps} \ val_data.context_length=$cl \ val_data.prediction_length=$pl \ - val_data.mode=${mode} + val_data.mode=${mode} \ + trainer.callbacks."1".monitor=val/PackedMSELoss \ + trainer.callbacks."3".monitor=val/PackedMSELoss done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/base/electricity.sh b/project/lsf-setup/multi_scale/finetune_two_stage/base/electricity.sh new file mode 100644 index 0000000..6fd9baa --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/base/electricity.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=lsf +data=electricity +cl=5000 +ps=32 +mode=S +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} \ + model.lr=5e-6 +done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/base/etth1.sh b/project/lsf-setup/multi_scale/finetune_two_stage/base/etth1.sh new file mode 100644 index 0000000..0295966 --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/base/etth1.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=lsf +data=etth1 +cl=5000 +ps=64 +mode=S +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} +done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/base/etth2.sh b/project/lsf-setup/multi_scale/finetune_two_stage/base/etth2.sh new file mode 100644 index 0000000..8600a3e --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/base/etth2.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=lsf +data=etth2 +cl=5000 +ps=64 +mode=S +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} +done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/base/ettm1.sh b/project/lsf-setup/multi_scale/finetune_two_stage/base/ettm1.sh new file mode 100644 index 0000000..4e5ed21 --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/base/ettm1.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=ms_qkv_1.0 +data=ettm1 +cl=5000 +ps=64 +mode=S +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} +done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/base/ettm2.sh b/project/lsf-setup/multi_scale/finetune_two_stage/base/ettm2.sh new file mode 100644 index 0000000..bc54333 --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/base/ettm2.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=ms_qkv_1.0 +data=ettm2 +cl=5000 +ps=128 +mode=S +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} +done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/base/run_multi.sh b/project/lsf-setup/multi_scale/finetune_two_stage/base/run_multi.sh new file mode 100644 index 0000000..bc8e2be --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/base/run_multi.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +bash project/multi_scale/finetune/small/etth1.sh +bash project/multi_scale/finetune/small/etth2.sh +#bash project/multi_scale/finetune/small/ettm1.sh +#bash project/multi_scale/finetune/small/ettm2.sh +#bash project/multi_scale/finetune/small/weather.sh \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/base/weather.sh b/project/lsf-setup/multi_scale/finetune_two_stage/base/weather.sh new file mode 100644 index 0000000..7813add --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/base/weather.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=ms_qkv_1.0 +data=weather +cl=5000 +ps=128 +mode=S # M +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} +done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/electricity.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/electricity.sh new file mode 100644 index 0000000..6c8f67b --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/electricity.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=lsf +data=electricity +cl=5000 +ps=64 +mode=S +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} \ + model.lr=5e-6 +done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/etth1.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/etth1.sh new file mode 100644 index 0000000..0295966 --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/etth1.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=lsf +data=etth1 +cl=5000 +ps=64 +mode=S +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} +done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/etth2.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/etth2.sh new file mode 100644 index 0000000..7b03769 --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/etth2.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=lsf +data=etth2 +cl=3000 +ps=64 +mode=S +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} +done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh new file mode 100644 index 0000000..f2bd104 --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=learned_time_id_2stage +data=ettm1 +cl=4000 +ps=128 +mode=S +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} +done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh new file mode 100644 index 0000000..f9282ef --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=2; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=learned_time_id_2stage +data=ettm2 +cl=3000 +ps=64 +mode=S +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} +done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/run_multi.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/run_multi.sh new file mode 100644 index 0000000..bc8e2be --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/run_multi.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +bash project/multi_scale/finetune/small/etth1.sh +bash project/multi_scale/finetune/small/etth2.sh +#bash project/multi_scale/finetune/small/ettm1.sh +#bash project/multi_scale/finetune/small/ettm2.sh +#bash project/multi_scale/finetune/small/weather.sh \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh new file mode 100644 index 0000000..4c10e1e --- /dev/null +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.0_R_small +cp=conf/lsf-setup/multi_scale/finetune_two_stage +exp_name=learned_time_id_2stage_valMSE +data=weather +cl=2000 +ps=128 +mode=S # M +ft_pattern=full + + +for pl in 96 192 336 720; do + python -m cli.train_two_stage \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + data.mode=${mode} \ + val_data=${data} \ + val_data.patch_size=${ps} \ + val_data.context_length=$cl \ + val_data.prediction_length=$pl \ + val_data.mode=${mode} \ + trainer.callbacks."1".monitor=val/PackedMSELoss \ + trainer.callbacks."3".monitor=val/PackedMSELoss +done \ No newline at end of file diff --git a/src/uni2ts/model/multi_scale_moirai/__init__.py b/src/uni2ts/model/multi_scale_moirai/__init__.py index 204b6ce..6af10fd 100644 --- a/src/uni2ts/model/multi_scale_moirai/__init__.py +++ b/src/uni2ts/model/multi_scale_moirai/__init__.py @@ -16,9 +16,11 @@ from .finetune import MoiraiFinetune from .forecast import MoiraiForecast from .module import MoiraiModule +from .finetune_two_stage import TwoStageMoiraiFinetune __all__ = [ "MoiraiFinetune", "MoiraiForecast", "MoiraiModule", + "TwoStageMoiraiFinetune" ] diff --git a/src/uni2ts/model/multi_scale_moirai/finetune.py b/src/uni2ts/model/multi_scale_moirai/finetune.py index 9ac1fd8..19c151a 100644 --- a/src/uni2ts/model/multi_scale_moirai/finetune.py +++ b/src/uni2ts/model/multi_scale_moirai/finetune.py @@ -141,6 +141,8 @@ def post_init(self): Initialize the new params added for Multi Scale. """ + self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.patch_size) + # for layer in self.module.encoder.layers: # # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention # if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): @@ -356,7 +358,7 @@ def configure_optimizers(self) -> dict: if "pe_weights" in pn: # Learnable RoPE for time id proj p.requires_grad = True - if "seq_id_q_proj" in pn or "seq_id_k_proj" in pn: + if "time_id_q_proj" in pn or "time_id_k_proj" in pn: p.requires_grad = True # Unfreeze the corresponding params @@ -443,9 +445,9 @@ def configure_optimizers(self) -> dict: continue fpn = f"{mn}.{pn}" if mn else pn - if pn.endswith("bias") and 'time_qk_proj' not in pn: + if pn.endswith("bias"): no_decay.add(fpn) - elif pn.endswith("weight") and isinstance(m, whitelist_params) and 'time_qk_proj' not in pn: + elif pn.endswith("weight") and isinstance(m, whitelist_params): decay.add(fpn) elif pn.endswith("weight") and isinstance(m, blacklist_params): no_decay.add(fpn) @@ -454,16 +456,6 @@ def configure_optimizers(self) -> dict: elif 'pe_weights' in pn: decay.add(fpn) - elif 'layers.0.self_attn.time_qk_proj.seq_id_q_proj' in pn and isinstance(m, whitelist_params): - decay.add(fpn) - elif 'layers.0.self_attn.time_qk_proj.seq_id_k_proj' in pn and isinstance(m, whitelist_params): - decay.add(fpn) - - elif 'layers.0.self_attn.time_qk_proj.seq_id_q_proj' in pn and pn.endswith("bias"): - no_decay.add(fpn) - elif 'layers.0.self_attn.time_qk_proj.seq_id_k_proj' in pn and pn.endswith("bias"): - no_decay.add(fpn) - # elif 'layers.0.self_attn.time_qk_proj.query_proj.pe_weights' in pn: # Shared time_qk_proj # decay.add(fpn) diff --git a/src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py b/src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py new file mode 100644 index 0000000..1a1186e --- /dev/null +++ b/src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py @@ -0,0 +1,823 @@ +# Copyright (c) 2024, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import defaultdict +from collections.abc import Callable, Sequence +from typing import Any, Optional + +import lightning as L +import numpy as np +import torch +from jaxtyping import Bool, Float, Int +from torch import nn +from torch.distributions import Distribution + +from uni2ts.loss.packed import ( + PackedDistributionLoss, + PackedLoss, + PackedNLLLoss, + PackedPointLoss, +) +from uni2ts.module.norm import RMSNorm +from uni2ts.module.position import ( + BinaryAttentionBias, + LearnedEmbedding, + LearnedProjection, + MultiScaleRotaryProjection +) +from uni2ts.module.ts_embed import MultiInSizeLinear, MultiOutSizeLinear +from uni2ts.optim import SchedulerType, get_scheduler +from uni2ts.transform import ( + AddNewScaleContextSeries, + AddObservedMask, + AddSampleIndex, + AddTimeIndex, + AddVariateIndex, + DefaultPatchSizeConstraints, + DummyValueImputation, + EvalPad, + ExtendMask, + FinetunePatchCrop, + FixedPatchSizeConstraints, + FlatPackCollection, + FlatPackFields, + GetPatchSize, + Identity, + ImputeTimeSeries, + MaskOutRangePaddedTokens, + MultiScaleEvalCrop, + MultiScaleMaskedPredictionGivenFixedConfig, + PackFields, + PadNewScaleSeries, + PatchCrop, + PatchCropGivenFixedConfig, + Patchify, + SelectFields, + SequencifyField, + Transformation, +) + +from .module import MoiraiModule +from uni2ts.module.multi_scale.attention import GroupedQueryAttention +from peft import LoraConfig, LoraModel + + +class TwoStageMoiraiFinetune(L.LightningModule): + seq_fields: tuple[str, ...] = ( + "target", + "observed_mask", + "time_id", + "variate_id", + "prediction_mask", + "patch_size", + ) + pad_func_map: dict[str, Callable[[Sequence[int], np.dtype], np.ndarray]] = { + "target": np.zeros, + "observed_mask": np.zeros, + "time_id": np.zeros, + "variate_id": np.zeros, + "prediction_mask": np.zeros, + "patch_size": np.zeros, + } + + def __init__( + self, + min_patches: int, + min_mask_ratio: float, + max_mask_ratio: float, + max_dim: int, + num_training_steps: int, + num_warmup_steps: int, + module_kwargs: Optional[dict[str, Any]] = None, + module: Optional[MoiraiModule] = None, + num_samples: int = 100, + beta1: float = 0.9, + beta2: float = 0.98, + loss_func: PackedDistributionLoss = PackedNLLLoss(), + val_metric: Optional[PackedLoss | list[PackedLoss]] = None, + lr: float = 1e-3, + weight_decay: float = 1e-2, + log_on_step: bool = False, + context_length: Optional[int | list[int]] = None, + prediction_length: Optional[int | list[int]] = None, + patch_size: Optional[int] = None, + finetune_pattern: str | list[str] = "full", + num_new_scales: Optional[int] = None, + ds_factor: int = 2, + use_lora: bool = False, + lora_kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__() + self.save_hyperparameters(ignore=["module"]) + self.module = MoiraiModule(**module_kwargs) if module is None else module + + self.context_length = context_length + self.prediction_length = prediction_length + self.patch_size = patch_size + self.finetune_pattern = finetune_pattern + self.num_new_scales = num_new_scales + self.ds_factor = ds_factor + + self.token_idx_per_scale, self.base_ctx_token_idx = self._get_token_idx_per_scale() + + # Lora config + self.lora_config = LoraConfig(**lora_kwargs) if use_lora else None + + self.current_stage = 1 # 用于切换阶段 + + def post_init(self): + """ + Initialize the new params added for Multi Scale. + """ + + self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.patch_size) + + # for layer in self.module.encoder.layers: + # # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention + # if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): + # # Call post_init() method of the GroupedQueryAttention object + # layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor) + + for module in self.module.encoder.modules(): + if isinstance(module, MultiScaleRotaryProjection): + module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx) + + if self.lora_config is not None: + self.module = LoraModel(self.module, self.lora_config, "default") + # Params not used in Lora are set as requires_grad=False automatically. + # Activate some of those params manually. FFN and out_proj are kept as frozen. + for pn, p in self.named_parameters(): + if "param_proj" in pn or "in_proj" in pn: + p.requires_grad = True + if "norm" in pn: + p.requires_grad = True + if "mask_encoding" in pn or "var_attn_bias" in pn: + p.requires_grad = True + # ToDo: Note to include new learnable params introduced in MS + + def _get_token_idx_per_scale(self): + base_token_len = math.ceil(self.context_length / self.patch_size) + math.ceil(self.prediction_length / self.patch_size) + ctx_len = self.context_length + new_scale_token_len = [] + + # New scales only include context part. + for i in range(self.num_new_scales): + ctx_len = math.ceil(ctx_len / self.ds_factor) + ctx_token_len = math.ceil(ctx_len / self.patch_size) + + new_scale_token_len.append(ctx_token_len) + + token_idx_per_scale = [list(range(base_token_len))] + + for i in range(self.num_new_scales): + start = base_token_len if i == 0 else end + end = start + new_scale_token_len[i] + + index = list(range(start, end)) + token_idx_per_scale.append(index) + + base_ctx_token_len = math.ceil(self.context_length / self.patch_size) + base_ctx_token_idx = list(range(base_ctx_token_len)) + + return token_idx_per_scale, base_ctx_token_idx + + + + def forward( + self, + target: Float[torch.Tensor, "*batch seq_len max_patch"], + observed_mask: Bool[torch.Tensor, "*batch seq_len max_patch"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + time_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + patch_size: Int[torch.Tensor, "*batch seq_len"], + ) -> Distribution: + distr = self.module( + target=target, + observed_mask=observed_mask, + sample_id=sample_id, + time_id=time_id, + variate_id=variate_id, + prediction_mask=prediction_mask, + patch_size=patch_size, + ) + return distr + + def training_step( + self, batch: dict[str, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + distr = self( + **{field: batch[field] for field in list(self.seq_fields) + ["sample_id"]} + ) + loss = self.hparams.loss_func( + pred=distr, + **{ + field: batch[field] + for field in [ + "target", + "prediction_mask", + "observed_mask", + "sample_id", + "variate_id", + ] + }, + ) + batch_size = ( + batch["sample_id"].max(dim=1).values.sum() if "sample_id" in batch else None + ) + self.log( + f"train/{self.hparams.loss_func.__class__.__name__}", + loss, + on_step=self.hparams.log_on_step, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + batch_size=batch_size, + rank_zero_only=True, + ) + return loss + + def validation_step( + self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int = 0 + ) -> torch.Tensor: + distr = self( + **{field: batch[field] for field in list(self.seq_fields) + ["sample_id"]} + ) + val_loss = self.hparams.loss_func( + pred=distr, + **{ + field: batch[field] + for field in [ + "target", + "prediction_mask", + "observed_mask", + "sample_id", + "variate_id", + ] + }, + ) + batch_size = ( + batch["sample_id"].max(dim=1).values.sum() if "sample_id" in batch else None + ) + self.log( + f"val/{self.hparams.loss_func.__class__.__name__}", + val_loss, + on_step=self.hparams.log_on_step, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + batch_size=batch_size, + rank_zero_only=True, + ) + + if self.hparams.val_metric is not None: + val_metrics = ( + self.hparams.val_metric + if isinstance(self.hparams.val_metric, list) + else [self.hparams.val_metric] + ) + for metric_func in val_metrics: + if isinstance(metric_func, PackedPointLoss): + pred = distr.sample(torch.Size((self.hparams.num_samples,))) + pred = torch.median(pred, dim=0).values + elif isinstance(metric_func, PackedDistributionLoss): + pred = distr + else: + raise ValueError(f"Unsupported loss function: {metric_func}") + + metric = metric_func( + pred=pred, + **{ + field: batch[field] + for field in [ + "target", + "prediction_mask", + "observed_mask", + "sample_id", + "variate_id", + ] + }, + ) + + self.log( + f"val/{metric_func.__class__.__name__}", + metric, + on_step=self.hparams.log_on_step, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + batch_size=batch_size, + rank_zero_only=True, + ) + + return val_loss + + def configure_optimizers(self) -> dict: + + if self.current_stage == 1: + warmup_pn = ['param_proj', 'time_id_q_proj', 'time_id_k_proj'] + warmup_params = { + pn: p for pn, p in self.named_parameters() + if any(keyword in pn for keyword in warmup_pn) # 检查pn是否包含warmup_pn中的任意字段 + } + self.trainable_params = warmup_params + optimizer = torch.optim.AdamW( + warmup_params.values(), + lr=1e-4, + betas=(self.hparams.beta1, self.hparams.beta2), + eps=1e-6, + ) + scheduler = get_scheduler( + SchedulerType.CONSTANT, # Use constant lr scheduler + optimizer, + num_warmup_steps=self.hparams.num_warmup_steps, + num_training_steps=self.hparams.num_training_steps, + ) + + if self.current_stage == 2: + decay = set() + no_decay = set() + + if "full" in self.finetune_pattern: + pass + else: + for param in self.parameters(): + param.requires_grad = False + + for pn, p in self.named_parameters(): + if "film" in pn: + p.requires_grad = True + + if "adapt_weight" in pn: + p.requires_grad = True + + if "adapt_bias" in pn: + p.requires_grad = True + + if "var_attn_bias.emb" in pn: + p.requires_grad = True + + if "pe_weights" in pn: # Learnable RoPE for time id proj + p.requires_grad = True + + if "time_id_q_proj" in pn or "time_id_k_proj" in pn: + p.requires_grad = True + + # Unfreeze the corresponding params + if "param_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "param_proj" in pn: + p.requires_grad = True + + if "in_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "in_proj" in pn: + p.requires_grad = True + + if "norm" in self.finetune_pattern: # + for pn, p in self.named_parameters(): + if "norm1" in pn or "norm2" in pn: + p.requires_grad = True + + if "mask" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "mask_encoding" in pn: + p.requires_grad = True + + if "ffn" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "ffn" in pn: + p.requires_grad = True + + if "q_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "q_proj" in pn: + p.requires_grad = True + + if "k_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "k_proj" in pn: + p.requires_grad = True + + if "v_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "v_proj" in pn: + p.requires_grad = True + + if "attn_norm" in self.finetune_pattern: # + for pn, p in self.named_parameters(): + if "self_attn.q_norm" in pn or "self_attn.k_norm" in pn: + p.requires_grad = True + + if "var_attn_bias" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "var_attn_bias" in pn: + p.requires_grad = True + + if "out_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "out_proj" in pn: + p.requires_grad = True + + if "studentT" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if ( + "param_proj.proj.components.0" in pn + or "param_proj.proj.weights_logits" in pn + ): + p.requires_grad = True + + whitelist_params = ( + LearnedProjection, + MultiInSizeLinear, + MultiOutSizeLinear, + nn.Linear, + ) + blacklist_params = ( + BinaryAttentionBias, + LearnedEmbedding, + RMSNorm, + nn.Embedding, + nn.LayerNorm, + ) + + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + if not p.requires_grad: + continue + + fpn = f"{mn}.{pn}" if mn else pn + if pn.endswith("bias"): + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_params): + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_params): + no_decay.add(fpn) + elif "adapt_weight" in pn or "adapt_bias" in pn: + decay.add(fpn) + elif 'pe_weights' in pn: + decay.add(fpn) + + # elif 'layers.0.self_attn.time_qk_proj.query_proj.pe_weights' in pn: # Shared time_qk_proj + # decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} + self.trainable_params = param_dict + + inter_params = decay & no_decay + union_params = decay | no_decay + assert ( + len(inter_params) == 0 + ), f"parameters {str(inter_params)} made it into both decay/no_decay sets!" + assert ( + len(param_dict.keys() - union_params) == 0 + ), f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" + assert ( + len(union_params - param_dict.keys()) == 0 + ), f"parameters {str(union_params - param_dict.keys())} were not included in param_dict!" + + + optim_groups = [ + { + "params": filter( + lambda p: p.requires_grad, + [param_dict[pn] for pn in sorted(list(decay))], + ), + "weight_decay": self.hparams.weight_decay, + }, + { + "params": filter( + lambda p: p.requires_grad, + [param_dict[pn] for pn in sorted(list(no_decay))], + ), + "weight_decay": 0.0, + }, + ] + + optimizer = torch.optim.AdamW( + optim_groups, + lr=self.hparams.lr, + betas=(self.hparams.beta1, self.hparams.beta2), + eps=1e-6, + ) + scheduler = get_scheduler( + SchedulerType.CONSTANT, # Use constant lr scheduler + optimizer, + num_warmup_steps=self.hparams.num_warmup_steps, + num_training_steps=self.hparams.num_training_steps, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "train_loss", + "interval": "step", + }, + } + + @property + def train_transform_map( + self, + ) -> dict[str | type, Callable[..., Transformation]]: + def default_train_transform( + distance: int, + prediction_length: int, + context_length: int, + patch_size: int, + ): + return ( + GetPatchSize( + min_time_patches=self.hparams.min_patches, + target_field="target", + patch_sizes=self.module.patch_sizes, + patch_size_constraints=FixedPatchSizeConstraints(patch_size), + offset=True, + ) + + FinetunePatchCrop( + distance, + prediction_length, + context_length, + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + ) + + PackFields( + output_field="target", + fields=("target",), + ) + + PackFields( + output_field="past_feat_dynamic_real", + fields=tuple(), + optional_fields=("past_feat_dynamic_real",), + ) + + EvalPad( + prediction_pad=-prediction_length % patch_size, + context_pad=-context_length % patch_size, + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + ) + # QZ: Apply downsample to target. Create a new field 'target{i}' for each scale. + + AddNewScaleContextSeries( + target_field="target", + ds_factor=self.ds_factor, + num_new_scales_fields=self.new_scales_target_fields, + expected_ndim=2, + ) + # Pad down-sampled scales. Make sure their context and prediction are dividable by patch_size + + PadNewScaleSeries( + fields=self.new_scales_target_fields, + ) + + AddObservedMask( + fields=("target",) + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + observed_mask_field="observed_mask", + collection_type=dict, + ) + + ImputeTimeSeries( + fields=("target",) + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + imputation_method=DummyValueImputation(value=0.0), + ) + + Patchify( + max_patch_size=max(self.module.patch_sizes), + fields=( + "target", + "observed_mask", + ) + + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + ) + + AddVariateIndex( + fields=("target",) + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + variate_id_field="variate_id", + expected_ndim=3, + max_dim=self.hparams.max_dim, + randomize=False, + collection_type=dict, + ) + + AddTimeIndex( + fields=("target",) + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + time_id_field="time_id", + expected_ndim=3, + collection_type=dict, + ) + + AddSampleIndex( + fields=("target",), + optional_fields=("past_feat_dynamic_real",) + + self.new_scales_target_fields, + sample_id_field="sample_id", + expected_ndim=3, + collection_type=dict, + ) + + MultiScaleMaskedPredictionGivenFixedConfig( + target_fields=("target",) + self.new_scales_target_fields, + prediction_mask_field="prediction_mask", + expected_ndim=3, + ) + # + ExtendMask( + # fields=tuple(), + # optional_fields=("past_feat_dynamic_real",), + # mask_field="prediction_mask", + # expected_ndim=3, + # ) + + FlatPackCollection( + field="variate_id", + feat=False, + ) + + FlatPackCollection( + field="time_id", + feat=False, + ) + + FlatPackCollection( + field="sample_id", + feat=False, + ) + + FlatPackCollection( + field="prediction_mask", + feat=False, + ) + + FlatPackCollection( + field="observed_mask", + feat=True, + ) + + FlatPackFields( + output_field="target", + fields=("target",) + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + feat=True, + ) + + SequencifyField(field="patch_size", target_field="target") + + SelectFields(fields=list(self.seq_fields)) + ) + + return defaultdict(lambda: default_train_transform) + + @property + def val_transform_map( + self, + ) -> dict[str | type, Callable[..., Transformation]]: + def default_val_transform( + offset: int, + distance: int, + prediction_length: int, + context_length: int, + patch_size: int, + ): + return ( + GetPatchSize( + min_time_patches=2, + target_field="target", + patch_sizes=self.module.patch_sizes, + patch_size_constraints=FixedPatchSizeConstraints(patch_size), + offset=True, + ) + + MultiScaleEvalCrop( + offset, + distance, + prediction_length, + context_length, + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + ) + + PackFields( + output_field="target", + fields=("target",), + ) + + PackFields( + output_field="past_feat_dynamic_real", + fields=tuple(), + optional_fields=("past_feat_dynamic_real",), + ) + + EvalPad( + prediction_pad=-prediction_length % patch_size, + context_pad=-context_length % patch_size, + fields=("target",), + optional_fields=("past_feat_dynamic_real",), + ) + + AddNewScaleContextSeries( + target_field="target", + ds_factor=self.ds_factor, + num_new_scales_fields=self.new_scales_target_fields, + expected_ndim=2, + ) + + PadNewScaleSeries( + fields=self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + ) + + AddObservedMask( + fields=("target",) + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + observed_mask_field="observed_mask", + collection_type=dict, + ) + + ImputeTimeSeries( + fields=("target",) + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + imputation_method=DummyValueImputation(value=0.0), + ) + + Patchify( + max_patch_size=max(self.module.patch_sizes), + fields=( + "target", + "observed_mask", + ) + + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + ) + + AddVariateIndex( + fields=("target",) + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + variate_id_field="variate_id", + expected_ndim=3, + max_dim=self.hparams.max_dim, + randomize=False, + collection_type=dict, + ) + + AddTimeIndex( + fields=("target",) + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + time_id_field="time_id", + expected_ndim=3, + collection_type=dict, + ) + + AddSampleIndex( + fields=("target",) + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + sample_id_field="sample_id", + expected_ndim=3, + collection_type=dict, + ) + + MultiScaleMaskedPredictionGivenFixedConfig( + target_fields=("target",) + self.new_scales_target_fields, + prediction_mask_field="prediction_mask", + expected_ndim=3, + ) + + FlatPackCollection( + field="variate_id", + feat=False, + ) + + FlatPackCollection( + field="time_id", + feat=False, + ) + + FlatPackCollection( + field="sample_id", + feat=False, + ) + + FlatPackCollection( + field="prediction_mask", + feat=False, + ) + + FlatPackCollection( + field="observed_mask", + feat=True, + ) + + FlatPackFields( + output_field="target", + fields=("target",) + self.new_scales_target_fields, + optional_fields=("past_feat_dynamic_real",), + feat=True, + ) + + SequencifyField(field="patch_size", target_field="target") + + SelectFields(fields=list(self.seq_fields)) + ) + + return defaultdict(lambda: default_val_transform) + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + """ + Modify state_dict to only save trainable params. + Note the default state_dict saved by PL converts all params to require_grads=False + """ + state = super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + filtered_state = { + name: tensor + for name, tensor in state.items() + if name in self.trainable_params + } + return filtered_state + + @property + def new_scales_target_fields(self): + return tuple(f"target{i}" for i in range(self.num_new_scales)) diff --git a/src/uni2ts/model/multi_scale_moirai/forecast.py b/src/uni2ts/model/multi_scale_moirai/forecast.py index ba4cd63..f5cba9e 100644 --- a/src/uni2ts/model/multi_scale_moirai/forecast.py +++ b/src/uni2ts/model/multi_scale_moirai/forecast.py @@ -141,6 +141,8 @@ def post_init(self): Initialize the new params added for Multi Scale. """ + self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.hparams.patch_size) + # for layer in self.module.encoder.layers: # # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention # if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): diff --git a/src/uni2ts/model/multi_scale_moirai/module.py b/src/uni2ts/model/multi_scale_moirai/module.py index 14af523..f364188 100644 --- a/src/uni2ts/model/multi_scale_moirai/module.py +++ b/src/uni2ts/model/multi_scale_moirai/module.py @@ -140,7 +140,8 @@ def __init__( self.distr_output = distr_output self.param_proj = self.distr_output.get_param_proj(d_model, patch_sizes) - # self.num_new_scales = num_new_scales + self.time_id_q_proj = nn.ParameterList() + self.time_id_k_proj = nn.ParameterList() def forward( self, @@ -172,13 +173,42 @@ def forward( :param patch_size: patch size for each token :return: predictive distribution """ + + # Map time id for new scales. + # Key: base scale context tokens; Value: base scale time id + time_id = time_id.to(torch.float) + idx_kv = self.base_ctx_token_idx + key = target[..., idx_kv, :self.ps].clone() # (bs, len0, dim) + value = time_id[..., idx_kv].clone().unsqueeze(-1).to(dtype=torch.float) # (bs, len0, 1) + + for i in range(1, self.num_scales): + idx_scale_i = self.token_idx_per_scale[i] + + query = target[..., idx_scale_i, :self.ps].clone() # (bs, leni, dim) + query = self.time_id_q_proj[i - 1](query) + key = self.time_id_k_proj[i - 1](key) # (bs, len0, dim) + + # Generate attn_mask. Make sure each query only attend to the keys in its down-sampling range. + attn_mask = self.generate_segmented_attn_mask(query, key, 2**i) + + # mapped_time_id is float time id on the original scale. (bs, len_i, 1) + mapped_time_id = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + ) + + time_id[..., idx_scale_i] = mapped_time_id.squeeze() + loc, scale = self.scaler( target, observed_mask * ~prediction_mask.unsqueeze(-1), sample_id, variate_id, ) - scaled_target = (target - loc) / scale + scaled_target = (target - loc) / scale # ToDo: If use conv for DS, consider to modify here? + reprs = self.in_proj(scaled_target, patch_size) masked_reprs = mask_fill(reprs, prediction_mask, self.mask_encoding.weight) @@ -191,3 +221,44 @@ def forward( distr_param = self.param_proj(reprs, patch_size) distr = self.distr_output.distribution(distr_param, loc=loc, scale=scale) return distr + + def post_init(self, token_idx_per_scale, base_ctx_token_idx, patch_size): + self.token_idx_per_scale = token_idx_per_scale + self.base_ctx_token_idx = base_ctx_token_idx + + self.num_scales = len(token_idx_per_scale) + + self.ps = patch_size + + # Assign Q and K for each new scale + for scale in range(1, self.num_scales): + self.time_id_q_proj.append(nn.Linear(self.ps, self.ps)) + self.time_id_k_proj.append(nn.Linear(self.ps, self.ps)) + + def generate_segmented_attn_mask(self, query, key, k): + """ + 生成一个 attention mask,使得 query 的位置 i 只能注意到 key 的范围 [k*i, k*(i+1)-1]。 + + 参数: + bs: batch size + len_q: query 的序列长度 + len_k: key 的序列长度 + k: 每个 query 索引范围内 key 的跨度 + + 返回: + attn_mask: BoolTensor,shape = (bs, len_q, len_k) + """ + bs, len_q, len_k = query.shape[0], query.shape[1], key.shape[1] + # 创建基础的 mask + attn_mask = torch.zeros(len_q, len_k, dtype=torch.bool) + + for i in range(len_q): + # 定义 query 的位置 i 对应的 key 范围 + start = i * k + end = min((i + 1) * k, len_k) # 防止超出 len_k + attn_mask[i, start:end] = True # 允许注意的范围 + + # 扩展到 batch 维度 + attn_mask = attn_mask.unsqueeze(0).expand(bs, -1, -1) + + return attn_mask.to(query.device) diff --git a/src/uni2ts/module/multi_scale/attention.py b/src/uni2ts/module/multi_scale/attention.py index fc0d0ac..e08145b 100644 --- a/src/uni2ts/module/multi_scale/attention.py +++ b/src/uni2ts/module/multi_scale/attention.py @@ -345,6 +345,7 @@ def forward( # value = init_value.clone() # # # ToDo: Plan B: Directly apply different Film on query / key to different scales. W.o revising RoPE + # Clone inplace切片 # if self.num_new_scales is not None: # index_by_variate = self.get_token_index_by_variate(query_var_id) # diff --git a/src/uni2ts/module/position/attn_projection.py b/src/uni2ts/module/position/attn_projection.py index 4a0bb5e..a6b68a5 100644 --- a/src/uni2ts/module/position/attn_projection.py +++ b/src/uni2ts/module/position/attn_projection.py @@ -142,25 +142,11 @@ def __init__( self.max_len = max_len - self.seq_id_q_proj = nn.ParameterList() - self.seq_id_k_proj = nn.ParameterList() - def post_init(self, token_idx_per_scale, base_ctx_token_idx): self.token_idx_per_scale = token_idx_per_scale self.base_ctx_token_idx = base_ctx_token_idx - self.num_scales = len(token_idx_per_scale) - dim = self.proj_width * self.num_heads # ToDo: dim是partial ratio后展开的dim - - # Assign Q and K for each new scale - for scale in range(1, self.num_scales): - self.seq_id_q_proj.append(nn.Linear(dim, dim)) - self.seq_id_k_proj.append(nn.Linear(dim, dim)) - - # # Todo: norm - # q_norm = nn.LayerNorm(dim) - def _init_freq(self, max_len: int): if self.cos is None or self.cos.size(-2) < max_len: position = torch.arange( @@ -187,40 +173,18 @@ def forward( rot_cos = torch.empty(rot_shape, device=seq_id.device, dtype=torch.float) rot_sin = torch.empty(rot_shape, device=seq_id.device, dtype=torch.float) - # Key: base scale context tokens; Value: base scale time id - idx_kv = self.base_ctx_token_idx - x_flat = rearrange(x, "... group hpg q_len dim -> ... q_len (group hpg dim)") # flat multi-head - seq_id_flat = rearrange(seq_id, "... group hpg q_len -> ... q_len (group hpg) ") - key = x_flat[..., idx_kv, :] # (bs, len0, dim) - value = seq_id_flat[..., idx_kv, :].to(dtype=torch.float) # (bs, len0, 1) - for i in range(self.num_scales): idx_scale_i = self.token_idx_per_scale[i] + mapped_seq_id = seq_id[..., :, :, idx_scale_i] # (bs, 1, 1, len) # Directly use original time_id to obtain sin/cos for base scale if i == 0: - mapped_seq_id = seq_id[..., :, :, idx_scale_i] # (bs, 1, 1, len) + mapped_seq_id = mapped_seq_id.to(torch.int) rot_cos[..., :, :, idx_scale_i, :] = self.cos[mapped_seq_id] # (bs, 1, 1, len0, proj_width) rot_sin[..., :, :, idx_scale_i, :] = self.sin[mapped_seq_id] - # For new scales, need to map their time_id to the original scale before computing sin/cos. + # For new scales, compute the theta for their float mapped id. And their cos/sin. else: - query = x_flat[..., idx_scale_i, :] # (bs, leni, dim) - query = self.seq_id_q_proj[i - 1](query) - key = self.seq_id_k_proj[i - 1](key) # (bs, len0, dim) - - # Generate attn_mask. Make sure each query only attend to the keys in its down-sampling range. - attn_mask = self.generate_segmented_attn_mask(query.shape[0], query.shape[1], key.shape[1], 2**i).to(x.device) - - # mapped_seq_id is float time id on the original scale. (bs, len_i, 1) - mapped_seq_id = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attn_mask, - ) - - # Compute the theta for these float id. And their cos/sin. m_theta = einsum(mapped_seq_id.squeeze(), self.theta, "bs length, width -> bs length width") m_theta = repeat(m_theta, "bs length width -> bs length (width 2)") rot_cos[..., :, :, idx_scale_i, :] = torch.cos(m_theta).unsqueeze(1).unsqueeze(2) @@ -228,32 +192,6 @@ def forward( return rot_cos * x + rot_sin * self._rotate(x) # QZ: Eq 34 in the paper - def generate_segmented_attn_mask(self, bs, len_q, len_k, k): - """ - 生成一个 attention mask,使得 query 的位置 i 只能注意到 key 的范围 [k*i, k*(i+1)-1]。 - - 参数: - bs: batch size - len_q: query 的序列长度 - len_k: key 的序列长度 - k: 每个 query 索引范围内 key 的跨度 - - 返回: - attn_mask: BoolTensor,shape = (bs, len_q, len_k) - """ - # 创建基础的 mask - attn_mask = torch.zeros(len_q, len_k, dtype=torch.bool) - - for i in range(len_q): - # 定义 query 的位置 i 对应的 key 范围 - start = i * k - end = min((i + 1) * k, len_k) # 防止超出 len_k - attn_mask[i, start:end] = True # 允许注意的范围 - - # 扩展到 batch 维度 - attn_mask = attn_mask.unsqueeze(0).expand(bs, -1, -1) - - return attn_mask # class MultiScaleRotaryProjection(Projection): # def __init__(