Skip to content

Commit

Permalink
Fix default_lr being applied
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Apr 3, 2024
1 parent c769160 commit 1933ab4
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
21 changes: 18 additions & 3 deletions networks/dylora.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,14 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
"""

# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
unet_lora_plus_ratio=None,
text_encoder_lora_plus_ratio=None
):
self.requires_grad_(True)
all_params = []

Expand Down Expand Up @@ -442,11 +449,19 @@ def assemble_params(loras, lr, lora_plus_ratio):
return params

if self.text_encoder_loras:
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_lora_plus_ratio
)
all_params.extend(params)

if self.unet_loras:
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_lora_plus_ratio
)
all_params.extend(params)

return all_params
Expand Down
30 changes: 23 additions & 7 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
return lr_weight

# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
unet_lora_plus_ratio=None,
text_encoder_lora_plus_ratio=None
):
self.requires_grad_(True)
all_params = []

Expand Down Expand Up @@ -1070,7 +1077,11 @@ def assemble_params(loras, lr, lora_plus_ratio):
return params

if self.text_encoder_loras:
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_lora_plus_ratio
)
all_params.extend(params)

if self.unet_loras:
Expand All @@ -1085,14 +1096,19 @@ def assemble_params(loras, lr, lora_plus_ratio):

# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
if unet_lr is not None:
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
elif default_lr is not None:
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
params = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_lora_plus_ratio
)
all_params.extend(params)

else:
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_lora_plus_ratio
)
all_params.extend(params)

return all_params
Expand Down
30 changes: 23 additions & 7 deletions networks/lora_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
return lr_weight

# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
unet_lora_plus_ratio=None,
text_encoder_lora_plus_ratio=None
):
self.requires_grad_(True)
all_params = []

Expand Down Expand Up @@ -1068,7 +1075,11 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
return params

if self.text_encoder_loras:
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_lora_plus_ratio
)
all_params.extend(params)

if self.unet_loras:
Expand All @@ -1083,14 +1094,19 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):

# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
if unet_lr is not None:
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
elif default_lr is not None:
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
params = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_lora_plus_ratio
)
all_params.extend(params)

else:
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_lora_plus_ratio
)
all_params.extend(params)

return all_params
Expand Down

0 comments on commit 1933ab4

Please sign in to comment.