Skip to content

Commit

Permalink
Revise configure_optimizer to finetune att params
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Sep 26, 2024
1 parent b6bcbfa commit bea1cf5
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
8 changes: 8 additions & 0 deletions project/finetune/eval/small/run_multi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

bash project/finetune/eval/small/etth1.sh
bash project/finetune/eval/small/etth2.sh
bash project/finetune/eval/small/ettm1.sh
bash project/finetune/eval/small/ettm2.sh
#bash project/finetune/eval/small/weather.sh

10 changes: 10 additions & 0 deletions project/finetune/train/small/run_multi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

#bash project/finetune/train/small/weather.sh
#bash project/finetune/train/small/weather_2.sh
bash project/finetune/train/small/etth1.sh
bash project/finetune/train/small/etth2.sh
bash project/finetune/train/small/ettm1.sh
bash project/finetune/train/small/ettm2.sh


38 changes: 35 additions & 3 deletions src/uni2ts/model/moirai/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,41 @@ def configure_optimizers(self) -> dict:
if "ffn" in pn:
p.requires_grad = True

if "self_attn" in self.finetune_pattern:
# Todo: Analyze each component in self_attn & Lora's impact.
pass
# if "self_attn" in self.finetune_pattern:
# # Todo: Analyze each component in self_attn & Lora's impact.
# pass

if "q_proj" in self.finetune_pattern:
for pn, p in self.named_parameters():
if "q_proj" in pn:
p.requires_grad = True

if "k_proj" in self.finetune_pattern:
for pn, p in self.named_parameters():
if "k_proj" in pn:
p.requires_grad = True

if "v_proj" in self.finetune_pattern:
for pn, p in self.named_parameters():
if "v_proj" in pn:
p.requires_grad = True

if "att_norm" in self.finetune_pattern: #
for pn, p in self.named_parameters():
if "self_attn.q_norm" in pn or "self_attn.k_norm" in pn:
p.requires_grad = True


if "var_attn_bias" in self.finetune_pattern:
for pn, p in self.named_parameters():
if "var_attn_bias" in pn:
p.requires_grad = True

if "out_proj" in self.finetune_pattern:
for pn, p in self.named_parameters():
if "out_proj" in pn:
p.requires_grad = True


whitelist_params = (
LearnedProjection,
Expand Down

0 comments on commit bea1cf5

Please sign in to comment.