Skip to content

Commit

Permalink
Merge pull request #1233 from rockerBOO/lora-plus
Browse files Browse the repository at this point in the history
Add LoRA+ support
  • Loading branch information
kohya-ss authored Apr 29, 2024
2 parents 0540c33 + 68467bd commit 834445a
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 74 deletions.
3 changes: 3 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2920,6 +2920,9 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
)
parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")


def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
Expand Down
60 changes: 48 additions & 12 deletions networks/dylora.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,27 +406,63 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
logger.info(f"weights are merged")
"""

def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []

def enumerate_params(loras):
params = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
for name, param in lora.named_parameters():
if ratio is not None and "lora_B" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr

if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue

params.append(param_data)

return params

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

return all_params

Expand Down
75 changes: 54 additions & 21 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,21 +1034,55 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
return lr_weight

# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []

def enumerate_params(loras):
params = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr

if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
print("NO LR skipping!")
continue

params.append(param_data)

return params

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

if self.unet_loras:
if self.block_lr:
Expand All @@ -1062,21 +1096,20 @@ def enumerate_params(loras):

# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}

if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
elif default_lr is not None:
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
params = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

else:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

return all_params

Expand Down
78 changes: 56 additions & 22 deletions networks/lora_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,22 +1033,54 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
return lr_weight

# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []

def enumerate_params(loras: List[LoRAModule]):
params = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
# params.extend(lora.parameters())
params.extend(lora.get_trainable_params())
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr

if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue

params.append(param_data)

return params

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

if self.unet_loras:
if self.block_lr:
Expand All @@ -1062,21 +1094,20 @@ def enumerate_params(loras: List[LoRAModule]):

# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}

if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
elif default_lr is not None:
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
params = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

else:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

return all_params

Expand All @@ -1093,6 +1124,9 @@ def on_epoch_start(self, text_encoder, unet):
def get_trainable_params(self):
return self.parameters()

def get_trainable_named_params(self):
return self.named_parameters()

def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
Expand Down
Loading

0 comments on commit 834445a

Please sign in to comment.