-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathconvert_model.py
224 lines (185 loc) · 7.26 KB
/
convert_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import argparse
import os
import re
import torch
from utils import get_hparams_from_file, load_checkpoint
from models_ncnn import SynthesizerTrn
import shutil
from torch import _weight_norm
from torch.nn import Parameter
def create_folders(model_path, multi):
# creating dirs
cache_root = "ncnn_cache/" # temp folder
model_path = model_path.replace("\\","/")
match = re.match('.*/(.*)\.pth', model_path)
out_folder = match.group(1) # out folder
# flow
flow_folder = cache_root + "flow"
# flow reversed
flow_reversed_folder = cache_root + "flow_reverse"
# enc_p folder
enc_p_folder = cache_root + "enc_p"
# dp
dp_folder = cache_root + "dp"
# dec
dec_folder = cache_root + "dec"
# enc_q
enc_q_folder = cache_root + "enc_q"
folders = [cache_root, out_folder, flow_reversed_folder, flow_folder, enc_p_folder, dp_folder, dec_folder, enc_q_folder]
if not multi:
folders.remove(flow_folder)
folders.remove(enc_q_folder)
for folder in folders:
if not os.path.exists(folder):
os.mkdir(folder)
if not os.path.exists(folder):
raise RuntimeError("Directory creation failed!")
return folders
def convert_model(net, folder, name, multi):
layer_inputs = None
if multi:
layer_inputs = {
"enc_p": [torch.randint(0,20,(1,100)), torch.LongTensor([100]), net.enc_p.emb.weight.data],
"dp": [torch.randn((1,192,100)),torch.ones((1,1,100)),torch.randn((1, 2, 100)),0.8 * torch.ones((1,2,100)),torch.randn((1,256,1))],
"flow":[torch.randn((1,192,255)),torch.ones((1,1,255)),torch.randn((1,256,1))],
"flow.reverse": [torch.randn((1,192,255)),torch.ones((1,1,255)),torch.randn((1,256,1))],
"dec": [torch.randn((1,192,255)),torch.randn((1,256,1))],
"enc_q": [torch.randn((1,513,336)),torch.LongTensor([336]),torch.randn((1,256,1))]
}
else:
layer_inputs = {
"enc_p": [torch.randint(0,20,(1,100)), torch.LongTensor([100]), net.enc_p.emb.weight.data],
"dp": [torch.randn((1,192,100)),torch.ones((1,1,100)),torch.randn((1, 2, 100)),0.8 * torch.ones((1,2,100))],
"flow.reverse": [torch.randn((1,192,255)),torch.ones((1,1,255))],
"dec": [torch.randn((1,192,255))]
}
custom_ops = {
"enc_p": "modules.Transpose,modules.SequenceMask,modules.Embedding,attentions.Attention,attentions.ExpandDim,attentions.SamePadding",
"dp": "modules.PRQTransform,modules.Transpose,modules.ReduceDims",
"flow": "modules.ResidualReverse",
"flow.reverse": "modules.ResidualReverse",
"dec": "",
"enc_q": "modules.RandnLike,modules.ResidualReverse,modules.SequenceMask",
}
if name == "flow_reverse":
name = name.replace("_",".")
layer = getattr(net, "flow")
layer.reverse = True
elif name == "flow":
layer = getattr(net, name)
layer.reverse = False
else:
layer = getattr(net, name)
path_pt = os.path.join(folder, name+".pt")
torch.jit.trace(layer, layer_inputs[name]).save(path_pt)
os.system("{} {} fp16={} moduleop={}".format(pnnx_path, path_pt, 1 if fp16 else 0, custom_ops[name]))
def export(root_folder, out_folder, name):
# copy files
src_folder = os.path.join(root_folder, name)
if name == "flow_reverse":
name = name.replace("_",".")
src_path = os.path.join(src_folder, name+".ncnn.bin")
target_path = os.path.join(out_folder, name+".ncnn.bin")
shutil.copy(src_path, target_path)
def main(args):
# multi model or not
multi = False
global pnnx_path
global fp16
# input path
config_path = args.config_path
model_path = args.model_path
fp16 = args.fp16
if os.name == "nt":
pnnx_path = "pnnx\\pnnx.exe"
elif os.name == "posix":
pnnx_path = "pnnx/pnnx"
else:
raise RuntimeError("Unsupported system!")
if not os.path.exists(config_path):
raise RuntimeError("Config file does not exist!")
if not os.path.exists(model_path):
raise RuntimeError("Model file does not exist!")
if not os.path.exists(pnnx_path):
raise RuntimeError("pnnx does not exist!")
# load configs
hps = get_hparams_from_file(config_path)
# 增加对https://github.com/JOETtheIV/VITS-Paimon的支持,请手动添加n_vocabs参数到配置文件的data下
if "n_vocabs" in hps.data.keys():
n_symbols = hps.data.n_vocabs
elif "symbols" in hps.keys():
n_symbols = len(hps.symbols)
else:
n_symbols = 0
if n_symbols == 0:
raise RuntimeError("Symbols can not be empty!")
if hps.data.n_speakers > 0:
multi = True
if hps.data.n_speakers != len(hps.speakers):
raise RuntimeError("n_speakers and speakers mismatch!")
# create model
if multi:
net_g = SynthesizerTrn(
n_symbols,
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
else:
net_g = SynthesizerTrn(
n_symbols,
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
# load checkpoints
_ = net_g.eval()
_ = load_checkpoint(model_path,net_g, None)
# remove redundant weigths
for _, module in net_g.named_modules():
g = getattr(module, "weight_g", None)
v = getattr(module, "weight_v", None)
if g != None and v != None:
normed = _weight_norm(v, g, 0)
module.weight = Parameter(normed)
delattr(module, "weight_g")
delattr(module, "weight_v")
# create folders
folders = create_folders(model_path, multi)
cache_root = folders[0]
out_root = folders[1]
# convert
for folder in folders[2:]:
name = folder.replace(cache_root, "")
convert_model(net_g, folder, name, multi)
# export embedding
emb_weight = net_g.enc_p.emb.weight.data.flatten().numpy().astype("float32")
with open(os.path.join(out_root, "emb_t.bin"), "wb") as f:
f.write(emb_weight)
if multi:
emb_weight = net_g.emb_g.weight.data.flatten().numpy().astype("float32")
with open(os.path.join(out_root, "emb_g.bin"), "wb") as f:
f.write(emb_weight)
# export
for folder in folders[2:]:
name = folder.replace(cache_root, "")
export(cache_root, out_root, name)
# clean
shutil.rmtree(cache_root)
if os.path.exists("debug.bin"):
os.remove("debug.bin")
if os.path.exists("debug.param"):
os.remove("debug.param")
if os.path.exists("debug2.bin"):
os.remove("debug2.bin")
if os.path.exists("debug2.param"):
os.remove("debug2.param")
print("Cleaned!")
shutil.copy(config_path, os.path.join(out_root,"config.json"))
print("Success!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-c","--config_path",type=str, help="path/to/config.json")
parser.add_argument("-m", "--model_path", type=str, help="path/to/model.pth")
parser.add_argument("-fp16", "--fp16", action="store_true", help="half precision on/off")
args = parser.parse_args()
main(args)