Skip to content
This repository has been archived by the owner on Feb 14, 2025. It is now read-only.

Commit

Permalink
Merge VQGAN v2 to dev (myshell-ai#56)
Browse files Browse the repository at this point in the history
* squash vqgan v2 changes

* Merge pretrain stage 1 and 2

* Optimize vqgan inference (remove redundant code)

* Implement data mixing

* Optimize vqgan v2 config

* Add support to freeze discriminator

* Add stft loss & larger segement size
  • Loading branch information
leng-yue authored Jan 11, 2024
1 parent 39f6902 commit 1609e9b
Show file tree
Hide file tree
Showing 17 changed files with 1,737 additions and 495 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ repos:
hooks:
- id: codespell
files: ^.*\.(py|md|rst|yml)$
args: [-L=fro]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
Expand Down
84 changes: 59 additions & 25 deletions fish_speech/configs/vqgan_pretrain_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ defaults:
- _self_

project: vqgan_pretrain_v2
ckpt_path: checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt
resume_weights_only: true

# Lightning Trainer
trainer:
Expand All @@ -15,22 +17,36 @@ trainer:

sample_rate: 44100
hop_length: 512
num_mels: 128
num_mels: 160
n_fft: 2048
win_length: 2048
segment_size: 256

# Dataset Configuration
train_dataset:
_target_: fish_speech.datasets.vqgan.VQGANDataset
filelist: data/Genshin/vq_train_filelist.txt
sample_rate: ${sample_rate}
hop_length: ${hop_length}
slice_frames: ${segment_size}
_target_: fish_speech.datasets.vqgan.MixDatast
datasets:
high-quality-441:
prob: 0.5
dataset:
_target_: fish_speech.datasets.vqgan.VQGANDataset
filelist: data/vocoder_data_441/vq_train_filelist.txt
sample_rate: ${sample_rate}
hop_length: ${hop_length}
slice_frames: ${segment_size}

common-voice:
prob: 0.5
dataset:
_target_: fish_speech.datasets.vqgan.VQGANDataset
filelist: data/cv-corpus-16.0-2023-12-06/vq_train_filelist.txt
sample_rate: ${sample_rate}
hop_length: ${hop_length}
slice_frames: ${segment_size}

val_dataset:
_target_: fish_speech.datasets.vqgan.VQGANDataset
filelist: data/Genshin/vq_val_filelist.txt
filelist: data/vocoder_data_441/vq_val_filelist.txt
sample_rate: ${sample_rate}
hop_length: ${hop_length}

Expand All @@ -47,8 +63,9 @@ model:
_target_: fish_speech.models.vqgan.VQGAN
sample_rate: ${sample_rate}
hop_length: ${hop_length}
segment_size: 8192
mode: pretrain-stage1
segment_size: 32768
mode: pretrain
freeze_discriminator: true

downsample:
_target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
Expand All @@ -67,8 +84,8 @@ model:
_target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
in_channels: 256
vq_channels: 256
codebook_size: 1024
codebook_layers: 4
codebook_size: 256
codebook_groups: 4
downsample: 1

decoder:
Expand All @@ -80,33 +97,50 @@ model:
n_layers: 6

generator:
_target_: fish_speech.models.vqgan.modules.decoder.Generator
initial_channel: ${num_mels}
resblock: "1"
_target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
hop_length: ${hop_length}
upsample_rates: [8, 8, 2, 2, 2] # aka. strides
upsample_kernel_sizes: [16, 16, 4, 4, 4]
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
upsample_rates: [8, 8, 2, 2, 2]
num_mels: ${num_mels}
upsample_initial_channel: 512
upsample_kernel_sizes: [16, 16, 4, 4, 4]

discriminator:
_target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
periods: [2, 3, 5, 7, 11, 17, 23, 37]

use_template: true
pre_conv_kernel_size: 7
post_conv_kernel_size: 7

discriminators:
_target_: torch.nn.ModuleDict
modules:
mpd:
_target_: fish_speech.models.vqgan.modules.discriminators.mpd.MultiPeriodDiscriminator
periods: [2, 3, 5, 7, 11, 17, 23, 37]

mrd:
_target_: fish_speech.models.vqgan.modules.discriminators.mrd.MultiResolutionDiscriminator
resolutions:
- ["${n_fft}", "${hop_length}", "${win_length}"]
- [1024, 120, 600]
- [2048, 240, 1200]
- [4096, 480, 2400]
- [512, 50, 240]

multi_resolution_stft_loss:
_target_: fish_speech.models.vqgan.losses.MultiResolutionSTFTLoss
resolutions: ${model.discriminators.modules.mrd.resolutions}

mel_transform:
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
sample_rate: ${sample_rate}
n_fft: ${n_fft}
hop_length: ${hop_length}
win_length: ${win_length}
n_mels: ${num_mels}
f_min: 0
f_max: 16000

optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 2e-4
lr: 1e-4
betas: [0.8, 0.99]
eps: 1e-5

Expand All @@ -119,7 +153,7 @@ callbacks:
grad_norm_monitor:
sub_module:
- generator
- discriminator
- discriminators
- mel_encoder
- vq_encoder
- decoder
31 changes: 29 additions & 2 deletions fish_speech/datasets/vqgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import torch
from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader, Dataset, IterableDataset

from fish_speech.utils import RankedLogger

Expand Down Expand Up @@ -72,6 +72,33 @@ def __getitem__(self, idx):
return None


class MixDatast(IterableDataset):
def __init__(self, datasets: dict[str, dict], seed: int = 42) -> None:
values = list(datasets.values())
probs = [v["prob"] for v in values]
self.datasets = [v["dataset"] for v in values]

total_probs = sum(probs)
self.probs = [p / total_probs for p in probs]
self.seed = seed

def __iter__(self):
rng = np.random.default_rng(self.seed)
dataset_iterators = [iter(dataset) for dataset in self.datasets]

while True:
# Random choice one
dataset_idx = rng.choice(len(self.datasets), p=self.probs)
dataset_iterator = dataset_iterators[dataset_idx]

try:
yield next(dataset_iterator)
except StopIteration:
# Exhausted, create a new iterator
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
yield next(dataset_iterators[dataset_idx])


@dataclass
class VQGANCollator:
def __call__(self, batch):
Expand Down Expand Up @@ -116,7 +143,7 @@ def train_dataloader(self):
batch_size=self.batch_size,
collate_fn=VQGANCollator(),
num_workers=self.num_workers,
shuffle=True,
shuffle=not isinstance(self.train_dataset, IterableDataset),
)

def val_dataloader(self):
Expand Down
Loading

0 comments on commit 1609e9b

Please sign in to comment.