Skip to content

Commit

Permalink
Initial version of learnable time id
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Nov 28, 2024
1 parent 9a5b10d commit f848942
Show file tree
Hide file tree
Showing 17 changed files with 156 additions and 67 deletions.
1 change: 1 addition & 0 deletions project/lsf-setup/lsf/finetune/base/electricity.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ for pl in 96 192 336 720; do
val_data.context_length=$cl \
val_data.prediction_length=$pl \
val_data.mode=${mode} \
train_dataloader.batch_size=256 \
model.lr=5e-6
done
1 change: 1 addition & 0 deletions project/lsf-setup/lsf/finetune/base/etth1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ for pl in 96 192 336 720; do
val_data.patch_size=${ps} \
val_data.context_length=$cl \
val_data.prediction_length=$pl \
train_dataloader.batch_size=256 \
val_data.mode=${mode}
done
1 change: 1 addition & 0 deletions project/lsf-setup/lsf/finetune/base/etth2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ for pl in 96 192 336 720; do
val_data.patch_size=${ps} \
val_data.context_length=$cl \
val_data.prediction_length=$pl \
train_dataloader.batch_size=256 \
val_data.mode=${mode}
done
1 change: 1 addition & 0 deletions project/lsf-setup/lsf/finetune/base/ettm1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ for pl in 96 192 336 720; do
val_data.patch_size=${ps} \
val_data.context_length=$cl \
val_data.prediction_length=$pl \
train_dataloader.batch_size=256 \
val_data.mode=${mode}
done
1 change: 1 addition & 0 deletions project/lsf-setup/lsf/finetune/base/ettm2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ for pl in 96 192 336 720; do
val_data.patch_size=${ps} \
val_data.context_length=$cl \
val_data.prediction_length=$pl \
train_dataloader.batch_size=256 \
val_data.mode=${mode}
done
1 change: 1 addition & 0 deletions project/lsf-setup/lsf/finetune/base/weather.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ for pl in 96 192 336 720; do
val_data.patch_size=${ps} \
val_data.context_length=$cl \
val_data.prediction_length=$pl \
train_dataloader.batch_size=256 \
val_data.mode=${mode}
done
10 changes: 5 additions & 5 deletions project/lsf-setup/multi_scale/eval/small/ettm1.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=1

mode=S
cp=conf/lsf-setup/multi_scale/eval
exp_name=lsf
cl=4000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm1/S/cl4000_pl96/checkpoints/epoch_3-step_1668.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm1/S/cl4000_pl192/checkpoints/epoch_1-step_832.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt'
cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm1/S/cl4000_pl96/checkpoints/epoch_3-step_1668.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm1/S/cl4000_pl192/checkpoints/epoch_1-step_832.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
8 changes: 4 additions & 4 deletions project/lsf-setup/multi_scale/eval/small/ettm2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ exp_name=lsf
cl=3000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm2/S/cl3000_pl96/checkpoints/epoch_5-step_2586.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm2/S/cl3000_pl192/checkpoints/epoch_1-step_858.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm2/S/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_new_scale_learned_pe/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'
cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm2/S/cl3000_pl96/checkpoints/epoch_6-step_3017.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm2/S/cl3000_pl192/checkpoints/epoch_1-step_858.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm2/S/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
21 changes: 12 additions & 9 deletions project/lsf-setup/multi_scale/eval/small/weather.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@ exp_name=lsf
cl=2000
model=moirai_lightning_ckpt

cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_rope_1.0/full/weather/S/cl2000_pl96/checkpoints/epoch_9-step_14280.ckpt'
cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_rope_1.0/full/weather/S/cl2000_pl192/checkpoints/epoch_6-step_9968.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_rope_1.0/full/weather/S/cl2000_pl336/checkpoints/epoch_4-step_7090.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/ms_qkv_rope_1.0/full/weather/S/cl2000_pl720/checkpoints/epoch_2-step_4206.ckpt'
#cpp1='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/weather/S/cl2000_pl96/checkpoints/epoch_8-step_12852.ckpt'
#cpp2='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/weather/S/cl2000_pl192/checkpoints/epoch_6-step_9968.ckpt'
cpp3='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/weather/S/cl2000_pl336/checkpoints/epoch_5-step_8508.ckpt'
cpp4='./outputs/lsf-setup/multi_scale/finetune/moirai_1.0_R_small/learned_time_id/full/weather/S/cl2000_pl720/checkpoints/epoch_3-step_5608.ckpt'

index=1
for pl in 96 192 336 720; do
for pl in 336 720; do # 96 192
case $index in
1) cpp=$cpp1 ;;
2) cpp=$cpp2 ;;
3) cpp=$cpp3 ;;
4) cpp=$cpp4 ;;
1) cpp=$cpp3 ;;
2) cpp=$cpp4 ;;

# 1) cpp=$cpp1 ;;
# 2) cpp=$cpp2 ;;
# 3) cpp=$cpp3 ;;
# 4) cpp=$cpp4 ;;
esac

pretrained_model=$(echo $cpp | cut -d'/' -f6)
Expand Down
2 changes: 1 addition & 1 deletion project/lsf-setup/multi_scale/finetune/small/ettm1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune
exp_name=ms_new_scale_learned_pe
exp_name=learned_time_id
data=ettm1
cl=4000
ps=128
Expand Down
4 changes: 2 additions & 2 deletions project/lsf-setup/multi_scale/finetune/small/ettm2.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=2;
export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune
exp_name=ms_new_scale_learned_pe
exp_name=learned_time_id
data=ettm2
cl=3000
ps=64
Expand Down
4 changes: 2 additions & 2 deletions project/lsf-setup/multi_scale/finetune/small/weather.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0;
export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1;

model=moirai_1.0_R_small
cp=conf/lsf-setup/multi_scale/finetune
exp_name=ms_qkv_rope_1.0
exp_name=learned_time_id
data=weather
cl=2000
ps=128
Expand Down
31 changes: 24 additions & 7 deletions src/uni2ts/model/multi_scale_moirai/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
self.num_new_scales = num_new_scales
self.ds_factor = ds_factor

self.token_idx_per_scale = self._get_token_idx_per_scale()
self.token_idx_per_scale, self.base_ctx_token_idx = self._get_token_idx_per_scale()

# Lora config
self.lora_config = LoraConfig(**lora_kwargs) if use_lora else None
Expand All @@ -147,9 +147,9 @@ def post_init(self):
# # Call post_init() method of the GroupedQueryAttention object
# layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor)

# for module in self.module.encoder.modules():
# if isinstance(module, MultiScaleRotaryProjection):
# module.post_init(self.token_idx_per_scale)
for module in self.module.encoder.modules():
if isinstance(module, MultiScaleRotaryProjection):
module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx)

if self.lora_config is not None:
self.module = LoraModel(self.module, self.lora_config, "default")
Expand Down Expand Up @@ -185,7 +185,10 @@ def _get_token_idx_per_scale(self):
index = list(range(start, end))
token_idx_per_scale.append(index)

return token_idx_per_scale
base_ctx_token_len = math.ceil(self.context_length / self.patch_size)
base_ctx_token_idx = list(range(base_ctx_token_len))

return token_idx_per_scale, base_ctx_token_idx



Expand Down Expand Up @@ -353,6 +356,9 @@ def configure_optimizers(self) -> dict:
if "pe_weights" in pn: # Learnable RoPE for time id proj
p.requires_grad = True

if "seq_id_q_proj" in pn or "seq_id_k_proj" in pn:
p.requires_grad = True

# Unfreeze the corresponding params
if "param_proj" in self.finetune_pattern:
for pn, p in self.named_parameters():
Expand Down Expand Up @@ -437,16 +443,27 @@ def configure_optimizers(self) -> dict:
continue

fpn = f"{mn}.{pn}" if mn else pn
if pn.endswith("bias"):
if pn.endswith("bias") and 'time_qk_proj' not in pn:
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, whitelist_params):
elif pn.endswith("weight") and isinstance(m, whitelist_params) and 'time_qk_proj' not in pn:
decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, blacklist_params):
no_decay.add(fpn)
elif "adapt_weight" in pn or "adapt_bias" in pn:
decay.add(fpn)
elif 'pe_weights' in pn:
decay.add(fpn)

elif 'layers.0.self_attn.time_qk_proj.seq_id_q_proj' in pn and isinstance(m, whitelist_params):
decay.add(fpn)
elif 'layers.0.self_attn.time_qk_proj.seq_id_k_proj' in pn and isinstance(m, whitelist_params):
decay.add(fpn)

elif 'layers.0.self_attn.time_qk_proj.seq_id_q_proj' in pn and pn.endswith("bias"):
no_decay.add(fpn)
elif 'layers.0.self_attn.time_qk_proj.seq_id_k_proj' in pn and pn.endswith("bias"):
no_decay.add(fpn)

# elif 'layers.0.self_attn.time_qk_proj.query_proj.pe_weights' in pn: # Shared time_qk_proj
# decay.add(fpn)

Expand Down
15 changes: 9 additions & 6 deletions src/uni2ts/model/multi_scale_moirai/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,14 @@ def __init__(
), "if module is not provided, module_kwargs is required"
super().__init__()
self.save_hyperparameters(ignore=["module"])
self.module = MoiraiModule(**module_kwargs) if module is None else module
self.module = MoiraiModule(**module_kwargs) if module is None else module # module is None. Initialized by module_kwargs
self.per_sample_loss_func = SampleNLLLoss()
self.num_new_scales = num_new_scales

self.ds_factor = ds_factor
self.strict_loading = False

self.token_idx_per_scale = self._get_token_idx_per_scale()
self.token_idx_per_scale, self.base_ctx_token_idx = self._get_token_idx_per_scale()

# Set Lora for Moirai
if use_lora:
Expand All @@ -147,9 +147,9 @@ def post_init(self):
# # Call post_init() method of the GroupedQueryAttention object
# layer.self_attn.init_multi_scale_modules(self.hparams.context_length, self.hparams.patch_size, self.num_new_scales, self.ds_factor)

# for module in self.module.encoder.modules():
# if isinstance(module, MultiScaleRotaryProjection):
# module.post_init(self.token_idx_per_scale)
for module in self.module.encoder.modules():
if isinstance(module, MultiScaleRotaryProjection):
module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx)

pass

Expand All @@ -175,7 +175,10 @@ def _get_token_idx_per_scale(self):
index = list(range(start, end))
token_idx_per_scale.append(index)

return token_idx_per_scale
base_ctx_token_len = math.ceil(self.hparams.context_length / self.hparams.patch_size)
base_ctx_token_idx = list(range(base_ctx_token_len))

return token_idx_per_scale, base_ctx_token_idx


@contextmanager
Expand Down
12 changes: 6 additions & 6 deletions src/uni2ts/model/multi_scale_moirai/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ def __init__(
# num_vars=4 # ToDo: 这个num_vars得提供外部接口
),
time_qk_proj_layer=partial(
# QueryKeyProjection,
# proj_layer=MultiScaleRotaryProjection,
# kwargs=dict(max_len=max_seq_len),
# partial_factor=(0.0, 0.5), # 之前的partial factor是0-0.5

QueryKeyProjection,
proj_layer=RotaryProjection, # ToDo: 可以改
proj_layer=MultiScaleRotaryProjection,
kwargs=dict(max_len=max_seq_len),
partial_factor=(0.0, 0.5), # 之前的partial factor是0-0.5

# QueryKeyProjection,
# proj_layer=RotaryProjection, # ToDo: 可以改
# kwargs=dict(max_len=max_seq_len),
# partial_factor=(0.0, 0.5), # 之前的partial factor是0-0.5
),
shared_var_attn_bias=False,
shared_time_qk_proj=True, # True by default
Expand Down
2 changes: 1 addition & 1 deletion src/uni2ts/module/multi_scale/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def forward(
)

# Add attn_bias
attn_mask = self._update_attn_mask(
attn_mask = self._update_attn_mask( # (bs, 6, 1, len, len)
attn_mask,
query,
key,
Expand Down
Loading

0 comments on commit f848942

Please sign in to comment.