Skip to content

Commit

Permalink
Add scalewise lora for attn
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Dec 7, 2024
1 parent 3fceb49 commit 07824a8
Show file tree
Hide file tree
Showing 12 changed files with 242 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ trainer_warmup:
logger:
_target_: lightning.pytorch.loggers.TensorBoardLogger
save_dir: ${hydra:runtime.output_dir}
name: logs
name: logs_warmup
callbacks:
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: epoch
Expand Down
3 changes: 2 additions & 1 deletion project/lsf-setup/lsf/finetune/base/electricity.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ for pl in 96 192 336 720; do
val_data.prediction_length=$pl \
val_data.mode=${mode} \
train_dataloader.batch_size=256 \
model.lr=5e-6
model.lr=1e-5 \
trainer.callbacks.'3'.patience=1
done
8 changes: 4 additions & 4 deletions project/lsf-setup/multi_scale/eval/small/ettm1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ exp_name=lsf
cl=4000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl96/checkpoints/epoch_4-step_2085.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl192/checkpoints/epoch_1-step_832.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt'
cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl96/checkpoints/epoch_5-step_2502.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl192/checkpoints/epoch_1-step_832.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
10 changes: 5 additions & 5 deletions project/lsf-setup/multi_scale/eval/small/ettm2.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=3
export CUDA_VISIBLE_DEVICES=2

mode=S
cp=conf/lsf-setup/multi_scale/eval
exp_name=lsf
cl=3000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl96/checkpoints/epoch_13-step_6034.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl192/checkpoints/epoch_4-step_2145.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl336/checkpoints/epoch_1-step_854.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'
cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl96/checkpoints/epoch_16-step_7327.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl192/checkpoints/epoch_3-step_1716.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl336/checkpoints/epoch_1-step_854.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
24 changes: 10 additions & 14 deletions project/lsf-setup/multi_scale/eval/small/weather.sh
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=3

mode=S
cp=conf/lsf-setup/multi_scale/eval
exp_name=lsf
cl=2000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl96/checkpoints/epoch_7-step_11424.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl192/checkpoints/epoch_6-step_9968.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl336/checkpoints/epoch_2-step_4254.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl720/checkpoints/epoch_1-step_2804.ckpt'
cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl96/checkpoints/epoch_14-step_21420.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl192/checkpoints/epoch_9-step_14240.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl336/checkpoints/epoch_5-step_8508.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl720/checkpoints/epoch_1-step_2804.ckpt'

index=1
for pl in 192 336 720; do # 96
for pl in 96 192 336 720; do
case $index in
1) cpp=$cpp2 ;;
2) cpp=$cpp3 ;;
3) cpp=$cpp4 ;;

# 1) cpp=$cpp1 ;;
# 2) cpp=$cpp2 ;;
# 3) cpp=$cpp3 ;;
# 4) cpp=$cpp4 ;;
1) cpp=$cpp1 ;;
2) cpp=$cpp2 ;;
3) cpp=$cpp3 ;;
4) cpp=$cpp4 ;;
esac

pretrained_model=$(echo $cpp | cut -d'/' -f6)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0;
export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune_two_stage
exp_name=learned_time_id_2stage
exp_name=1tid_2inproj_all_scale_lora_freezeqkv
data=ettm1
cl=4000
ps=128
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=2;
export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune_two_stage
exp_name=learned_time_id_2stage
exp_name=1tid_2inproj_all_scale_lora_freezeqkv
data=ettm2
cl=3000
ps=64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune_two_stage
exp_name=learned_time_id_2stage_valMSE
exp_name=1tid_2inproj_all_scale_lora_freezeqkv
data=weather
cl=2000
ps=128
Expand Down Expand Up @@ -33,5 +33,7 @@ for pl in 96 192 336 720; do
val_data.prediction_length=$pl \
val_data.mode=${mode} \
trainer.callbacks."1".monitor=val/PackedMSELoss \
trainer.callbacks."3".monitor=val/PackedMSELoss
trainer.callbacks."3".monitor=val/PackedMSELoss \
trainer_warmup.callbacks."1".monitor=val/PackedMSELoss \
trainer_warmup.callbacks."2".monitor=val/PackedMSELoss
done
16 changes: 11 additions & 5 deletions src/uni2ts/model/multi_scale_moirai/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ def post_init(self):

self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.patch_size)

# for layer in self.module.encoder.layers:
# # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention
# if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention):
# # Call post_init() method of the GroupedQueryAttention object
# layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor)
for layer in self.module.encoder.layers:
# Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention
if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention):
# Call post_init() method of the GroupedQueryAttention object
layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor)

for module in self.module.encoder.modules():
if isinstance(module, MultiScaleRotaryProjection):
Expand Down Expand Up @@ -455,6 +455,12 @@ def configure_optimizers(self) -> dict:
decay.add(fpn)
elif 'pe_weights' in pn:
decay.add(fpn)
elif 'q_A' in pn or 'q_B' in pn or 'q_bias' in pn:
decay.add(fpn)
elif 'k_A' in pn or 'k_B' in pn or 'k_bias' in pn:
decay.add(fpn)
elif 'v_A' in pn or 'v_B' in pn or 'v_bias' in pn:
decay.add(fpn)

# elif 'layers.0.self_attn.time_qk_proj.query_proj.pe_weights' in pn: # Shared time_qk_proj
# decay.add(fpn)
Expand Down
58 changes: 53 additions & 5 deletions src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ def post_init(self):

self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.patch_size)

# for layer in self.module.encoder.layers:
# # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention
# if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention):
# # Call post_init() method of the GroupedQueryAttention object
# layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor)
for layer in self.module.encoder.layers:
# Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention
if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention):
# Call post_init() method of the GroupedQueryAttention object
layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor)

for module in self.module.encoder.modules():
if isinstance(module, MultiScaleRotaryProjection):
Expand Down Expand Up @@ -331,6 +331,48 @@ def validation_step(

def configure_optimizers(self) -> dict:

# if self.current_stage == 1:
# warmup_pn_group1 = ['param_proj', 'time_id_q_proj', 'time_id_k_proj']
# warmup_pn_group2 = ['in_proj_new_scales']
#
# optim_groups = [
# {
# "params": [
# p for pn, p in self.named_parameters()
# if any(keyword in pn for keyword in warmup_pn_group1)
# ],
# "lr": 5e-4,
# "weight_decay": self.hparams.weight_decay,
# },
# {
# "params": [
# p for pn, p in self.named_parameters()
# if any(keyword in pn for keyword in warmup_pn_group2)
# ],
# "lr": 5e-6,
# "weight_decay": self.hparams.weight_decay,
# },
# ]
#
# optimizer = torch.optim.AdamW(
# optim_groups,
# betas=(self.hparams.beta1, self.hparams.beta2),
# eps=1e-6,
# )
#
# warmup_params_all = {
# pn: p for pn, p in self.named_parameters()
# if any(keyword in pn for keyword in warmup_pn_group1 + warmup_pn_group2)
# }
# self.trainable_params = warmup_params_all
#
# scheduler = get_scheduler(
# SchedulerType.CONSTANT, # Use constant lr scheduler
# optimizer,
# num_warmup_steps=self.hparams.num_warmup_steps,
# num_training_steps=self.hparams.num_training_steps,
# )

if self.current_stage == 1:
warmup_pn = ['param_proj', 'time_id_q_proj', 'time_id_k_proj']
warmup_params = {
Expand Down Expand Up @@ -474,6 +516,12 @@ def configure_optimizers(self) -> dict:
decay.add(fpn)
elif 'pe_weights' in pn:
decay.add(fpn)
elif 'q_A' in pn or 'q_B' in pn:
decay.add(fpn)
elif 'k_A' in pn or 'k_B' in pn:
decay.add(fpn)
elif 'v_A' in pn or 'v_B' in pn:
decay.add(fpn)

# elif 'layers.0.self_attn.time_qk_proj.query_proj.pe_weights' in pn: # Shared time_qk_proj
# decay.add(fpn)
Expand Down
10 changes: 5 additions & 5 deletions src/uni2ts/model/multi_scale_moirai/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ def post_init(self):

self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.hparams.patch_size)

# for layer in self.module.encoder.layers:
# # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention
# if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention):
# # Call post_init() method of the GroupedQueryAttention object
# layer.self_attn.init_multi_scale_modules(self.hparams.context_length, self.hparams.patch_size, self.num_new_scales, self.ds_factor)
for layer in self.module.encoder.layers:
# Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention
if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention):
# Call post_init() method of the GroupedQueryAttention object
layer.self_attn.init_multi_scale_modules(self.hparams.context_length, self.hparams.patch_size, self.num_new_scales, self.ds_factor)

for module in self.module.encoder.modules():
if isinstance(module, MultiScaleRotaryProjection):
Expand Down
Loading

0 comments on commit 07824a8

Please sign in to comment.