This repository has been archived by the owner on Mar 29, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathlora_load_weight_only.py
141 lines (122 loc) · 4.78 KB
/
lora_load_weight_only.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import comfy
import folder_paths
import os
import re
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
PRESET_FILE = os.path.join(CURRENT_DIR, "preset.txt")
def extract_numbers(s):
return [int(num) for num in re.findall(r'\d+', s)]
def expand_lbw(weight_list):
length = len(weight_list)
if length == 17:
new_list = []
j = 0
for i in range(26):
if i in LBW17TO26:
new_list.append(0.0)
else:
new_list.append(weight_list[j])
j += 1
elif length == 12:
new_list = []
j = 0
for i in range(20):
if i in LBW12TO20:
new_list.append(0.0)
else:
new_list.append(weight_list[j])
j += 1
else:
new_list = weight_list
return new_list
def parse_weight_preset(text):
lines = text.strip().split("\n")
weight_dict = {}
for line in lines:
key, values = line.split(":")
float_values = [float(x) for x in values.split(",")]
weight_dict[key] = float_values
return weight_dict
def parse_weight_list(text):
if os.path.exists(PRESET_FILE):
with open(PRESET_FILE, "r") as f:
dic = parse_weight_preset(f.read())
else:
dic = {}
if text in dic:
return dic[text]
else:
return [float(weight) for weight in text.split(",")]
LBW17TO26 = [2, 5, 8, 11, 12, 13, 15, 16, 17]
LBW12TO20 = [2, 3, 4, 5, 8, 18, 19, 20]
MID_ID = {26:13, 20:10}
class LoraLoaderWeightOnly:
def __init__(self):
self.loaded_lora = None
self.lbw = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
"lbw": ("STRING", {
"multiline": False,
"default": ""
}),
}
}
RETURN_TYPES = ("LoRA", )
FUNCTION = "load_lora_weight_only"
CATEGORY = "lora_merge"
def load_lora_weight_only(self, lora_name, strength_model, strength_clip, lbw):
lora_path = folder_paths.get_full_path("loras", lora_name)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
temp = self.loaded_lora
self.loaded_lora = None
del temp
if lora is None or self.lbw != lbw:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
if lbw != "":
weight_list = parse_weight_list(lbw)
print(f"{lora_name} block weight is :{weight_list}")
weight_list = expand_lbw(weight_list)
length = len(weight_list)
strength_clip = strength_clip * weight_list[0]
up_keys = [key for key in lora.keys() if "lora_up" in key and not "lora_te" in key]
for key in up_keys:
ids = extract_numbers(key)
if "input_blocks" in key:
block_id = ids[0]
elif "middle_block" in key:
block_id = MID_ID[length]
elif "output_blocks" in key:
block_id = ids[0] + MID_ID[length] + 1
elif "down_blocks" in key:
block_id = ids[0]*3 + ids[1] + 1
if "down_sampler" in key:
block_id += 2
elif "mid_block" in key:
block_id = MID_ID[length]
elif "up_blocks" in key:
block_id = ids[0]*3 + ids[1] + MID_ID[length] + 1
if "up_sampler" in key:
block_id += 2
else:
block_id = 0
#print(key, block_id)
weight = weight_list[block_id]
if weight != 0.0:
lora[key] = lora[key] * weight
else:
del lora[key]
del lora[key.replace("lora_up", "lora_down")]
del lora[key.replace("lora_up.weight", "alpha")]
self.loaded_lora = (lora_path, lora)
self.lbw = lbw
return ({"lora": lora, "strength_model": strength_model, "strength_clip": strength_clip}, )