Skip to content

Commit

Permalink
Interface for data mode
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Nov 18, 2024
1 parent 6f65513 commit 1810c19
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 34 deletions.
7 changes: 3 additions & 4 deletions cli/conf/lsf-setup/multi_scale/finetune/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ trainer:
mode: min
save_top_k: 1 # Qz: Sometimes the 1st validation gets anomalous results. Discard that ckpt, and use the 2nd one.
every_n_epochs: 1
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
save_weights_only: true
- _target_: lightning.pytorch.callbacks.EarlyStopping # uni2ts.callbacks.earlystop.WarmupEarlyStopping
monitor: val/PackedNLLLoss
min_delta: 0.0
Expand All @@ -41,10 +44,6 @@ trainer:
strict: false
verbose: true
# warmup_steps: 1
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
save_last: true
save_weights_only: true
max_epochs: 1000
enable_progress_bar: true
accumulate_grad_batches: 1
Expand Down
3 changes: 2 additions & 1 deletion cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def main(cfg: DictConfig):
trainer: L.Trainer = instantiate(cfg.trainer)

# '=' in ckpt name make it cannot be directly loaded with hydra. Change it to '_'.
trainer.callbacks[-1].CHECKPOINT_EQUALS_CHAR = "_"
trainer.callbacks[1].CHECKPOINT_EQUALS_CHAR = "_"
trainer.callbacks[2].CHECKPOINT_EQUALS_CHAR = "_"

train_dataset: Dataset = instantiate(cfg.data).load_dataset(
model.train_transform_map
Expand Down
23 changes: 13 additions & 10 deletions project/lsf-setup/multi_scale/eval/small/ettm1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,24 @@ export CUDA_VISIBLE_DEVICES=0
mode=S
cp=conf/lsf-setup/multi_scale/eval
exp_name=lsf
cl=3000
cl=4000
model=moirai_lightning_ckpt

cpp1='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_ms_rope/ettm1/cl3000_pl96/checkpoints/epoch_2-step_1293.ckpt'
cpp2='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_ms_rope/ettm1/cl3000_pl192/checkpoints/epoch_1-step_858.ckpt'
cpp3='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_ms_rope/ettm1/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt'
cpp4='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_ms_rope/ettm1/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'
#cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/lsf/full/ettm1/S/cl4000_pl96/checkpoints/epoch_3-step_1668.ckpt'
#cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/lsf/full/ettm2/S/cl3000_pl192/checkpoints/epoch_2-step_1287.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/lsf/full/ettm1/S/cl4000_pl336/checkpoints/epoch_3-step_1656.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/lsf/full/ettm1/S/cl4000_pl720/checkpoints/epoch_3-step_1632.ckpt'

index=1
for pl in 96 192 336 720; do
for pl in 336 720; do # 96 192
case $index in
1) cpp=$cpp1 ;;
2) cpp=$cpp2 ;;
3) cpp=$cpp3 ;;
4) cpp=$cpp4 ;;
1) cpp=$cpp3 ;;
2) cpp=$cpp4 ;;

# 1) cpp=$cpp1 ;;
# 2) cpp=$cpp2 ;;
# 3) cpp=$cpp3 ;;
# 4) cpp=$cpp4 ;;
esac

pretrained_model=$(echo $cpp | cut -d'/' -f6)
Expand Down
24 changes: 14 additions & 10 deletions project/lsf-setup/multi_scale/eval/small/ettm2.sh
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=3

mode=S
cp=conf/lsf-setup/multi_scale/eval
exp_name=lsf
cl=3000
model=moirai_lightning_ckpt

cpp1='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_ms_rope/ettm2/cl3000_pl96/checkpoints/epoch_5-step_2586.ckpt'
cpp2='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_ms_rope/ettm2/cl3000_pl192/checkpoints/epoch_2-step_1287.ckpt'
cpp3='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_ms_rope/ettm2/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt'
cpp4='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_ms_rope/ettm2/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'
#cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/lsf/full/ettm2/S/cl3000_pl96/checkpoints/epoch_5-step_2586.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/lsf/full/ettm2/S/cl3000_pl192/checkpoints/epoch_2-step_1287.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/lsf/full/ettm2/S/cl3000_pl336/checkpoints/epoch_3-step_1708.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/lsf/full/ettm2/S/cl4000_pl720/checkpoints/epoch_3-step_1688.ckpt'

index=1
for pl in 96 192 336 720; do
for pl in 192 336 720; do # 96
case $index in
1) cpp=$cpp1 ;;
2) cpp=$cpp2 ;;
3) cpp=$cpp3 ;;
4) cpp=$cpp4 ;;
1) cpp=$cpp2 ;;
2) cpp=$cpp3 ;;
3) cpp=$cpp4 ;;

# 1) cpp=$cpp1 ;;
# 2) cpp=$cpp2 ;;
# 3) cpp=$cpp3 ;;
# 4) cpp=$cpp4 ;;
esac

pretrained_model=$(echo $cpp | cut -d'/' -f6)
Expand Down
18 changes: 9 additions & 9 deletions src/uni2ts/model/multi_scale_moirai/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,15 @@ def __init__(


def post_init(self):
# for layer in self.module.encoder.layers:
# # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention
# if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention):
# # Call post_init() method of the GroupedQueryAttention object
# layer.self_attn.init_multi_scale_modules(self.hparams.context_length, self.hparams.patch_size, self.num_new_scales, self.ds_factor)

for module in self.module.encoder.modules():
if isinstance(module, MultiScaleRotaryProjection):
module.post_init(self.token_idx_per_scale)
for layer in self.module.encoder.layers:
# Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention
if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention):
# Call post_init() method of the GroupedQueryAttention object
layer.self_attn.init_multi_scale_modules(self.hparams.context_length, self.hparams.patch_size, self.num_new_scales, self.ds_factor)

# for module in self.module.encoder.modules():
# if isinstance(module, MultiScaleRotaryProjection):
# module.post_init(self.token_idx_per_scale)


def _get_token_idx_per_scale(self):
Expand Down

0 comments on commit 1810c19

Please sign in to comment.