diff --git a/cli/conf/lsf_point/finetune/model/moirai_1.1_R_small.yaml b/cli/conf/lsf_point/finetune/model/moirai_1.1_R_small.yaml index f5761e3..160d3b1 100644 --- a/cli/conf/lsf_point/finetune/model/moirai_1.1_R_small.yaml +++ b/cli/conf/lsf_point/finetune/model/moirai_1.1_R_small.yaml @@ -29,7 +29,7 @@ val_metric: - _target_: uni2ts.loss.packed.PackedMSELoss - _target_: uni2ts.loss.packed.PackedNRMSELoss normalize: absolute_target_squared -lr: 5e-7 # On ETT dataset, using 1e-6/5e-7 converge within 1-2 epochs. 1e-7 converge in tens of epochs +lr: 5e-6 # On ETT dataset, using 1e-6/5e-7 converge within 1-2 epochs. 1e-7 converge in tens of epochs weight_decay: 1e-1 beta1: 0.9 beta2: 0.98 diff --git a/cli/train.py b/cli/train.py index 14cade8..1ece54a 100644 --- a/cli/train.py +++ b/cli/train.py @@ -128,7 +128,8 @@ def main(cfg: DictConfig): model: L.LightningModule = instantiate(cfg.model, _convert_="all") - # model.module.replace_forecast_head(seq_len=49, pred_len=96) + if hasattr(model, 'post_init') and callable(getattr(model, 'post_init')): + model.post_init() if "collate_fn" not in cfg.train_dataloader: model.seq_fields = model.seq_fields + ("sample_id",) diff --git a/project/lsf_point/eval/small/ettm1.sh b/project/lsf_point/eval/small/ettm1.sh index 6db5afb..be14748 100644 --- a/project/lsf_point/eval/small/ettm1.sh +++ b/project/lsf_point/eval/small/ettm1.sh @@ -1,7 +1,7 @@ #!/bin/bash export HYDRA_FULL_ERROR=1 -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=1 mode=S cp=conf/lsf_point/eval @@ -10,10 +10,10 @@ cl=3000 model=moirai_lightning_ckpt -cpp1='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm1/cl3000_pl96/checkpoints/epoch_84-step_18360.ckpt' -cpp2='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm1/cl3000_pl192/checkpoints/epoch_79-step_17200.ckpt' -cpp3='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm1/cl3000_pl336/checkpoints/epoch_68-step_14766.ckpt' -cpp4='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm1/cl3000_pl720/checkpoints/epoch_54-step_11605.ckpt' +cpp1='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm1/cl3000_pl96/checkpoints/epoch_44-step_9720.ckpt' +cpp2='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm1/cl3000_pl192/checkpoints/epoch_36-step_7955.ckpt' +cpp3='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm1/cl3000_pl336/checkpoints/epoch_11-step_2568.ckpt' +cpp4='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm1/cl3000_pl720/checkpoints/epoch_7-step_1688.ckpt' index=1 diff --git a/project/lsf_point/eval/small/ettm2.sh b/project/lsf_point/eval/small/ettm2.sh index de68ea3..c0d2332 100644 --- a/project/lsf_point/eval/small/ettm2.sh +++ b/project/lsf_point/eval/small/ettm2.sh @@ -1,7 +1,7 @@ #!/bin/bash export HYDRA_FULL_ERROR=1 -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=1 mode=S cp=conf/lsf_point/eval @@ -10,10 +10,10 @@ cl=3000 model=moirai_lightning_ckpt -cpp1='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm2/cl3000_pl96/checkpoints/epoch_63-step_13824.ckpt' -cpp2='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm2/cl3000_pl192/checkpoints/epoch_52-step_11395.ckpt' -cpp3='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm2/cl3000_pl336/checkpoints/epoch_37-step_8132.ckpt' -cpp4='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head/ettm2/cl3000_pl720/checkpoints/epoch_32-step_6963.ckpt' +cpp1='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm2/cl3000_pl96/checkpoints/epoch_13-step_3024.ckpt' +cpp2='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm2/cl3000_pl192/checkpoints/epoch_6-step_1505.ckpt' +cpp3='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm2/cl3000_pl336/checkpoints/epoch_5-step_1284.ckpt' +cpp4='./outputs/lsf_point/finetune/moirai_1.1_R_small/lsf/head_dp02/ettm2/cl3000_pl720/checkpoints/epoch_5-step_1266.ckpt' index=1 diff --git a/project/lsf_point/finetune/small/ettm1.sh b/project/lsf_point/finetune/small/ettm1.sh index 3b54426..a4d7df9 100644 --- a/project/lsf_point/finetune/small/ettm1.sh +++ b/project/lsf_point/finetune/small/ettm1.sh @@ -6,7 +6,7 @@ model=moirai_1.1_R_small cp=conf/lsf_point/finetune exp_name=lsf cl=3000 -ft_pattern=full +ft_pattern=head_dp07 data=ettm1 ps=128 diff --git a/project/lsf_point/finetune/small/ettm2.sh b/project/lsf_point/finetune/small/ettm2.sh index 0f2e754..55b6dea 100644 --- a/project/lsf_point/finetune/small/ettm2.sh +++ b/project/lsf_point/finetune/small/ettm2.sh @@ -1,12 +1,12 @@ #!/bin/bash -export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3; +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; model=moirai_1.1_R_small cp=conf/lsf_point/finetune exp_name=lsf cl=3000 -ft_pattern=full +ft_pattern=head_dp07 data=ettm2 ps=64 diff --git a/project/multi_scale/eval/small/ettm1.sh b/project/multi_scale/eval/small/ettm1.sh index cc5a513..65c39d3 100644 --- a/project/multi_scale/eval/small/ettm1.sh +++ b/project/multi_scale/eval/small/ettm1.sh @@ -1,7 +1,7 @@ #!/bin/bash export HYDRA_FULL_ERROR=1 -export CUDA_VISIBLE_DEVICES=3 +export CUDA_VISIBLE_DEVICES=0 mode=S cp=conf/multi_scale/eval @@ -9,10 +9,10 @@ exp_name=lsf cl=3000 model=moirai_lightning_ckpt -cpp1='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl96/checkpoints/epoch_4-step_50.ckpt' -cpp2='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl192/checkpoints/epoch_1-step_20.ckpt' -cpp3='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl336/checkpoints/epoch_1-step_20.ckpt' -cpp4='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl720/checkpoints/epoch_1-step_20.ckpt' +cpp1='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm1/cl3000_pl96/checkpoints/epoch_2-step_1293.ckpt' +cpp2='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm1/cl3000_pl192/checkpoints/epoch_1-step_858.ckpt' +cpp3='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm1/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt' +cpp4='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm1/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt' index=1 for pl in 96 192 336 720; do diff --git a/project/multi_scale/eval/small/ettm2.sh b/project/multi_scale/eval/small/ettm2.sh index f931b21..bf0c1d9 100644 --- a/project/multi_scale/eval/small/ettm2.sh +++ b/project/multi_scale/eval/small/ettm2.sh @@ -1,7 +1,7 @@ #!/bin/bash export HYDRA_FULL_ERROR=1 -export CUDA_VISIBLE_DEVICES=3 +export CUDA_VISIBLE_DEVICES=0 mode=S cp=conf/multi_scale/eval @@ -9,10 +9,10 @@ exp_name=lsf cl=3000 model=moirai_lightning_ckpt -cpp1='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm2/cl3000_pl96/checkpoints/epoch_5-step_2586.ckpt' -cpp2='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm2/cl3000_pl192/checkpoints/epoch_1-step_858.ckpt' -cpp3='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm2/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt' -cpp4='./outputs/origin/moirai_1.1_R_small/lsf/full/ettm2/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt' +cpp1='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm2/cl3000_pl96/checkpoints/epoch_5-step_2586.ckpt' +cpp2='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm2/cl3000_pl192/checkpoints/epoch_1-step_858.ckpt' +cpp3='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm2/cl3000_pl336/checkpoints/epoch_0-step_427.ckpt' +cpp4='./outputs/multi_scale/finetune/moirai_1.1_R_small/lsf/full_film_dim/ettm2/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt' index=1 for pl in 96 192 336 720; do diff --git a/project/multi_scale/finetune/small/ettm1.sh b/project/multi_scale/finetune/small/ettm1.sh index 5c5b7d7..cefde43 100644 --- a/project/multi_scale/finetune/small/ettm1.sh +++ b/project/multi_scale/finetune/small/ettm1.sh @@ -6,7 +6,7 @@ model=moirai_1.1_R_small cp=conf/multi_scale/finetune exp_name=lsf cl=3000 -ft_pattern=full_0 +ft_pattern=full_film_dim data=ettm1 ps=128 diff --git a/project/multi_scale/finetune/small/ettm2.sh b/project/multi_scale/finetune/small/ettm2.sh index 37569fa..7af0203 100644 --- a/project/multi_scale/finetune/small/ettm2.sh +++ b/project/multi_scale/finetune/small/ettm2.sh @@ -6,7 +6,7 @@ model=moirai_1.1_R_small cp=conf/multi_scale/finetune exp_name=lsf cl=3000 -ft_pattern=full_0 +ft_pattern=full_film_dim data=ettm2 ps=64 diff --git a/src/uni2ts/model/lsf_moirai_point/module.py b/src/uni2ts/model/lsf_moirai_point/module.py index 323223e..3d519b6 100644 --- a/src/uni2ts/model/lsf_moirai_point/module.py +++ b/src/uni2ts/model/lsf_moirai_point/module.py @@ -132,7 +132,7 @@ def replace_forecast_head(self, seq_len, pred_len): self.distr_output = None self.param_proj = None - self.head_dropout = nn.Dropout(p=0, inplace=False) + self.head_dropout = nn.Dropout(p=0.7, inplace=False) self.head_fc1 = nn.Linear( in_features=seq_len * self.d_model, out_features=self.d_model ) diff --git a/src/uni2ts/model/multi_scale_moirai/finetune.py b/src/uni2ts/model/multi_scale_moirai/finetune.py index f6b1d27..2397447 100644 --- a/src/uni2ts/model/multi_scale_moirai/finetune.py +++ b/src/uni2ts/model/multi_scale_moirai/finetune.py @@ -70,6 +70,8 @@ ) from .module import MoiraiModule +from uni2ts.module.multi_scale.attention import GroupedQueryAttention + class MoiraiFinetune(L.LightningModule): @@ -112,7 +114,7 @@ def __init__( prediction_length: Optional[int | list[int]] = None, patch_size: Optional[int] = None, finetune_pattern: str | list[str] = "full", - num_new_scales: int = 1, + num_new_scales: Optional[int] = None, ds_factor: int = 2, ): super().__init__() @@ -126,6 +128,13 @@ def __init__( self.num_new_scales = num_new_scales self.ds_factor = ds_factor + def post_init(self): + for layer in self.module.encoder.layers: + # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention + if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): + # Call post_init() method of the GroupedQueryAttention object + layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor) + def forward( self, target: Float[torch.Tensor, "*batch seq_len max_patch"], @@ -275,7 +284,7 @@ def configure_optimizers(self) -> dict: # p.requires_grad = True for pn, p in self.named_parameters(): - if "filmed_generator" in pn: + if "film" in pn: p.requires_grad = True # Unfreeze the corresponding params diff --git a/src/uni2ts/model/multi_scale_moirai/forecast.py b/src/uni2ts/model/multi_scale_moirai/forecast.py index f881b6d..556098e 100644 --- a/src/uni2ts/model/multi_scale_moirai/forecast.py +++ b/src/uni2ts/model/multi_scale_moirai/forecast.py @@ -40,6 +40,7 @@ from uni2ts.loss.packed import PackedNLLLoss as _PackedNLLLoss from .module import MoiraiModule +from uni2ts.module.multi_scale.attention import GroupedQueryAttention class SampleNLLLoss(_PackedNLLLoss): @@ -107,9 +108,20 @@ def __init__( self.module = MoiraiModule(**module_kwargs) if module is None else module self.per_sample_loss_func = SampleNLLLoss() self.num_new_scales = num_new_scales + self.ds_factor = ds_factor self.strict_loading = False + self.post_init() # ToDO: Make it optional. + + + def post_init(self): + for layer in self.module.encoder.layers: + # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention + if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): + # Call post_init() method of the GroupedQueryAttention object + layer.self_attn.init_multi_scale_modules(self.hparams.context_length, self.hparams.patch_size, self.num_new_scales, self.ds_factor) + @contextmanager def hparams_context( self, diff --git a/src/uni2ts/module/multi_scale/attention.py b/src/uni2ts/module/multi_scale/attention.py index 91a297d..9288c4a 100644 --- a/src/uni2ts/module/multi_scale/attention.py +++ b/src/uni2ts/module/multi_scale/attention.py @@ -97,30 +97,46 @@ def __init__( self.attn_dropout_p = attn_dropout_p self.out_proj = nn.Linear(dim, dim, bias=bias) - # Todo: Create these mudules based on num_new_scales & ds_factor & seq_len - # base = 24 # 24 for ETTH1, 12 for ETTm1 - # self.query_filmed_generator = nn.ModuleList( - # [ - # nn.Linear(in_features=dim, out_features=2 * base), - # nn.Linear(in_features=dim, out_features=2 * base // 2), - # nn.Linear(in_features=dim, out_features=2 * base // 4), - # ] - # ) - # - # self.key_filmed_generator = nn.ModuleList( - # [ - # nn.Linear(in_features=dim, out_features=2 * base), - # nn.Linear(in_features=dim, out_features=2 * base // 2), - # nn.Linear(in_features=dim, out_features=2 * base // 4), - # ] - # ) - # - # # self.value_filmed_generator = nn.ModuleList( - # # [ - # # nn.Linear(in_features=dim, out_features=2 * 12), # each scale's length - # # nn.Linear(in_features=dim, out_features=2 * 6) - # # ] - # # ) + self.dim = dim + self.num_new_scales = None + + def init_multi_scale_modules(self, context_length, patch_size, num_new_scales, ds_factor): + + self.num_new_scales = num_new_scales + + nh = self.dim//4 + self.film_controller = nn.Sequential(nn.Linear(self.dim, nh), nn.SiLU()) + + self.query_film_generator = nn.ModuleList([ + nn.Linear(in_features=nh, out_features=self.dim) for _ in range(num_new_scales) + ]) + + self.key_film_generator = nn.ModuleList([ + nn.Linear(in_features=nh, out_features=self.dim) for _ in range(num_new_scales) + ]) + + # def init_multi_scale_modules(self, context_length, patch_size, num_new_scales, ds_factor): + # + # self.num_new_scales = num_new_scales + # + # base_len = math.ceil(context_length / patch_size) # num context patches in base scale + # + # scale_len = math.ceil(base_len / ds_factor) + # self.query_film_generator = nn.ModuleList([ + # nn.Linear(in_features=self.dim, out_features=2 * scale_len) + # ]) + # self.key_film_generator = nn.ModuleList([ + # nn.Linear(in_features=self.dim, out_features=2 * scale_len) + # ]) + # + # for _ in range(1, num_new_scales): + # scale_len = math.ceil(scale_len / ds_factor) + # self.query_film_generator.append( + # nn.Linear(in_features=self.dim, out_features=2 * scale_len) + # ) + # self.key_film_generator.append( + # nn.Linear(in_features=self.dim, out_features=2 * scale_len) + # ) def _get_var_id( self, @@ -290,29 +306,40 @@ def forward( value = self.v_proj(value) # ToDo: Plan B: Directly apply different Film on query / key to different scales. W.o revising RoPE - # index_by_variate = self.get_token_index_by_variate(query_var_id) - # - # for scale in range(3): # ToDO: number_of scales: - # assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id" - # index = index_by_variate[scale+1] - # query_scale = query[..., index, :] - # query_film_out = self.query_filmed_generator[scale](torch.mean(query_scale, dim=1)) - # query_weight, query_bias = query_film_out[:, :int(query_film_out.size(-1) / 2)], query_film_out[:, int(query_film_out.size(-1) / 2):] - # query[..., index, :] = query_weight.unsqueeze(-1) * query_scale + query_bias.unsqueeze(-1) + if self.num_new_scales is not None: + index_by_variate = self.get_token_index_by_variate(query_var_id) + + for scale in range(self.num_new_scales): + assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id" + index = index_by_variate[scale + 1] + + query_scale = query[..., index, :] # (bs, num_patch_new_scale, dim) + query_scale_reprs = self.film_controller(torch.mean(query_scale, dim=1)) + query_weight = self.query_film_generator[scale](query_scale_reprs) + query[..., index, :] = query_weight.unsqueeze(-2) * query_scale + + key_scale = key[..., index, :] + key_scale_reprs = self.film_controller(torch.mean(key_scale, dim=1)) + key_weight = self.key_film_generator[scale](key_scale_reprs) + key[..., index, :] = key_weight.unsqueeze(-2) * key_scale + + # if self.num_new_scales is not None: + # index_by_variate = self.get_token_index_by_variate(query_var_id) # - # key_scale = key[..., index, :] - # key_film_out = self.key_filmed_generator[scale](torch.mean(key_scale, dim=1)) - # key_weight, key_bias = key_film_out[:, :int(key_film_out.size(-1) / 2)], key_film_out[:, - # int(key_film_out.size( - # -1) / 2):] - # key[..., index, :] = key_weight.unsqueeze(-1) * key_scale + key_bias.unsqueeze(-1) + # for scale in range(self.num_new_scales): + # assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id" + # index = index_by_variate[scale+1] + # query_scale = query[..., index, :] # (bs, num_patch_new_scale, dim) + # query_film_out = self.query_film_generator[scale](torch.mean(query_scale, dim=1)) # ToDo: 换成faltten试试? + # query_weight, query_bias = query_film_out[:, :int(query_film_out.size(-1) / 2)], query_film_out[:, int(query_film_out.size(-1) / 2):] + # query[..., index, :] = query_weight.unsqueeze(-1) * query_scale + query_bias.unsqueeze(-1) # - # # value_i = value[..., index, :] - # # value_film_out = self.value_filmed_generator[scale](torch.mean(value_i, dim=1)) - # # value_weight, value_bias = value_film_out[:, :int(value_film_out.size(-1) / 2)], value_film_out[:, - # # int(value_film_out.size( - # # -1) / 2):] - # # value[..., index, :] = value_weight.unsqueeze(-1) * value_i + value_bias.unsqueeze(-1) + # key_scale = key[..., index, :] + # key_film_out = self.key_film_generator[scale](torch.mean(key_scale, dim=1)) + # key_weight, key_bias = key_film_out[:, :int(key_film_out.size(-1) / 2)], key_film_out[:, + # int(key_film_out.size( + # -1) / 2):] + # key[..., index, :] = key_weight.unsqueeze(-1) * key_scale + key_bias.unsqueeze(-1) query = self.q_norm( rearrange(