diff --git a/.gitignore b/.gitignore index 5d4e82e0..feb4ef97 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,8 @@ log generated data text +datasets +testout # Created by https://www.gitignore.io diff --git a/README.md b/README.md index 4a47a2ef..fa610fd0 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,8 @@ A notebook supposed to be executed on https://colab.research.google.com is avail - Convolutional sequence-to-sequence model with attention for text-to-speech synthesis - Multi-speaker and single speaker versions of DeepVoice3 - Audio samples and pre-trained models -- Preprocessor for [LJSpeech (en)](https://keithito.com/LJ-Speech-Dataset/), [JSUT (jp)](https://sites.google.com/site/shinnosuketakamichi/publication/jsut) and [VCTK](http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html) datasets -- Language-dependent frontend text processor for English and Japanese +- Preprocessor for [LJSpeech (en)](https://keithito.com/LJ-Speech-Dataset/), [JSUT (jp)](https://sites.google.com/site/shinnosuketakamichi/publication/jsut) and [VCTK](http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html) datasets, as well as [carpedm20/multi-speaker-tacotron-tensorflow](https://github.com/carpedm20/multi-Speaker-tacotron-tensorflow) compatible custom dataset (in JSON format) +- Language-dependent frontend text processor for English and Japanese ### Samples @@ -102,7 +102,7 @@ python train.py --preset=presets/deepvoice3_ljspeech.json --data-root=./data/ljs - LJSpeech (en): https://keithito.com/LJ-Speech-Dataset/ - VCTK (en): http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html - JSUT (jp): https://sites.google.com/site/shinnosuketakamichi/publication/jsut -- NIKL (ko): http://www.korean.go.kr/front/board/boardStandardView.do?board_id=4&mn_id=17&b_seq=464 +- NIKL (ko) (**Need korean cellphone number to access it**): http://www.korean.go.kr/front/board/boardStandardView.do?board_id=4&mn_id=17&b_seq=464 ### 1. Preprocessing @@ -128,6 +128,47 @@ python preprocess.py --preset=presets/deepvoice3_ljspeech.json ljspeech ~/data/L When this is done, you will see extracted features (mel-spectrograms and linear spectrograms) in `./data/ljspeech`. +#### 1-1. Building custom dataset. (using json_meta) +Building your own dataset, with metadata in JSON format (compatible with [carpedm20/multi-speaker-tacotron-tensorflow](https://github.com/carpedm20/multi-Speaker-tacotron-tensorflow)) is currently supported. +Usage: + +``` +python preprocess.py json_meta ${list-of-JSON-metadata-paths} ${out_dir} --preset= +``` +You may need to modify pre-existing preset JSON file, especially `n_speakers`. For english multispeaker, start with `presets/deepvoice3_vctk.json`. + +Assuming you have dataset A (Speaker A) and dataset B (Speaker B), each described in the JSON metadata file `./datasets/datasetA/alignment.json` and `./datasets/datasetB/alignment.json`, then you can preprocess data by: + +``` +python preprocess.py json_meta "./datasets/datasetA/alignment.json,./datasets/datasetB/alignment.json" "./datasets/processed_A+B" --preset=(path to preset json file) +``` + +#### 1-2. Preprocessing custom english datasets with long silence. (Based on [vctk_preprocess](vctk_preprocess/)) + +Some dataset, especially automatically generated dataset may include long silence and undesirable leading/trailing noises, undermining the char-level seq2seq model. +(e.g. VCTK, although this is covered in vctk_preprocess) + +To deal with the problem, `gentle_web_align.py` will +- **Prepare phoneme alignments for all utterances** +- Cut silences during preprocessing + +`gentle_web_align.py` uses [Gentle](https://github.com/lowerquality/gentle), a kaldi based speech-text alignment tool. This accesses web-served Gentle application, aligns given sound segments with transcripts and converts the result to HTK-style label files, to be processed in `preprocess.py`. Gentle can be run in Linux/Mac/Windows(via Docker). + +Preliminary results show that while HTK/festival/merlin-based method in `vctk_preprocess/prepare_vctk_labels.py` works better on VCTK, Gentle is more stable with audio clips with ambient noise. (e.g. movie excerpts) + +Usage: +(Assuming Gentle is running at `localhost:8567` (Default when not specified)) +1. When sound file and transcript files are saved in separate folders. (e.g. sound files are at `datasetA/wavs` and transcripts are at `datasetA/txts`) +``` +python gentle_web_align.py -w "datasetA/wavs/*.wav" -t "datasetA/txts/*.txt" --server_addr=localhost --port=8567 +``` + +2. When sound file and transcript files are saved in nested structure. (e.g. `datasetB/speakerN/blahblah.wav` and `datasetB/speakerN/blahblah.txt`) +``` +python gentle_web_align.py --nested-directories="datasetB" --server_addr=localhost --port=8567 +``` +**Once you have phoneme alignment for each utterance, you can extract features by running `preprocess.py`** + ### 2. Training Usage: @@ -139,7 +180,7 @@ python train.py --data-root=${data-root} --preset= --hparams="parameters y Suppose you build a DeepVoice3-style model using LJSpeech dataset, then you can train your model by: ``` -python train.py --preset=presets/deepvoice3_ljspeech.json --data-root=./data/ljspeech/ +python train.py --preset=presets/deepvoice3_ljspeech.json --data-root=./data/ljspeech/ ``` Model checkpoints (.pth) and alignments (.png) are saved in `./checkpoints` directory per 10000 steps by default. @@ -247,7 +288,9 @@ From my experience, it can get reasonable speech quality very quickly rather tha There are two important options used above: - `--restore-parts=`: It specifies where to load model parameters. The differences from the option `--checkpoint=` are 1) `--restore-parts=` ignores all invalid parameters, while `--checkpoint=` doesn't. 2) `--restore-parts=` tell trainer to start from 0-step, while `--checkpoint=` tell trainer to continue from last step. `--checkpoint=` should be ok if you are using exactly same model and continue to train, but it would be useful if you want to customize your model architecture and take advantages of pre-trained model. -- `--speaker-id=`: It specifies what speaker of data is used for training. This should only be specified if you are using multi-speaker dataset. As for VCTK, speaker id is automatically assigned incrementally (0, 1, ..., 107) according to the `speaker_info.txt` in the dataset. +- `--speaker-id=`: It specifies what speaker of data is used for training. This should only be specified if you are using multi-speaker dataset. As for VCTK, speaker id is automatically assigned incrementally (0, 1, ..., 107) according to the `speaker_info.txt` in the dataset. + +If you are training multi-speaker model, speaker adaptation will only work **when `n_speakers` is identical**. ## Acknowledgements diff --git a/deepvoice3_pytorch/conv.py b/deepvoice3_pytorch/conv.py index e60377f2..121de23d 100644 --- a/deepvoice3_pytorch/conv.py +++ b/deepvoice3_pytorch/conv.py @@ -40,7 +40,7 @@ def incremental_forward(self, input): self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() # append next input self.input_buffer[:, -1, :] = input[:, -1, :] - input = torch.Tensor(self.input_buffer) + input = self.input_buffer.clone() if dilation > 1: input = input[:, 0::dilation, :].contiguous() output = F.linear(input.view(bsz, -1), weight, self.bias) diff --git a/gentle_web_align.py b/gentle_web_align.py new file mode 100644 index 00000000..856deec6 --- /dev/null +++ b/gentle_web_align.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Apr 21 09:06:37 2018 +Phoneme alignment and conversion in HTK-style label file using Web-served Gentle +This works on any type of english dataset. +Unlike prepare_htk_alignments_vctk.py, this is Python3 and Windows(with Docker) compatible. +Preliminary results show that gentle has better performance with noisy dataset +(e.g. movie extracted audioclips) +*This work was derived from vctk_preprocess/prepare_htk_alignments_vctk.py +@author: engiecat(github) + +usage: + gentle_web_align.py (-w wav_pattern) (-t text_pattern) [options] + gentle_web_align.py (--nested-directories=) [options] + +options: + -w --wav_pattern= Pattern of wav files to be aligned + -t --txt_pattern= Pattern of txt transcript files to be aligned (same name required) + --nested-directories= Process every wav/txt file in the subfolders of the given folder + --server_addr= Server address that serves gentle. [default: localhost] + --port= Server port that serves gentle. [default: 8567] + --max_unalign= Maximum threshold for unalignment occurence (0.0 ~ 1.0) [default: 0.3] + --skip-already-done Skips if there are preexisting .lab file + -h --help show this help message and exit +""" + +from docopt import docopt +from glob import glob +from tqdm import tqdm +import os.path +import requests +import numpy as np + +def write_hts_label(labels, lab_path): + lab = "" + for s, e, l in labels: + s, e = float(s) * 1e7, float(e) * 1e7 + s, e = int(s), int(e) + lab += "{} {} {}\n".format(s, e, l) + print(lab) + with open(lab_path, "w", encoding='utf-8') as f: + f.write(lab) + + +def json2hts(data): + emit_bos = False + emit_eos = False + + phone_start = 0 + phone_end = None + labels = [] + failure_count = 0 + + for word in data["words"]: + case = word["case"] + if case != "success": + failure_count += 1 # instead of failing everything, + #raise RuntimeError("Alignment failed") + continue + start = float(word["start"]) + word_end = float(word["end"]) + + if not emit_bos: + labels.append((phone_start, start, "silB")) + emit_bos = True + + phone_start = start + phone_end = None + for phone in word["phones"]: + ph = str(phone["phone"][:-2]) + duration = float(phone["duration"]) + phone_end = phone_start + duration + labels.append((phone_start, phone_end, ph)) + phone_start += duration + assert np.allclose(phone_end, word_end) + if not emit_eos: + labels.append((phone_start, phone_end, "silE")) + emit_eos = True + unalign_ratio = float(failure_count) / len(data['words']) + return unalign_ratio, labels + + +def gentle_request(wav_path,txt_path, server_addr, port, debug=False): + print('\n') + response = None + wav_name = os.path.basename(wav_path) + txt_name = os.path.basename(txt_path) + if os.path.splitext(wav_name)[0] != os.path.splitext(txt_name)[0]: + print(' [!] wav name and transcript name does not match - exiting...') + return response + with open(txt_path, 'r', encoding='utf-8-sig') as txt_file: + print('Transcript - '+''.join(txt_file.readlines())) + with open(wav_path,'rb') as wav_file, open(txt_path, 'rb') as txt_file: + params = (('async','false'),) + files={'audio':(wav_name,wav_file), + 'transcript':(txt_name,txt_file), + } + server_path = 'http://'+server_addr+':'+str(port)+'/transcriptions' + response = requests.post(server_path, params=params,files=files) + if response.status_code != 200: + print(' [!] External server({}) returned bad response({})'.format(server_path, response.status_code)) + if debug: + print('Response') + print(response.json()) + return response + +if __name__ == '__main__': + arguments = docopt(__doc__) + server_addr = arguments['--server_addr'] + port = int(arguments['--port']) + max_unalign = float(arguments['--max_unalign']) + if arguments['--nested-directories'] is None: + wav_paths = sorted(glob(arguments['--wav_pattern'])) + txt_paths = sorted(glob(arguments['--txt_pattern'])) + else: + # if this is multi-foldered environment + # (e.g. DATASET/speaker1/blahblah.wav) + wav_paths=[] + txt_paths=[] + topdir = arguments['--nested-directories'] + subdirs = [f for f in os.listdir(topdir) if os.path.isdir(os.path.join(topdir, f))] + for subdir in subdirs: + wav_pattern_subdir = os.path.join(topdir, subdir, '*.wav') + txt_pattern_subdir = os.path.join(topdir, subdir, '*.txt') + wav_paths.extend(sorted(glob(wav_pattern_subdir))) + txt_paths.extend(sorted(glob(txt_pattern_subdir))) + + t = tqdm(range(len(wav_paths))) + for idx in t: + try: + t.set_description("Align via Gentle") + wav_path = wav_paths[idx] + txt_path = txt_paths[idx] + lab_path = os.path.splitext(wav_path)[0]+'.lab' + if os.path.exists(lab_path) and arguments['--skip-already-done']: + print('[!] skipping because of pre-existing .lab file - {}'.format(lab_path)) + continue + res=gentle_request(wav_path,txt_path, server_addr, port) + unalign_ratio, lab = json2hts(res.json()) + print('[*] Unaligned Ratio - {}'.format(unalign_ratio)) + if unalign_ratio > max_unalign: + print('[!] skipping this due to bad alignment') + continue + write_hts_label(lab, lab_path) + except: + # if sth happens, skip it + import traceback + tb = traceback.format_exc() + print('[!] ERROR while processing {}'.format(wav_paths[idx])) + print('[!] StackTrace - ') + print(tb) + + \ No newline at end of file diff --git a/hparams.py b/hparams.py index 0a1e15fa..2373a050 100644 --- a/hparams.py +++ b/hparams.py @@ -125,6 +125,14 @@ # Forced garbage collection probability # Use only when MemoryError continues in Windows (Disabled by default) #gc_probability = 0.001, + + # json_meta mode only + # 0: "use all", + # 1: "ignore only unmatched_alignment", + # 2: "fully ignore recognition", + ignore_recognition_level = 2, + min_text=20, # when dealing with non-dedicated speech dataset(e.g. movie excerpts), setting min_text above 15 is desirable. Can be adjusted by dataset. + process_only_htk_aligned = False, # if true, data without phoneme alignment file(.lab) will be ignored ) diff --git a/json_meta.py b/json_meta.py new file mode 100644 index 00000000..2e654e9a --- /dev/null +++ b/json_meta.py @@ -0,0 +1,260 @@ +''' +Started in 1945h, Mar 10, 2018 +First done in 2103h, Mar 11, 2018 +Test done in 2324h, Mar 11, 2018 +Modified for HTK labeling in 1426h, Apr 21, 2018 +by engiecat(github) + +This makes r9y9/deepvoice3_pytorch compatible with json format of carpedm20/multi-speaker-tacotron-tensorflow and keithito/tacotron. +The json file is given per speaker, generated in the format of + (if completely aligned) + (path-to-the-audio):aligned text + + (if partially aligned) + (path-to-the-audio):[candidate sentence - not aligned,recognized words] + + (if non-aligned) + (path-to-the-audio):[recognized words] +is given per speaker. + +(e.g. python preprocess.py json_meta "./datasets/LJSpeech_1_0/alignment.json,./datasets/GoTBookRev/alignment.json" "./datasets/LJ+GoTBookRev" --preset=./presets/deepvoice3_vctk.json ) + +usage: + python preprocess.py [option] + + +options: + --preset Path of preset parameters (json). + -h --help show this help message and exit + + +''' + +from concurrent.futures import ProcessPoolExecutor +from functools import partial +import numpy as np +import os +import audio +from nnmnkwii.io import hts +from hparams import hparams +from os.path import exists +import librosa +import json + +def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x): + executor = ProcessPoolExecutor(max_workers=num_workers) + futures = [] + + json_paths = in_dir.split(',') + json_paths = [json_path.replace("'", "").replace('"',"") for json_path in json_paths] + num_speakers = len(json_paths) + is_aligned = {} + + speaker_id=0 + for json_path in json_paths: + # Loads json metadata info + if json_path.endswith("json"): + with open(json_path, encoding='utf8') as f: + content = f.read() + info = json.loads(content) + elif json_path.endswith("csv"): + with open(json_path) as f: + info = {} + for line in f: + path, text = line.strip().split('|') + info[path] = text + else: + raise Exception(" [!] Unknown metadata format: {}".format(json_path)) + + print(" [*] Loaded - {}".format(json_path)) + # check audio file existence + base_dir = os.path.dirname(json_path) + new_info = {} + for path in info.keys(): + if not os.path.exists(path): + new_path = os.path.join(base_dir, path) + if not os.path.exists(new_path): + print(" [!] Audio not found: {}".format([path, new_path])) + continue + else: + new_path = path + + new_info[new_path] = info[path] + + info = new_info + + # ignore_recognition_level check + for path in info.keys(): + is_aligned[path] = True + if isinstance(info[path], list): + if hparams.ignore_recognition_level == 1 and len(info[path]) == 1 or \ + hparams.ignore_recognition_level == 2: + # flag the path to be 'non-aligned' text + is_aligned[path] = False + info[path] = info[path][0] + + # Reserve for future processing + queue_count = 0 + for audio_path, text in info.items(): + if isinstance(text, list): + if hparams.ignore_recognition_level == 0: + text = text[-1] + else: + text = text[0] + if hparams.ignore_recognition_level > 0 and not is_aligned[audio_path]: + continue + if hparams.min_text > len(text): + continue + if num_speakers == 1: + # Single-speaker + futures.append(executor.submit( + partial(_process_utterance_single, out_dir, text, audio_path))) + else: + # Multi-speaker + futures.append(executor.submit( + partial(_process_utterance, out_dir, text, audio_path, speaker_id))) + queue_count += 1 + print(" [*] Appended {} entries in the queue".format(queue_count)) + + # increase speaker_id + speaker_id += 1 + + # Show ignore_recognition_level description + ignore_description = { + 0: "use all", + 1: "ignore only unmatched_alignment", + 2: "fully ignore recognition", + } + print(" [!] Skip recognition level: {} ({})". \ + format(hparams.ignore_recognition_level, + ignore_description[hparams.ignore_recognition_level])) + + if num_speakers == 1: + print(" [!] Single-speaker mode activated!") + else: + print(" [!] Multi-speaker({}) mode activated!".format(num_speakers)) + + # Now, Do the job! + results = [future.result() for future in tqdm(futures)] + # Remove entries with None (That has been filtered due to bad htk alginment (if process_only_htk_aligned is enabled in hparams) + results = [result for result in results if result != None] + return results + + +def start_at(labels): + has_silence = labels[0][-1] == "pau" + if not has_silence: + return labels[0][0] + for i in range(1, len(labels)): + if labels[i][-1] != "pau": + return labels[i][0] + assert False + + +def end_at(labels): + has_silence = labels[-1][-1] == "pau" + if not has_silence: + return labels[-1][1] + for i in range(len(labels) - 2, 0, -1): + if labels[i][-1] != "pau": + return labels[i][1] + assert False + + +def _process_utterance(out_dir, text, wav_path, speaker_id=None): + + # check whether singlespeaker_mode + if speaker_id is None: + return _process_utterance_single(out_dir,text,wav_path) + # modified version of VCTK _process_utterance + sr = hparams.sample_rate + + # Load the audio to a numpy array: + wav = audio.load_wav(wav_path) + + lab_path = wav_path.replace("wav48/", "lab/").replace(".wav", ".lab") + if not exists(lab_path): + lab_path = os.path.splitext(wav_path)[0]+'.lab' + + # Trim silence from hts labels if available + if exists(lab_path): + labels = hts.load(lab_path) + b = int(start_at(labels) * 1e-7 * sr) + e = int(end_at(labels) * 1e-7 * sr) + wav = wav[b:e] + wav, _ = librosa.effects.trim(wav, top_db=25) + else: + if hparams.process_only_htk_aligned: + return None + wav, _ = librosa.effects.trim(wav, top_db=15) + + if hparams.rescaling: + wav = wav / np.abs(wav).max() * hparams.rescaling_max + + # Compute the linear-scale spectrogram from the wav: + spectrogram = audio.spectrogram(wav).astype(np.float32) + n_frames = spectrogram.shape[1] + + # Compute a mel-scale spectrogram from the wav: + mel_spectrogram = audio.melspectrogram(wav).astype(np.float32) + + # Write the spectrograms to disk: + # Get filename from wav_path + wav_name = os.path.basename(wav_path) + wav_name = os.path.splitext(wav_name)[0] + + # case if wave files across different speakers have the same naming format. + # e.g. Recording0.wav + spectrogram_filename = 'spec-{}-{}.npy'.format(speaker_id, wav_name) + mel_filename = 'mel-{}-{}.npy'.format(speaker_id, wav_name) + np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False) + np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.T, allow_pickle=False) + # Return a tuple describing this training example: + return (spectrogram_filename, mel_filename, n_frames, text, speaker_id) + +def _process_utterance_single(out_dir, text, wav_path): + # modified version of LJSpeech _process_utterance + + # Load the audio to a numpy array: + wav = audio.load_wav(wav_path) + sr = hparams.sample_rate + # Added from the multispeaker version + lab_path = wav_path.replace("wav48/", "lab/").replace(".wav", ".lab") + if not exists(lab_path): + lab_path = os.path.splitext(wav_path)[0]+'.lab' + + # Trim silence from hts labels if available + if exists(lab_path): + labels = hts.load(lab_path) + b = int(start_at(labels) * 1e-7 * sr) + e = int(end_at(labels) * 1e-7 * sr) + wav = wav[b:e] + wav, _ = librosa.effects.trim(wav, top_db=25) + else: + if hparams.process_only_htk_aligned: + return None + wav, _ = librosa.effects.trim(wav, top_db=15) + # End added from the multispeaker version + + if hparams.rescaling: + wav = wav / np.abs(wav).max() * hparams.rescaling_max + + # Compute the linear-scale spectrogram from the wav: + spectrogram = audio.spectrogram(wav).astype(np.float32) + n_frames = spectrogram.shape[1] + + # Compute a mel-scale spectrogram from the wav: + mel_spectrogram = audio.melspectrogram(wav).astype(np.float32) + + # Write the spectrograms to disk: + # Get filename from wav_path + wav_name = os.path.basename(wav_path) + wav_name = os.path.splitext(wav_name)[0] + spectrogram_filename = 'spec-{}.npy'.format(wav_name) + mel_filename = 'mel-{}.npy'.format(wav_name) + np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False) + np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.T, allow_pickle=False) + + # Return a tuple describing this training example: + return (spectrogram_filename, mel_filename, n_frames, text) + diff --git a/preprocess.py b/preprocess.py index 0785897f..d76de83f 100644 --- a/preprocess.py +++ b/preprocess.py @@ -54,6 +54,6 @@ def write_metadata(metadata, out_dir): assert hparams.name == "deepvoice3" print(hparams_debug_string()) - assert name in ["jsut", "ljspeech", "vctk", "nikl_m", "nikl_s"] + assert name in ["jsut", "ljspeech", "vctk", "nikl_m", "nikl_s", "json_meta"] mod = importlib.import_module(name) preprocess(mod, in_dir, out_dir, num_workers) diff --git a/setup.py b/setup.py index fa7c9308..17f19ba0 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,8 @@ def create_readme_rst(): "numba", "lws <= 1.0", "nltk", + "requests", + "PyQt5", ], extras_require={ "train": [ diff --git a/train.py b/train.py index 0a253f44..b7918066 100644 --- a/train.py +++ b/train.py @@ -59,6 +59,9 @@ fs = hparams.sample_rate +# Prevent Issue #5 +plt.switch_backend('Qt5Agg') + global_step = 0 global_epoch = 0 use_cuda = torch.cuda.is_available() @@ -916,8 +919,7 @@ def restore_parts(path, model): # Preventing Windows specific error such as MemoryError # Also reduces the occurrence of THAllocator.c 0x05 error in Widows build of PyTorch if platform.system() == "Windows": - print("Windows Detected - num_workers set to 1") - hparams.set_hparam('num_workers', 1) + print(" [!] Windows Detected - IF THAllocator.c 0x05 error occurs SET num_workers to 1") assert hparams.name == "deepvoice3" print(hparams_debug_string())