Skip to content

Commit

Permalink
Add r and alpha for attn_lora
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Dec 10, 2024
1 parent 07824a8 commit 1bc73f2
Show file tree
Hide file tree
Showing 26 changed files with 179 additions and 1,468 deletions.
5 changes: 1 addition & 4 deletions cli/conf/lsf-setup/multi_scale/finetune/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,10 @@ trainer:
mode: min
save_top_k: 1 # Qz: Sometimes the 1st validation gets anomalous results. Discard that ckpt, and use the 2nd one.
every_n_epochs: 1
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
save_weights_only: true
- _target_: lightning.pytorch.callbacks.EarlyStopping # uni2ts.callbacks.earlystop.WarmupEarlyStopping
monitor: val/PackedNLLLoss
min_delta: 0.0
patience: 3 # Set to a small value as now each epoch has many batches.
patience: 3
mode: min
strict: false
verbose: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ prediction_length: null
finetune_pattern: full
num_new_scales: 3
ds_factor: 2
r: 16
alpha: 16

use_lora: True
lora_kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ prediction_length: null
finetune_pattern: full
num_new_scales: 3
ds_factor: 2
r: 16
alpha: 16

use_lora: False
lora_kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ prediction_length: null
finetune_pattern: full
num_new_scales: 3
ds_factor: 2
r: 16
alpha: 16

use_lora: False
lora_kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,10 @@ trainer:
mode: min
save_top_k: 1 # Qz: Sometimes the 1st validation gets anomalous results. Discard that ckpt, and use the 2nd one.
every_n_epochs: 1
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
save_weights_only: true
- _target_: lightning.pytorch.callbacks.EarlyStopping # uni2ts.callbacks.earlystop.WarmupEarlyStopping
- _target_: lightning.pytorch.callbacks.EarlyStopping # uni2ts.callbacks.earlystop.WarmupEarlyStopping
monitor: val/PackedNLLLoss
min_delta: 0.0
patience: 3 # Set to a small value as now each epoch has many batches.
patience: 3
mode: min
strict: false
verbose: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ prediction_length: null
finetune_pattern: full
num_new_scales: 3
ds_factor: 2
r: 16
alpha: 16

use_lora: True
lora_kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ prediction_length: null
finetune_pattern: full
num_new_scales: 3
ds_factor: 2
r: 16
alpha: 16

use_lora: False
lora_kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ prediction_length: null
finetune_pattern: full
num_new_scales: 3
ds_factor: 2
r: 16
alpha: 16

use_lora: False
lora_kwargs:
Expand Down
23 changes: 20 additions & 3 deletions cli/train_two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@
from uni2ts.common import hydra_util # noqa: hydra resolvers
from uni2ts.data.loader import DataLoader

import os
import glob


def get_best_checkpoint_path(checkpoint_dir: str):
# list all .ckpt files
ckpt_files = glob.glob(os.path.join(checkpoint_dir, "*.ckpt"))

if len(ckpt_files) == 1:
return ckpt_files[0] # Return the path of the only .ckpt file
elif len(ckpt_files) == 0:
raise FileNotFoundError(f"No .ckpt file found in {checkpoint_dir}")
else:
raise ValueError(f"Multiple .ckpt files found in {checkpoint_dir}. Expected only one.")


class DataModule(L.LightningDataModule):
def __init__(
Expand Down Expand Up @@ -139,13 +154,12 @@ def main(cfg: DictConfig):

# ToDo: 写training_warmup的config
trainer_warmup: L.Trainer = instantiate(cfg.trainer_warmup)

trainer_warmup.callbacks[-1].CHECKPOINT_EQUALS_CHAR = "_"

trainer: L.Trainer = instantiate(cfg.trainer)

# '=' in ckpt name make it cannot be directly loaded with hydra. Change it to '_'.
trainer.callbacks[-1].CHECKPOINT_EQUALS_CHAR = "_"
trainer.callbacks[-2].CHECKPOINT_EQUALS_CHAR = "_"

train_dataset: Dataset = instantiate(cfg.data).load_dataset(
model.train_transform_map
Expand Down Expand Up @@ -190,7 +204,10 @@ def main(cfg: DictConfig):
ckpt_path=cfg.ckpt_path,
)

print("Finished warmup stage. Now finetuning the whole model...")
# Load the saved ckpt of the best model in stage 1
print("Finished warmup stage. Now loading the saved model and finetuning the whole model...")
checkpoint = torch.load(get_best_checkpoint_path(trainer_warmup.callbacks[-1].dirpath))
model.load_state_dict(checkpoint["state_dict"])
model.current_stage = 2

trainer.fit(
Expand Down
10 changes: 5 additions & 5 deletions project/lsf-setup/multi_scale/eval/small/ettm1.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=2

mode=S
cp=conf/lsf-setup/multi_scale/eval
exp_name=lsf
cl=4000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl96/checkpoints/epoch_5-step_2502.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl192/checkpoints/epoch_1-step_832.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt'
cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm1/S/cl4000_pl96/checkpoints_warmup/epoch_1-step_834-v2.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm1/S/cl4000_pl192/checkpoints_warmup/epoch_3-step_1664.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm1/S/cl4000_pl336/checkpoints_warmup/epoch_2-step_1242.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm1/S/cl4000_pl720/checkpoints_warmup/epoch_2-step_1224.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
8 changes: 4 additions & 4 deletions project/lsf-setup/multi_scale/eval/small/ettm2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ exp_name=lsf
cl=3000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl96/checkpoints/epoch_16-step_7327.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl192/checkpoints/epoch_3-step_1716.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl336/checkpoints/epoch_1-step_854.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'
cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm2/S/cl3000_pl96/checkpoints_warmup/epoch_1-step_862-v1.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm2/S/cl3000_pl192/checkpoints_warmup/epoch_0-step_429-v1.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm2/S/cl3000_pl336/checkpoints_warmup/epoch_0-step_427.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/ettm2/S/cl3000_pl720/checkpoints_warmup/epoch_26-step_11394.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
8 changes: 4 additions & 4 deletions project/lsf-setup/multi_scale/eval/small/weather.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ exp_name=lsf
cl=2000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl96/checkpoints/epoch_14-step_21420.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl192/checkpoints/epoch_9-step_14240.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl336/checkpoints/epoch_5-step_8508.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl720/checkpoints/epoch_1-step_2804.ckpt'
cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/weather/S/cl2000_pl96/checkpoints_warmup/epoch_2-step_4284.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/weather/S/cl2000_pl192/checkpoints_warmup/epoch_1-step_2848.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/weather/S/cl2000_pl336/checkpoints_warmup/epoch_1-step_2836.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/direct_1full_2head/full/weather/S/cl2000_pl720/checkpoints_warmup/epoch_1-step_2804.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1;
export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune_two_stage
exp_name=1tid_2inproj_all_scale_lora_freezeqkv
exp_name=direct_1full_2head
data=ettm1
cl=4000
ps=128
Expand Down Expand Up @@ -32,4 +32,6 @@ for pl in 96 192 336 720; do
val_data.context_length=$cl \
val_data.prediction_length=$pl \
val_data.mode=${mode}
# trainer_warmup.callbacks."1".monitor=val/PackedMSELoss \
# trainer_warmup.callbacks."2".monitor=val/PackedMSELoss
done
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune_two_stage
exp_name=1tid_2inproj_all_scale_lora_freezeqkv
exp_name=direct_1full_2head
data=ettm2
cl=3000
ps=64
Expand Down Expand Up @@ -32,4 +32,6 @@ for pl in 96 192 336 720; do
val_data.context_length=$cl \
val_data.prediction_length=$pl \
val_data.mode=${mode}
# trainer_warmup.callbacks."1".monitor=val/PackedMSELoss \
# trainer_warmup.callbacks."2".monitor=val/PackedMSELoss
done
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune_two_stage
exp_name=1tid_2inproj_all_scale_lora_freezeqkv
exp_name=direct_1full_2head
data=weather
cl=2000
ps=128
Expand Down Expand Up @@ -33,7 +33,7 @@ for pl in 96 192 336 720; do
val_data.prediction_length=$pl \
val_data.mode=${mode} \
trainer.callbacks."1".monitor=val/PackedMSELoss \
trainer.callbacks."3".monitor=val/PackedMSELoss \
trainer.callbacks."2".monitor=val/PackedMSELoss \
trainer_warmup.callbacks."1".monitor=val/PackedMSELoss \
trainer_warmup.callbacks."2".monitor=val/PackedMSELoss
done
4 changes: 2 additions & 2 deletions src/uni2ts/model/lsf_moirai/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def configure_optimizers(self) -> dict:

# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
self.trainable_params = param_dict
self.updated_params = param_dict

inter_params = decay & no_decay
union_params = decay | no_decay
Expand Down Expand Up @@ -736,7 +736,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
filtered_state = {
name: tensor
for name, tensor in state.items()
if name in self.trainable_params
if name in self.updated_params
}
return filtered_state

Expand Down
4 changes: 2 additions & 2 deletions src/uni2ts/model/lsf_moirai_point/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def configure_optimizers(self) -> dict:

# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
self.trainable_params = param_dict
self.updated_params = param_dict

inter_params = decay & no_decay
union_params = decay | no_decay
Expand Down Expand Up @@ -665,6 +665,6 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
filtered_state = {
name: tensor
for name, tensor in state.items()
if name in self.trainable_params
if name in self.updated_params
}
return filtered_state
4 changes: 2 additions & 2 deletions src/uni2ts/model/moirai/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def configure_optimizers(self) -> dict:

# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
self.trainable_params = param_dict
self.updated_params = param_dict

inter_params = decay & no_decay
union_params = decay | no_decay
Expand Down Expand Up @@ -702,7 +702,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
filtered_state = {
name: tensor
for name, tensor in state.items()
if name in self.trainable_params
if name in self.updated_params
}
return filtered_state

Expand Down
Loading

0 comments on commit 1bc73f2

Please sign in to comment.