Skip to content

Commit

Permalink
Film dim
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Oct 24, 2024
1 parent 7f43cc1 commit c715305
Show file tree
Hide file tree
Showing 14 changed files with 124 additions and 75 deletions.
2 changes: 1 addition & 1 deletion cli/conf/lsf_point/finetune/model/moirai_1.1_R_small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ val_metric:
- _target_: uni2ts.loss.packed.PackedMSELoss
- _target_: uni2ts.loss.packed.PackedNRMSELoss
normalize: absolute_target_squared
lr: 5e-7 # On ETT dataset, using 1e-6/5e-7 converge within 1-2 epochs. 1e-7 converge in tens of epochs
lr: 5e-6 # On ETT dataset, using 1e-6/5e-7 converge within 1-2 epochs. 1e-7 converge in tens of epochs
weight_decay: 1e-1
beta1: 0.9
beta2: 0.98
Expand Down
3 changes: 2 additions & 1 deletion cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def main(cfg: DictConfig):

model: L.LightningModule = instantiate(cfg.model, _convert_="all")

# model.module.replace_forecast_head(seq_len=49, pred_len=96)
if hasattr(model, 'post_init') and callable(getattr(model, 'post_init')):
model.post_init()

if "collate_fn" not in cfg.train_dataloader:
model.seq_fields = model.seq_fields + ("sample_id",)
Expand Down
10 changes: 5 additions & 5 deletions project/lsf_point/eval/small/ettm1.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=1

mode=S
cp=conf/lsf_point/eval
Expand All @@ -10,10 +10,10 @@ cl=3000
model=moirai_lightning_ckpt


cpp1='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm1/cl3000_pl96/checkpoints/epoch_84-step_18360.ckpt'
cpp2='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm1/cl3000_pl192/checkpoints/epoch_79-step_17200.ckpt'
cpp3='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm1/cl3000_pl336/checkpoints/epoch_68-step_14766.ckpt'
cpp4='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm1/cl3000_pl720/checkpoints/epoch_54-step_11605.ckpt'
cpp1='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm1/cl3000_pl96/checkpoints/epoch_44-step_9720.ckpt'
cpp2='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm1/cl3000_pl192/checkpoints/epoch_36-step_7955.ckpt'
cpp3='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm1/cl3000_pl336/checkpoints/epoch_11-step_2568.ckpt'
cpp4='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm1/cl3000_pl720/checkpoints/epoch_7-step_1688.ckpt'


index=1
Expand Down
10 changes: 5 additions & 5 deletions project/lsf_point/eval/small/ettm2.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=1

mode=S
cp=conf/lsf_point/eval
Expand All @@ -10,10 +10,10 @@ cl=3000
model=moirai_lightning_ckpt


cpp1='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm2/cl3000_pl96/checkpoints/epoch_63-step_13824.ckpt'
cpp2='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm2/cl3000_pl192/checkpoints/epoch_52-step_11395.ckpt'
cpp3='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm2/cl3000_pl336/checkpoints/epoch_37-step_8132.ckpt'
cpp4='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm2/cl3000_pl720/checkpoints/epoch_32-step_6963.ckpt'
cpp1='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm2/cl3000_pl96/checkpoints/epoch_13-step_3024.ckpt'
cpp2='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm2/cl3000_pl192/checkpoints/epoch_6-step_1505.ckpt'
cpp3='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm2/cl3000_pl336/checkpoints/epoch_5-step_1284.ckpt'
cpp4='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm2/cl3000_pl720/checkpoints/epoch_5-step_1266.ckpt'


index=1
Expand Down
2 changes: 1 addition & 1 deletion project/lsf_point/finetune/small/ettm1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ model=moirai_1.1_R_small
cp=conf/lsf_point/finetune
exp_name=lsf
cl=3000
ft_pattern=full
ft_pattern=head_dp07

data=ettm1
ps=128
Expand Down
4 changes: 2 additions & 2 deletions project/lsf_point/finetune/small/ettm2.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3;
export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0;

model=moirai_1.1_R_small
cp=conf/lsf_point/finetune
exp_name=lsf
cl=3000
ft_pattern=full
ft_pattern=head_dp07

data=ettm2
ps=64
Expand Down
10 changes: 5 additions & 5 deletions project/multi_scale/eval/small/ettm1.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=3
export CUDA_VISIBLE_DEVICES=0

mode=S
cp=conf/multi_scale/eval
exp_name=lsf
cl=3000
model=moirai_lightning_ckpt

cpp1='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl96/checkpoints/epoch_4-step_50.ckpt'
cpp2='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl192/checkpoints/epoch_1-step_20.ckpt'
cpp3='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl336/checkpoints/epoch_1-step_20.ckpt'
cpp4='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl720/checkpoints/epoch_1-step_20.ckpt'
cpp1='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm1/cl3000_pl96/checkpoints/epoch_2-step_1293.ckpt'
cpp2='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm1/cl3000_pl192/checkpoints/epoch_1-step_858.ckpt'
cpp3='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm1/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt'
cpp4='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm1/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
10 changes: 5 additions & 5 deletions project/multi_scale/eval/small/ettm2.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=3
export CUDA_VISIBLE_DEVICES=0

mode=S
cp=conf/multi_scale/eval
exp_name=lsf
cl=3000
model=moirai_lightning_ckpt

cpp1='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm2/cl3000_pl96/checkpoints/epoch_5-step_2586.ckpt'
cpp2='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm2/cl3000_pl192/checkpoints/epoch_1-step_858.ckpt'
cpp3='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm2/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt'
cpp4='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm2/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'
cpp1='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm2/cl3000_pl96/checkpoints/epoch_5-step_2586.ckpt'
cpp2='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm2/cl3000_pl192/checkpoints/epoch_1-step_858.ckpt'
cpp3='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm2/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt'
cpp4='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm2/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt'

index=1
for pl in 96 192 336 720; do
Expand Down
2 changes: 1 addition & 1 deletion project/multi_scale/finetune/small/ettm1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ model=moirai_1.1_R_small
cp=conf/multi_scale/finetune
exp_name=lsf
cl=3000
ft_pattern=full_0
ft_pattern=full_film_dim

data=ettm1
ps=128
Expand Down
2 changes: 1 addition & 1 deletion project/multi_scale/finetune/small/ettm2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ model=moirai_1.1_R_small
cp=conf/multi_scale/finetune
exp_name=lsf
cl=3000
ft_pattern=full_0
ft_pattern=full_film_dim

data=ettm2
ps=64
Expand Down
2 changes: 1 addition & 1 deletion src/uni2ts/model/lsf_moirai_point/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def replace_forecast_head(self, seq_len, pred_len):
self.distr_output = None
self.param_proj = None

self.head_dropout = nn.Dropout(p=0, inplace=False)
self.head_dropout = nn.Dropout(p=0.7, inplace=False)
self.head_fc1 = nn.Linear(
in_features=seq_len * self.d_model, out_features=self.d_model
)
Expand Down
13 changes: 11 additions & 2 deletions src/uni2ts/model/multi_scale_moirai/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
)

from .module import MoiraiModule
from uni2ts.module.multi_scale.attention import GroupedQueryAttention



class MoiraiFinetune(L.LightningModule):
Expand Down Expand Up @@ -112,7 +114,7 @@ def __init__(
prediction_length: Optional[int | list[int]] = None,
patch_size: Optional[int] = None,
finetune_pattern: str | list[str] = "full",
num_new_scales: int = 1,
num_new_scales: Optional[int] = None,
ds_factor: int = 2,
):
super().__init__()
Expand All @@ -126,6 +128,13 @@ def __init__(
self.num_new_scales = num_new_scales
self.ds_factor = ds_factor

def post_init(self):
for layer in self.module.encoder.layers:
# Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention
if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention):
# Call post_init() method of the GroupedQueryAttention object
layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor)

def forward(
self,
target: Float[torch.Tensor, "*batch seq_len max_patch"],
Expand Down Expand Up @@ -275,7 +284,7 @@ def configure_optimizers(self) -> dict:
# p.requires_grad = True

for pn, p in self.named_parameters():
if "filmed_generator" in pn:
if "film" in pn:
p.requires_grad = True

# Unfreeze the corresponding params
Expand Down
12 changes: 12 additions & 0 deletions src/uni2ts/model/multi_scale_moirai/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from uni2ts.loss.packed import PackedNLLLoss as _PackedNLLLoss

from .module import MoiraiModule
from uni2ts.module.multi_scale.attention import GroupedQueryAttention


class SampleNLLLoss(_PackedNLLLoss):
Expand Down Expand Up @@ -107,9 +108,20 @@ def __init__(
self.module = MoiraiModule(**module_kwargs) if module is None else module
self.per_sample_loss_func = SampleNLLLoss()
self.num_new_scales = num_new_scales

self.ds_factor = ds_factor
self.strict_loading = False

self.post_init() # ToDO: Make it optional.


def post_init(self):
for layer in self.module.encoder.layers:
# Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention
if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention):
# Call post_init() method of the GroupedQueryAttention object
layer.self_attn.init_multi_scale_modules(self.hparams.context_length, self.hparams.patch_size, self.num_new_scales, self.ds_factor)

@contextmanager
def hparams_context(
self,
Expand Down
117 changes: 72 additions & 45 deletions src/uni2ts/module/multi_scale/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,30 +97,46 @@ def __init__(
self.attn_dropout_p = attn_dropout_p
self.out_proj = nn.Linear(dim, dim, bias=bias)

# Todo: Create these mudules based on num_new_scales & ds_factor & seq_len
# base = 24 # 24 for ETTH1, 12 for ETTm1
# self.query_filmed_generator = nn.ModuleList(
# [
# nn.Linear(in_features=dim, out_features=2 * base),
# nn.Linear(in_features=dim, out_features=2 * base // 2),
# nn.Linear(in_features=dim, out_features=2 * base // 4),
# ]
# )
#
# self.key_filmed_generator = nn.ModuleList(
# [
# nn.Linear(in_features=dim, out_features=2 * base),
# nn.Linear(in_features=dim, out_features=2 * base // 2),
# nn.Linear(in_features=dim, out_features=2 * base // 4),
# ]
# )
#
# # self.value_filmed_generator = nn.ModuleList(
# # [
# # nn.Linear(in_features=dim, out_features=2 * 12), # each scale's length
# # nn.Linear(in_features=dim, out_features=2 * 6)
# # ]
# # )
self.dim = dim
self.num_new_scales = None

def init_multi_scale_modules(self, context_length, patch_size, num_new_scales, ds_factor):

self.num_new_scales = num_new_scales

nh = self.dim//4
self.film_controller = nn.Sequential(nn.Linear(self.dim, nh), nn.SiLU())

self.query_film_generator = nn.ModuleList([
nn.Linear(in_features=nh, out_features=self.dim) for _ in range(num_new_scales)
])

self.key_film_generator = nn.ModuleList([
nn.Linear(in_features=nh, out_features=self.dim) for _ in range(num_new_scales)
])

# def init_multi_scale_modules(self, context_length, patch_size, num_new_scales, ds_factor):
#
# self.num_new_scales = num_new_scales
#
# base_len = math.ceil(context_length / patch_size) # num context patches in base scale
#
# scale_len = math.ceil(base_len / ds_factor)
# self.query_film_generator = nn.ModuleList([
# nn.Linear(in_features=self.dim, out_features=2 * scale_len)
# ])
# self.key_film_generator = nn.ModuleList([
# nn.Linear(in_features=self.dim, out_features=2 * scale_len)
# ])
#
# for _ in range(1, num_new_scales):
# scale_len = math.ceil(scale_len / ds_factor)
# self.query_film_generator.append(
# nn.Linear(in_features=self.dim, out_features=2 * scale_len)
# )
# self.key_film_generator.append(
# nn.Linear(in_features=self.dim, out_features=2 * scale_len)
# )

def _get_var_id(
self,
Expand Down Expand Up @@ -290,29 +306,40 @@ def forward(
value = self.v_proj(value)

# ToDo: Plan B: Directly apply different Film on query / key to different scales. W.o revising RoPE
# index_by_variate = self.get_token_index_by_variate(query_var_id)
#
# for scale in range(3): # ToDO: number_of scales:
# assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id"
# index = index_by_variate[scale+1]
# query_scale = query[..., index, :]
# query_film_out = self.query_filmed_generator[scale](torch.mean(query_scale, dim=1))
# query_weight, query_bias = query_film_out[:, :int(query_film_out.size(-1) / 2)], query_film_out[:, int(query_film_out.size(-1) / 2):]
# query[..., index, :] = query_weight.unsqueeze(-1) * query_scale + query_bias.unsqueeze(-1)
if self.num_new_scales is not None:
index_by_variate = self.get_token_index_by_variate(query_var_id)

for scale in range(self.num_new_scales):
assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id"
index = index_by_variate[scale + 1]

query_scale = query[..., index, :] # (bs, num_patch_new_scale, dim)
query_scale_reprs = self.film_controller(torch.mean(query_scale, dim=1))
query_weight = self.query_film_generator[scale](query_scale_reprs)
query[..., index, :] = query_weight.unsqueeze(-2) * query_scale

key_scale = key[..., index, :]
key_scale_reprs = self.film_controller(torch.mean(key_scale, dim=1))
key_weight = self.key_film_generator[scale](key_scale_reprs)
key[..., index, :] = key_weight.unsqueeze(-2) * key_scale

# if self.num_new_scales is not None:
# index_by_variate = self.get_token_index_by_variate(query_var_id)
#
# key_scale = key[..., index, :]
# key_film_out = self.key_filmed_generator[scale](torch.mean(key_scale, dim=1))
# key_weight, key_bias = key_film_out[:, :int(key_film_out.size(-1) / 2)], key_film_out[:,
# int(key_film_out.size(
# -1) / 2):]
# key[..., index, :] = key_weight.unsqueeze(-1) * key_scale + key_bias.unsqueeze(-1)
# for scale in range(self.num_new_scales):
# assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id"
# index = index_by_variate[scale+1]
# query_scale = query[..., index, :] # (bs, num_patch_new_scale, dim)
# query_film_out = self.query_film_generator[scale](torch.mean(query_scale, dim=1)) # ToDo: 换成faltten试试?
# query_weight, query_bias = query_film_out[:, :int(query_film_out.size(-1) / 2)], query_film_out[:, int(query_film_out.size(-1) / 2):]
# query[..., index, :] = query_weight.unsqueeze(-1) * query_scale + query_bias.unsqueeze(-1)
#
# # value_i = value[..., index, :]
# # value_film_out = self.value_filmed_generator[scale](torch.mean(value_i, dim=1))
# # value_weight, value_bias = value_film_out[:, :int(value_film_out.size(-1) / 2)], value_film_out[:,
# # int(value_film_out.size(
# # -1) / 2):]
# # value[..., index, :] = value_weight.unsqueeze(-1) * value_i + value_bias.unsqueeze(-1)
# key_scale = key[..., index, :]
# key_film_out = self.key_film_generator[scale](torch.mean(key_scale, dim=1))
# key_weight, key_bias = key_film_out[:, :int(key_film_out.size(-1) / 2)], key_film_out[:,
# int(key_film_out.size(
# -1) / 2):]
# key[..., index, :] = key_weight.unsqueeze(-1) * key_scale + key_bias.unsqueeze(-1)

query = self.q_norm(
rearrange(
Expand Down

0 comments on commit c715305

Please sign in to comment.