diff --git a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml index f35ddd8..6243e3a 100644 --- a/cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml +++ b/cli/conf/lsf-setup/multi_scale/finetune_two_stage/default.yaml @@ -23,7 +23,7 @@ trainer_warmup: logger: _target_: lightning.pytorch.loggers.TensorBoardLogger save_dir: ${hydra:runtime.output_dir} - name: logs + name: logs_warmup callbacks: - _target_: lightning.pytorch.callbacks.LearningRateMonitor logging_interval: epoch diff --git a/project/lsf-setup/lsf/finetune/base/electricity.sh b/project/lsf-setup/lsf/finetune/base/electricity.sh index 8999086..05b980f 100644 --- a/project/lsf-setup/lsf/finetune/base/electricity.sh +++ b/project/lsf-setup/lsf/finetune/base/electricity.sh @@ -33,5 +33,6 @@ for pl in 96 192 336 720; do val_data.prediction_length=$pl \ val_data.mode=${mode} \ train_dataloader.batch_size=256 \ - model.lr=5e-6 + model.lr=1e-5 \ + trainer.callbacks.'3'.patience=1 done \ No newline at end of file diff --git a/project/lsf-setup/multi_scale/eval/small/ettm1.sh b/project/lsf-setup/multi_scale/eval/small/ettm1.sh index 4f25b21..1b79a84 100644 --- a/project/lsf-setup/multi_scale/eval/small/ettm1.sh +++ b/project/lsf-setup/multi_scale/eval/small/ettm1.sh @@ -9,10 +9,10 @@ exp_name=lsf cl=4000 model=moirai_lightning_ckpt -cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl96/checkpoints/epoch_4-step_2085.ckpt' -cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl192/checkpoints/epoch_1-step_832.ckpt' -cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt' -cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt' +cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl96/checkpoints/epoch_5-step_2502.ckpt' +cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl192/checkpoints/epoch_1-step_832.ckpt' +cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl336/checkpoints/epoch_0-step_414.ckpt' +cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm1/S/cl4000_pl720/checkpoints/epoch_0-step_408.ckpt' index=1 for pl in 96 192 336 720; do diff --git a/project/lsf-setup/multi_scale/eval/small/ettm2.sh b/project/lsf-setup/multi_scale/eval/small/ettm2.sh index 4dbbc44..fb091af 100644 --- a/project/lsf-setup/multi_scale/eval/small/ettm2.sh +++ b/project/lsf-setup/multi_scale/eval/small/ettm2.sh @@ -1,7 +1,7 @@ #!/bin/bash export HYDRA_FULL_ERROR=1 -export CUDA_VISIBLE_DEVICES=3 +export CUDA_VISIBLE_DEVICES=2 mode=S cp=conf/lsf-setup/multi_scale/eval @@ -9,10 +9,10 @@ exp_name=lsf cl=3000 model=moirai_lightning_ckpt -cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl96/checkpoints/epoch_13-step_6034.ckpt' -cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl192/checkpoints/epoch_4-step_2145.ckpt' -cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl336/checkpoints/epoch_1-step_854.ckpt' -cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt' +cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl96/checkpoints/epoch_16-step_7327.ckpt' +cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl192/checkpoints/epoch_3-step_1716.ckpt' +cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl336/checkpoints/epoch_1-step_854.ckpt' +cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/ettm2/S/cl3000_pl720/checkpoints/epoch_0-step_422.ckpt' index=1 for pl in 96 192 336 720; do diff --git a/project/lsf-setup/multi_scale/eval/small/weather.sh b/project/lsf-setup/multi_scale/eval/small/weather.sh index a192c5d..072fcaa 100644 --- a/project/lsf-setup/multi_scale/eval/small/weather.sh +++ b/project/lsf-setup/multi_scale/eval/small/weather.sh @@ -1,7 +1,7 @@ #!/bin/bash export HYDRA_FULL_ERROR=1 -export CUDA_VISIBLE_DEVICES=1 +export CUDA_VISIBLE_DEVICES=3 mode=S cp=conf/lsf-setup/multi_scale/eval @@ -9,22 +9,18 @@ exp_name=lsf cl=2000 model=moirai_lightning_ckpt -cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl96/checkpoints/epoch_7-step_11424.ckpt' -cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl192/checkpoints/epoch_6-step_9968.ckpt' -cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl336/checkpoints/epoch_2-step_4254.ckpt' -cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/learned_time_id_2stage_valMSE/full/weather/S/cl2000_pl720/checkpoints/epoch_1-step_2804.ckpt' +cpp1='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl96/checkpoints/epoch_14-step_21420.ckpt' +cpp2='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl192/checkpoints/epoch_9-step_14240.ckpt' +cpp3='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl336/checkpoints/epoch_5-step_8508.ckpt' +cpp4='./outputs/lsf-setup/multi_scale/finetune_two_stage/moirai_1.0_R_small/1tid_2inproj_all_scale_lora_freezeqkv/full/weather/S/cl2000_pl720/checkpoints/epoch_1-step_2804.ckpt' index=1 -for pl in 192 336 720; do # 96 +for pl in 96 192 336 720; do case $index in - 1) cpp=$cpp2 ;; - 2) cpp=$cpp3 ;; - 3) cpp=$cpp4 ;; - -# 1) cpp=$cpp1 ;; -# 2) cpp=$cpp2 ;; -# 3) cpp=$cpp3 ;; -# 4) cpp=$cpp4 ;; + 1) cpp=$cpp1 ;; + 2) cpp=$cpp2 ;; + 3) cpp=$cpp3 ;; + 4) cpp=$cpp4 ;; esac pretrained_model=$(echo $cpp | cut -d'/' -f6) diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh index f2bd104..631098b 100644 --- a/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm1.sh @@ -1,10 +1,10 @@ #!/bin/bash -export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1; model=moirai_1.0_R_small cp=conf/lsf-setup/multi_scale/finetune_two_stage -exp_name=learned_time_id_2stage +exp_name=1tid_2inproj_all_scale_lora_freezeqkv data=ettm1 cl=4000 ps=128 diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh index f9282ef..59ef855 100644 --- a/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/ettm2.sh @@ -1,10 +1,10 @@ #!/bin/bash -export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=2; +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=3; model=moirai_1.0_R_small cp=conf/lsf-setup/multi_scale/finetune_two_stage -exp_name=learned_time_id_2stage +exp_name=1tid_2inproj_all_scale_lora_freezeqkv data=ettm2 cl=3000 ps=64 diff --git a/project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh b/project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh index 4c10e1e..6412134 100644 --- a/project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh +++ b/project/lsf-setup/multi_scale/finetune_two_stage/small/weather.sh @@ -4,7 +4,7 @@ export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; model=moirai_1.0_R_small cp=conf/lsf-setup/multi_scale/finetune_two_stage -exp_name=learned_time_id_2stage_valMSE +exp_name=1tid_2inproj_all_scale_lora_freezeqkv data=weather cl=2000 ps=128 @@ -33,5 +33,7 @@ for pl in 96 192 336 720; do val_data.prediction_length=$pl \ val_data.mode=${mode} \ trainer.callbacks."1".monitor=val/PackedMSELoss \ - trainer.callbacks."3".monitor=val/PackedMSELoss + trainer.callbacks."3".monitor=val/PackedMSELoss \ + trainer_warmup.callbacks."1".monitor=val/PackedMSELoss \ + trainer_warmup.callbacks."2".monitor=val/PackedMSELoss done \ No newline at end of file diff --git a/src/uni2ts/model/multi_scale_moirai/finetune.py b/src/uni2ts/model/multi_scale_moirai/finetune.py index 19c151a..6429b3a 100644 --- a/src/uni2ts/model/multi_scale_moirai/finetune.py +++ b/src/uni2ts/model/multi_scale_moirai/finetune.py @@ -143,11 +143,11 @@ def post_init(self): self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.patch_size) - # for layer in self.module.encoder.layers: - # # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention - # if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): - # # Call post_init() method of the GroupedQueryAttention object - # layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor) + for layer in self.module.encoder.layers: + # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention + if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): + # Call post_init() method of the GroupedQueryAttention object + layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor) for module in self.module.encoder.modules(): if isinstance(module, MultiScaleRotaryProjection): @@ -455,6 +455,12 @@ def configure_optimizers(self) -> dict: decay.add(fpn) elif 'pe_weights' in pn: decay.add(fpn) + elif 'q_A' in pn or 'q_B' in pn or 'q_bias' in pn: + decay.add(fpn) + elif 'k_A' in pn or 'k_B' in pn or 'k_bias' in pn: + decay.add(fpn) + elif 'v_A' in pn or 'v_B' in pn or 'v_bias' in pn: + decay.add(fpn) # elif 'layers.0.self_attn.time_qk_proj.query_proj.pe_weights' in pn: # Shared time_qk_proj # decay.add(fpn) diff --git a/src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py b/src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py index 45d216d..84830c6 100644 --- a/src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py +++ b/src/uni2ts/model/multi_scale_moirai/finetune_two_stage.py @@ -145,11 +145,11 @@ def post_init(self): self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.patch_size) - # for layer in self.module.encoder.layers: - # # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention - # if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): - # # Call post_init() method of the GroupedQueryAttention object - # layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor) + for layer in self.module.encoder.layers: + # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention + if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): + # Call post_init() method of the GroupedQueryAttention object + layer.self_attn.init_multi_scale_modules(self.context_length, self.patch_size, self.num_new_scales, self.ds_factor) for module in self.module.encoder.modules(): if isinstance(module, MultiScaleRotaryProjection): @@ -331,6 +331,48 @@ def validation_step( def configure_optimizers(self) -> dict: + # if self.current_stage == 1: + # warmup_pn_group1 = ['param_proj', 'time_id_q_proj', 'time_id_k_proj'] + # warmup_pn_group2 = ['in_proj_new_scales'] + # + # optim_groups = [ + # { + # "params": [ + # p for pn, p in self.named_parameters() + # if any(keyword in pn for keyword in warmup_pn_group1) + # ], + # "lr": 5e-4, + # "weight_decay": self.hparams.weight_decay, + # }, + # { + # "params": [ + # p for pn, p in self.named_parameters() + # if any(keyword in pn for keyword in warmup_pn_group2) + # ], + # "lr": 5e-6, + # "weight_decay": self.hparams.weight_decay, + # }, + # ] + # + # optimizer = torch.optim.AdamW( + # optim_groups, + # betas=(self.hparams.beta1, self.hparams.beta2), + # eps=1e-6, + # ) + # + # warmup_params_all = { + # pn: p for pn, p in self.named_parameters() + # if any(keyword in pn for keyword in warmup_pn_group1 + warmup_pn_group2) + # } + # self.trainable_params = warmup_params_all + # + # scheduler = get_scheduler( + # SchedulerType.CONSTANT, # Use constant lr scheduler + # optimizer, + # num_warmup_steps=self.hparams.num_warmup_steps, + # num_training_steps=self.hparams.num_training_steps, + # ) + if self.current_stage == 1: warmup_pn = ['param_proj', 'time_id_q_proj', 'time_id_k_proj'] warmup_params = { @@ -474,6 +516,12 @@ def configure_optimizers(self) -> dict: decay.add(fpn) elif 'pe_weights' in pn: decay.add(fpn) + elif 'q_A' in pn or 'q_B' in pn: + decay.add(fpn) + elif 'k_A' in pn or 'k_B' in pn: + decay.add(fpn) + elif 'v_A' in pn or 'v_B' in pn: + decay.add(fpn) # elif 'layers.0.self_attn.time_qk_proj.query_proj.pe_weights' in pn: # Shared time_qk_proj # decay.add(fpn) diff --git a/src/uni2ts/model/multi_scale_moirai/forecast.py b/src/uni2ts/model/multi_scale_moirai/forecast.py index f5cba9e..23c9a54 100644 --- a/src/uni2ts/model/multi_scale_moirai/forecast.py +++ b/src/uni2ts/model/multi_scale_moirai/forecast.py @@ -143,11 +143,11 @@ def post_init(self): self.module.post_init(self.token_idx_per_scale, self.base_ctx_token_idx, self.hparams.patch_size) - # for layer in self.module.encoder.layers: - # # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention - # if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): - # # Call post_init() method of the GroupedQueryAttention object - # layer.self_attn.init_multi_scale_modules(self.hparams.context_length, self.hparams.patch_size, self.num_new_scales, self.ds_factor) + for layer in self.module.encoder.layers: + # Check if the layer has an attribute named `self_attn` and if it is an instance of GroupedQueryAttention + if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, GroupedQueryAttention): + # Call post_init() method of the GroupedQueryAttention object + layer.self_attn.init_multi_scale_modules(self.hparams.context_length, self.hparams.patch_size, self.num_new_scales, self.ds_factor) for module in self.module.encoder.modules(): if isinstance(module, MultiScaleRotaryProjection): diff --git a/src/uni2ts/module/multi_scale/attention.py b/src/uni2ts/module/multi_scale/attention.py index e08145b..1f07797 100644 --- a/src/uni2ts/module/multi_scale/attention.py +++ b/src/uni2ts/module/multi_scale/attention.py @@ -101,38 +101,84 @@ def __init__( self.num_new_scales = None - def init_multi_scale_modules(self, context_length, patch_size, num_new_scales, ds_factor): + def init_multi_scale_modules(self, context_length, patch_size, num_new_scales, ds_factor, ): self.num_new_scales = num_new_scales - base_len = math.ceil(context_length / patch_size) # num context patches in base scale - scale_len = math.ceil(base_len / ds_factor) + rank = 16 # Initialize parameter lists - self.query_adapt_weight = nn.ParameterList() - self.key_adapt_weight = nn.ParameterList() - self.value_adapt_weight = nn.ParameterList() - self.query_adapt_bias = nn.ParameterList() - self.key_adapt_bias = nn.ParameterList() - self.value_adapt_bias = nn.ParameterList() - - for _ in range(num_new_scales): + self.q_A = nn.ParameterList() + self.q_B = nn.ParameterList() + self.q_bias = nn.ParameterList() + + self.k_A = nn.ParameterList() + self.k_B = nn.ParameterList() + self.k_bias = nn.ParameterList() + + self.v_A = nn.ParameterList() + self.v_B = nn.ParameterList() + self.v_bias = nn.ParameterList() + + # 包括origin scale,也用lora;冻结住q k v + self.q_proj.requires_grad_(False) + self.k_proj.requires_grad_(False) + self.v_proj.requires_grad_(False) + + for _ in range(1+num_new_scales): # Append the new parameters for the current scale - self.query_adapt_weight.append( - nn.Parameter(torch.ones((scale_len, self.dim), dtype=torch.float), requires_grad=True)) - self.key_adapt_weight.append( - nn.Parameter(torch.ones((scale_len, self.dim), dtype=torch.float), requires_grad=True)) - self.value_adapt_weight.append( - nn.Parameter(torch.ones((scale_len, self.dim), dtype=torch.float), requires_grad=True)) + self.q_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) + self.k_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) + self.v_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) + + self.q_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) + self.k_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) + self.v_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) + + + # for _ in range(num_new_scales): + # # Append the new parameters for the current scale + # self.q_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) + # self.k_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) + # self.v_A.append(nn.Parameter(torch.randn((rank, self.dim), dtype=torch.float) * 0.01)) + # + # self.q_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) + # self.k_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) + # self.v_B.append(nn.Parameter(torch.zeros((self.dim, rank), dtype=torch.float))) + + - self.query_adapt_bias.append( - nn.Parameter(torch.zeros((scale_len, self.dim), dtype=torch.float), requires_grad=True)) - self.key_adapt_bias.append( - nn.Parameter(torch.zeros((scale_len, self.dim), dtype=torch.float), requires_grad=True)) - self.value_adapt_bias.append( - nn.Parameter(torch.zeros((scale_len, self.dim), dtype=torch.float), requires_grad=True)) - # Update scale length for the next iteration - scale_len = math.ceil(scale_len / ds_factor) + + + # base_len = math.ceil(context_length / patch_size) # num context patches in base scale + # scale_len = math.ceil(base_len / ds_factor) + + # # Initialize parameter lists + # self.query_adapt_weight = nn.ParameterList() + # self.key_adapt_weight = nn.ParameterList() + # self.value_adapt_weight = nn.ParameterList() + # self.query_adapt_bias = nn.ParameterList() + # self.key_adapt_bias = nn.ParameterList() + # self.value_adapt_bias = nn.ParameterList() + # + # for _ in range(num_new_scales): + # # Append the new parameters for the current scale + # self.query_adapt_weight.append( + # nn.Parameter(torch.ones((scale_len, self.dim), dtype=torch.float), requires_grad=True)) + # self.key_adapt_weight.append( + # nn.Parameter(torch.ones((scale_len, self.dim), dtype=torch.float), requires_grad=True)) + # self.value_adapt_weight.append( + # nn.Parameter(torch.ones((scale_len, self.dim), dtype=torch.float), requires_grad=True)) + # + # self.query_adapt_bias.append( + # nn.Parameter(torch.zeros((scale_len, self.dim), dtype=torch.float), requires_grad=True)) + # self.key_adapt_bias.append( + # nn.Parameter(torch.zeros((scale_len, self.dim), dtype=torch.float), requires_grad=True)) + # self.value_adapt_bias.append( + # nn.Parameter(torch.zeros((scale_len, self.dim), dtype=torch.float), requires_grad=True)) + # + # # Update scale length for the next iteration + # scale_len = math.ceil(scale_len / ds_factor) # def init_multi_scale_modules(self, context_length, patch_size, num_new_scales, ds_factor): @@ -321,6 +367,30 @@ def get_token_index_by_variate( return indices_by_variate + def apply_lora(self, + input: torch.Tensor, + layer: nn.Linear, + A: nn.Parameter, + B: nn.Parameter, + alpha: float = 1.0, + ): + """ + 在给定的线性层上应用 LoRA。 + """ + # 获取线性层的权重和偏置 + W_no_grad = layer.weight.detach() # 冻结权重 + + # LoRA 更新部分 + lora_update = alpha * (B @ A) # (in_features, out_features) + + # 合成 LoRA 后的权重 + W_lora = W_no_grad + lora_update # 最终的权重 (in_features, out_features) + + # 计算输出 + out = torch.matmul(input, W_lora.T) + + return out + def forward( self, query: Float[torch.Tensor, "*batch q_len dim"], @@ -332,75 +402,64 @@ def forward( query_time_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, kv_time_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, ) -> Float[torch.Tensor, "*batch q_len dim"]: - query = self.q_proj(query) - key = self.k_proj(key) - value = self.v_proj(value) - # init_query = self.q_proj(query) - # init_key = self.k_proj(key) - # init_value = self.v_proj(value) - # - # query = init_query.clone() - # key = init_key.clone() - # value = init_value.clone() - # - # # ToDo: Plan B: Directly apply different Film on query / key to different scales. W.o revising RoPE - # Clone inplace切片 - # if self.num_new_scales is not None: - # index_by_variate = self.get_token_index_by_variate(query_var_id) - # - # for scale in range(self.num_new_scales): - # assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id" - # index = index_by_variate[scale + 1] - # query_scale = init_query[..., index, :] # (bs, num_patch_new_scale, dim) - # query[..., index, :] = self.query_adapt_weight[scale] * query_scale + self.query_adapt_bias[scale] - # - # key_scale = init_key[..., index, :] # (bs, num_patch_new_scale, dim) - # key[..., index, :] = self.key_adapt_weight[scale] * key_scale + self.key_adapt_bias[scale] - # - # value_scale = init_value[..., index, :] # (bs, num_patch_new_scale, dim) - # value[..., index, :] = self.value_adapt_weight[scale] * value_scale + self.value_adapt_bias[scale] + # query = self.q_proj(query) + # key = self.k_proj(key) + # value = self.v_proj(value) + updated_query = query.clone() + updated_key = key.clone() + updated_value = value.clone() - # # Apply a different transformation for each dimension. All tokens share the same transformation. - # if self.num_new_scales is not None: - # index_by_variate = self.get_token_index_by_variate(query_var_id) - # - # for scale in range(self.num_new_scales): - # assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id" - # index = index_by_variate[scale + 1] - # - # query_scale = query[..., index, :] # (bs, num_patch_new_scale, dim) - # query_scale_reprs = self.film_controller(torch.mean(query_scale, dim=1)) - # query_adapt_weight = self.query_film_generator[scale](query_scale_reprs) # (bs, dim) - # query[..., index, :] = query_adapt_weight.unsqueeze(-2) * query_scale - # - # key_scale = key[..., index, :] - # key_scale_reprs = self.film_controller(torch.mean(key_scale, dim=1)) - # key_adapt_weight = self.key_film_generator[scale](key_scale_reprs) - # key[..., index, :] = key_adapt_weight.unsqueeze(-2) * key_scale + if self.num_new_scales is not None: + index_by_variate = self.get_token_index_by_variate(query_var_id) + assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id" + + for i in range(1 + self.num_new_scales): + index = index_by_variate[i] + query_scale = query[..., index, :] + key_scale = key[..., index, :] + value_scale = value[..., index, :] + + updated_query[..., index, :] = self.apply_lora(query_scale, self.q_proj, self.q_A[i], self.q_B[i]) + updated_key[..., index, :] = self.apply_lora(key_scale, self.k_proj, self.k_A[i], self.k_B[i]) + updated_value[..., index, :] = self.apply_lora(value_scale, self.v_proj, self.v_A[i], self.v_B[i]) + query = updated_query + key = updated_key + value = updated_value - # # Apply a different transformation for each token. All dimensions of a token share the same transformation. + + # # ToDo: 这个可以 v1 + # updated_query = query.clone() + # updated_key = key.clone() + # updated_value = value.clone() + # # if self.num_new_scales is not None: # index_by_variate = self.get_token_index_by_variate(query_var_id) + # assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id" # - # for scale in range(self.num_new_scales): - # assert torch.equal(query_var_id, kv_var_id), "query_var_id is different from kv_var_id" - # index = index_by_variate[scale+1] - # query_scale = query[..., index, :] # (bs, num_patch_new_scale, dim) - # query_film_out = self.query_film_generator[scale](torch.mean(query_scale, dim=1)) # ToDo: 换成faltten试试? - # query_adapt_weight, query_adapt_bias = query_film_out[:, :int(query_film_out.size(-1) / 2)], query_film_out[:, int(query_film_out.size(-1) / 2):] - # query[..., index, :] = query_adapt_weight.unsqueeze(-1) * query_scale + query_adapt_bias.unsqueeze(-1) - # + # for scale in range(1 + self.num_new_scales): + # index = index_by_variate[scale] + # query_scale = query[..., index, :] # key_scale = key[..., index, :] - # key_film_out = self.key_film_generator[scale](torch.mean(key_scale, dim=1)) - # key_adapt_weight, key_adapt_bias = key_film_out[:, :int(key_film_out.size(-1) / 2)], key_film_out[:, - # int(key_film_out.size( - # -1) / 2):] - # key[..., index, :] = key_adapt_weight.unsqueeze(-1) * key_scale + key_adapt_bias.unsqueeze(-1) - + # value_scale = value[..., index, :] + # + # if scale == 0: + # updated_query[..., index, :] = self.q_proj(query_scale) + # updated_key[..., index, :] = self.k_proj(key_scale) + # updated_value[..., index, :] = self.v_proj(value_scale) + # + # else: + # i = scale-1 + # updated_query[..., index, :] = self.apply_lora(query_scale, self.q_proj, self.q_A[i], self.q_B[i]) + # updated_key[..., index, :] = self.apply_lora(key_scale, self.k_proj, self.k_A[i], self.k_B[i]) + # updated_value[..., index, :] = self.apply_lora(value_scale, self.v_proj, self.v_A[i], self.v_B[i]) + # + # query = updated_query + # key = updated_key + # value = updated_value query = self.q_norm(