From 1bc73f2254827cc83e2943252cb82ad364b3f778 Mon Sep 17 00:00:00 2001 From: Qiao Zhongzheng Date: Tue, 10 Dec 2024 21:18:21 +0800 Subject: [PATCH] Add r and alpha for attn_lora --- .../multi_scale/finetune/default.yaml | 5 +- .../finetune/model/moirai_1.0_R_base.yaml | 2 + .../finetune/model/moirai_1.0_R_small.yaml | 2 + .../finetune/model/moirai_1.1_R_small.yaml | 2 + .../finetune_two_stage/default.yaml | 7 +- .../model/moirai_1.0_R_base.yaml | 2 + .../model/moirai_1.0_R_small.yaml | 2 + .../model/moirai_1.1_R_small.yaml | 2 + cli/train_two_stage.py | 23 +- .../lsf-setup/multi_scale/eval/small/ettm1.sh | 10 +- .../lsf-setup/multi_scale/eval/small/ettm2.sh | 8 +- .../multi_scale/eval/small/weather.sh | 8 +- .../finetune_two_stage/small/ettm1.sh | 6 +- .../finetune_two_stage/small/ettm2.sh | 4 +- .../finetune_two_stage/small/weather.sh | 4 +- src/uni2ts/model/lsf_moirai/finetune.py | 4 +- src/uni2ts/model/lsf_moirai_point/finetune.py | 4 +- src/uni2ts/model/moirai/finetune.py | 4 +- .../model/multi_scale_moirai/finetune.py | 62 +- .../finetune_cov_two_stage.py | 823 ------------------ .../multi_scale_moirai/finetune_two_stage.py | 96 +- .../model/multi_scale_moirai/forecast.py | 20 +- src/uni2ts/model/multi_scale_moirai/module.py | 97 +-- .../model/multi_scale_moirai/module_conv.py | 287 ------ .../model/seasonal_naive_moirai/finetune.py | 6 +- src/uni2ts/module/multi_scale/attention.py | 157 +--- 26 files changed, 179 insertions(+), 1468 deletions(-) delete mode 100644 src/uni2ts/model/multi_scale_moirai/finetune_cov_two_stage.py delete mode 100644 src/uni2ts/model/multi_scale_moirai/module_conv.py diff --git a/cli/conf/lsf-setup/multi_scale/finetune/default.yaml b/cli/conf/lsf-setup/multi_scale/finetune/default.yaml index 83f4320..7243bbf 100644 --- a/cli/conf/lsf-setup/multi_scale/finetune/default.yaml +++ b/cli/conf/lsf-setup/multi_scale/finetune/default.yaml @@ -33,13 +33,10 @@ trainer: mode: min save_top_k: 1 # Qz: Sometimes the 1st validation gets anomalous results. Discard that ckpt, and use the 2nd one. every_n_epochs: 1 - - _target_: lightning.pytorch.callbacks.ModelCheckpoint - dirpath: ${hydra:runtime.output_dir}/checkpoints - save_weights_only: true - _target_: lightning.pytorch.callbacks.EarlyStopping # uni2ts.callbacks.earlystop.WarmupEarlyStopping monitor: val/PackedNLLLoss min_delta: 0.0 - patience: 3 # Set to a small value as now each epoch has many batches. + patience: 3 mode: min strict: false verbose: true diff --git a/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.0_R_base.yaml b/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.0_R_base.yaml index ea10165..f4f3fe8 100644 --- a/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.0_R_base.yaml +++ b/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.0_R_base.yaml @@ -41,6 +41,8 @@ prediction_length: null finetune_pattern: full num_new_scales: 3 ds_factor: 2 +r: 16 +alpha: 16 use_lora: True lora_kwargs: diff --git a/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.0_R_small.yaml b/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.0_R_small.yaml index 1237d36..ac37542 100644 --- a/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.0_R_small.yaml +++ b/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.0_R_small.yaml @@ -41,6 +41,8 @@ prediction_length: null finetune_pattern: full num_new_scales: 3 ds_factor: 2 +r: 16 +alpha: 16 use_lora: False lora_kwargs: diff --git a/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.1_R_small.yaml b/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.1_R_small.yaml index f4ddcc6..d7fead4 100644 --- a/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.1_R_small.yaml +++ b/cli/conf/lsf-setup/multi_scale/finetune/model/moirai_1.1_R_small.yaml @@ -41,6 +41,8 @@ prediction_length: null finetune_pattern: full num_new_scales: 3 ds_factor: 2 +r: 16 +alpha: 16 use_lora: False lora_kwargs: diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml index 6243e3a..de11117 100644 --- a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml @@ -69,13 +69,10 @@ trainer: mode: min save_top_k: 1 # Qz: Sometimes the 1st validation gets anomalous results. Discard that ckpt, and use the 2nd one. every_n_epochs: 1 - - _target_: lightning.pytorch.callbacks.ModelCheckpoint - dirpath: ${hydra:runtime.output_dir}/checkpoints - save_weights_only: true - - _target_: lightning.pytorch.callbacks.EarlyStopping # uni2ts.callbacks.earlystop.WarmupEarlyStopping + - _target_: lightning.pytorch.callbacks.EarlyStopping # uni2ts.callbacks.earlystop.WarmupEarlyStopping monitor: val/PackedNLLLoss min_delta: 0.0 - patience: 3 # Set to a small value as now each epoch has many batches. + patience: 3 mode: min strict: false verbose: true diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_base.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_base.yaml index 75cef7c..9aedc02 100644 --- a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_base.yaml +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_base.yaml @@ -41,6 +41,8 @@ prediction_length: null finetune_pattern: full num_new_scales: 3 ds_factor: 2 +r: 16 +alpha: 16 use_lora: True lora_kwargs: diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_small.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_small.yaml index 39ae9d7..51cf998 100644 --- a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_small.yaml +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.0_R_small.yaml @@ -41,6 +41,8 @@ prediction_length: null finetune_pattern: full num_new_scales: 3 ds_factor: 2 +r: 16 +alpha: 16 use_lora: False lora_kwargs: diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.1_R_small.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.1_R_small.yaml index f4ddcc6..d7fead4 100644 --- a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.1_R_small.yaml +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/model/moirai_1.1_R_small.yaml @@ -41,6 +41,8 @@ prediction_length: null finetune_pattern: full num_new_scales: 3 ds_factor: 2 +r: 16 +alpha: 16 use_lora: False lora_kwargs: diff --git a/cli/train_two_stage.py b/cli/train_two_stage.py index d0e94e9..08d43e6 100644 --- a/cli/train_two_stage.py +++ b/cli/train_two_stage.py @@ -27,6 +27,21 @@ from uni2ts.common import hydra_util # noqa: hydra resolvers from uni2ts.data.loader import DataLoader +import os +import glob + + +def get_best_checkpoint_path(checkpoint_dir: str): + # list all .ckpt files + ckpt_files = glob.glob(os.path.join(checkpoint_dir, "*.ckpt")) + + if len(ckpt_files) == 1: + return ckpt_files[0] # Return the path of the only .ckpt file + elif len(ckpt_files) == 0: + raise FileNotFoundError(f"No .ckpt file found in {checkpoint_dir}") + else: + raise ValueError(f"Multiple .ckpt files found in {checkpoint_dir}. Expected only one.") + class DataModule(L.LightningDataModule): def __init__( @@ -139,13 +154,12 @@ def main(cfg: DictConfig): # ToDo: 写training_warmup的config trainer_warmup: L.Trainer = instantiate(cfg.trainer_warmup) - + trainer_warmup.callbacks[-1].CHECKPOINT_EQUALS_CHAR = "_" trainer: L.Trainer = instantiate(cfg.trainer) # '=' in ckpt name make it cannot be directly loaded with hydra. Change it to '_'. trainer.callbacks[-1].CHECKPOINT_EQUALS_CHAR = "_" - trainer.callbacks[-2].CHECKPOINT_EQUALS_CHAR = "_" train_dataset: Dataset = instantiate(cfg.data).load_dataset( model.train_transform_map @@ -190,7 +204,10 @@ def main(cfg: DictConfig): ckpt_path=cfg.ckpt_path, ) - print("Finished warmup stage. Now finetuning the whole model...") + # Load the saved ckpt of the best model in stage 1 + print("Finished warmup stage. Now loading the saved model and finetuning the whole model...") + checkpoint = torch.load(get_best_checkpoint_path(trainer_warmup.callbacks[-1].dirpath)) + model.load_state_dict(checkpoint["state_dict"]) model.current_stage = 2 trainer.fit( diff --git a/project/lsf-setup/multi_scale/eval/small/ettm1.sh b/project/lsf-setup/multi_scale/eval/small/ettm1.sh index 1b79a84..098e73b 100644 --- a/project/lsf-setup/multi_scale/eval/small/ettm1.sh +++ b/project/lsf-setup/multi_scale/eval/small/ettm1.sh @@ -1,7 +1,7 @@ #!/bin/bash export HYDRA_FULL_ERROR=1 -export CUDA_VISIBLE_DEVICES=1 +export CUDA_VISIBLE_DEVICES=2 mode=S cp=conf/lsf-setup/multi_scale/eval @@ -9,10 +9,10 @@ exp_name=lsf cl=4000 model=moirai_lightning_ckpt -cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl96/checkpoints/epoch_5-step_2502.ckpt' -cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl192/checkpoints/epoch_1-step_832.ckpt' -cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt' -cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt' +cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm1/S/cl4000_pl96/checkpoints_warmup/epoch_1-step_834-v2.ckpt' +cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm1/S/cl4000_pl192/checkpoints_warmup/epoch_3-step_1664.ckpt' +cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm1/S/cl4000_pl336/checkpoints_warmup/epoch_2-step_1242.ckpt' +cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm1/S/cl4000_pl720/checkpoints_warmup/epoch_2-step_1224.ckpt' index=1 for pl in 96 192 336 720; do diff --git a/project/lsf-setup/multi_scale/eval/small/ettm2.sh b/project/lsf-setup/multi_scale/eval/small/ettm2.sh index fb091af..b1f1767 100644 --- a/project/lsf-setup/multi_scale/eval/small/ettm2.sh +++ b/project/lsf-setup/multi_scale/eval/small/ettm2.sh @@ -9,10 +9,10 @@ exp_name=lsf cl=3000 model=moirai_lightning_ckpt -cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl96/checkpoints/epoch_16-step_7327.ckpt' -cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl192/checkpoints/epoch_3-step_1716.ckpt' -cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl336/checkpoints/epoch_1-step_854.ckpt' -cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt' +cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm2/S/cl3000_pl96/checkpoints_warmup/epoch_1-step_862-v1.ckpt' +cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm2/S/cl3000_pl192/checkpoints_warmup/epoch_0-step_429-v1.ckpt' +cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm2/S/cl3000_pl336/checkpoints_warmup/epoch_0-step_427.ckpt' +cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm2/S/cl3000_pl720/checkpoints_warmup/epoch_26-step_11394.ckpt' index=1 for pl in 96 192 336 720; do diff --git a/project/lsf-setup/multi_scale/eval/small/weather.sh b/project/lsf-setup/multi_scale/eval/small/weather.sh index 072fcaa..3fe0026 100644 --- a/project/lsf-setup/multi_scale/eval/small/weather.sh +++ b/project/lsf-setup/multi_scale/eval/small/weather.sh @@ -9,10 +9,10 @@ exp_name=lsf cl=2000 model=moirai_lightning_ckpt -cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl96/checkpoints/epoch_14-step_21420.ckpt' -cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl192/checkpoints/epoch_9-step_14240.ckpt' -cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl336/checkpoints/epoch_5-step_8508.ckpt' -cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl720/checkpoints/epoch_1-step_2804.ckpt' +cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/weather/S/cl2000_pl96/checkpoints_warmup/epoch_2-step_4284.ckpt' +cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/weather/S/cl2000_pl192/checkpoints_warmup/epoch_1-step_2848.ckpt' +cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/weather/S/cl2000_pl336/checkpoints_warmup/epoch_1-step_2836.ckpt' +cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/weather/S/cl2000_pl720/checkpoints_warmup/epoch_1-step_2804.ckpt' index=1 for pl in 96 192 336 720; do diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh index 631098b..2baf59f 100644 --- a/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh @@ -1,10 +1,10 @@ #!/bin/bash -export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1; +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3; model=moirai_1.0_R_small cp=conf/lsf-setup/multi_scale/finetune_two_stage -exp_name=1tid_2inproj_all_scale_lora_freezeqkv +exp_name=direct_1full_2head data=ettm1 cl=4000 ps=128 @@ -32,4 +32,6 @@ for pl in 96 192 336 720; do val_data.context_length=$cl \ val_data.prediction_length=$pl \ val_data.mode=${mode} +# trainer_warmup.callbacks."1".monitor=val/PackedMSELoss \ +# trainer_warmup.callbacks."2".monitor=val/PackedMSELoss done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh index 59ef855..357b8c7 100644 --- a/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh @@ -4,7 +4,7 @@ export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3; model=moirai_1.0_R_small cp=conf/lsf-setup/multi_scale/finetune_two_stage -exp_name=1tid_2inproj_all_scale_lora_freezeqkv +exp_name=direct_1full_2head data=ettm2 cl=3000 ps=64 @@ -32,4 +32,6 @@ for pl in 96 192 336 720; do val_data.context_length=$cl \ val_data.prediction_length=$pl \ val_data.mode=${mode} +# trainer_warmup.callbacks."1".monitor=val/PackedMSELoss \ +# trainer_warmup.callbacks."2".monitor=val/PackedMSELoss done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh index 6412134..30b2c11 100644 --- a/project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh @@ -4,7 +4,7 @@ export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; model=moirai_1.0_R_small cp=conf/lsf-setup/multi_scale/finetune_two_stage -exp_name=1tid_2inproj_all_scale_lora_freezeqkv +exp_name=direct_1full_2head data=weather cl=2000 ps=128 @@ -33,7 +33,7 @@ for pl in 96 192 336 720; do val_data.prediction_length=$pl \ val_data.mode=${mode} \ trainer.callbacks."1".monitor=val/PackedMSELoss \ - trainer.callbacks."3".monitor=val/PackedMSELoss \ + trainer.callbacks."2".monitor=val/PackedMSELoss \ trainer_warmup.callbacks."1".monitor=val/PackedMSELoss \ trainer_warmup.callbacks."2".monitor=val/PackedMSELoss done \ No newline at end of file diff --git a/src/uni2ts/model/lsf_moirai/finetune.py b/src/uni2ts/model/lsf_moirai/finetune.py index 0dcb5e1..160595d 100644 --- a/src/uni2ts/model/lsf_moirai/finetune.py +++ b/src/uni2ts/model/lsf_moirai/finetune.py @@ -414,7 +414,7 @@ def configure_optimizers(self) -> dict: # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} - self.trainable_params = param_dict + self.updated_params = param_dict inter_params = decay & no_decay union_params = decay | no_decay @@ -736,7 +736,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): filtered_state = { name: tensor for name, tensor in state.items() - if name in self.trainable_params + if name in self.updated_params } return filtered_state diff --git a/src/uni2ts/model/lsf_moirai_point/finetune.py b/src/uni2ts/model/lsf_moirai_point/finetune.py index ef6ec63..9603095 100644 --- a/src/uni2ts/model/lsf_moirai_point/finetune.py +++ b/src/uni2ts/model/lsf_moirai_point/finetune.py @@ -343,7 +343,7 @@ def configure_optimizers(self) -> dict: # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} - self.trainable_params = param_dict + self.updated_params = param_dict inter_params = decay & no_decay union_params = decay | no_decay @@ -665,6 +665,6 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): filtered_state = { name: tensor for name, tensor in state.items() - if name in self.trainable_params + if name in self.updated_params } return filtered_state diff --git a/src/uni2ts/model/moirai/finetune.py b/src/uni2ts/model/moirai/finetune.py index e729fa5..177467b 100644 --- a/src/uni2ts/model/moirai/finetune.py +++ b/src/uni2ts/model/moirai/finetune.py @@ -372,7 +372,7 @@ def configure_optimizers(self) -> dict: # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} - self.trainable_params = param_dict + self.updated_params = param_dict inter_params = decay & no_decay union_params = decay | no_decay @@ -702,7 +702,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): filtered_state = { name: tensor for name, tensor in state.items() - if name in self.trainable_params + if name in self.updated_params } return filtered_state diff --git a/src/uni2ts/model/multi_scale_moirai/finetune.py b/src/uni2ts/model/multi_scale_moirai/finetune.py index 6429b3a..274401c 100644 --- a/src/uni2ts/model/multi_scale_moirai/finetune.py +++ b/src/uni2ts/model/multi_scale_moirai/finetune.py @@ -117,6 +117,8 @@ def __init__( finetune_pattern: str | list[str] = "full", num_new_scales: Optional[int] = None, ds_factor: int = 2, + r: int = 16, + alpha: int = 16, use_lora: bool = False, lora_kwargs: Optional[dict[str, Any]] = None, ): @@ -130,6 +132,8 @@ def __init__( self.finetune_pattern = finetune_pattern self.num_new_scales = num_new_scales self.ds_factor = ds_factor + self.r = r + self.alpha = alpha self.token_idx_per_scale, self.base_ctx_token_idx = self._get_token_idx_per_scale() @@ -140,18 +144,19 @@ def post_init(self): """ Initialize the new params added for Multi Scale. """ - - self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.patch_size) + # ToDo: for time id & in_proj + # self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.patch_size) for layer in self.module.encoder.layers: # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): # Call post_init() method of the GroupedQueryAttention object - layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor) + layer.self_attn.init_multi_scale_modules(self.num_new_scales, self.r, self.alpha) - for module in self.module.encoder.modules(): - if isinstance(module, MultiScaleRotaryProjection): - module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx) + # ToDo: for time id + # for module in self.module.encoder.modules(): + # if isinstance(module, MultiScaleRotaryProjection): + # module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx) if self.lora_config is not None: self.module = LoraModel(self.module, self.lora_config, "default") @@ -192,8 +197,6 @@ def _get_token_idx_per_scale(self): return token_idx_per_scale, base_ctx_token_idx - - def forward( self, target: Float[torch.Tensor, "*batch seq_len max_patch"], @@ -337,27 +340,10 @@ def configure_optimizers(self) -> dict: for param in self.parameters(): param.requires_grad = False - # # Always learn the scale embedding - # for pn, p in self.named_parameters(): - # if "new_scale_encoding" in pn: - # p.requires_grad = True - for pn, p in self.named_parameters(): - if "film" in pn: - p.requires_grad = True - - if "adapt_weight" in pn: - p.requires_grad = True - - if "adapt_bias" in pn: - p.requires_grad = True - if "var_attn_bias.emb" in pn: p.requires_grad = True - if "pe_weights" in pn: # Learnable RoPE for time id proj - p.requires_grad = True - if "time_id_q_proj" in pn or "time_id_k_proj" in pn: p.requires_grad = True @@ -372,7 +358,7 @@ def configure_optimizers(self) -> dict: if "in_proj" in pn: p.requires_grad = True - if "norm" in self.finetune_pattern: # + if "rms_norm" in self.finetune_pattern: for pn, p in self.named_parameters(): if "norm1" in pn or "norm2" in pn: p.requires_grad = True @@ -417,14 +403,6 @@ def configure_optimizers(self) -> dict: if "out_proj" in pn: p.requires_grad = True - if "studentT" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if ( - "param_proj.proj.components.0" in pn - or "param_proj.proj.weights_logits" in pn - ): - p.requires_grad = True - whitelist_params = ( LearnedProjection, MultiInSizeLinear, @@ -451,15 +429,11 @@ def configure_optimizers(self) -> dict: decay.add(fpn) elif pn.endswith("weight") and isinstance(m, blacklist_params): no_decay.add(fpn) - elif "adapt_weight" in pn or "adapt_bias" in pn: - decay.add(fpn) - elif 'pe_weights' in pn: - decay.add(fpn) - elif 'q_A' in pn or 'q_B' in pn or 'q_bias' in pn: + elif 'q_A' in pn or 'q_B' in pn: decay.add(fpn) - elif 'k_A' in pn or 'k_B' in pn or 'k_bias' in pn: + elif 'k_A' in pn or 'k_B' in pn: decay.add(fpn) - elif 'v_A' in pn or 'v_B' in pn or 'v_bias' in pn: + elif 'v_A' in pn or 'v_B' in pn: decay.add(fpn) # elif 'layers.0.self_attn.time_qk_proj.query_proj.pe_weights' in pn: # Shared time_qk_proj @@ -467,7 +441,7 @@ def configure_optimizers(self) -> dict: # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} - self.trainable_params = param_dict + self.updated_params = param_dict inter_params = decay & no_decay union_params = decay | no_decay @@ -792,7 +766,7 @@ def default_val_transform( def state_dict(self, *args, destination=None, prefix="", keep_vars=False): """ - Modify state_dict to only save trainable params. + Modify state_dict to only save updated params. Note the default state_dict saved by PL converts all params to require_grads=False """ state = super().state_dict( @@ -801,7 +775,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): filtered_state = { name: tensor for name, tensor in state.items() - if name in self.trainable_params + if name in self.updated_params } return filtered_state diff --git a/src/uni2ts/model/multi_scale_moirai/finetune_cov_two_stage.py b/src/uni2ts/model/multi_scale_moirai/finetune_cov_two_stage.py deleted file mode 100644 index 45d216d..0000000 --- a/src/uni2ts/model/multi_scale_moirai/finetune_cov_two_stage.py +++ /dev/null @@ -1,823 +0,0 @@ -# Copyright (c) 2024, Salesforce, Inc. -# SPDX-License-Identifier: Apache-2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from collections import defaultdict -from collections.abc import Callable, Sequence -from typing import Any, Optional - -import lightning as L -import numpy as np -import torch -from jaxtyping import Bool, Float, Int -from torch import nn -from torch.distributions import Distribution - -from uni2ts.loss.packed import ( - PackedDistributionLoss, - PackedLoss, - PackedNLLLoss, - PackedPointLoss, -) -from uni2ts.module.norm import RMSNorm -from uni2ts.module.position import ( - BinaryAttentionBias, - LearnedEmbedding, - LearnedProjection, - MultiScaleRotaryProjection -) -from uni2ts.module.ts_embed import MultiInSizeLinear, MultiOutSizeLinear -from uni2ts.optim import SchedulerType, get_scheduler -from uni2ts.transform import ( - AddNewScaleContextSeries, - AddObservedMask, - AddSampleIndex, - AddTimeIndex, - AddVariateIndex, - DefaultPatchSizeConstraints, - DummyValueImputation, - EvalPad, - ExtendMask, - FinetunePatchCrop, - FixedPatchSizeConstraints, - FlatPackCollection, - FlatPackFields, - GetPatchSize, - Identity, - ImputeTimeSeries, - MaskOutRangePaddedTokens, - MultiScaleEvalCrop, - MultiScaleMaskedPredictionGivenFixedConfig, - PackFields, - PadNewScaleSeries, - PatchCrop, - PatchCropGivenFixedConfig, - Patchify, - SelectFields, - SequencifyField, - Transformation, -) - -from .module import MoiraiModule -from uni2ts.module.multi_scale.attention import GroupedQueryAttention -from peft import LoraConfig, LoraModel - - -class TwoStageMoiraiFinetune(L.LightningModule): - seq_fields: tuple[str, ...] = ( - "target", - "observed_mask", - "time_id", - "variate_id", - "prediction_mask", - "patch_size", - ) - pad_func_map: dict[str, Callable[[Sequence[int], np.dtype], np.ndarray]] = { - "target": np.zeros, - "observed_mask": np.zeros, - "time_id": np.zeros, - "variate_id": np.zeros, - "prediction_mask": np.zeros, - "patch_size": np.zeros, - } - - def __init__( - self, - min_patches: int, - min_mask_ratio: float, - max_mask_ratio: float, - max_dim: int, - num_training_steps: int, - num_warmup_steps: int, - module_kwargs: Optional[dict[str, Any]] = None, - module: Optional[MoiraiModule] = None, - num_samples: int = 100, - beta1: float = 0.9, - beta2: float = 0.98, - loss_func: PackedDistributionLoss = PackedNLLLoss(), - val_metric: Optional[PackedLoss | list[PackedLoss]] = None, - lr: float = 1e-3, - weight_decay: float = 1e-2, - log_on_step: bool = False, - context_length: Optional[int | list[int]] = None, - prediction_length: Optional[int | list[int]] = None, - patch_size: Optional[int] = None, - finetune_pattern: str | list[str] = "full", - num_new_scales: Optional[int] = None, - ds_factor: int = 2, - use_lora: bool = False, - lora_kwargs: Optional[dict[str, Any]] = None, - ): - super().__init__() - self.save_hyperparameters(ignore=["module"]) - self.module = MoiraiModule(**module_kwargs) if module is None else module - - self.context_length = context_length - self.prediction_length = prediction_length - self.patch_size = patch_size - self.finetune_pattern = finetune_pattern - self.num_new_scales = num_new_scales - self.ds_factor = ds_factor - - self.token_idx_per_scale, self.base_ctx_token_idx = self._get_token_idx_per_scale() - - # Lora config - self.lora_config = LoraConfig(**lora_kwargs) if use_lora else None - - self.current_stage = 1 # 用于切换阶段 - - def post_init(self): - """ - Initialize the new params added for Multi Scale. - """ - - self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.patch_size) - - # for layer in self.module.encoder.layers: - # # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention - # if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): - # # Call post_init() method of the GroupedQueryAttention object - # layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor) - - for module in self.module.encoder.modules(): - if isinstance(module, MultiScaleRotaryProjection): - module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx) - - if self.lora_config is not None: - self.module = LoraModel(self.module, self.lora_config, "default") - # Params not used in Lora are set as requires_grad=False automatically. - # Activate some of those params manually. FFN and out_proj are kept as frozen. - for pn, p in self.named_parameters(): - if "param_proj" in pn or "in_proj" in pn: - p.requires_grad = True - if "norm" in pn: - p.requires_grad = True - if "mask_encoding" in pn or "var_attn_bias" in pn: - p.requires_grad = True - # ToDo: Note to include new learnable params introduced in MS - - def _get_token_idx_per_scale(self): - base_token_len = math.ceil(self.context_length / self.patch_size) + math.ceil(self.prediction_length / self.patch_size) - ctx_len = self.context_length - new_scale_token_len = [] - - # New scales only include context part. - for i in range(self.num_new_scales): - ctx_len = math.ceil(ctx_len / self.ds_factor) - ctx_token_len = math.ceil(ctx_len / self.patch_size) - - new_scale_token_len.append(ctx_token_len) - - token_idx_per_scale = [list(range(base_token_len))] - - for i in range(self.num_new_scales): - start = base_token_len if i == 0 else end - end = start + new_scale_token_len[i] - - index = list(range(start, end)) - token_idx_per_scale.append(index) - - base_ctx_token_len = math.ceil(self.context_length / self.patch_size) - base_ctx_token_idx = list(range(base_ctx_token_len)) - - return token_idx_per_scale, base_ctx_token_idx - - - - def forward( - self, - target: Float[torch.Tensor, "*batch seq_len max_patch"], - observed_mask: Bool[torch.Tensor, "*batch seq_len max_patch"], - sample_id: Int[torch.Tensor, "*batch seq_len"], - time_id: Int[torch.Tensor, "*batch seq_len"], - variate_id: Int[torch.Tensor, "*batch seq_len"], - prediction_mask: Bool[torch.Tensor, "*batch seq_len"], - patch_size: Int[torch.Tensor, "*batch seq_len"], - ) -> Distribution: - distr = self.module( - target=target, - observed_mask=observed_mask, - sample_id=sample_id, - time_id=time_id, - variate_id=variate_id, - prediction_mask=prediction_mask, - patch_size=patch_size, - ) - return distr - - def training_step( - self, batch: dict[str, torch.Tensor], batch_idx: int - ) -> torch.Tensor: - distr = self( - **{field: batch[field] for field in list(self.seq_fields) + ["sample_id"]} - ) - loss = self.hparams.loss_func( - pred=distr, - **{ - field: batch[field] - for field in [ - "target", - "prediction_mask", - "observed_mask", - "sample_id", - "variate_id", - ] - }, - ) - batch_size = ( - batch["sample_id"].max(dim=1).values.sum() if "sample_id" in batch else None - ) - self.log( - f"train/{self.hparams.loss_func.__class__.__name__}", - loss, - on_step=self.hparams.log_on_step, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - batch_size=batch_size, - rank_zero_only=True, - ) - return loss - - def validation_step( - self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int = 0 - ) -> torch.Tensor: - distr = self( - **{field: batch[field] for field in list(self.seq_fields) + ["sample_id"]} - ) - val_loss = self.hparams.loss_func( - pred=distr, - **{ - field: batch[field] - for field in [ - "target", - "prediction_mask", - "observed_mask", - "sample_id", - "variate_id", - ] - }, - ) - batch_size = ( - batch["sample_id"].max(dim=1).values.sum() if "sample_id" in batch else None - ) - self.log( - f"val/{self.hparams.loss_func.__class__.__name__}", - val_loss, - on_step=self.hparams.log_on_step, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - batch_size=batch_size, - rank_zero_only=True, - ) - - if self.hparams.val_metric is not None: - val_metrics = ( - self.hparams.val_metric - if isinstance(self.hparams.val_metric, list) - else [self.hparams.val_metric] - ) - for metric_func in val_metrics: - if isinstance(metric_func, PackedPointLoss): - pred = distr.sample(torch.Size((self.hparams.num_samples,))) - pred = torch.median(pred, dim=0).values - elif isinstance(metric_func, PackedDistributionLoss): - pred = distr - else: - raise ValueError(f"Unsupported loss function: {metric_func}") - - metric = metric_func( - pred=pred, - **{ - field: batch[field] - for field in [ - "target", - "prediction_mask", - "observed_mask", - "sample_id", - "variate_id", - ] - }, - ) - - self.log( - f"val/{metric_func.__class__.__name__}", - metric, - on_step=self.hparams.log_on_step, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - batch_size=batch_size, - rank_zero_only=True, - ) - - return val_loss - - def configure_optimizers(self) -> dict: - - if self.current_stage == 1: - warmup_pn = ['param_proj', 'time_id_q_proj', 'time_id_k_proj'] - warmup_params = { - pn: p for pn, p in self.named_parameters() - if any(keyword in pn for keyword in warmup_pn) # 检查pn是否包含warmup_pn中的任意字段 - } - self.trainable_params = warmup_params - optimizer = torch.optim.AdamW( - warmup_params.values(), - lr=5e-4, - betas=(self.hparams.beta1, self.hparams.beta2), - eps=1e-6, - ) - scheduler = get_scheduler( - SchedulerType.CONSTANT, # Use constant lr scheduler - optimizer, - num_warmup_steps=self.hparams.num_warmup_steps, - num_training_steps=self.hparams.num_training_steps, - ) - - if self.current_stage == 2: - decay = set() - no_decay = set() - - if "full" in self.finetune_pattern: - pass - else: - for param in self.parameters(): - param.requires_grad = False - - for pn, p in self.named_parameters(): - if "film" in pn: - p.requires_grad = True - - if "adapt_weight" in pn: - p.requires_grad = True - - if "adapt_bias" in pn: - p.requires_grad = True - - if "var_attn_bias.emb" in pn: - p.requires_grad = True - - if "pe_weights" in pn: # Learnable RoPE for time id proj - p.requires_grad = True - - if "time_id_q_proj" in pn or "time_id_k_proj" in pn: - p.requires_grad = True - - # Unfreeze the corresponding params - if "param_proj" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if "param_proj" in pn: - p.requires_grad = True - - if "in_proj" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if "in_proj" in pn: - p.requires_grad = True - - if "norm" in self.finetune_pattern: # - for pn, p in self.named_parameters(): - if "norm1" in pn or "norm2" in pn: - p.requires_grad = True - - if "mask" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if "mask_encoding" in pn: - p.requires_grad = True - - if "ffn" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if "ffn" in pn: - p.requires_grad = True - - if "q_proj" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if "q_proj" in pn: - p.requires_grad = True - - if "k_proj" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if "k_proj" in pn: - p.requires_grad = True - - if "v_proj" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if "v_proj" in pn: - p.requires_grad = True - - if "attn_norm" in self.finetune_pattern: # - for pn, p in self.named_parameters(): - if "self_attn.q_norm" in pn or "self_attn.k_norm" in pn: - p.requires_grad = True - - if "var_attn_bias" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if "var_attn_bias" in pn: - p.requires_grad = True - - if "out_proj" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if "out_proj" in pn: - p.requires_grad = True - - if "studentT" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if ( - "param_proj.proj.components.0" in pn - or "param_proj.proj.weights_logits" in pn - ): - p.requires_grad = True - - whitelist_params = ( - LearnedProjection, - MultiInSizeLinear, - MultiOutSizeLinear, - nn.Linear, - ) - blacklist_params = ( - BinaryAttentionBias, - LearnedEmbedding, - RMSNorm, - nn.Embedding, - nn.LayerNorm, - ) - - for mn, m in self.named_modules(): - for pn, p in m.named_parameters(): - if not p.requires_grad: - continue - - fpn = f"{mn}.{pn}" if mn else pn - if pn.endswith("bias"): - no_decay.add(fpn) - elif pn.endswith("weight") and isinstance(m, whitelist_params): - decay.add(fpn) - elif pn.endswith("weight") and isinstance(m, blacklist_params): - no_decay.add(fpn) - elif "adapt_weight" in pn or "adapt_bias" in pn: - decay.add(fpn) - elif 'pe_weights' in pn: - decay.add(fpn) - - # elif 'layers.0.self_attn.time_qk_proj.query_proj.pe_weights' in pn: # Shared time_qk_proj - # decay.add(fpn) - - # validate that we considered every parameter - param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} - self.trainable_params = param_dict - - inter_params = decay & no_decay - union_params = decay | no_decay - assert ( - len(inter_params) == 0 - ), f"parameters {str(inter_params)} made it into both decay/no_decay sets!" - assert ( - len(param_dict.keys() - union_params) == 0 - ), f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" - assert ( - len(union_params - param_dict.keys()) == 0 - ), f"parameters {str(union_params - param_dict.keys())} were not included in param_dict!" - - - optim_groups = [ - { - "params": filter( - lambda p: p.requires_grad, - [param_dict[pn] for pn in sorted(list(decay))], - ), - "weight_decay": self.hparams.weight_decay, - }, - { - "params": filter( - lambda p: p.requires_grad, - [param_dict[pn] for pn in sorted(list(no_decay))], - ), - "weight_decay": 0.0, - }, - ] - - optimizer = torch.optim.AdamW( - optim_groups, - lr=self.hparams.lr, - betas=(self.hparams.beta1, self.hparams.beta2), - eps=1e-6, - ) - scheduler = get_scheduler( - SchedulerType.CONSTANT, # Use constant lr scheduler - optimizer, - num_warmup_steps=self.hparams.num_warmup_steps, - num_training_steps=self.hparams.num_training_steps, - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "monitor": "train_loss", - "interval": "step", - }, - } - - @property - def train_transform_map( - self, - ) -> dict[str | type, Callable[..., Transformation]]: - def default_train_transform( - distance: int, - prediction_length: int, - context_length: int, - patch_size: int, - ): - return ( - GetPatchSize( - min_time_patches=self.hparams.min_patches, - target_field="target", - patch_sizes=self.module.patch_sizes, - patch_size_constraints=FixedPatchSizeConstraints(patch_size), - offset=True, - ) - + FinetunePatchCrop( - distance, - prediction_length, - context_length, - fields=("target",), - optional_fields=("past_feat_dynamic_real",), - ) - + PackFields( - output_field="target", - fields=("target",), - ) - + PackFields( - output_field="past_feat_dynamic_real", - fields=tuple(), - optional_fields=("past_feat_dynamic_real",), - ) - + EvalPad( - prediction_pad=-prediction_length % patch_size, - context_pad=-context_length % patch_size, - fields=("target",), - optional_fields=("past_feat_dynamic_real",), - ) - # QZ: Apply downsample to target. Create a new field 'target{i}' for each scale. - + AddNewScaleContextSeries( - target_field="target", - ds_factor=self.ds_factor, - num_new_scales_fields=self.new_scales_target_fields, - expected_ndim=2, - ) - # Pad down-sampled scales. Make sure their context and prediction are dividable by patch_size - + PadNewScaleSeries( - fields=self.new_scales_target_fields, - ) - + AddObservedMask( - fields=("target",) + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - observed_mask_field="observed_mask", - collection_type=dict, - ) - + ImputeTimeSeries( - fields=("target",) + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - imputation_method=DummyValueImputation(value=0.0), - ) - + Patchify( - max_patch_size=max(self.module.patch_sizes), - fields=( - "target", - "observed_mask", - ) - + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - ) - + AddVariateIndex( - fields=("target",) + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - variate_id_field="variate_id", - expected_ndim=3, - max_dim=self.hparams.max_dim, - randomize=False, - collection_type=dict, - ) - + AddTimeIndex( - fields=("target",) + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - time_id_field="time_id", - expected_ndim=3, - collection_type=dict, - ) - + AddSampleIndex( - fields=("target",), - optional_fields=("past_feat_dynamic_real",) - + self.new_scales_target_fields, - sample_id_field="sample_id", - expected_ndim=3, - collection_type=dict, - ) - + MultiScaleMaskedPredictionGivenFixedConfig( - target_fields=("target",) + self.new_scales_target_fields, - prediction_mask_field="prediction_mask", - expected_ndim=3, - ) - # + ExtendMask( - # fields=tuple(), - # optional_fields=("past_feat_dynamic_real",), - # mask_field="prediction_mask", - # expected_ndim=3, - # ) - + FlatPackCollection( - field="variate_id", - feat=False, - ) - + FlatPackCollection( - field="time_id", - feat=False, - ) - + FlatPackCollection( - field="sample_id", - feat=False, - ) - + FlatPackCollection( - field="prediction_mask", - feat=False, - ) - + FlatPackCollection( - field="observed_mask", - feat=True, - ) - + FlatPackFields( - output_field="target", - fields=("target",) + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - feat=True, - ) - + SequencifyField(field="patch_size", target_field="target") - + SelectFields(fields=list(self.seq_fields)) - ) - - return defaultdict(lambda: default_train_transform) - - @property - def val_transform_map( - self, - ) -> dict[str | type, Callable[..., Transformation]]: - def default_val_transform( - offset: int, - distance: int, - prediction_length: int, - context_length: int, - patch_size: int, - ): - return ( - GetPatchSize( - min_time_patches=2, - target_field="target", - patch_sizes=self.module.patch_sizes, - patch_size_constraints=FixedPatchSizeConstraints(patch_size), - offset=True, - ) - + MultiScaleEvalCrop( - offset, - distance, - prediction_length, - context_length, - fields=("target",), - optional_fields=("past_feat_dynamic_real",), - ) - + PackFields( - output_field="target", - fields=("target",), - ) - + PackFields( - output_field="past_feat_dynamic_real", - fields=tuple(), - optional_fields=("past_feat_dynamic_real",), - ) - + EvalPad( - prediction_pad=-prediction_length % patch_size, - context_pad=-context_length % patch_size, - fields=("target",), - optional_fields=("past_feat_dynamic_real",), - ) - + AddNewScaleContextSeries( - target_field="target", - ds_factor=self.ds_factor, - num_new_scales_fields=self.new_scales_target_fields, - expected_ndim=2, - ) - + PadNewScaleSeries( - fields=self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - ) - + AddObservedMask( - fields=("target",) + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - observed_mask_field="observed_mask", - collection_type=dict, - ) - + ImputeTimeSeries( - fields=("target",) + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - imputation_method=DummyValueImputation(value=0.0), - ) - + Patchify( - max_patch_size=max(self.module.patch_sizes), - fields=( - "target", - "observed_mask", - ) - + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - ) - + AddVariateIndex( - fields=("target",) + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - variate_id_field="variate_id", - expected_ndim=3, - max_dim=self.hparams.max_dim, - randomize=False, - collection_type=dict, - ) - + AddTimeIndex( - fields=("target",) + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - time_id_field="time_id", - expected_ndim=3, - collection_type=dict, - ) - + AddSampleIndex( - fields=("target",) + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - sample_id_field="sample_id", - expected_ndim=3, - collection_type=dict, - ) - + MultiScaleMaskedPredictionGivenFixedConfig( - target_fields=("target",) + self.new_scales_target_fields, - prediction_mask_field="prediction_mask", - expected_ndim=3, - ) - + FlatPackCollection( - field="variate_id", - feat=False, - ) - + FlatPackCollection( - field="time_id", - feat=False, - ) - + FlatPackCollection( - field="sample_id", - feat=False, - ) - + FlatPackCollection( - field="prediction_mask", - feat=False, - ) - + FlatPackCollection( - field="observed_mask", - feat=True, - ) - + FlatPackFields( - output_field="target", - fields=("target",) + self.new_scales_target_fields, - optional_fields=("past_feat_dynamic_real",), - feat=True, - ) - + SequencifyField(field="patch_size", target_field="target") - + SelectFields(fields=list(self.seq_fields)) - ) - - return defaultdict(lambda: default_val_transform) - - def state_dict(self, *args, destination=None, prefix="", keep_vars=False): - """ - Modify state_dict to only save trainable params. - Note the default state_dict saved by PL converts all params to require_grads=False - """ - state = super().state_dict( - destination=destination, prefix=prefix, keep_vars=keep_vars - ) - filtered_state = { - name: tensor - for name, tensor in state.items() - if name in self.trainable_params - } - return filtered_state - - @property - def new_scales_target_fields(self): - return tuple(f"target{i}" for i in range(self.num_new_scales)) diff --git a/src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py b/src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py index 84830c6..cd7d877 100644 --- a/src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py +++ b/src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py @@ -117,6 +117,8 @@ def __init__( finetune_pattern: str | list[str] = "full", num_new_scales: Optional[int] = None, ds_factor: int = 2, + r: int = 16, + alpha: int = 16, use_lora: bool = False, lora_kwargs: Optional[dict[str, Any]] = None, ): @@ -130,6 +132,8 @@ def __init__( self.finetune_pattern = finetune_pattern self.num_new_scales = num_new_scales self.ds_factor = ds_factor + self.r = r + self.alpha = alpha self.token_idx_per_scale, self.base_ctx_token_idx = self._get_token_idx_per_scale() @@ -143,17 +147,19 @@ def post_init(self): Initialize the new params added for Multi Scale. """ - self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.patch_size) + # ToDo: for time id & in_proj + # self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.patch_size) for layer in self.module.encoder.layers: # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): # Call post_init() method of the GroupedQueryAttention object - layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor) + layer.self_attn.init_multi_scale_modules(self.num_new_scales, self.r, self.alpha) - for module in self.module.encoder.modules(): - if isinstance(module, MultiScaleRotaryProjection): - module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx) + # ToDo: for time id + # for module in self.module.encoder.modules(): + # if isinstance(module, MultiScaleRotaryProjection): + # module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx) if self.lora_config is not None: self.module = LoraModel(self.module, self.lora_config, "default") @@ -194,8 +200,6 @@ def _get_token_idx_per_scale(self): return token_idx_per_scale, base_ctx_token_idx - - def forward( self, target: Float[torch.Tensor, "*batch seq_len max_patch"], @@ -331,55 +335,13 @@ def validation_step( def configure_optimizers(self) -> dict: - # if self.current_stage == 1: - # warmup_pn_group1 = ['param_proj', 'time_id_q_proj', 'time_id_k_proj'] - # warmup_pn_group2 = ['in_proj_new_scales'] - # - # optim_groups = [ - # { - # "params": [ - # p for pn, p in self.named_parameters() - # if any(keyword in pn for keyword in warmup_pn_group1) - # ], - # "lr": 5e-4, - # "weight_decay": self.hparams.weight_decay, - # }, - # { - # "params": [ - # p for pn, p in self.named_parameters() - # if any(keyword in pn for keyword in warmup_pn_group2) - # ], - # "lr": 5e-6, - # "weight_decay": self.hparams.weight_decay, - # }, - # ] - # - # optimizer = torch.optim.AdamW( - # optim_groups, - # betas=(self.hparams.beta1, self.hparams.beta2), - # eps=1e-6, - # ) - # - # warmup_params_all = { - # pn: p for pn, p in self.named_parameters() - # if any(keyword in pn for keyword in warmup_pn_group1 + warmup_pn_group2) - # } - # self.trainable_params = warmup_params_all - # - # scheduler = get_scheduler( - # SchedulerType.CONSTANT, # Use constant lr scheduler - # optimizer, - # num_warmup_steps=self.hparams.num_warmup_steps, - # num_training_steps=self.hparams.num_training_steps, - # ) - if self.current_stage == 1: - warmup_pn = ['param_proj', 'time_id_q_proj', 'time_id_k_proj'] + warmup_pn = ['param_proj'] # 'time_id_q_proj', 'time_id_k_proj' warmup_params = { pn: p for pn, p in self.named_parameters() if any(keyword in pn for keyword in warmup_pn) # 检查pn是否包含warmup_pn中的任意字段 } - self.trainable_params = warmup_params + self.updated_params = warmup_params optimizer = torch.optim.AdamW( warmup_params.values(), lr=5e-4, @@ -404,21 +366,9 @@ def configure_optimizers(self) -> dict: param.requires_grad = False for pn, p in self.named_parameters(): - if "film" in pn: - p.requires_grad = True - - if "adapt_weight" in pn: - p.requires_grad = True - - if "adapt_bias" in pn: - p.requires_grad = True - if "var_attn_bias.emb" in pn: p.requires_grad = True - if "pe_weights" in pn: # Learnable RoPE for time id proj - p.requires_grad = True - if "time_id_q_proj" in pn or "time_id_k_proj" in pn: p.requires_grad = True @@ -433,7 +383,7 @@ def configure_optimizers(self) -> dict: if "in_proj" in pn: p.requires_grad = True - if "norm" in self.finetune_pattern: # + if "rms_norm" in self.finetune_pattern: # for pn, p in self.named_parameters(): if "norm1" in pn or "norm2" in pn: p.requires_grad = True @@ -478,14 +428,6 @@ def configure_optimizers(self) -> dict: if "out_proj" in pn: p.requires_grad = True - if "studentT" in self.finetune_pattern: - for pn, p in self.named_parameters(): - if ( - "param_proj.proj.components.0" in pn - or "param_proj.proj.weights_logits" in pn - ): - p.requires_grad = True - whitelist_params = ( LearnedProjection, MultiInSizeLinear, @@ -512,10 +454,6 @@ def configure_optimizers(self) -> dict: decay.add(fpn) elif pn.endswith("weight") and isinstance(m, blacklist_params): no_decay.add(fpn) - elif "adapt_weight" in pn or "adapt_bias" in pn: - decay.add(fpn) - elif 'pe_weights' in pn: - decay.add(fpn) elif 'q_A' in pn or 'q_B' in pn: decay.add(fpn) elif 'k_A' in pn or 'k_B' in pn: @@ -528,7 +466,7 @@ def configure_optimizers(self) -> dict: # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} - self.trainable_params = param_dict + self.updated_params = {**self.updated_params, **param_dict} inter_params = decay & no_decay union_params = decay | no_decay @@ -853,7 +791,7 @@ def default_val_transform( def state_dict(self, *args, destination=None, prefix="", keep_vars=False): """ - Modify state_dict to only save trainable params. + Modify state_dict to only save updated params. Note the default state_dict saved by PL converts all params to require_grads=False """ state = super().state_dict( @@ -862,7 +800,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): filtered_state = { name: tensor for name, tensor in state.items() - if name in self.trainable_params + if name in self.updated_params } return filtered_state diff --git a/src/uni2ts/model/multi_scale_moirai/forecast.py b/src/uni2ts/model/multi_scale_moirai/forecast.py index 23c9a54..2208982 100644 --- a/src/uni2ts/model/multi_scale_moirai/forecast.py +++ b/src/uni2ts/model/multi_scale_moirai/forecast.py @@ -109,6 +109,8 @@ def __init__( pretrained_checkpoint_path: str = None, num_new_scales: int = 1, ds_factor: int = 2, + r: int = 16, + alpha: int = 16, use_lora: bool = False, lora_kwargs: Optional[dict[str, Any]] = None, ): @@ -122,6 +124,9 @@ def __init__( self.num_new_scales = num_new_scales self.ds_factor = ds_factor + self.r = r + self.alpha = alpha + self.strict_loading = False self.token_idx_per_scale, self.base_ctx_token_idx = self._get_token_idx_per_scale() @@ -132,7 +137,7 @@ def __init__( self.module = LoraModel(self.module, self.lora_config, "default") - self.post_init() # ToDO: Make it optional. + self.post_init() @@ -140,18 +145,19 @@ def post_init(self): """ Initialize the new params added for Multi Scale. """ - - self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.hparams.patch_size) + # ToDo: for time id & in_proj + # self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.hparams.patch_size) for layer in self.module.encoder.layers: # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): # Call post_init() method of the GroupedQueryAttention object - layer.self_attn.init_multi_scale_modules(self.hparams.context_length, self.hparams.patch_size, self.num_new_scales, self.ds_factor) + layer.self_attn.init_multi_scale_modules(self.num_new_scales, self.r, self.alpha) - for module in self.module.encoder.modules(): - if isinstance(module, MultiScaleRotaryProjection): - module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx) + # ToDo: for time id + # for module in self.module.encoder.modules(): + # if isinstance(module, MultiScaleRotaryProjection): + # module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx) pass diff --git a/src/uni2ts/model/multi_scale_moirai/module.py b/src/uni2ts/model/multi_scale_moirai/module.py index c8fbb90..213100b 100644 --- a/src/uni2ts/model/multi_scale_moirai/module.py +++ b/src/uni2ts/model/multi_scale_moirai/module.py @@ -124,15 +124,15 @@ def __init__( # num_vars=4 # ToDo: 这个num_vars得提供外部接口 ), time_qk_proj_layer=partial( - QueryKeyProjection, - proj_layer=MultiScaleRotaryProjection, - kwargs=dict(max_len=max_seq_len), - partial_factor=(0.0, 0.5), # 之前的partial factor是0-0.5 - # QueryKeyProjection, - # proj_layer=RotaryProjection, # ToDo: 可以改 + # proj_layer=MultiScaleRotaryProjection, # kwargs=dict(max_len=max_seq_len), # partial_factor=(0.0, 0.5), # 之前的partial factor是0-0.5 + + QueryKeyProjection, + proj_layer=RotaryProjection, # ToDo: 可以改 + kwargs=dict(max_len=max_seq_len), + partial_factor=(0.0, 0.5), # 之前的partial factor是0-0.5 ), shared_var_attn_bias=False, shared_time_qk_proj=True, # True by default @@ -177,32 +177,32 @@ def forward( :return: predictive distribution """ - # Map time id for new scales. - # Key: base scale context tokens; Value: base scale time id - time_id = time_id.to(torch.float) - idx_kv = self.base_ctx_token_idx - key = target[..., idx_kv, :self.ps].clone() # (bs, len0, dim) - value = time_id[..., idx_kv].clone().unsqueeze(-1).to(dtype=torch.float) # (bs, len0, 1) - - for i in range(1, self.num_scales): - idx_scale_i = self.token_idx_per_scale[i] - - query = target[..., idx_scale_i, :self.ps].clone() # (bs, leni, dim) - query = self.time_id_q_proj[i - 1](query) - key = self.time_id_k_proj[i - 1](key) # (bs, len0, dim) - - # Generate attn_mask. Make sure each query only attend to the keys in its down-sampling range. - attn_mask = self.generate_segmented_attn_mask(query, key, 2**i) - - # mapped_time_id is float time id on the original scale. (bs, len_i, 1) - mapped_time_id = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attn_mask, - ) - - time_id[..., idx_scale_i] = mapped_time_id.squeeze() + # ToDo: Map time id for new scales. + # # Key: base scale context tokens; Value: base scale time id + # time_id = time_id.to(torch.float) + # idx_kv = self.base_ctx_token_idx + # key = target[..., idx_kv, :self.ps].clone() # (bs, len0, dim) + # value = time_id[..., idx_kv].clone().unsqueeze(-1).to(dtype=torch.float) # (bs, len0, 1) + # + # for i in range(1, self.num_scales): + # idx_scale_i = self.token_idx_per_scale[i] + # + # query = target[..., idx_scale_i, :self.ps].clone() # (bs, leni, dim) + # query = self.time_id_q_proj[i - 1](query) + # key = self.time_id_k_proj[i - 1](key) # (bs, len0, dim) + # + # # Generate attn_mask. Make sure each query only attend to the keys in its down-sampling range. + # attn_mask = self.generate_segmented_attn_mask(query, key, 2**i) + # + # # mapped_time_id is float time id on the original scale. (bs, len_i, 1) + # mapped_time_id = F.scaled_dot_product_attention( + # query, + # key, + # value, + # attn_mask=attn_mask, + # ) + # + # time_id[..., idx_scale_i] = mapped_time_id.squeeze() loc, scale = self.scaler( target, @@ -212,23 +212,22 @@ def forward( ) scaled_target = (target - loc) / scale # ToDo: If use conv for DS, consider to modify here? - - - # reprs = self.in_proj(scaled_target, patch_size) - - reprs_all_scales = [] - for i in range(0, self.num_scales): - idx_scale_i = self.token_idx_per_scale[i] - - if i == 0: - reprs_base = self.in_proj(scaled_target[..., idx_scale_i, :], patch_size[..., idx_scale_i]) - reprs_all_scales.append(reprs_base) - - else: - reprs_new_scale = self.in_proj_new_scales[i - 1](scaled_target[..., idx_scale_i, :], patch_size[..., idx_scale_i]) - reprs_all_scales.append(reprs_new_scale) - - reprs = torch.cat(reprs_all_scales, dim=-2) + reprs = self.in_proj(scaled_target, patch_size) + + # ToDo: Add a specific in_proj for each scale + # reprs_all_scales = [] + # for i in range(0, self.num_scales): + # idx_scale_i = self.token_idx_per_scale[i] + # + # if i == 0: + # reprs_base = self.in_proj(scaled_target[..., idx_scale_i, :], patch_size[..., idx_scale_i]) + # reprs_all_scales.append(reprs_base) + # + # else: + # reprs_new_scale = self.in_proj_new_scales[i - 1](scaled_target[..., idx_scale_i, :], patch_size[..., idx_scale_i]) + # reprs_all_scales.append(reprs_new_scale) + # + # reprs = torch.cat(reprs_all_scales, dim=-2) masked_reprs = mask_fill(reprs, prediction_mask, self.mask_encoding.weight) diff --git a/src/uni2ts/model/multi_scale_moirai/module_conv.py b/src/uni2ts/model/multi_scale_moirai/module_conv.py deleted file mode 100644 index c8fbb90..0000000 --- a/src/uni2ts/model/multi_scale_moirai/module_conv.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright (c) 2024, Salesforce, Inc. -# SPDX-License-Identifier: Apache-2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial - -import torch -import torch.nn.functional as F -from einops import reduce -from huggingface_hub import PyTorchModelHubMixin -from hydra.utils import instantiate -from jaxtyping import Bool, Float, Int -from torch import nn -from torch.distributions import Distribution -from torch.utils._pytree import tree_map - -from uni2ts.common.torch_util import mask_fill, packed_attention_mask -from uni2ts.distribution import DistributionOutput -from uni2ts.module.multi_scale.transformer import TransformerEncoder -from uni2ts.module.norm import RMSNorm -from uni2ts.module.packed_scaler import PackedNOPScaler, PackedStdScaler -from uni2ts.module.position import ( - BinaryAttentionBias, - CrossVariateAttentionBias, - QueryKeyProjection, - RotaryProjection, - MultiScaleRotaryProjection -) -from uni2ts.module.ts_embed import MultiInSizeLinear - -import copy - -def encode_distr_output( - distr_output: DistributionOutput, -) -> dict[str, str | float | int]: - """Serialization function for DistributionOutput""" - - def _encode(val): - if not isinstance(val, DistributionOutput): - return val - - return { - "_target_": f"{val.__class__.__module__}.{val.__class__.__name__}", - **tree_map(_encode, val.__dict__), - } - - return _encode(distr_output) - - -def decode_distr_output(config: dict[str, str | float | int]) -> DistributionOutput: - """Deserialization function for DistributionOutput""" - return instantiate(config, _convert_="all") - - -class MoiraiModule( - nn.Module, - PyTorchModelHubMixin, - coders={DistributionOutput: (encode_distr_output, decode_distr_output)}, -): - """ - Contains components of Moirai, to ensure implementation is identical across models. - Subclasses huggingface_hub.PyTorchModelHubMixin to support loading from HuggingFace Hub. - """ - - def __init__( - self, - distr_output: DistributionOutput, - d_model: int, - num_layers: int, - patch_sizes: tuple[int, ...], # tuple[int, ...] | list[int] - max_seq_len: int, - attn_dropout_p: float, - dropout_p: float, - scaling: bool = True, - ): - """ - :param distr_output: distribution output object - :param d_model: model hidden dimensions - :param num_layers: number of transformer layers - :param patch_sizes: sequence of patch sizes - :param max_seq_len: maximum sequence length for inputs - :param attn_dropout_p: dropout probability for attention layers - :param dropout_p: dropout probability for all other layers - :param scaling: whether to apply scaling (standardization) - """ - super().__init__() - self.d_model = d_model - self.num_layers = num_layers - self.patch_sizes = patch_sizes - self.max_seq_len = max_seq_len - self.scaling = scaling - - self.mask_encoding = nn.Embedding(num_embeddings=1, embedding_dim=d_model) - self.scaler = PackedStdScaler() if scaling else PackedNOPScaler() - self.in_proj = MultiInSizeLinear( - in_features_ls=patch_sizes, - out_features=d_model, - ) - self.encoder = TransformerEncoder( - d_model, - num_layers, - num_heads=None, - pre_norm=True, - attn_dropout_p=attn_dropout_p, - dropout_p=dropout_p, - norm_layer=RMSNorm, - activation=F.silu, - use_glu=True, - use_qk_norm=True, - var_attn_bias_layer=partial( - BinaryAttentionBias - # CrossVariateAttentionBias, - # num_vars=4 # ToDo: 这个num_vars得提供外部接口 - ), - time_qk_proj_layer=partial( - QueryKeyProjection, - proj_layer=MultiScaleRotaryProjection, - kwargs=dict(max_len=max_seq_len), - partial_factor=(0.0, 0.5), # 之前的partial factor是0-0.5 - - # QueryKeyProjection, - # proj_layer=RotaryProjection, # ToDo: 可以改 - # kwargs=dict(max_len=max_seq_len), - # partial_factor=(0.0, 0.5), # 之前的partial factor是0-0.5 - ), - shared_var_attn_bias=False, - shared_time_qk_proj=True, # True by default - d_ff=None, - ) - self.distr_output = distr_output - self.param_proj = self.distr_output.get_param_proj(d_model, patch_sizes) - - self.time_id_q_proj = nn.ParameterList() - self.time_id_k_proj = nn.ParameterList() - - self.in_proj_new_scales = nn.ParameterList() - - def forward( - self, - target: Float[torch.Tensor, "*batch seq_len max_patch"], - observed_mask: Bool[torch.Tensor, "*batch seq_len max_patch"], - sample_id: Int[torch.Tensor, "*batch seq_len"], - time_id: Int[torch.Tensor, "*batch seq_len"], - variate_id: Int[torch.Tensor, "*batch seq_len"], - prediction_mask: Bool[torch.Tensor, "*batch seq_len"], - patch_size: Int[torch.Tensor, "*batch seq_len"], - ) -> Distribution: - """ - Defines the forward pass of MoiraiModule. - This method expects processed inputs. - - 1. Apply scaling to observations - 2. Project from observations to representations - 3. Replace prediction window with learnable mask - 4. Apply transformer layers - 5. Project from representations to distribution parameters - 6. Return distribution object - - :param target: input data - :param observed_mask: binary mask for missing values, 1 if observed, 0 otherwise - :param sample_id: indices indicating the sample index (for packing) - :param time_id: indices indicating the time index - :param variate_id: indices indicating the variate index - :param prediction_mask: binary mask for prediction horizon, 1 if part of the horizon, 0 otherwise - :param patch_size: patch size for each token - :return: predictive distribution - """ - - # Map time id for new scales. - # Key: base scale context tokens; Value: base scale time id - time_id = time_id.to(torch.float) - idx_kv = self.base_ctx_token_idx - key = target[..., idx_kv, :self.ps].clone() # (bs, len0, dim) - value = time_id[..., idx_kv].clone().unsqueeze(-1).to(dtype=torch.float) # (bs, len0, 1) - - for i in range(1, self.num_scales): - idx_scale_i = self.token_idx_per_scale[i] - - query = target[..., idx_scale_i, :self.ps].clone() # (bs, leni, dim) - query = self.time_id_q_proj[i - 1](query) - key = self.time_id_k_proj[i - 1](key) # (bs, len0, dim) - - # Generate attn_mask. Make sure each query only attend to the keys in its down-sampling range. - attn_mask = self.generate_segmented_attn_mask(query, key, 2**i) - - # mapped_time_id is float time id on the original scale. (bs, len_i, 1) - mapped_time_id = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attn_mask, - ) - - time_id[..., idx_scale_i] = mapped_time_id.squeeze() - - loc, scale = self.scaler( - target, - observed_mask * ~prediction_mask.unsqueeze(-1), - sample_id, - variate_id, - ) - scaled_target = (target - loc) / scale # ToDo: If use conv for DS, consider to modify here? - - - - # reprs = self.in_proj(scaled_target, patch_size) - - reprs_all_scales = [] - for i in range(0, self.num_scales): - idx_scale_i = self.token_idx_per_scale[i] - - if i == 0: - reprs_base = self.in_proj(scaled_target[..., idx_scale_i, :], patch_size[..., idx_scale_i]) - reprs_all_scales.append(reprs_base) - - else: - reprs_new_scale = self.in_proj_new_scales[i - 1](scaled_target[..., idx_scale_i, :], patch_size[..., idx_scale_i]) - reprs_all_scales.append(reprs_new_scale) - - reprs = torch.cat(reprs_all_scales, dim=-2) - - - masked_reprs = mask_fill(reprs, prediction_mask, self.mask_encoding.weight) - - reprs = self.encoder( - masked_reprs, - packed_attention_mask(sample_id), - time_id=time_id, - var_id=variate_id, - ) # (bs, seq_len, max_patch) - distr_param = self.param_proj(reprs, patch_size) - distr = self.distr_output.distribution(distr_param, loc=loc, scale=scale) - return distr - - def post_init(self, token_idx_per_scale, base_ctx_token_idx, patch_size): - self.token_idx_per_scale = token_idx_per_scale - self.base_ctx_token_idx = base_ctx_token_idx - - self.num_scales = len(token_idx_per_scale) - - self.ps = patch_size - - # Assign Q and K for each new scale - for scale in range(1, self.num_scales): - self.time_id_q_proj.append(nn.Linear(self.ps, self.ps)) - self.time_id_k_proj.append(nn.Linear(self.ps, self.ps)) - - self.in_proj_new_scales.append(copy.deepcopy(self.in_proj)) - - def generate_segmented_attn_mask(self, query, key, k): - """ - 生成一个 attention mask,使得 query 的位置 i 只能注意到 key 的范围 [k*i, k*(i+1)-1]。 - - 参数: - bs: batch size - len_q: query 的序列长度 - len_k: key 的序列长度 - k: 每个 query 索引范围内 key 的跨度 - - 返回: - attn_mask: BoolTensor,shape = (bs, len_q, len_k) - """ - bs, len_q, len_k = query.shape[0], query.shape[1], key.shape[1] - # 创建基础的 mask - attn_mask = torch.zeros(len_q, len_k, dtype=torch.bool) - - for i in range(len_q): - # 定义 query 的位置 i 对应的 key 范围 - start = i * k - end = min((i + 1) * k, len_k) # 防止超出 len_k - attn_mask[i, start:end] = True # 允许注意的范围 - - # 扩展到 batch 维度 - attn_mask = attn_mask.unsqueeze(0).expand(bs, -1, -1) - - return attn_mask.to(query.device) diff --git a/src/uni2ts/model/seasonal_naive_moirai/finetune.py b/src/uni2ts/model/seasonal_naive_moirai/finetune.py index 542f7f9..186c0fa 100644 --- a/src/uni2ts/model/seasonal_naive_moirai/finetune.py +++ b/src/uni2ts/model/seasonal_naive_moirai/finetune.py @@ -293,7 +293,7 @@ def configure_optimizers(self) -> dict: if "in_proj" in pn: p.requires_grad = True - if "norm" in self.finetune_pattern: + if "rms_norm" in self.finetune_pattern: for pn, p in self.named_parameters(): if "norm1" in pn or "norm2" in pn: p.requires_grad = True @@ -375,7 +375,7 @@ def configure_optimizers(self) -> dict: # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} - self.trainable_params = param_dict + self.updated_params = param_dict inter_params = decay & no_decay union_params = decay | no_decay @@ -720,6 +720,6 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): filtered_state = { name: tensor for name, tensor in state.items() - if name in self.trainable_params + if name in self.updated_params } return filtered_state diff --git a/src/uni2ts/module/multi_scale/attention.py b/src/uni2ts/module/multi_scale/attention.py index 1f07797..e817e2f 100644 --- a/src/uni2ts/module/multi_scale/attention.py +++ b/src/uni2ts/module/multi_scale/attention.py @@ -101,119 +101,36 @@ def __init__( self.num_new_scales = None - def init_multi_scale_modules(self, context_length, patch_size, num_new_scales, ds_factor, ): + def init_multi_scale_modules(self, num_new_scales, r, alpha): self.num_new_scales = num_new_scales - rank = 16 + self.r = r + self.alpha = alpha + + print(" ******** r={}, alpha={} *********".format(r, alpha)) # Initialize parameter lists self.q_A = nn.ParameterList() self.q_B = nn.ParameterList() - self.q_bias = nn.ParameterList() - self.k_A = nn.ParameterList() self.k_B = nn.ParameterList() - self.k_bias = nn.ParameterList() - self.v_A = nn.ParameterList() self.v_B = nn.ParameterList() - self.v_bias = nn.ParameterList() - # 包括origin scale,也用lora;冻结住q k v + # freeze q k v self.q_proj.requires_grad_(False) self.k_proj.requires_grad_(False) self.v_proj.requires_grad_(False) for _ in range(1+num_new_scales): # Append the new parameters for the current scale - self.q_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) - self.k_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) - self.v_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) - - self.q_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) - self.k_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) - self.v_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) - - - # for _ in range(num_new_scales): - # # Append the new parameters for the current scale - # self.q_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) - # self.k_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) - # self.v_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) - # - # self.q_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) - # self.k_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) - # self.v_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) - - - - - - - # base_len = math.ceil(context_length / patch_size) # num context patches in base scale - # scale_len = math.ceil(base_len / ds_factor) - - # # Initialize parameter lists - # self.query_adapt_weight = nn.ParameterList() - # self.key_adapt_weight = nn.ParameterList() - # self.value_adapt_weight = nn.ParameterList() - # self.query_adapt_bias = nn.ParameterList() - # self.key_adapt_bias = nn.ParameterList() - # self.value_adapt_bias = nn.ParameterList() - # - # for _ in range(num_new_scales): - # # Append the new parameters for the current scale - # self.query_adapt_weight.append( - # nn.Parameter(torch.ones((scale_len, self.dim), dtype=torch.float), requires_grad=True)) - # self.key_adapt_weight.append( - # nn.Parameter(torch.ones((scale_len, self.dim), dtype=torch.float), requires_grad=True)) - # self.value_adapt_weight.append( - # nn.Parameter(torch.ones((scale_len, self.dim), dtype=torch.float), requires_grad=True)) - # - # self.query_adapt_bias.append( - # nn.Parameter(torch.zeros((scale_len, self.dim), dtype=torch.float), requires_grad=True)) - # self.key_adapt_bias.append( - # nn.Parameter(torch.zeros((scale_len, self.dim), dtype=torch.float), requires_grad=True)) - # self.value_adapt_bias.append( - # nn.Parameter(torch.zeros((scale_len, self.dim), dtype=torch.float), requires_grad=True)) - # - # # Update scale length for the next iteration - # scale_len = math.ceil(scale_len / ds_factor) - - - # def init_multi_scale_modules(self, context_length, patch_size, num_new_scales, ds_factor): - # - # self.num_new_scales = num_new_scales - # - # nh = self.dim//4 - # self.film_controller = nn.Sequential(nn.Linear(self.dim, nh), nn.SiLU()) - # - # self.query_film_generator = nn.ModuleList([ - # nn.Linear(in_features=nh, out_features=self.dim) for _ in range(num_new_scales) - # ]) - # - # self.key_film_generator = nn.ModuleList([ - # nn.Linear(in_features=nh, out_features=self.dim) for _ in range(num_new_scales) - # ]) - - # def init_multi_scale_modules(self, context_length, patch_size, num_new_scales, ds_factor): - # - # self.num_new_scales = num_new_scales - # - # base_len = math.ceil(context_length / patch_size) # num context patches in base scale - # scale_len = math.ceil(base_len / ds_factor) - # - # self.query_film_generator = nn.ModuleList() - # self.key_film_generator = nn.ModuleList() - # - # for _ in range(num_new_scales): - # self.query_film_generator.append( - # nn.Linear(in_features=self.dim, out_features=2 * scale_len) - # ) - # self.key_film_generator.append( - # nn.Linear(in_features=self.dim, out_features=2 * scale_len) - # ) - # scale_len = math.ceil(scale_len / ds_factor) + self.q_A.append(nn.Parameter(torch.randn((r, self.dim), dtype=torch.float) * 0.01)) + self.k_A.append(nn.Parameter(torch.randn((r, self.dim), dtype=torch.float) * 0.01)) + self.v_A.append(nn.Parameter(torch.randn((r, self.dim), dtype=torch.float) * 0.01)) + + self.q_B.append(nn.Parameter(torch.zeros((self.dim, r), dtype=torch.float))) + self.k_B.append(nn.Parameter(torch.zeros((self.dim, r), dtype=torch.float))) + self.v_B.append(nn.Parameter(torch.zeros((self.dim, r), dtype=torch.float))) def _get_var_id( self, @@ -372,21 +289,13 @@ def apply_lora(self, layer: nn.Linear, A: nn.Parameter, B: nn.Parameter, - alpha: float = 1.0, ): """ 在给定的线性层上应用 LoRA。 """ - # 获取线性层的权重和偏置 - W_no_grad = layer.weight.detach() # 冻结权重 - - # LoRA 更新部分 - lora_update = alpha * (B @ A) # (in_features, out_features) - - # 合成 LoRA 后的权重 - W_lora = W_no_grad + lora_update # 最终的权重 (in_features, out_features) - - # 计算输出 + W_no_grad = layer.weight.detach() + lora_update = (self.alpha / self.r) * (B @ A) # (in_features, out_features) + W_lora = W_no_grad + lora_update # (in_features, out_features) out = torch.matmul(input, W_lora.T) return out @@ -407,7 +316,7 @@ def forward( # key = self.k_proj(key) # value = self.v_proj(value) - + # ToDO: Apply lora for each scale updated_query = query.clone() updated_key = key.clone() updated_value = value.clone() @@ -430,38 +339,6 @@ def forward( key = updated_key value = updated_value - - # # ToDo: 这个可以 v1 - # updated_query = query.clone() - # updated_key = key.clone() - # updated_value = value.clone() - # - # if self.num_new_scales is not None: - # index_by_variate = self.get_token_index_by_variate(query_var_id) - # assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id" - # - # for scale in range(1 + self.num_new_scales): - # index = index_by_variate[scale] - # query_scale = query[..., index, :] - # key_scale = key[..., index, :] - # value_scale = value[..., index, :] - # - # if scale == 0: - # updated_query[..., index, :] = self.q_proj(query_scale) - # updated_key[..., index, :] = self.k_proj(key_scale) - # updated_value[..., index, :] = self.v_proj(value_scale) - # - # else: - # i = scale-1 - # updated_query[..., index, :] = self.apply_lora(query_scale, self.q_proj, self.q_A[i], self.q_B[i]) - # updated_key[..., index, :] = self.apply_lora(key_scale, self.k_proj, self.k_A[i], self.k_B[i]) - # updated_value[..., index, :] = self.apply_lora(value_scale, self.v_proj, self.v_A[i], self.v_B[i]) - # - # query = updated_query - # key = updated_key - # value = updated_value - - query = self.q_norm( rearrange( query,