From bea1cf56d051befb66dba8a93cd83f2296a9137b Mon Sep 17 00:00:00 2001 From: Qiao Zhongzheng Date: Thu, 26 Sep 2024 22:07:21 +0800 Subject: [PATCH] Revise configure_optimizer to finetune att params --- project/finetune/eval/small/run_multi.sh | 8 +++++ project/finetune/train/small/run_multi.sh | 10 ++++++ src/uni2ts/model/moirai/finetune.py | 38 +++++++++++++++++++++-- 3 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 project/finetune/eval/small/run_multi.sh create mode 100644 project/finetune/train/small/run_multi.sh diff --git a/project/finetune/eval/small/run_multi.sh b/project/finetune/eval/small/run_multi.sh new file mode 100644 index 0000000..933462c --- /dev/null +++ b/project/finetune/eval/small/run_multi.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +bash project/finetune/eval/small/etth1.sh +bash project/finetune/eval/small/etth2.sh +bash project/finetune/eval/small/ettm1.sh +bash project/finetune/eval/small/ettm2.sh +#bash project/finetune/eval/small/weather.sh + diff --git a/project/finetune/train/small/run_multi.sh b/project/finetune/train/small/run_multi.sh new file mode 100644 index 0000000..16a8301 --- /dev/null +++ b/project/finetune/train/small/run_multi.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +#bash project/finetune/train/small/weather.sh +#bash project/finetune/train/small/weather_2.sh +bash project/finetune/train/small/etth1.sh +bash project/finetune/train/small/etth2.sh +bash project/finetune/train/small/ettm1.sh +bash project/finetune/train/small/ettm2.sh + + diff --git a/src/uni2ts/model/moirai/finetune.py b/src/uni2ts/model/moirai/finetune.py index a2d9fff..cea6483 100644 --- a/src/uni2ts/model/moirai/finetune.py +++ b/src/uni2ts/model/moirai/finetune.py @@ -304,9 +304,41 @@ def configure_optimizers(self) -> dict: if "ffn" in pn: p.requires_grad = True - if "self_attn" in self.finetune_pattern: - # Todo: Analyze each component in self_attn & Lora's impact. - pass + # if "self_attn" in self.finetune_pattern: + # # Todo: Analyze each component in self_attn & Lora's impact. + # pass + + if "q_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "q_proj" in pn: + p.requires_grad = True + + if "k_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "k_proj" in pn: + p.requires_grad = True + + if "v_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "v_proj" in pn: + p.requires_grad = True + + if "att_norm" in self.finetune_pattern: # + for pn, p in self.named_parameters(): + if "self_attn.q_norm" in pn or "self_attn.k_norm" in pn: + p.requires_grad = True + + + if "var_attn_bias" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "var_attn_bias" in pn: + p.requires_grad = True + + if "out_proj" in self.finetune_pattern: + for pn, p in self.named_parameters(): + if "out_proj" in pn: + p.requires_grad = True + whitelist_params = ( LearnedProjection,