Skip to content

Commit

Permalink
Add Lora and learnable RoPE
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Nov 24, 2024
1 parent 3998641 commit 9a5b10d
Show file tree
Hide file tree
Showing 25 changed files with 404 additions and 388 deletions.
8 changes: 8 additions & 0 deletions cli/conf/lsf-setup/lsf/finetune/model/moirai_1.0_R_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,11 @@ patch_size: null
context_length: null
prediction_length: null
finetune_pattern: full

use_lora: True
lora_kwargs:
_target_: builtins.dict
r: 16
target_modules: ["q_proj", "k_proj", "v_proj"]
lora_alpha: 32
lora_dropout: 0.05
10 changes: 9 additions & 1 deletion cli/conf/lsf-setup/lsf/finetune/model/moirai_1.0_R_small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,12 @@ num_warmup_steps: 0
patch_size: null
context_length: null
prediction_length: null
finetune_pattern: full
finetune_pattern: full

use_lora: False
lora_kwargs:
_target_: builtins.dict
r: 16
target_modules: ["q_proj", "k_proj", "v_proj"]
lora_alpha: 32
lora_dropout: 0.05
10 changes: 9 additions & 1 deletion cli/conf/lsf-setup/lsf/finetune/model/moirai_1.1_R_small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,12 @@ num_warmup_steps: 0
patch_size: null
context_length: null
prediction_length: null
finetune_pattern: full
finetune_pattern: full

use_lora: False
lora_kwargs:
_target_: builtins.dict
r: 16
target_modules: ["q_proj", "k_proj", "v_proj"]
lora_alpha: 32
lora_dropout: 0.05
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,11 @@ prediction_length: null
finetune_pattern: full
num_new_scales: 3
ds_factor: 2

use_lora: True
lora_kwargs:
_target_: builtins.dict
r: 16
target_modules: ["q_proj", "k_proj", "v_proj"]
lora_alpha: 32
lora_dropout: 0.05
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,12 @@ context_length: null
prediction_length: null
finetune_pattern: full
num_new_scales: 3
ds_factor: 2
ds_factor: 2

use_lora: False
lora_kwargs:
_target_: builtins.dict
r: 16
target_modules: ["q_proj", "k_proj", "v_proj"]
lora_alpha: 32
lora_dropout: 0.05
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,12 @@ context_length: null
prediction_length: null
finetune_pattern: full
num_new_scales: 3
ds_factor: 2
ds_factor: 2

use_lora: False
lora_kwargs:
_target_: builtins.dict
r: 16
target_modules: ["q_proj", "k_proj", "v_proj"]
lora_alpha: 32
lora_dropout: 0.05
14 changes: 13 additions & 1 deletion cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,26 @@ def main(cfg: DictConfig):
# QZ: If eval the finetuned model, need to load moirai's frozen params manually.
if "pretrained_checkpoint_path" in cfg.model:
checkpoint = torch.load(cfg.model.checkpoint_path)
hyper_params = checkpoint['hyper_parameters']
lora_target_modules = hyper_params['lora_kwargs']['target_modules']

tuned_state_dict = checkpoint["state_dict"]
pretrained_moirai_state_dict = torch.load(
cfg.model.pretrained_checkpoint_path, weights_only=True
)

new_state_dict = {}
for name, tensor in pretrained_moirai_state_dict.items():
new_name = "module." + name
# If using Lora, need to rename the pretrained weights before loading.
if hyper_params['use_lora']:
new_name = 'module.model.' + name
# In LoraModel, Lora's target_modules will be added a suffix '.base_layer'.
for module in lora_target_modules:
if module in new_name:
new_name = new_name.replace(module, module + '.base_layer')
break
else:
new_name = "module." + name
new_state_dict[new_name] = tensor
pretrained_moirai_state_dict = new_state_dict

Expand Down
2 changes: 1 addition & 1 deletion cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def main(cfg: DictConfig):
)

# Validate before training, check the performance of original pretrained model.
trainer.validate(model, datamodule=DataModule(cfg, train_dataset, val_dataset))
# trainer.validate(model, datamodule=DataModule(cfg, train_dataset, val_dataset))

trainer.fit(
model,
Expand Down
8 changes: 4 additions & 4 deletions project/lsf-setup/lsf/eval/small/electricity.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ ps=64
mode=S


cpp1=''
cpp2=''
cpp3=''
cpp4=''
cpp1='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/electricity/S/cl5000_pl96/checkpoints/epoch_13-step_58450.ckpt'
cpp2='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/electricity/S/cl5000_pl192/checkpoints/epoch_7-step_33160.ckpt'
cpp3='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/electricity/S/cl5000_pl336/checkpoints/epoch_6-step_28700.ckpt'
cpp4='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/electricity/S/cl5000_pl720/checkpoints/epoch_2-step_11937.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
10 changes: 5 additions & 5 deletions project/lsf-setup/lsf/eval/small/ettm2.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=1

cp=conf/lsf-setup/lsf/eval
exp_name=lsf
Expand All @@ -12,10 +12,10 @@ ps=64
mode=S


cpp1='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/ettm2/cl3000_pl96/checkpoints/epoch_12-step_2808.ckpt'
cpp2='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/ettm2/cl3000_pl192/checkpoints/epoch_4-step_1075.ckpt'
cpp3='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/ettm2/cl3000_pl336/checkpoints/epoch_2-step_642.ckpt'
cpp4='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/ettm2/cl3000_pl720/checkpoints/epoch_1-step_422.ckpt'
cpp1='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/ettm2/S/cl3000_pl96/checkpoints/epoch_15-step_3456.ckpt'
cpp2='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/ettm2/S/cl3000_pl192/checkpoints/epoch_7-step_1720.ckpt'
cpp3='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/ettm2/S/cl3000_pl336/checkpoints/epoch_5-step_1284.ckpt'
cpp4='./outputs/lsf-setup/lsf/finetune/moirai_1.0_R_small/lsf/full/ettm2/S/cl3000_pl720/checkpoints/epoch_3-step_844.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
2 changes: 1 addition & 1 deletion project/lsf-setup/lsf/finetune/small/ettm2.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0;
export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1;

model=moirai_1.0_R_small
cp=conf/lsf-setup/lsf/finetune
Expand Down
8 changes: 4 additions & 4 deletions project/lsf-setup/multi_scale/eval/small/ettm1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ exp_name=lsf
cl=4000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.1_R_small/lsf/full/ettm1/S/cl4000_pl96/checkpoints/epoch_2-step_1251.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.1_R_small/lsf/full/ettm1/S/cl4000_pl192/checkpoints/epoch_0-step_416.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.1_R_small/lsf/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.1_R_small/lsf/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt'
cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm1/S/cl4000_pl96/checkpoints/epoch_3-step_1668.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm1/S/cl4000_pl192/checkpoints/epoch_1-step_832.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
10 changes: 5 additions & 5 deletions project/lsf-setup/multi_scale/eval/small/ettm2.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=3
export CUDA_VISIBLE_DEVICES=1

mode=S
cp=conf/lsf-setup/multi_scale/eval
exp_name=lsf
cl=3000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.1_R_small/ms_qkv_1.1_ctx3000/full/ettm2/S/cl3000_pl96/checkpoints/epoch_4-step_2155.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.1_R_small/ms_qkv_1.1_ctx3000/full/ettm2/S/cl3000_pl192/checkpoints/epoch_2-step_1287.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.1_R_small/ms_qkv_1.1_ctx3000/full/ettm2/S/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.1_R_small/ms_qkv_1.1_ctx3000/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'
cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm2/S/cl3000_pl96/checkpoints/epoch_5-step_2586.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm2/S/cl3000_pl192/checkpoints/epoch_1-step_858.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm2/S/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
10 changes: 5 additions & 5 deletions project/lsf-setup/multi_scale/eval/small/weather.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=2
export CUDA_VISIBLE_DEVICES=3

mode=S
cp=conf/lsf-setup/multi_scale/eval
exp_name=lsf
cl=2000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_1.0/full/weather/S/cl2000_pl96/checkpoints/epoch_10-step_15708.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_1.0/full/weather/S/cl2000_pl192/checkpoints/epoch_7-step_11392.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_1.0/full/weather/S/cl2000_pl336/checkpoints/epoch_4-step_7090.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_1.0/full/weather/S/cl2000_pl720/checkpoints/epoch_2-step_4206.ckpt'
cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_rope_1.0/full/weather/S/cl2000_pl96/checkpoints/epoch_9-step_14280.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_rope_1.0/full/weather/S/cl2000_pl192/checkpoints/epoch_6-step_9968.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_rope_1.0/full/weather/S/cl2000_pl336/checkpoints/epoch_4-step_7090.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_rope_1.0/full/weather/S/cl2000_pl720/checkpoints/epoch_2-step_4206.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
4 changes: 2 additions & 2 deletions project/lsf-setup/multi_scale/finetune/small/ettm1.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0;
export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune
exp_name=ms_qkv_1.0
exp_name=ms_new_scale_learned_pe
data=ettm1
cl=4000
ps=128
Expand Down
4 changes: 2 additions & 2 deletions project/lsf-setup/multi_scale/finetune/small/ettm2.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3;
export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=2;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune
exp_name=ms_qkv_1.0
exp_name=ms_new_scale_learned_pe
data=ettm2
cl=3000
ps=64
Expand Down
2 changes: 1 addition & 1 deletion project/lsf-setup/multi_scale/finetune/small/weather.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune
exp_name=ms_qkv_1.0
exp_name=ms_qkv_rope_1.0
data=weather
cl=2000
ps=128
Expand Down
Loading

0 comments on commit 9a5b10d

Please sign in to comment.