diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 58bcb2206..919222ce8 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -15,10 +15,8 @@ import torch import re from library.utils import setup_logging - setup_logging() import logging - logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -506,15 +504,6 @@ def create_network( if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) - loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) - loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) - loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) - loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None - loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None - loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None - if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: - network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) - return network @@ -540,9 +529,7 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - logger.warning( - f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" - ) + logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -816,17 +803,11 @@ def __init__( self.rank_dropout = rank_dropout self.module_dropout = module_dropout - self.loraplus_lr_ratio = None - self.loraplus_unet_lr_ratio = None - self.loraplus_text_encoder_lr_ratio = None - if modules_dim is not None: logger.info(f"create LoRA network from weights") elif block_dims is not None: logger.info(f"create LoRA network from block_dims") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(f"block_dims: {block_dims}") logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: @@ -834,13 +815,9 @@ def __init__( logger.info(f"conv_block_alphas: {conv_block_alphas}") else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: - logger.info( - f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" - ) + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -962,11 +939,6 @@ def create_modules( assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) - def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): - self.loraplus_lr_ratio = loraplus_lr_ratio - self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio - self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio - def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -1065,42 +1037,18 @@ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, ratio): - param_groups = {"lora": {}, "plus": {}} - for lora in loras: - for name, param in lora.named_parameters(): - if ratio is not None and "lora_up" in name: - param_groups["plus"][f"{lora.lora_name}.{name}"] = param - else: - param_groups["lora"][f"{lora.lora_name}.{name}"] = param - + def enumerate_params(loras: List[LoRAModule]): params = [] - for key in param_groups.keys(): - param_data = {"params": param_groups[key].values()} - - if len(param_data["params"]) == 0: - continue - - if lr is not None: - if key == "plus": - param_data["lr"] = lr * ratio - else: - param_data["lr"] = lr - - if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: - continue - - params.append(param_data) - + for lora in loras: + # params.extend(lora.parameters()) + params.extend(lora.get_trainable_params()) return params if self.text_encoder_loras: - params = assemble_params( - self.text_encoder_loras, - text_encoder_lr if text_encoder_lr is not None else default_lr, - self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, - ) - all_params.extend(params) + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) if self.unet_loras: if self.block_lr: @@ -1114,20 +1062,21 @@ def assemble_params(loras, lr, ratio): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - params = assemble_params( - block_loras, - (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - self.loraplus_unet_lr_ratio or self.loraplus_ratio, - ) - all_params.extend(params) + param_data = {"params": enumerate_params(block_loras)} + + if unet_lr is not None: + param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + elif default_lr is not None: + param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + all_params.append(param_data) else: - params = assemble_params( - self.unet_loras, - unet_lr if unet_lr is not None else default_lr, - self.loraplus_unet_lr_ratio or self.loraplus_ratio, - ) - all_params.extend(params) + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) return all_params @@ -1144,9 +1093,6 @@ def on_epoch_start(self, text_encoder, unet): def get_trainable_params(self): return self.parameters() - def get_trainable_named_params(self): - return self.named_parameters() - def save_weights(self, file, dtype, metadata): if metadata is not None and len(metadata) == 0: metadata = None