Skip to content

Commit

Permalink
[TTS]add starganv2 vc trainer (#3143)
Browse files Browse the repository at this point in the history
* add starganv2 vc trainer

* fix StarGANv2VCUpdater and losses

* fix StarGANv2VCEvaluator

* add some typehint
  • Loading branch information
yt605155624 authored Apr 10, 2023
1 parent 54ef90f commit 72aa19c
Show file tree
Hide file tree
Showing 6 changed files with 911 additions and 97 deletions.
111 changes: 106 additions & 5 deletions examples/vctk/vc3/conf/default.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,123 @@
generator_params:
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
# 其实没用上,其实用的是 16000
sr: 24000
n_fft: 2048
win_length: 1200
hop_length: 300
n_mels: 80
###########################################################
# MODEL SETTING #
###########################################################
generator_params:
dim_in: 64
style_dim: 64
max_conv_dim: 512
w_hpf: 0
F0_channel: 256
mapping_network_params:
mapping_network_params:
num_domains: 20 # num of speakers in StarGANv2
latent_dim: 16
style_dim: 64 # same as style_dim in generator_params
hidden_dim: 512 # same as max_conv_dim in generator_params
style_encoder_params:
style_encoder_params:
dim_in: 64 # same as dim_in in generator_params
style_dim: 64 # same as style_dim in generator_params
num_domains: 20 # same as num_domains in generator_params
max_conv_dim: 512 # same as max_conv_dim in generator_params
discriminator_params:
discriminator_params:
dim_in: 64 # same as dim_in in generator_params
num_domains: 20 # same as num_domains in mapping_network_params
max_conv_dim: 512 # same as max_conv_dim in generator_params
n_repeat: 4

asr_params:
input_dim: 80
hidden_dim: 256
n_token: 80
token_embedding_dim: 256

###########################################################
# ADVERSARIAL LOSS SETTING #
###########################################################
loss_params:
g_loss:
lambda_sty: 1.
lambda_cyc: 5.
lambda_ds: 1.
lambda_norm: 1.
lambda_asr: 10.
lambda_f0: 5.
lambda_f0_sty: 0.1
lambda_adv: 2.
lambda_adv_cls: 0.5
norm_bias: 0.5
d_loss:
lambda_reg: 1.
lambda_adv_cls: 0.1
lambda_con_reg: 10.

adv_cls_epoch: 50
con_reg_epoch: 30


###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 5 # Batch size.
num_workers: 2 # Number of workers in DataLoader.

###########################################################
# OPTIMIZER & SCHEDULER SETTING #
###########################################################
generator_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
generator_scheduler_params:
max_learning_rate: 2e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
style_encoder_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
style_encoder_scheduler_params:
max_learning_rate: 2e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
mapping_network_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
mapping_network_scheduler_params:
max_learning_rate: 2e-6
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-6
discriminator_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
discriminator_scheduler_params:
max_learning_rate: 2e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4

###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 150
num_snapshots: 5
seed: 1
32 changes: 26 additions & 6 deletions paddlespeech/t2s/datasets/am_batch_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def erniesat_batch_fn(examples,
]
span_bdy = paddle.to_tensor(span_bdy)

# dual_mask 的是混合中英时候同时 mask 语音和文本
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if text_masking:
masked_pos, text_masked_pos = phones_text_masking(
Expand Down Expand Up @@ -153,7 +153,7 @@ def erniesat_batch_fn(examples,
batch = {
"text": text,
"speech": speech,
# need to generate
# need to generate
"masked_pos": masked_pos,
"speech_mask": speech_mask,
"text_mask": text_mask,
Expand Down Expand Up @@ -415,10 +415,13 @@ def fastspeech2_multi_spk_batch_fn(examples):


def diffsinger_single_spk_batch_fn(examples):
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", \
# "speech", "speech_lengths", "durations", "pitch", "energy"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
note = [np.array(item["note"], dtype=np.int64) for item in examples]
note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples]
note_dur = [
np.array(item["note_dur"], dtype=np.float32) for item in examples
]
is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
Expand Down Expand Up @@ -471,10 +474,13 @@ def diffsinger_single_spk_batch_fn(examples):


def diffsinger_multi_spk_batch_fn(examples):
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"]
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", \
# "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
note = [np.array(item["note"], dtype=np.int64) for item in examples]
note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples]
note_dur = [
np.array(item["note_dur"], dtype=np.float32) for item in examples
]
is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
Expand Down Expand Up @@ -663,6 +669,20 @@ def vits_multi_spk_batch_fn(examples):
return batch


# 未完成
def starganv2_vc_batch_fn(examples):
batch = {
"x_real": None,
"y_org": None,
"x_ref": None,
"x_ref2": None,
"y_trg": None,
"z_trg": None,
"z_trg2": None,
}
return batch


# for PaddleSlim
def fastspeech2_single_spk_batch_fn_static(examples):
text = [np.array(item["text"], dtype=np.int64) for item in examples]
Expand Down
Loading

0 comments on commit 72aa19c

Please sign in to comment.