diff --git a/.gitignore b/.gitignore index f401bc6d..6537fcf1 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,12 @@ split/ singing/ toucan_conda_venv/ venv/ +.venv/ +.hypothesis vis/ +wandb/ +wav2vec2_checkpoints/ +testing* *_graph diff --git a/Evaluation/objective_evaluation.py b/Evaluation/objective_evaluation.py new file mode 100644 index 00000000..8f44fad3 --- /dev/null +++ b/Evaluation/objective_evaluation.py @@ -0,0 +1,499 @@ +import os +from tqdm import tqdm +import csv +from numpy import trim_zeros +import string +import subprocess +import time +from statistics import median, mean + +import torch +import torchaudio +from torch.nn import CosineSimilarity +from datasets import load_dataset +import pandas as pd +import soundfile as sf +from torchmetrics import WordErrorRate + +from Utility.storage_config import PREPROCESSING_DIR +from Utility.utils import get_emotion_from_path +from Utility.utils import float2pcm +from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface +from Preprocessing.AudioPreprocessor import AudioPreprocessor + +EMOTIONS = ["anger", "joy", "neutral", "sadness", "surprise"] + +def extract_dailydialogue_sentences(): + dataset = load_dataset("daily_dialog", split="train", cache_dir=os.path.join(PREPROCESSING_DIR, 'DailyDialogues')) + id_to_emotion = {0: "neutral", 1: "anger", 2: "disgust", 3: "fear", 4: "joy", 5: "sadness", 6: "surprise"} + emotion_to_sents = emotion_to_sents = {"anger":[], "disgust":[], "fear":[], "joy":[], "neutral":[], "sadness":[], "surprise":[]} + + for dialog, emotions in tqdm(zip(dataset["dialog"], dataset["emotion"])): + for sent, emotion in zip(dialog, emotions): + emotion_to_sents[id_to_emotion[emotion]].append(sent.strip()) + + return emotion_to_sents + +def extract_tales_sentences(data_dir): + id_to_emotion = {"N": "neutral", "A": "anger", "D": "disgust", "F": "fear", "H": "joy", "Sa": "sadness", "Su+": "surprise", "Su-": "surprise"} + emotion_to_sents = emotion_to_sents = {"anger":[], "disgust":[], "fear":[], "joy":[], "neutral":[], "sadness":[], "surprise":[]} + + for author in tqdm(os.listdir(data_dir)): + if not author.endswith(".pt"): + for file in os.listdir(os.path.join(data_dir, author, "emmood")): + df = pd.read_csv(os.path.join(data_dir, author, "emmood", file), sep="\t", header=None, quoting=csv.QUOTE_NONE) + for index, (sent_id, emo, mood, sent) in df.iterrows(): + emotions = emo.split(":") + if emotions[0] == emotions[1]: + emotion_to_sents[id_to_emotion[emotions[0]]].append(sent) + return emotion_to_sents + +def get_sorted_test_sentences(emotion_to_sents, classifier): + emotion_to_sents_sorted = {} + for emotion, sentences in emotion_to_sents.items(): + if emotion == "disgust" or emotion == "fear": + continue + sent_score = {} + for sent in tqdm(sentences): + result = classifier(sent) + emo = result[0][0]['label'] + score = result[0][0]['score'] + if emo == emotion: + sent_score[sent] = score + sent_score = dict(sorted(sent_score.items(), key=lambda item: item[1], reverse=True)) + emotion_to_sents_sorted[emotion] = list(sent_score.keys()) + return emotion_to_sents_sorted + +def synthesize_test_sentences(version="Baseline", + exec_device="cpu", + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + test_sentences=None, + silent=False): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/Evaluation", exist_ok=True) + + if version == "Baseline": + os.makedirs(f"audios/Evaluation/Baseline", exist_ok=True) + os.makedirs(f"audios/Evaluation/Baseline/dailydialogues", exist_ok=True) + os.makedirs(f"audios/Evaluation/Baseline/tales", exist_ok=True) + model_id = "Baseline_Finetuning_2_80k" + if version == "Sent": + os.makedirs(f"audios/Evaluation/Sent", exist_ok=True) + os.makedirs(f"audios/Evaluation/Sent/dailydialogues", exist_ok=True) + os.makedirs(f"audios/Evaluation/Sent/tales", exist_ok=True) + model_id = "Sent_Finetuning_2_80k" + if version == "Prompt": + os.makedirs(f"audios/Evaluation/Prompt", exist_ok=True) + os.makedirs(f"audios/Evaluation/Prompt/dailydialogues", exist_ok=True) + os.makedirs(f"audios/Evaluation/Prompt/tales", exist_ok=True) + model_id = "Sent_Finetuning_2_80k" + + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor) + tts.set_language("en") + for speaker_id in tqdm(range(25, 35), "Speakers"): + os.makedirs(f"audios/Evaluation/{version}/dailydialogues/{speaker_id - 14}", exist_ok=True) + os.makedirs(f"audios/Evaluation/{version}/tales/{speaker_id - 14}", exist_ok=True) + tts.set_speaker_id(speaker_id) + for dataset, emotion_to_sents in tqdm(test_sentences.items(), "Datasets"): + for emotion, sentences in tqdm(emotion_to_sents.items(), "Emotions"): + os.makedirs(f"audios/Evaluation/{version}/{dataset}/{speaker_id - 14}/{emotion}", exist_ok=True) + for i, sent in enumerate(tqdm(sentences, "Sentences")): + if version == 'Prompt': + for prompt_emotion in list(emotion_to_sents.keys()): + os.makedirs(f"audios/Evaluation/{version}/{dataset}/{speaker_id - 14}/{emotion}/{prompt_emotion}", exist_ok=True) + prompt = emotion_to_sents[prompt_emotion][len(sentences) - 1 - i] + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sent], + file_location=f"audios/Evaluation/{version}/{dataset}/{speaker_id - 14}/{emotion}/{prompt_emotion}/{i}.wav", + increased_compatibility_mode=True, + silent=silent) + else: + tts.read_to_file(text_list=[sent], + file_location=f"audios/Evaluation/{version}/{dataset}/{speaker_id - 14}/{emotion}/{i}.wav", + increased_compatibility_mode=True, + silent=silent) + +def extract_speaker_embeddings(audio_dir, classifier, version): + speaker_embeddings = {} + if version == "Original": + for speaker in tqdm(os.listdir(os.path.join(audio_dir, version)), "Speaker"): + speaker_embeddings[speaker] = {} + for emotion in tqdm(os.listdir(os.path.join(audio_dir, version, speaker)), "Emotion"): + speaker_embeddings[speaker][emotion] = {} + for audio_file in tqdm(os.listdir(os.path.join(audio_dir, version, speaker, emotion)), "Audio File"): + file_id = int(audio_file.split('.wav')[0]) + wave, sr = torchaudio.load(os.path.join(audio_dir, version, speaker, emotion, audio_file)) + # mono + wave = torch.mean(wave, dim=0, keepdim=True) + # resampling + wave = torchaudio.functional.resample(wave, orig_freq=sr, new_freq=16000) + wave = wave.squeeze(0) + embedding = classifier.encode_batch(wave).squeeze(0).squeeze(0) + speaker_embeddings[speaker][emotion][file_id] = embedding + return speaker_embeddings + if version == "Baseline" or version == "Sent": + for dataset in tqdm(os.listdir(os.path.join(audio_dir, version)), "Dataset"): + speaker_embeddings[dataset] = {} + for speaker in tqdm(os.listdir(os.path.join(audio_dir, version, dataset)), "Speaker"): + speaker_embeddings[dataset][speaker] = {} + for emotion in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker)), "Emotion"): + speaker_embeddings[dataset][speaker][emotion] = {} + for audio_file in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker, emotion)), "Audio File"): + file_id = int(audio_file.split('.wav')[0]) + wave, sr = torchaudio.load(os.path.join(audio_dir, version, dataset, speaker, emotion, audio_file)) + # mono + wave = torch.mean(wave, dim=0, keepdim=True) + # resampling + wave = torchaudio.functional.resample(wave, orig_freq=sr, new_freq=16000) + wave = wave.squeeze(0) + embedding = classifier.encode_batch(wave).squeeze(0).squeeze(0) + speaker_embeddings[dataset][speaker][emotion][file_id] = embedding + return speaker_embeddings + if version == "Prompt": + for dataset in tqdm(os.listdir(os.path.join(audio_dir, version)), "Dataset"): + speaker_embeddings[dataset] = {} + for speaker in tqdm(os.listdir(os.path.join(audio_dir, version, dataset)), "Speaker"): + speaker_embeddings[dataset][speaker] = {} + for emotion in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker)), "Emotion"): + speaker_embeddings[dataset][speaker][emotion] = {} + for prompt_emotion in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker, emotion)), "Prompt Emotion"): + speaker_embeddings[dataset][speaker][emotion][prompt_emotion] = {} + for audio_file in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker, emotion, prompt_emotion)), "Audio File"): + file_id = int(audio_file.split('.wav')[0]) + wave, sr = torchaudio.load(os.path.join(audio_dir, version, dataset, speaker, emotion, prompt_emotion, audio_file)) + # mono + wave = torch.mean(wave, dim=0, keepdim=True) + # resampling + wave = torchaudio.functional.resample(wave, orig_freq=sr, new_freq=16000) + wave = wave.squeeze(0) + embedding = classifier.encode_batch(wave).squeeze(0).squeeze(0) + speaker_embeddings[dataset][speaker][emotion][prompt_emotion][file_id] = embedding + return speaker_embeddings + +def vocode_original(mel2wav, num_sentences, device): + esds_data_dir = "/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore" + for speaker in tqdm(os.listdir(esds_data_dir), "Speaker"): + if speaker.startswith("00"): + if int(speaker) > 10: + for emotion in tqdm(os.listdir(os.path.join(esds_data_dir, speaker)), "Emotion"): + if not emotion.endswith(".txt") and not emotion.endswith(".DS_Store"): + counter = 0 + for audio_file in tqdm(os.listdir(os.path.join(esds_data_dir, speaker, emotion))): + if audio_file.endswith(".wav"): + counter += 1 + if counter > num_sentences: + break + emo = get_emotion_from_path(os.path.join(esds_data_dir, speaker, emotion, audio_file)) + sent_id = audio_file.split("_")[1].split(".wav")[0] + + wave, sr = sf.read(os.path.join(esds_data_dir, speaker, emotion, audio_file)) + ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=True, device='cpu') + norm_wave = ap.audio_to_wave_tensor(normalize=True, audio=wave) + norm_wave = torch.tensor(trim_zeros(norm_wave.numpy())) + spec = ap.audio_to_mel_spec_tensor(audio=norm_wave, normalize=False, explicit_sampling_rate=16000) + + wave = mel2wav(spec.to(device)).cpu() + silence = torch.zeros([10600]) + wav = silence.clone() + wav = torch.cat((wav, wave, silence), 0) + + wav = [val for val in wav.detach().numpy() for _ in (0, 1)] # doubling the sampling rate for better compatibility (24kHz is not as standard as 48kHz) + os.makedirs(os.path.join(f"./audios/Evaluation/Original/{int(speaker)}/{emo}"), exist_ok=True) + sf.write(file=f"./audios/Evaluation/Original/{int(speaker)}/{emo}/{sent_id}.wav", data=float2pcm(wav), samplerate=48000, subtype="PCM_16") + +def compute_speaker_similarity(speaker_embeddings_original, speaker_embeddings, version): + speaker_similarities = {} + if version == "Baseline" or version == "Sent": + for dataset, speakers in tqdm(speaker_embeddings.items()): + speaker_similarities[dataset] = {} + for speaker, emotions in speakers.items(): + speaker_similarities[dataset][speaker] = {} + for emotion, file_ids in emotions.items(): + cos_sims_emotion = [] + for file_id, embedding in file_ids.items(): + cos_sims_file = [] + for file_id_original, embedding_original in speaker_embeddings_original[speaker][emotion].items(): + cos_sims_file.append(speaker_similarity(embedding_original, embedding)) + cos_sims_emotion.append(median(cos_sims_file)) + speaker_similarities[dataset][speaker][emotion] = median(cos_sims_emotion) + return speaker_similarities + if version == "Prompt": + for dataset, speakers in tqdm(speaker_embeddings.items()): + speaker_similarities[dataset] = {} + for speaker, emotions in speakers.items(): + speaker_similarities[dataset][speaker] = {} + for emotion, prompt_emotions in emotions.items(): + speaker_similarities[dataset][speaker][emotion] = {} + for prompt_emotion, file_ids in prompt_emotions.items(): + cos_sims_prompt_emotion = [] + for file_id, embedding in file_ids.items(): + cos_sims_file = [] + for file_id_original, embedding_original in speaker_embeddings_original[speaker][emotion].items(): + cos_sims_file.append(speaker_similarity(embedding_original, embedding)) + cos_sims_prompt_emotion.append(median(cos_sims_file)) + speaker_similarities[dataset][speaker][emotion][prompt_emotion] = median(cos_sims_prompt_emotion) + + speaker_similarities_prompt_emotions = {} + for dataset, speakers in speaker_similarities.items(): + speaker_similarities_prompt_emotions[dataset] = {} + for speaker, emotions in speakers.items(): + speaker_similarities_prompt_emotions[dataset][speaker] = {} + for prompt_emotion in EMOTIONS: + cos_sims = [] + for emotion in EMOTIONS: + cos_sims.append(speaker_similarities[dataset][speaker][emotion][prompt_emotion]) + speaker_similarities_prompt_emotions[dataset][speaker][prompt_emotion] = median(cos_sims) + return speaker_similarities_prompt_emotions + +def speaker_similarity(speaker_embedding1, speaker_embedding2): + cosine_similarity = CosineSimilarity(dim=-1) + return cosine_similarity(speaker_embedding1, speaker_embedding2).numpy() + +def asr_transcribe(audio_dir, processor, model, version): + transcriptions = {} + if version == "Original": + for speaker in tqdm(os.listdir(os.path.join(audio_dir, version)), "Speaker"): + transcriptions[speaker] = {} + for emotion in tqdm(os.listdir(os.path.join(audio_dir, version, speaker)), "Emotion"): + transcriptions[speaker][emotion] = {} + for audio_file in tqdm(os.listdir(os.path.join(audio_dir, version, speaker, emotion)), "Audio File"): + file_id = int(audio_file.split('.wav')[0]) + wave, sr = torchaudio.load(os.path.join(audio_dir, version, speaker, emotion, audio_file)) + # mono + wave = torch.mean(wave, dim=0, keepdim=True) + # resampling + wave = torchaudio.functional.resample(wave, orig_freq=sr, new_freq=16000) + wave = wave.squeeze(0) + input_values = processor(wave, sampling_rate=16000, return_tensors="pt", padding="longest").input_values.to(model.device) + with torch.no_grad(): + logits = model(input_values).logits + predicted_ids = torch.argmax(logits, dim=-1) + transcription = processor.batch_decode(predicted_ids) + transcriptions[speaker][emotion][file_id] = transcription + return transcriptions + if version == "Baseline" or version == "Sent": + for dataset in tqdm(os.listdir(os.path.join(audio_dir, version)), "Dataset"): + transcriptions[dataset] = {} + for speaker in tqdm(os.listdir(os.path.join(audio_dir, version, dataset)), "Speaker"): + transcriptions[dataset][speaker] = {} + for emotion in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker)), "Emotion"): + transcriptions[dataset][speaker][emotion] = {} + for audio_file in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker, emotion)), "Audio File"): + file_id = int(audio_file.split('.wav')[0]) + wave, sr = torchaudio.load(os.path.join(audio_dir, version, dataset, speaker, emotion, audio_file)) + # mono + wave = torch.mean(wave, dim=0, keepdim=True) + # resampling + wave = torchaudio.functional.resample(wave, orig_freq=sr, new_freq=16000) + wave = wave.squeeze(0) + input_values = processor(wave, sampling_rate=16000, return_tensors="pt", padding="longest").input_values.to(model.device) + with torch.no_grad(): + logits = model(input_values).logits + predicted_ids = torch.argmax(logits, dim=-1) + transcription = processor.batch_decode(predicted_ids) + transcriptions[dataset][speaker][emotion][file_id] = transcription + return transcriptions + if version == "Prompt": + for dataset in tqdm(os.listdir(os.path.join(audio_dir, version)), "Dataset"): + transcriptions[dataset] = {} + for speaker in tqdm(os.listdir(os.path.join(audio_dir, version, dataset)), "Speaker"): + transcriptions[dataset][speaker] = {} + for emotion in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker)), "Emotion"): + transcriptions[dataset][speaker][emotion] = {} + for prompt_emotion in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker, emotion)), "Prompt Emotion"): + transcriptions[dataset][speaker][emotion][prompt_emotion] = {} + for audio_file in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker, emotion, prompt_emotion)), "Audio File"): + file_id = int(audio_file.split('.wav')[0]) + wave, sr = torchaudio.load(os.path.join(audio_dir, version, dataset, speaker, emotion, prompt_emotion, audio_file)) + # mono + wave = torch.mean(wave, dim=0, keepdim=True) + # resampling + wave = torchaudio.functional.resample(wave, orig_freq=sr, new_freq=16000) + wave = wave.squeeze(0) + input_values = processor(wave, sampling_rate=16000, return_tensors="pt", padding="longest").input_values.to(model.device) + with torch.no_grad(): + logits = model(input_values).logits + predicted_ids = torch.argmax(logits, dim=-1) + transcription = processor.batch_decode(predicted_ids) + transcriptions[dataset][speaker][emotion][prompt_emotion][file_id] = transcription + return transcriptions + +def compute_word_error_rate(transcriptions, test_sentences, version): + wer_calc = WordErrorRate() + word_error_rates = {} + if version == "Original": + for speaker, emotions in tqdm(transcriptions.items()): + word_error_rates[speaker] = {} + for emotion, sent_ids in emotions.items(): + wers = [] + for sent_id, transcript in sent_ids.items(): + target = get_esds_target_transcript(speaker, sent_id) + wers.append(word_error_rate(target, transcript, wer_calc)) + word_error_rates[speaker][emotion] = median(wers) + return word_error_rates + if version == "Baseline" or version == "Sent": + for dataset, speakers in tqdm(transcriptions.items()): + word_error_rates[dataset] = {} + for speaker, emotions in speakers.items(): + word_error_rates[dataset][speaker] = {} + for emotion, sent_ids in emotions.items(): + wers = [] + for sent_id, transcript in sent_ids.items(): + target = test_sentences[dataset][emotion][sent_id] + wers.append(word_error_rate(target, transcript, wer_calc)) + word_error_rates[dataset][speaker][emotion] = median(wers) + return word_error_rates + if version == "Prompt": + for dataset, speakers in tqdm(transcriptions.items()): + word_error_rates[dataset] = {} + for speaker, emotions in speakers.items(): + word_error_rates[dataset][speaker] = {} + for emotion, prompt_emotions in emotions.items(): + word_error_rates[dataset][speaker][emotion] = {} + for prompt_emotion, sent_ids in prompt_emotions.items(): + wers = [] + for sent_id, transcript in sent_ids.items(): + target = test_sentences[dataset][emotion][sent_id] + wers.append(word_error_rate(target, transcript, wer_calc)) + word_error_rates[dataset][speaker][emotion][prompt_emotion] = median(wers) + word_error_rates_prompt_emotions = {} + for dataset, speakers in word_error_rates.items(): + word_error_rates_prompt_emotions[dataset] = {} + for speaker, emotions in speakers.items(): + word_error_rates_prompt_emotions[dataset][speaker] = {} + for prompt_emotion in EMOTIONS: + wers = [] + for emotion in EMOTIONS: + wers.append(word_error_rates[dataset][speaker][emotion][prompt_emotion]) + word_error_rates_prompt_emotions[dataset][speaker][prompt_emotion] = median(wers) + return word_error_rates_prompt_emotions + +def get_esds_target_transcript(speaker, sent_id): + sent_id = '0' * (6 - len(str(sent_id))) + str(sent_id) # insert zeros at the beginning + root = "/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore" + speaker_dir = f"00{speaker}" + with open(f"{root}/{speaker_dir}/fixed_unicode.txt", mode="r", encoding="utf8") as f: + transcripts = f.read() + for line in transcripts.replace("\n\n", "\n").replace(",", ", ").split("\n"): + if line.strip() != "": + filename, text, emo_dir = line.split("\t") + if filename.split("_")[1] == sent_id: + return text + +def word_error_rate(target, predicted, wer): + target = target.translate(str.maketrans('', '', string.punctuation)).upper() + return float(wer(predicted, target)) + +def classify_speech_emotion(audio_dir, classifier, version): + emotions_classified = {} + if version == "Original": + for speaker in tqdm(os.listdir(os.path.join(audio_dir, version)), "Speaker"): + emotions_classified[speaker] = {} + for emotion in tqdm(os.listdir(os.path.join(audio_dir, version, speaker)), "Emotion"): + emotions_classified[speaker][emotion] = {} + for audio_file in tqdm(os.listdir(os.path.join(audio_dir, version, speaker, emotion)), "Audio File"): + file_id = int(audio_file.split('.wav')[0]) + out_prob, score, index, text_lab = classifier.classify_file(os.path.join(audio_dir, version, speaker, emotion, audio_file)) + emotions_classified[speaker][emotion][file_id] = text_lab[0] + # wav2vec2 saves wav files, they have to be deleted such they are not used again for the next iteration, since they are named the same + command = 'rm *.wav' + process = subprocess.Popen(command, shell=True) + process.wait() + time.sleep(10) # ensure that files are deleted before next iteration + return emotions_classified + if version == "Baseline" or version == "Sent": + for dataset in tqdm(os.listdir(os.path.join(audio_dir, version)), "Dataset"): + emotions_classified[dataset] = {} + for speaker in tqdm(os.listdir(os.path.join(audio_dir, version, dataset)), "Speaker"): + emotions_classified[dataset][speaker] = {} + for emotion in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker)), "Emotion"): + emotions_classified[dataset][speaker][emotion] = {} + for audio_file in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker, emotion)), "Audio File"): + file_id = int(audio_file.split('.wav')[0]) + out_prob, score, index, text_lab = classifier.classify_file(os.path.join(audio_dir, version, dataset, speaker, emotion, audio_file)) + emotions_classified[dataset][speaker][emotion][file_id] = text_lab[0] + # wav2vec2 saves wav files, they have to be deleted such they are not used again for the next iteration, since they are named the same + command = 'rm *.wav' + process = subprocess.Popen(command, shell=True) + process.wait() + time.sleep(10) # ensure that files are deleted before next iteration + return emotions_classified + if version == "Prompt": + for dataset in tqdm(os.listdir(os.path.join(audio_dir, version)), "Dataset"): + emotions_classified[dataset] = {} + for speaker in tqdm(os.listdir(os.path.join(audio_dir, version, dataset)), "Speaker"): + emotions_classified[dataset][speaker] = {} + for emotion in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker)), "Emotion"): + emotions_classified[dataset][speaker][emotion] = {} + for prompt_emotion in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker, emotion)), "Prompt Emotion"): + emotions_classified[dataset][speaker][emotion][prompt_emotion] = {} + for audio_file in tqdm(os.listdir(os.path.join(audio_dir, version, dataset, speaker, emotion, prompt_emotion)), "Audio File"): + file_id = int(audio_file.split('.wav')[0]) + out_prob, score, index, text_lab = classifier.classify_file(os.path.join(audio_dir, version, dataset, speaker, emotion, prompt_emotion, audio_file)) + emotions_classified[dataset][speaker][emotion][prompt_emotion][file_id] = text_lab[0] + # wav2vec2 saves wav files, they have to be deleted such they are not used again for the next iteration, since they are named the same + command = 'rm *.wav' + process = subprocess.Popen(command, shell=True) + process.wait() + time.sleep(10) # ensure that files are deleted before next iteration + return emotions_classified + +def compute_predicted_emotions_frequencies(predicted_emotions, version): + predicted_frequencies = {} + if version == "Original": + for speaker, emotions in tqdm(predicted_emotions.items()): + predicted_frequencies[speaker] = {} + for emotion, sent_ids in emotions.items(): + predicted_frequencies[speaker][emotion] = {} + for sent_id, predicted_emotion in sent_ids.items(): + if predicted_emotion not in predicted_frequencies[speaker][emotion]: + predicted_frequencies[speaker][emotion][predicted_emotion] = 0 + predicted_frequencies[speaker][emotion][predicted_emotion] += 1 + return predicted_frequencies + if version == "Baseline" or version == "Sent": + for dataset, speakers in tqdm(predicted_emotions.items()): + predicted_frequencies[dataset] = {} + for speaker, emotions in speakers.items(): + predicted_frequencies[dataset][speaker] = {} + for emotion, sent_ids in emotions.items(): + predicted_frequencies[dataset][speaker][emotion] = {} + for sent_id, predicted_emotion in sent_ids.items(): + if predicted_emotion not in predicted_frequencies[dataset][speaker][emotion]: + predicted_frequencies[dataset][speaker][emotion][predicted_emotion] = 0 + predicted_frequencies[dataset][speaker][emotion][predicted_emotion] += 1 + return predicted_frequencies + if version == "Prompt": + for dataset, speakers in tqdm(predicted_emotions.items()): + predicted_frequencies[dataset] = {} + for speaker, emotions in speakers.items(): + predicted_frequencies[dataset][speaker] = {} + for emotion, prompt_emotions in emotions.items(): + predicted_frequencies[dataset][speaker][emotion] = {} + for prompt_emotion, sent_ids in prompt_emotions.items(): + predicted_frequencies[dataset][speaker][emotion][prompt_emotion] = {} + for sent_id, predicted_emotion in sent_ids.items(): + if predicted_emotion not in predicted_frequencies[dataset][speaker][emotion][prompt_emotion]: + predicted_frequencies[dataset][speaker][emotion][prompt_emotion][predicted_emotion] = 0 + predicted_frequencies[dataset][speaker][emotion][prompt_emotion][predicted_emotion] += 1 + predicted_frequencies_prompt_emotions = {} + for dataset, speakers in predicted_frequencies.items(): + predicted_frequencies_prompt_emotions[dataset] = {} + for speaker, emotions in speakers.items(): + predicted_frequencies_prompt_emotions[dataset][speaker] = {} + for prompt_emotion in EMOTIONS: + predicted_frequencies_prompt_emotions[dataset][speaker][prompt_emotion] = {} + for emotion in EMOTIONS: + pred_freqs = predicted_frequencies[dataset][speaker][emotion][prompt_emotion] + for pred_emo, freq in pred_freqs.items(): + if pred_emo not in predicted_frequencies_prompt_emotions[dataset][speaker][prompt_emotion]: + predicted_frequencies_prompt_emotions[dataset][speaker][prompt_emotion][pred_emo] = freq + else: + predicted_frequencies_prompt_emotions[dataset][speaker][prompt_emotion][pred_emo] += freq + return predicted_frequencies_prompt_emotions + \ No newline at end of file diff --git a/Evaluation/plotting.py b/Evaluation/plotting.py new file mode 100644 index 00000000..3003ffd1 --- /dev/null +++ b/Evaluation/plotting.py @@ -0,0 +1,831 @@ +import matplotlib.pyplot as plt +from matplotlib.ticker import MultipleLocator +import numpy as np + +EMOTIONS = ["anger", "joy", "neutral", "sadness", "surprise"] +EMOTIONS_SHORT = ["a", "j", "n", "sa", "su"] +COLORS = ['red', 'green', 'blue', 'gray', 'orange'] + +def barplot_counts(d: dict, v, save_dir): + labels = get_variable_labels(v) + if labels is None: + labels = {k:k for k in list(d.keys())} + values = [d[label] for label in labels if label in d] + + plt.bar(range(len(values)), values, align='center') + plt.xticks(range(len(values)), [labels[label] for label in labels if label in d]) + plt.xlabel(v) + plt.ylabel('counts') + plt.savefig(save_dir) + plt.close() + +def pie_chart_counts(d: dict, v, save_dir): + labels = get_variable_labels(v) + if labels is None: + labels = {k:k for k in list(d.keys())} + relevant_labels = set(labels.keys()) & set(d.keys()) + #values = [d[label] for label in labels if label in d] + values = [d[label] for label in relevant_labels] + total = sum(values) + labels = [labels[label] for label in relevant_labels] + percentages = [(value / total) * 100 for value in values] + + colors = plt.cm.Set3(range(len(labels))) + + plt.figure(figsize=(8, 8)) + plt.pie(percentages, labels=labels, autopct=lambda p: '{:.1f}%'.format(p) if p >= 2 else '', colors=colors, startangle=90, textprops={'fontsize': 16}) + plt.axis('equal') + plt.rcParams['font.size'] = 12 + plt.savefig(save_dir) + plt.close() + +def barplot_pref(d, save_dir): + # Define the emotions and their corresponding colors + emotions = EMOTIONS + colors = COLORS + + # Define the x-axis tick labels + x_ticks = ['Baseline', 'Proposed', 'No Preference'] + + # Initialize the plot + fig, ax = plt.subplots() + + # Set the x-axis tick locations and labels + ax.set_xticks(np.arange(len(x_ticks))) + ax.set_xticklabels(x_ticks) + + # Calculate the width of each bar + bar_width = 0.15 + + # Calculate the offset for each emotion's bars + offsets = np.linspace(-2 * bar_width, 2 * bar_width, len(emotions)) + + # Calculate the total count for each emotion + total_counts = [sum(d[emotion].values()) for emotion in emotions] + + # Iterate over the emotions and plot the bars + for i, emotion in enumerate(emotions): + counts = [d[emotion].get(1.0, 0), d[emotion].get(2.0, 0), d[emotion].get(3.0, 0)] + percentages = [count / total_counts[i] * 100 for count in counts] + positions = np.arange(len(percentages)) + offsets[i] + ax.bar(positions, percentages, width=bar_width, color=colors[i], label=emotion) + + # Set the legend + ax.legend() + + # Set the plot title and axis labels + ax.set_ylabel('Percentage (%)') + + # Adjust the x-axis limits and labels + ax.set_xlim(-2 * bar_width - bar_width, len(x_ticks) - 1 + 2 * bar_width + bar_width) + ax.set_xticks(np.arange(len(x_ticks))) + + # Save the plot + plt.savefig(save_dir) + plt.close() + +def barplot_pref2(d, save_dir): + # Define the emotions and their corresponding colors + emotions = EMOTIONS + colors = COLORS + emotions = emotions[::-1] + colors = colors[::-1] + patterns = ['///', '\\\\\\', 'xxx'] + + # Define the y-axis tick labels + y_ticks = emotions + + # Calculate the height of each bar + bar_height = 0.5 + + # Initialize the plot + fig, ax = plt.subplots() + + # Set the y-axis tick locations and labels + ax.set_yticks(np.arange(len(y_ticks))) + ax.set_yticklabels(y_ticks) + + # Calculate the width of each bar segment + total_counts = [d[emotion].get(1.0, 0) + d[emotion].get(2.0, 0) + d[emotion].get(3.0, 0) for emotion in emotions] + baseline_counts = [d[emotion].get(1.0, 0) for emotion in emotions] + proposed_counts = [d[emotion].get(2.0, 0) for emotion in emotions] + no_pref_counts = [d[emotion].get(3.0, 0) for emotion in emotions] + percentages = [count / sum(total_counts) * 100 for count in total_counts] + baseline_percentages = [count / total * 100 if total != 0 else 0 for count, total in zip(baseline_counts, total_counts)] + proposed_percentages = [count / total * 100 if total != 0 else 0 for count, total in zip(proposed_counts, total_counts)] + no_pref_percentages = [count / total * 100 if total != 0 else 0 for count, total in zip(no_pref_counts, total_counts)] + total_width = np.max(percentages) + + # Plot the bars with different shadings and colors + for i in range(len(emotions)): + # Plot only the first bar of each emotion with the corresponding label + ax.barh(y_ticks[i], baseline_percentages[i], height=bar_height, color=colors[i], edgecolor='black', hatch=patterns[0], label='Baseline' if i == 0 else '') + ax.barh(y_ticks[i], no_pref_percentages[i], height=bar_height, color=colors[i], edgecolor='black', hatch=patterns[2], left=baseline_percentages[i], label='No Preference' if i == 0 else '') + ax.barh(y_ticks[i], proposed_percentages[i], height=bar_height, color=colors[i], edgecolor='black', hatch=patterns[1], left=np.add(baseline_percentages[i], no_pref_percentages[i]), label='Proposed' if i == 0 else '') + ax.text(total_width + 10, i, emotions[i], color=colors[i], verticalalignment='center') + + # Set the legend outside the plot + legend_labels = [plt.Rectangle((0, 0), 0, 0, edgecolor='black', hatch=patterns[i], facecolor='white') for i in range(len(patterns))] + legend_names = ['Baseline', 'Proposed', 'No Preference'] + ax.legend(legend_labels, legend_names, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3, edgecolor='black', facecolor='none') + + # Set the plot title and axis labels + ax.set_xlabel('Percentage (%)') + + # Adjust the y-axis limits and labels + ax.set_ylim(-0.5, len(y_ticks) - 0.5) + ax.set_yticks(np.arange(len(y_ticks))) + for tick_label, color in zip(ax.get_yticklabels(), colors): + tick_label.set_color(color) + + # Save the plot + plt.savefig(save_dir) + plt.close() + +def barplot_pref3(d, save_dir): + # Define the emotions and their corresponding colors + emotions = EMOTIONS + colors = ['black', 'gray', 'lightgray'] # Colors for Original, Baseline, and Proposed bars + emotion_colors = COLORS + + # Define the x-axis tick labels (emotions) + x_ticks = emotions + + # Initialize the plot + fig, ax = plt.subplots(figsize=(8,6)) + + # Set the x-axis tick locations and labels + ax.set_xticks(np.arange(len(x_ticks))) + labels = ax.set_xticklabels(x_ticks) + for label, color in zip(labels, emotion_colors): + label.set_color(color) + + # Calculate the width of each bar group + bar_width = 0.2 + + # Calculate the offset for each scenario's bars within each group + offsets = np.linspace(-bar_width, bar_width, 3) + + scenarios = ['Baseline', 'Proposed', 'No Preference'] + # Iterate over the scenarios and plot the bars for each emotion + for i, scenario in enumerate([1.0, 2.0, 3.0]): + percentages = [d[emotion].get(scenario, 0) / sum(d[emotion].values()) * 100 for emotion in emotions] + positions = np.arange(len(percentages)) + offsets[i] + ax.bar(positions, percentages, width=bar_width, color=colors[i], label=scenarios[i]) + + # Set the legend + ax.legend() + + # Set the plot title and axis labels + ax.set_ylabel('Percentage (%)', fontsize=16) + + # Adjust the x-axis limits and labels + ax.set_xlim(-bar_width * 2, len(x_ticks) - 1 + bar_width * 2) + ax.set_xticks(np.arange(len(x_ticks))) + ax.set_ylim(0, 100) + + ax.tick_params(axis='x', labelsize=16) + ax.tick_params(axis='y', labelsize=16) + + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + # Save the plot + plt.savefig(save_dir) + plt.close() + +def barplot_pref_total(d, save_dir): + + # Define the x-axis tick labels + x_ticks = ['Baseline', 'Prompt Conditioned', 'No Preference'] + colors = ['black', 'gray', 'lightgray'] + + # Initialize the plot + fig, ax = plt.subplots(figsize=(10, 6)) + + # Set the x-axis tick locations and labels + ax.set_xticks(np.arange(len(x_ticks))) + ax.set_xticklabels(x_ticks, fontsize=20) + + # Calculate the width of each bar + bar_width = 0.5 + + # Calculate the offset for each emotion's bars + offsets = np.linspace(-2 * bar_width, 2 * bar_width, len(x_ticks)) + + # Calculate the total count for each emotion + total_counts = sum(list(d.values())) + + # Iterate over the emotions and plot the bars + for i, scenario in enumerate([1.0, 2.0, 3.0]): + counts = [d.get(scenario, 0)] + percentages = [count / total_counts * 100 for count in counts] + positions = i + ax.bar(positions, percentages, width=bar_width, color=colors[i], align='center') + plt.text(i, round(percentages[0], 2), str(round(percentages[0], 2)), ha='center', va='bottom', fontsize=20) + + # Set the plot title and axis labels + ax.set_ylabel('Percentage (%)', fontsize=20) + + # Adjust the x-axis limits and labels + #ax.set_xlim(-2 * bar_width - bar_width, len(x_ticks) - 1 + 2 * bar_width + bar_width) + ax.set_xticks(np.arange(len(x_ticks))) + ax.set_ylim(0, 60) + + ax.tick_params(axis='x', labelsize=20) + ax.tick_params(axis='y', labelsize=20) + + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + # Save the plot + plt.savefig(save_dir, format="pdf") + plt.close() + +def horizontal_barplot_pref_total(d, save_dir): + plt.rc('font', size=20) + + # Define the scenarios and corresponding colors + labels = ['Baseline', 'Prompt Conditioned', 'No Preference'] + scenarios = [1.0, 2.0, 3.0] + colors = ['black', '#404040', '#696969'] + + # Calculate total count + total_count = sum(d.values()) + + # Calculate percentages for each scenario + percentages = [d[scenario] / total_count * 100 for scenario in scenarios] + + # Initialize the plot + fig, ax = plt.subplots(figsize=(10, 3)) + + # Plot the horizontal bar + left = 0 + bar_height = 0.3 # Adjust the height of the bar + for i, (scenario, percentage, color) in enumerate(zip(labels, percentages, colors)): + ax.barh(0, percentage, height=bar_height, color=color, left=left) + ax.text(left + percentage / 2, -0.2, f"{percentage:.1f}%", ha='center', va='top', color=color) + ax.text(left + percentage / 2, 0.2, f"{scenario}", ha='center', va='bottom', color=color) + left += percentage + + # Set the y-axis limits and labels + ax.set_ylim(-0.5, 0.5) + ax.set_yticks([]) + ax.set_xlabel('Percentage (%)') + + # Set the title + #ax.set_title('Preference Distribution') + + # Remove spines + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + # Save the plot + plt.tight_layout() + plt.savefig(save_dir, format="pdf") + plt.close() + +def pie_barplot_pref_total(d, save_dir): + plt.rc('font', size=20) + + # Define the scenarios and corresponding colors + labels = ['Baseline', 'Prompt Conditioned', 'No Preference'] + scenarios = [1.0, 2.0, 3.0] + colors = ['black', 'gray', 'lightgray'] + + # Calculate total count + total_count = sum(d.values()) + + # Calculate percentages for each scenario + percentages = [d[scenario] / total_count * 100 for scenario in scenarios] + + # Initialize the plot + fig, ax = plt.subplots(figsize=(7.5, 7.5)) # Make the figure square for a pie chart + + # Plot the pie chart + wedges, texts, autotexts = ax.pie(percentages, labels=labels, colors=colors, autopct='%1.1f%%', startangle=140) + + # Set the color of all percentage labels to black + for autotext in autotexts: + autotext.set_color('black') + + # Set the color of the percentage label on the "Baseline" slice to white + autotexts[0].set_color('white') + + # Equal aspect ratio ensures that pie is drawn as a circle + ax.axis('equal') + + # Set the title + #ax.set_title('Preference Distribution') + + # Save the plot + plt.tight_layout() + plt.savefig(save_dir, format="pdf") + plt.close() + +def pie_barplot_pref(emotion_dicts, save_dir): + plt.rc('font', size=28) + + # Define the scenarios and corresponding colors + labels = ['Baseline', 'Prompt Conditioned', 'No Preference'] + scenarios = [1.0, 2.0, 3.0] + colors = [(0.6, 0.85, 0.6), (0.8, 0.6, 0.5), (0.7, 0.9, 1.0)] + + num_emotions = len(emotion_dicts) + + fig, axs = plt.subplots(2, 3, figsize=(20, 10)) # Create a 2x3 grid layout + + # Loop over each emotion dictionary + for i, (emotion, d) in enumerate(emotion_dicts.items()): + # Calculate total count + total_count = sum(d.values()) + + # Calculate percentages for each scenario + percentages = [d[scenario] / total_count * 100 for scenario in scenarios] + + # Get the axis for the current subplot + ax_row = i // 3 + ax_col = i % 3 + ax = axs[ax_row, ax_col] + + # Plot the pie chart + wedges, _ = ax.pie(percentages, colors=colors, startangle=140) + + # Equal aspect ratio ensures that pie is drawn as a circle + ax.axis('equal') + + # Set the title as the emotion label + ax.set_title(f'{emotion.capitalize()}') + + # Remove the bottom right subplot + axs[1, 2].remove() + + # Add a single subplot spanning both columns for the bottom row + bottom_ax = fig.add_subplot(2, 2, (3, 4)) + bottom_ax.axis('off') # Turn off axis for the centered subplot + + # Plot centered pie charts + # You can plot your desired centered pie charts here using the bottom_ax + + # Create a legend with labels and colors at the right of the last subplot + last_ax = axs[1, 1] # Assuming the last subplot is at row 1, col 1 (modify if needed) + fig.legend(labels, loc='center left', bbox_to_anchor=(last_ax.get_position().x1 + 0.05, 0.3)) + + # Adjust layout to prevent overlap + plt.tight_layout() + + # Save the plot + plt.savefig(save_dir, format="pdf") + plt.close() + +def barplot_sim(data, save_dir): + emotions = EMOTIONS + ratings = [data[emotion] for emotion in emotions] + colors = COLORS + + fig, ax = plt.subplots(figsize=(8,6)) + + x_ticks = emotions + + ax.set_xticks(np.arange(len(x_ticks))) + labels = ax.set_xticklabels(x_ticks, fontsize=16) + for label, color in zip(labels, colors): + label.set_color(color) + + bar_width = 0.5 + + positions = np.arange(len(ratings)) + + bars = ax.bar(positions, ratings, width=bar_width, color=colors) + + for bar, rating in zip(bars, ratings): + ax.annotate(f'{rating:.2f}', # Format the rating to one decimal place + xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()), + xytext=(0, 3), # 3 points vertical offset from the bar + textcoords='offset points', + ha='center', va='bottom', fontsize=14) + + ax.set_ylabel('Mean Similatrity Score', fontsize=16) + + ax.set_yticks(np.arange(6)) + + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + plt.savefig(save_dir) + plt.close() + +def barplot_sim_total(data, save_dir): + speakers = list(data.keys()) + ratings = list(data.values()) + + fig, ax = plt.subplots() + + x_ticks = speakers + + ax.set_xticks(np.arange(len(x_ticks))) + ax.set_xticklabels(x_ticks, fontsize=16) + + bar_width = 0.5 + + positions = np.arange(len(ratings)) + + bars = ax.bar(positions, ratings, width=bar_width, color='gray') + + for bar, rating in zip(bars, ratings): + ax.annotate(f'{rating:.2f}', # Format the rating to one decimal place + xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()), + xytext=(0, 3), # 3 points vertical offset from the bar + textcoords='offset points', + ha='center', va='bottom', fontsize=14) + + ax.set_ylabel('Mean Similatrity Score', fontsize=16) + + ax.set_yticks(np.arange(6)) + + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + plt.savefig(save_dir) + plt.close() + +def boxplot_rating(data, save_dir): + emotions = EMOTIONS + ratings = {emotion: list(rating for rating, count in data[emotion].items() for _ in range(count)) for emotion in emotions} + colors = COLORS + + plt.figure(figsize=(10, 6)) + box_plot = plt.boxplot(ratings.values(), patch_artist=True, widths=0.7) + for patch, color in zip(box_plot['boxes'], colors): + patch.set_facecolor(color) + for median in box_plot['medians']: + median.set(color='black', linestyle='-', linewidth=3) + plt.xticks(range(1, len(emotions) + 1), emotions, fontsize=16) + ax = plt.gca() + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.set_ylim(0.9, 5.1) + ax.tick_params(axis='y', labelsize=16) + ax.yaxis.set_major_locator(MultipleLocator(base=1)) + for i, t in enumerate(ax.xaxis.get_ticklabels()): + t.set_color(colors[i]) + plt.savefig(save_dir) + plt.close() + +def barplot_mos(l: list, save_dir): + values = l + labels = ["Ground Truth", "Baseline", "Proposed"] + + plt.bar(range(len(values)), values, align='center', color="gray") + plt.xticks(range(len(values)), labels, fontsize=16) + plt.ylabel('MOS', fontsize=16) + ax = plt.gca() + ax.set_yticks(np.arange(6)) + ax.set_ylim(0.9, 5.1) + ax.tick_params(axis='y', labelsize=16) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + for i, v in enumerate(values): + plt.text(i, round(v, 2), str(round(v, 2)), ha='center', va='bottom', fontsize=14) + plt.savefig(save_dir) + plt.close() + +def barplot_emotion(d: dict, save_dir): + emotions = EMOTIONS + colors = COLORS + bar_width = 0.5 + + d = {key: d[key] for key in sorted(d, key=lambda x: emotions.index(x))} + + fig, ax = plt.subplots() + + subdicts = d.items() + num_subdicts = len(subdicts) + num_labels = len(list(subdicts)[0][1]) # Assuming all sub-dicts have the same labels + + total_width = bar_width * num_subdicts + group_width = bar_width / num_subdicts + x = np.arange(num_labels) - (total_width / 2) + (group_width / 2) + + color_id = 0 + for i, (v, subdict) in enumerate(subdicts): + values = [subdict[label] for label in subdict] + offset = i * group_width + color = colors[color_id] + ax.bar(x + offset, values, group_width, align='center', color=color) + color_id += 1 + + ax.set_xticks(x + 0.2) + ax.set_xticklabels(list(list(subdicts)[0][1].keys())) + for i, t in enumerate(ax.xaxis.get_ticklabels()): + t.set_color(colors[i]) + ax.tick_params(axis='x', length=0) + ax.legend(emotions, loc='upper right', bbox_to_anchor=(0.9, 1)) + ax.set_ylabel('Counts') + + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + plt.savefig(save_dir) + plt.close() + +def heatmap_emotion(d: dict, save_dir): + emotions = EMOTIONS + + # Create a numpy array to store the counts + counts = np.array([[d[emotion].get(label, 0) for emotion in emotions] for label in emotions]) + # normalize counts for each emotion category + normalized_counts = np.around(counts / counts.sum(axis=0, keepdims=True), decimals=2) + + # Set up the figure and heatmap + fig, ax = plt.subplots(figsize=(10, 8)) + im = ax.imshow(normalized_counts, cmap='viridis', vmin=0, vmax=1) + + # Show counts as text in each cell + for i in range(len(emotions)): + for j in range(len(emotions)): + text = ax.text(i, j, normalized_counts[j, i], ha='center', va='center', color='black', fontsize=16) + + # Set the axis labels and title + ax.set_xticks(np.arange(len(emotions))) + ax.set_yticks(np.arange(len(emotions))) + ax.set_xticklabels(emotions, fontsize=16) + ax.set_yticklabels(emotions, fontsize=16) + + # Rotate the tick labels for better readability (optional) + plt.xticks(rotation=45, ha='right', fontsize=16) + + # Create a colorbar + cbar = ax.figure.colorbar(im, ax=ax) + cbar.set_label('Relative Frequency', fontsize=16) + + # Save the heatmap + plt.savefig(save_dir) + plt.close() + +def heatmap_emotion_multiple(dicts: list, titles: list, save_dir): + emotions = EMOTIONS + emotions_short = EMOTIONS_SHORT + num_dicts = len(dicts) + fig, axs = plt.subplots(1, num_dicts, figsize=(5*num_dicts, 8)) + + max_rows = max(len(ax.get_yticklabels()) for ax in axs) # Get the maximum number of rows in the heatmaps + + for idx, (d, title) in enumerate(zip(dicts, titles)): + # Create a numpy array to store the counts + counts = np.array([[d[emotion].get(label, 0) for emotion in emotions] for label in emotions]) + # normalize counts for each emotion category + normalized_counts = np.around(counts / counts.sum(axis=0, keepdims=True), decimals=2) + + # Plot heatmap with consistent colormap + im = axs[idx].imshow(normalized_counts, cmap='viridis', vmin=0, vmax=1, interpolation=None) + + # Show counts as text in each cell + for i in range(len(emotions)): + for j in range(len(emotions)): + # Get the brightness of the cell's color + brightness = np.mean(im.norm(normalized_counts[j, i])) + # Set text color based on brightness + if brightness > 0.7: + color = 'black' + else: + color = 'white' + text = axs[idx].text(i, j, normalized_counts[j, i], ha='center', va='center', color=color, fontsize=20) + + # Set the axis labels and title + if idx == 0: # Only label the y-axis for the first subplot + axs[idx].set_yticks(np.arange(len(emotions_short))) + axs[idx].set_yticklabels(emotions_short, fontsize=20) + axs[idx].set_ylabel('Emotion Labels', fontsize=20) + else: + axs[idx].set_yticks([]) # Remove y-ticks for the other subplots + + axs[idx].set_xticks(np.arange(len(emotions_short))) + axs[idx].set_xticklabels(emotions_short, fontsize=20) + axs[idx].tick_params(axis='x', rotation=45, labelsize=20) + + # Set title + axs[idx].set_title(title, fontsize=20) + + # Calculate color bar height based on the number of rows in the heatmaps + color_bar_height = 0.45 + + # Create a colorbar + cbar_ax = fig.add_axes([0.92, 0.27, 0.02, color_bar_height]) # [left, bottom, width, height] + cbar = fig.colorbar(im, cax=cbar_ax) + cbar.set_label('Relative Frequency', fontsize=20) + + # Set font size for color bar tick labels + cbar.ax.tick_params(labelsize=20) + + # Adjust layout + plt.subplots_adjust(wspace=0.1) # Reduce the space between heatmaps + + # Save the heatmap as PDF + plt.savefig(save_dir, format='pdf') + plt.close() + +def scatterplot_va(v: dict, a: dict, save_dir: str): + # Initialize a color map for differentiating emotions + colors = COLORS + + # Create a figure and axis + fig, ax = plt.subplots() + + # Iterate over emotions and variations + color_id = 0 + for emo in EMOTIONS: + valence = v[emo] + arousal = a[emo] + + # Plot a single point with a unique color + ax.scatter(valence, arousal, color=colors[color_id], label=emo) + color_id += 1 + + # Set plot title and labels + #ax.set_xlabel("Valence") + #ax.set_ylabel("Arousal") + + # Set axis limits and center the plot at (3, 3) + ax.set_xlim(0.9, 5.1) + ax.set_ylim(0.9, 5.1) + ax.set_xticks([1, 5]) + ax.set_yticks([1, 5]) + ax.set_xticklabels(['negative', 'positive']) + ax.set_yticklabels(['calm', 'excited']) + ax.set_aspect('equal') + ax.spines['left'].set_position(('data', 3)) + ax.spines['bottom'].set_position(('data', 3)) + ax.spines['right'].set_color('none') + ax.spines['top'].set_color('none') + ax.xaxis.set_ticks_position('bottom') + ax.yaxis.set_ticks_position('left') + + # Add a legend in the top-left corner + legend = ax.legend(loc='upper left', bbox_to_anchor=(0.01, 0.99), bbox_transform=plt.gcf().transFigure) + + # Adjust the plot layout to accommodate the legend + plt.subplots_adjust(top=0.95, right=0.95) + + # Save the figure + plt.savefig(save_dir, bbox_inches='tight') + plt.close() + +def boxplot_objective(data, save_dir): + data = dict(sorted(data.items())) + speakers = list(data.keys()) + ratings = list(data.values()) + + plt.figure(figsize=(10, 6)) + box_plot = plt.boxplot(ratings, patch_artist=True, widths=0.7) + for patch in box_plot['boxes']: + patch.set_facecolor('white') + for median in box_plot['medians']: + median.set(color='black', linestyle='-', linewidth=3) + plt.xticks(range(1, len(speakers) + 1), speakers, fontsize=16) + plt.xlabel('Speaker', fontsize=16) + ax = plt.gca() + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.tick_params(axis='y', labelsize=16) + ax.yaxis.set_major_locator(MultipleLocator(base=0.01)) + plt.savefig(save_dir) + plt.close() + +def boxplot_objective2(data, save_dir): + data = dict(sorted(data.items())) + speakers = list(data.keys()) + ratings = list(data.values()) + + plt.figure(figsize=(10, 6)) + box_plot = plt.boxplot(ratings, patch_artist=True, widths=0.7) + for patch in box_plot['boxes']: + patch.set_facecolor('white') + for median in box_plot['medians']: + median.set(color='black', linestyle='-', linewidth=3) + plt.xticks(range(1, len(speakers) + 1), speakers, fontsize=16) + plt.xlabel('Speaker', fontsize=16) + ax = plt.gca() + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.tick_params(axis='y', labelsize=16) + ax.yaxis.set_major_locator(MultipleLocator(base=0.1)) + plt.savefig(save_dir) + plt.close() + +def barplot_speaker_similarity(data, save_dir): + labels = ['Baseline', 'Proposed'] + + plt.bar(range(len(data)), data, align='center', color="gray") + plt.xticks(range(len(data)), labels, fontsize=16) + plt.ylabel('Cosine Similarity', fontsize=16) + ax = plt.gca() + ax.tick_params(axis='y', labelsize=16) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + for i, v in enumerate(data): + plt.text(i, round(v, 4), str(round(v, 4)), ha='center', va='bottom', fontsize=14) + plt.savefig(save_dir) + plt.close() + +def barplot_wer(data, save_dir): + labels = ['Ground Thruth', 'Baseline', 'Proposed'] + + plt.bar(range(len(data)), data, align='center', color="gray") + plt.xticks(range(len(data)), labels, fontsize=16) + plt.ylabel('Word Error Rate', fontsize=16) + ax = plt.gca() + ax.tick_params(axis='y', labelsize=16) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + for i, v in enumerate(data): + plt.text(i, round(v, 3), str(round(v, 3)), ha='center', va='bottom', fontsize=14) + plt.subplots_adjust(left=0.15) + plt.savefig(save_dir) + plt.close() + +def barplot_emotion_recognition(data, save_dir): + labels = ['Ground Truth', 'Baseline', 'Proposed Same', 'Proposed Other'] + + # Set up the figure with a larger width to accommodate tick labels + fig, ax = plt.subplots(figsize=(10, 6)) + + # Use a horizontal bar plot with reversed data + bars = ax.barh(range(len(data)), data[::-1], color="gray") # Reversed data + + # Set the tick positions and labels + ax.set_yticks(range(len(data))) + ax.set_yticklabels(labels[::-1], fontsize=16) # Reversed labels + + # Set the x-axis label + ax.set_xlabel('Accuracy', fontsize=16) + + # Set the font size for y-axis tick labels + ax.tick_params(axis='y', labelsize=16) + ax.tick_params(axis='x', labelsize=16) + + # Remove the right and top spines + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + # Add accuracy values as text beside each bar + for i, v in enumerate(data[::-1]): # Reversed data + ax.text(v, i, str(round(v, 2)), ha='left', va='center', fontsize=14) # Reversed data + + # Adjust the plot layout to prevent the labels from being cut off + plt.subplots_adjust(left=0.2, right=0.95) + + # Save the bar plot + plt.savefig(save_dir) + plt.close() + +def get_variable_labels(v): + labels = None + + if v == "age": + labels = {1: "<20", + 2: "20-29", + 3: "30-39", + 4: "40-49", + 5: "50-59", + 6: "60-69", + 7: "70-79", + 8: ">80", + -9: "--"} + if v == "gender": + labels = {1: "female", + 2: "male", + 3: "divers", + -9: "--"} + if v == "english_skills": + labels = {1: "none", + 2: "beginner", + 3: "intermediate", + 4: "advanced", + 5: "fluent", + 6: "native", + -9: "--"} + if v == "experience": + labels = {1: "daily", + 2: "regularly", + 3: "rarely", + 4: "never", + -9: "--"} + if v.startswith("pref_"): + labels = {1: "A", + 2: "B", + 3: "no preference"} + if v.startswith("sim_"): + labels = {1: "1", + 2: "2", + 3: "3", + 4: "4", + 5: "5"} + + return labels + +def emo_sent_speaker(speaker): + if speaker == "f": + return ["anger_0", "joy_0", "neutral_0", "sadness_0", "surprise_1"] + else: + return ["anger_1", "joy_1", "neutral_1", "sadness_1", "surprise_0"] diff --git a/Evaluation/subjective_evaluation.py b/Evaluation/subjective_evaluation.py new file mode 100644 index 00000000..ae382202 --- /dev/null +++ b/Evaluation/subjective_evaluation.py @@ -0,0 +1,263 @@ +import pandas as pd +import numpy as np +import scipy.stats as stats + +EMOTIONS = ["anger", "joy", "neutral", "sadness", "surprise"] + +def index_to_emotion(i): + id_emotion = {1: "anger", 2: "joy", 3: "neutral", 4: "sadness", 5: "surprise"} + return id_emotion[i] + +def read_data(path_to_data): + return pd.read_csv(path_to_data, encoding="utf-8", delimiter=";") + +def sociodemographics(data): + d = {} + d["age"] = dict(data["SD03"].value_counts().sort_index()) + d["gender"] = dict(data["SD01"].value_counts().sort_index()) + d["english_skills"] = dict(data["SD20"].value_counts().sort_index()) + d["experience"] = dict(data["SD19"].value_counts().sort_index()) + return d + +def preference(data): + d = {} + # A baseline, B proposed + d["pref_anger_0"] = dict(data["CP01"].value_counts().sort_index()) + d["pref_anger_1"] = dict(data["CP02"].value_counts().sort_index()) + d["pref_joy_0"] = dict(data["CP03"].value_counts().sort_index()) + d["pref_joy_1"] = dict(data["CP04"].value_counts().sort_index()) + d["pref_neutral_0"] = dict(data["CP05"].value_counts().sort_index()) + # A proposed, B baseline, so keys are switched such that the order is always the same as above + d_tmp = dict(data["CP06"].value_counts().sort_index()) + d["pref_neutral_1"] = {1.0: d_tmp.get(2.0), 2.0: d_tmp.get(1.0), 3.0: d_tmp.get(3.0)} + d_tmp = dict(data["CP07"].value_counts().sort_index()) + d["pref_sadness_0"] = {1.0: d_tmp.get(2.0), 2.0: d_tmp.get(1.0), 3.0: d_tmp.get(3.0)} + d_tmp = dict(data["CP08"].value_counts().sort_index()) + d["pref_sadness_1"] = {1.0: d_tmp.get(2.0), 2.0: d_tmp.get(1.0), 3.0: d_tmp.get(3.0)} + d_tmp = dict(data["CP09"].value_counts().sort_index()) + d["pref_surprise_0"] = {1.0: d_tmp.get(2.0), 2.0: d_tmp.get(1.0), 3.0: d_tmp.get(3.0)} + d_tmp = dict(data["CP10"].value_counts().sort_index()) + d["pref_surprise_1"] = {1.0: d_tmp.get(2.0), 2.0: d_tmp.get(1.0), 3.0: d_tmp.get(3.0)} + return d + +def similarity(data): + d = {} + d["sim_anger_0"] = dict(data["CS01_01"].value_counts().sort_index()) + d["sim_anger_1"] = dict(data["CS02_01"].value_counts().sort_index()) + d["sim_joy_0"] = dict(data["CS03_01"].value_counts().sort_index()) + d["sim_joy_1"] = dict(data["CS04_01"].value_counts().sort_index()) + d["sim_neutral_0"] = dict(data["CS05_01"].value_counts().sort_index()) + d["sim_neutral_1"] = dict(data["CS06_01"].value_counts().sort_index()) + d["sim_sadness_0"] = dict(data["CS07_01"].value_counts().sort_index()) + d["sim_sadness_1"] = dict(data["CS08_01"].value_counts().sort_index()) + d["sim_surprise_0"] = dict(data["CS09_01"].value_counts().sort_index()) + d["sim_surprise_1"] = dict(data["CS10_01"].value_counts().sort_index()) + return d + +def mean_opinion_score(data, version): + d = {} + d["mos_anger_0"] = dict(data[f"M{version}01"].value_counts().sort_index()) + d["mos_anger_1"] = dict(data[f"M{version}02"].value_counts().sort_index()) + d["mos_joy_0"] = dict(data[f"M{version}03"].value_counts().sort_index()) + d["mos_joy_1"] = dict(data[f"M{version}04"].value_counts().sort_index()) + d["mos_neutral_0"] = dict(data[f"M{version}05"].value_counts().sort_index()) + d["mos_neutral_1"] = dict(data[f"M{version}06"].value_counts().sort_index()) + d["mos_sadness_0"] = dict(data[f"M{version}07"].value_counts().sort_index()) + d["mos_sadness_1"] = dict(data[f"M{version}08"].value_counts().sort_index()) + d["mos_surprise_0"] = dict(data[f"M{version}09"].value_counts().sort_index()) + d["mos_surprise_1"] = dict(data[f"M{version}10"].value_counts().sort_index()) + return d + +def emotion(data, version): + d = {} + d["emotion_anger_0"] = {} + d["emotion_anger_1"] = {} + d["emotion_joy_0"] = {} + d["emotion_joy_1"] = {} + d["emotion_neutral_0"] = {} + d["emotion_neutral_1"] = {} + d["emotion_sadness_0"] = {} + d["emotion_sadness_1"] = {} + d["emotion_surprise_0"] = {} + d["emotion_surprise_1"] = {} + variable_count = 1 + for emo in EMOTIONS: + for j in range(2): + for k, emo_count in enumerate(EMOTIONS): + try: + variable = f"0{variable_count}" if variable_count < 10 else variable_count + d[f"emotion_{emo}_{j}"][emo_count] = dict(data[f"E{version}{variable}_0{k+1}"].value_counts().sort_index())[2] + except KeyError: + d[f"emotion_{emo}_{j}"][emo_count] = 0 + variable_count += 1 + return d + +def valence(data, version): + d = {} + d["anger_0"] = dict(data[f"V{version}01_01"].value_counts().sort_index()) + d["anger_1"] = dict(data[f"V{version}02_01"].value_counts().sort_index()) + d["joy_0"] = dict(data[f"V{version}03_01"].value_counts().sort_index()) + d["joy_1"] = dict(data[f"V{version}04_01"].value_counts().sort_index()) + d["neutral_0"] = dict(data[f"V{version}05_01"].value_counts().sort_index()) + d["neutral_1"] = dict(data[f"V{version}06_01"].value_counts().sort_index()) + d["sadness_0"] = dict(data[f"V{version}07_01"].value_counts().sort_index()) + d["sadness_1"] = dict(data[f"V{version}08_01"].value_counts().sort_index()) + d["surprise_0"] = dict(data[f"V{version}09_01"].value_counts().sort_index()) + d["surprise_1"] = dict(data[f"V{version}10_01"].value_counts().sort_index()) + return d + +def arousal(data, version): + d = {} + d["anger_0"] = dict(data[f"V{version}01_02"].value_counts().sort_index()) + d["anger_1"] = dict(data[f"V{version}02_02"].value_counts().sort_index()) + d["joy_0"] = dict(data[f"V{version}03_02"].value_counts().sort_index()) + d["joy_1"] = dict(data[f"V{version}04_02"].value_counts().sort_index()) + d["neutral_0"] = dict(data[f"V{version}05_02"].value_counts().sort_index()) + d["neutral_1"] = dict(data[f"V{version}06_02"].value_counts().sort_index()) + d["sadness_0"] = dict(data[f"V{version}07_02"].value_counts().sort_index()) + d["sadness_1"] = dict(data[f"V{version}08_02"].value_counts().sort_index()) + d["surprise_0"] = dict(data[f"V{version}09_02"].value_counts().sort_index()) + d["surprise_1"] = dict(data[f"V{version}10_02"].value_counts().sort_index()) + return d + +def get_mean_rating_nested(d: dict): + d_mean = {} + for k, v in d.items(): + total_sum = 0 + count = 0 + for rating, count_value in v.items(): + total_sum += rating * count_value + count += count_value + mean_rating = total_sum / count + d_mean[k] = mean_rating + return d_mean + +def split_female_male(combined): + female = {} + male = {} + for sent_id, d in combined.items(): + try: + if sent_id.split("_")[1] + "_" + sent_id.split("_")[2] in emo_sent_speaker("f"): + female[sent_id.split("_")[1]] = d + if sent_id.split("_")[1] + "_" + sent_id.split("_")[2] in emo_sent_speaker("m"): + male[sent_id.split("_")[1]] = d + except IndexError: + if sent_id in emo_sent_speaker("f"): + female[sent_id.split("_")[0]] = d + if sent_id in emo_sent_speaker("m"): + male[sent_id.split("_")[0]] = d + return female, male + +def make_emotion_prompts(emotion_d, speaker): + emo_prompt_match = emo_prompt_speaker(speaker) + emotion_prompt_d = {} + for emo, d in emotion_d.items(): + emotion_prompt_d[emo_prompt_match[emo]] = d + return emotion_prompt_d + +def emo_sent_speaker(speaker): + if speaker == "f": + return ["anger_0", "joy_0", "neutral_0", "sadness_0", "surprise_1"] + else: + return ["anger_1", "joy_1", "neutral_1", "sadness_1", "surprise_0"] + +def emo_prompt_speaker(speaker): + if speaker == "f": + return {"anger" : "neutral", + "joy" : "surprise", + "neutral" : "joy", + "sadness" : "anger", + "surprise": "sadness"} + else: + return {"anger" : "surprise", + "joy" : "sadness", + "neutral" : "anger", + "sadness" : "joy", + "surprise": "neutral"} + +def remove_outliers(data): + ratings = {emotion: list(rating for rating, count in ratings_counts.items() for _ in range(count)) for emotion, ratings_counts in data.items()} + cleaned_data = {} + for emotion, ratings_list in ratings.items(): + sorted_data = sorted(ratings_list) + q1, q3 = np.percentile(sorted_data, [25, 75]) + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + cleaned_data[emotion] = [x for x in sorted_data if lower_bound <= x <= upper_bound] + counts = {} + for emotion, ratings in cleaned_data.items(): + rating_counts = {} + for rating in ratings: + if rating in rating_counts: + rating_counts[rating] += 1 + else: + rating_counts[rating] = 1 + counts[emotion] = rating_counts + return counts + +def combine_dicts(d1, d2): + combined_dict = {} + for emotion in d1: + combined_dict[emotion] = {} + all_keys = sorted(set(d1[emotion].keys()) | set(d2[emotion].keys())) + for rating in all_keys: + combined_dict[emotion][rating] = d1[emotion].get(rating, 0) + d2[emotion].get(rating, 0) + return combined_dict + +def collapse_subdicts(d): + collapsed_dict = {} + for emotion, sub_dict in d.items(): + for scenario, count in sub_dict.items(): + if scenario not in collapsed_dict: + collapsed_dict[scenario] = 0 + collapsed_dict[scenario] += count + return collapsed_dict + +def independent_samples_t_test(data11, data12, data21, data22): + # remove outliers and unfold data + data11 = remove_outliers(data11) + data12 = remove_outliers(data12) + data21 = remove_outliers(data21) + data22 = remove_outliers(data22) + ratings11 = [] + for emotion, ratings in data11.items(): + for rating, count in ratings.items(): + ratings11.extend([rating] * count) + ratings12 = [] + for emotion, ratings in data12.items(): + for rating, count in ratings.items(): + ratings12.extend([rating] * count) + ratings21 = [] + for emotion, ratings in data21.items(): + for rating, count in ratings.items(): + ratings21.extend([rating] * count) + ratings22 = [] + for emotion, ratings in data22.items(): + for rating, count in ratings.items(): + ratings22.extend([rating] * count) + ratings11.extend(ratings12) + ratings21.extend(ratings22) + # Perform independent samples t-test + # ratings are assumed to be independent because every participant only sees a selection of samples + # i.e. there isn't a rating for each sample by each participant + t_statistic, p_value = stats.ttest_ind(ratings11, ratings21) + return t_statistic, p_value + +def cramers_v(data): + # Convert the data dictionary into a 2D array + counts = np.array([[data[emotion][label] for emotion in data] for label in data[list(data.keys())[0]]]) + + # Compute the chi-squared statistic and p-value + chi2, p, _, _ = stats.chi2_contingency(counts) + + # Number of observations (total counts) + n = np.sum(counts) + + # Number of rows and columns in the contingency table + num_rows = len(data[list(data.keys())[0]]) + num_cols = len(data) + + # Compute Cramér's V + cramer_v = np.sqrt(chi2 / (n * (min(num_rows, num_cols) - 1))) + return p, cramer_v diff --git a/InferenceInterfaces/InferenceArchitectures/InferenceToucanTTS.py b/InferenceInterfaces/InferenceArchitectures/InferenceToucanTTS.py index 14a8695b..672f5b85 100644 --- a/InferenceInterfaces/InferenceArchitectures/InferenceToucanTTS.py +++ b/InferenceInterfaces/InferenceArchitectures/InferenceToucanTTS.py @@ -1,4 +1,5 @@ import torch +from torchvision.ops import SqueezeExcitation from torch.nn import Linear from torch.nn import Sequential from torch.nn import Tanh @@ -72,7 +73,10 @@ def __init__(self, utt_embed_dim=64, detach_postflow=True, lang_embs=8000, - weights=None): + weights=None, + sent_embed_dim=None, + word_embed_dim=None, + static_speaker_embed=False): super().__init__() self.input_feature_dimensions = input_feature_dimensions @@ -82,6 +86,27 @@ def __init__(self, self.use_scaled_pos_enc = use_scaled_positional_encoding self.multilingual_model = lang_embs is not None self.multispeaker_model = utt_embed_dim is not None + self.use_sent_embed = sent_embed_dim is not None + self.use_word_embed = word_embed_dim is not None + self.static_speaker_embed = static_speaker_embed + + if self.static_speaker_embed: + # emovdb - 4, cremad - 91, esds - 10, ravdess - 24, ljspeech - 1, librittsr - 1230, tess - 2 + self.speaker_embedding = torch.nn.Embedding(10 + 24 + 1 + 1230 + 2, utt_embed_dim) + + if self.use_sent_embed: + self.sentence_embedding_adaptation = Linear(sent_embed_dim, 512) + sent_embed_dim = 512 + + self.squeeze_excitation = SqueezeExcitation(utt_embed_dim + sent_embed_dim, 192) + self.style_embedding_projection = Sequential(Linear(utt_embed_dim + sent_embed_dim, 512), + Tanh(), + Linear(512, 192)) + utt_embed_dim = 192 + else: + if utt_embed_dim is not None: + self.speaker_embedding_adaptation = Linear(utt_embed_dim, 192) + utt_embed_dim = 192 articulatory_feature_embedding = Sequential(Linear(input_feature_dimensions, 100), Tanh(), Linear(100, attention_dimension)) self.encoder = Conformer(idim=input_feature_dimensions, @@ -102,7 +127,9 @@ def __init__(self, zero_triu=False, utt_embed=utt_embed_dim, lang_embs=lang_embs, - use_output_norm=True) + word_embed_dim=word_embed_dim, + use_output_norm=True, + conformer_encoder=True) self.duration_predictor = DurationPredictor(idim=attention_dimension, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, @@ -149,6 +176,7 @@ def __init__(self, macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_decoder_kernel_size, + utt_embed=utt_embed_dim, use_output_norm=False) self.feat_out = Linear(attention_dimension, output_spectrogram_channels) @@ -188,6 +216,9 @@ def _forward(self, gold_energy=None, duration_scaling_factor=1.0, utterance_embedding=None, + speaker_id=None, + sentence_embedding=None, + word_embedding=None, lang_ids=None, pitch_variance_scale=1.0, energy_variance_scale=1.0, @@ -199,11 +230,44 @@ def _forward(self, if not self.multispeaker_model: utterance_embedding = None else: - utterance_embedding = torch.nn.functional.normalize(utterance_embedding) + if self.static_speaker_embed: + utterance_embedding = self.speaker_embedding(speaker_id) + else: + utterance_embedding = torch.nn.functional.normalize(utterance_embedding) + + if not self.use_sent_embed: + sentence_embedding = None + utterance_embedding = self.speaker_embedding_adaptation(utterance_embedding) + else: + sentence_embedding = torch.nn.functional.normalize(sentence_embedding) + # forward sentence embedding adaptation + sentence_embedding = self.sentence_embedding_adaptation(sentence_embedding) + utterance_embedding = torch.cat([utterance_embedding, sentence_embedding], dim=1) + utterance_embedding = self.squeeze_excitation(utterance_embedding.transpose(0, 1).unsqueeze(-1)).squeeze(-1).transpose(0, 1) + utterance_embedding = self.style_embedding_projection(utterance_embedding) + + if not self.use_word_embed: + word_embedding = None + word_boundaries_batch = None + else: + # get word boundaries + word_boundaries_batch = [] + for batch_id, batch in enumerate(text_tensors): + word_boundaries = [] + for phoneme_index, phoneme_vector in enumerate(batch): + if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1: + word_boundaries.append(phoneme_index) + word_boundaries.append(text_lengths[batch_id].cpu().numpy()-1) # marker for last word of sentence + word_boundaries_batch.append(torch.tensor(word_boundaries)) # encoding the texts text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2) - encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) + encoded_texts, _ = self.encoder(text_tensors, + text_masks, + utterance_embedding=utterance_embedding, + word_embedding=word_embedding, + word_boundaries=word_boundaries_batch, + lang_ids=lang_ids) # predicting pitch, energy and durations pitch_predictions = self.pitch_predictor(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_pitch is None else gold_pitch @@ -235,7 +299,9 @@ def _forward(self, upsampled_enriched_encoded_texts = self.length_regulator(enriched_encoded_texts, predicted_durations) # decoding spectrogram - decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, None) + decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, + None, + utterance_embedding=utterance_embedding) decoded_spectrogram = self.feat_out(decoded_speech).view(decoded_speech.size(0), -1, self.output_spectrogram_channels) refined_spectrogram = decoded_spectrogram + self.conv_postnet(decoded_spectrogram.transpose(1, 2)).transpose(1, 2) @@ -256,6 +322,9 @@ def forward(self, pitch=None, energy=None, utterance_embedding=None, + speaker_id=None, + sentence_embedding=None, + word_embedding=None, return_duration_pitch_energy=False, lang_id=None, duration_scaling_factor=1.0, @@ -309,7 +378,11 @@ def forward(self, gold_durations=durations, gold_pitch=pitch, gold_energy=energy, - utterance_embedding=utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None, lang_ids=lang_id, + utterance_embedding=utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None, + speaker_id=speaker_id, + sentence_embedding=sentence_embedding.unsqueeze(0) if sentence_embedding is not None else None, + word_embedding=word_embedding.unsqueeze(0) if word_embedding is not None else None, + lang_ids=lang_id, duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, diff --git a/InferenceInterfaces/ToucanTTSInterface.py b/InferenceInterfaces/ToucanTTSInterface.py index 85cf7188..8d243399 100644 --- a/InferenceInterfaces/ToucanTTSInterface.py +++ b/InferenceInterfaces/ToucanTTSInterface.py @@ -6,6 +6,7 @@ import sounddevice import soundfile import torch +import torchaudio from InferenceInterfaces.InferenceArchitectures.InferenceAvocodo import HiFiGANGenerator from InferenceInterfaces.InferenceArchitectures.InferenceBigVGAN import BigVGAN @@ -27,6 +28,9 @@ def __init__(self, vocoder_model_path=None, # path to the hifigan/avocodo/bigvgan checkpoint faster_vocoder=True, # whether to use the quicker HiFiGAN or the better BigVGAN language="en", # initial language of the model, can be changed later with the setter methods + sent_emb_extractor=None, + word_emb_extractor=None, + xvect_model=None ): super().__init__() self.device = device @@ -53,14 +57,49 @@ def __init__(self, # load phone to mel model # ################################ self.use_lang_id = True + self.use_sent_emb = False + self.static_speaker_embed=False + self.use_word_emb = False try: + if "sent_emb" in tts_model_path: + raise RuntimeError self.phone2mel = ToucanTTS(weights=checkpoint["model"]) # multi speaker multi language except RuntimeError: try: + if "sent_emb" in tts_model_path: + raise RuntimeError self.use_lang_id = False - self.phone2mel = ToucanTTS(weights=checkpoint["model"], lang_embs=None) # multi speaker single language + self.phone2mel = ToucanTTS(weights=checkpoint["model"], + lang_embs=None) # multi speaker single language except RuntimeError: - self.phone2mel = ToucanTTS(weights=checkpoint["model"], lang_embs=None, utt_embed_dim=None) # single speaker + try: + if "sent_emb" in tts_model_path: + raise RuntimeError + self.use_lang_id = False + self.phone2mel = ToucanTTS(weights=checkpoint["model"], + lang_embs=None, + utt_embed_dim=None) # single speaker, single language + except RuntimeError: + try: + if "sent_emb" in tts_model_path: + raise RuntimeError + print("Loading baseline architecture") + self.use_lang_id = False + self.static_speaker_embed = True + self.phone2mel = ToucanTTS(weights=checkpoint["model"], + lang_embs=None, + utt_embed_dim=512, + static_speaker_embed=self.static_speaker_embed) + except RuntimeError: + print("Loading sent emb architecture") + self.use_lang_id = False + self.use_sent_emb = True + self.static_speaker_embed = True + self.phone2mel = ToucanTTS(weights=checkpoint["model"], + lang_embs=None, + utt_embed_dim=512, + sent_embed_dim=768, + static_speaker_embed=self.static_speaker_embed) with torch.no_grad(): self.phone2mel.store_inverse_all() # this also removes weight norm self.phone2mel = self.phone2mel.to(torch.device(device)) @@ -68,13 +107,35 @@ def __init__(self, ################################# # load mel to style models # ################################# - self.style_embedding_function = StyleEmbedding() - if embedding_model_path is None: - check_dict = torch.load(os.path.join(MODELS_DIR, "Embedding", "embedding_function.pt"), map_location="cpu") - else: + if embedding_model_path is not None: + self.style_embedding_function = StyleEmbedding() check_dict = torch.load(embedding_model_path, map_location="cpu") - self.style_embedding_function.load_state_dict(check_dict["style_emb_func"]) - self.style_embedding_function.to(self.device) + self.style_embedding_function.load_state_dict(check_dict["style_emb_func"]) + self.style_embedding_function.to(self.device) + else: + self.style_embedding_function = None + + self.xvect_model = xvect_model if xvect_model is not None else None + + ################################# + # load sent emb extractor # + ################################# + self.sentence_embedding_extractor = None + if self.use_sent_emb: + if sent_emb_extractor is not None: + self.sentence_embedding_extractor = sent_emb_extractor + else: + raise KeyError("Please specify a sentence embedding extractor.") + + ################################# + # load word emb extractor # + ################################# + self.word_embedding_extractor = None + if self.use_word_emb: + if word_emb_extractor is not None: + self.word_embedding_extractor = word_emb_extractor + else: + raise KeyError("Please specify a word embedding extractor.") ################################ # load mel to wave model # @@ -88,11 +149,20 @@ def __init__(self, ################################ # set defaults # ################################ - self.default_utterance_embedding = checkpoint["default_emb"].to(self.device) + try: + self.default_utterance_embedding = checkpoint["default_emb"].to(self.device) + except KeyError: + self.default_utterance_embedding = None + if self.static_speaker_embed: + self.default_speaker_id = torch.LongTensor([0]).to(self.device) + else: + self.default_speaker_id = None + self.sentence_embedding = None self.audio_preprocessor = AudioPreprocessor(input_sr=16000, output_sr=16000, cut_silence=True, device=self.device) self.phone2mel.eval() self.mel2wav.eval() - self.style_embedding_function.eval() + if self.style_embedding_function is not None: + self.style_embedding_function.eval() if self.use_lang_id: self.lang_id = get_language_id(language) else: @@ -105,13 +175,35 @@ def set_utterance_embedding(self, path_to_reference_audio="", embedding=None): self.default_utterance_embedding = embedding.squeeze().to(self.device) return assert os.path.exists(path_to_reference_audio) - wave, sr = soundfile.read(path_to_reference_audio) - if sr != self.audio_preprocessor.sr: - self.audio_preprocessor = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=True, device=self.device) - spec = self.audio_preprocessor.audio_to_mel_spec_tensor(wave).transpose(0, 1) - spec_len = torch.LongTensor([len(spec)]) - self.default_utterance_embedding = self.style_embedding_function(spec.unsqueeze(0).to(self.device), - spec_len.unsqueeze(0).to(self.device)).squeeze() + if self.xvect_model is not None: + print("Extracting xvect from reference audio.") + wave, sr = torchaudio.load(path_to_reference_audio) + # mono + wave = torch.mean(wave, dim=0, keepdim=True) + # resampling + wave = torchaudio.functional.resample(wave, orig_freq=sr, new_freq=16000) + wave = wave.squeeze(0) + self.default_utterance_embedding = self.xvect_model.encode_batch(wave).squeeze(0).squeeze(0) + else: + wave, sr = soundfile.read(path_to_reference_audio) + if sr != self.audio_preprocessor.sr: + self.audio_preprocessor = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=True, device=self.device) + spec = self.audio_preprocessor.audio_to_mel_spec_tensor(wave).transpose(0, 1) + spec_len = torch.LongTensor([len(spec)]) + self.default_utterance_embedding = self.style_embedding_function(spec.unsqueeze(0).to(self.device), + spec_len.unsqueeze(0).to(self.device)).squeeze() + + def set_sentence_embedding(self, prompt:str, silent=True): + if self.use_sent_emb: + if not silent: + print(f"Using sentence embedding of given prompt: {prompt}") + prompt_embedding = self.sentence_embedding_extractor.encode([prompt]).squeeze().to(self.device) + self.sentence_embedding = prompt_embedding + else: + print("Skipping setting sentence embedding.") + + def set_speaker_id(self, id:int): + self.default_speaker_id = torch.LongTensor([id]).to(self.device) def set_language(self, lang_id): """ @@ -132,6 +224,7 @@ def set_accent_language(self, lang_id): def forward(self, text, view=False, + view_contours=False, duration_scaling_factor=1.0, pitch_variance_scale=1.0, energy_variance_scale=1.0, @@ -140,7 +233,9 @@ def forward(self, pitch=None, energy=None, input_is_phones=False, - return_plot_as_filepath=False): + return_plot_as_filepath=False, + plot_name="tmp", + silent=False): """ duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. 1.0 means no scaling happens, higher values increase durations for the whole @@ -154,9 +249,25 @@ def forward(self, """ with torch.inference_mode(): phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device)) + if self.use_sent_emb and self.sentence_embedding is None: + if not silent: + print("Using sentence embedding of input text.") + sentence_embedding = self.sentence_embedding_extractor.encode([text]).squeeze().to(self.device) + else: + sentence_embedding = self.sentence_embedding + if self.use_word_emb: + if not silent: + print("Extracting word embeddings.") + word_embeddings, _ = self.word_embedding_extractor.encode([text]) + word_embeddings = word_embeddings.squeeze().to(self.device) + else: + word_embeddings = None mel, durations, pitch, energy = self.phone2mel(phones, return_duration_pitch_energy=True, utterance_embedding=self.default_utterance_embedding, + speaker_id=self.default_speaker_id, + sentence_embedding=sentence_embedding, + word_embedding=word_embeddings, durations=durations, pitch=pitch, energy=energy, @@ -170,7 +281,7 @@ def forward(self, if view or return_plot_as_filepath: from Utility.utils import cumsum_durations - fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(9, 6)) + fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(9, 4)) ax[0].plot(wave.cpu().numpy()) lbd.specshow(mel.cpu().numpy(), ax=ax[1], @@ -218,14 +329,73 @@ def forward(self, pitch_array = pitch.cpu().numpy() for pitch_index, xrange in enumerate(zip(duration_splits[:-1], duration_splits[1:])): if pitch_array[pitch_index] != 0: - ax[1].hlines(pitch_array[pitch_index] * 1000, xmin=xrange[0], xmax=xrange[1], color="magenta", + ax[1].hlines(pitch_array[pitch_index] * 1000, xmin=xrange[0], xmax=xrange[1], color="red", linestyles="solid", linewidth=1.) plt.subplots_adjust(left=0.05, bottom=0.12, right=0.95, top=.9, wspace=0.0, hspace=0.0) if not return_plot_as_filepath: plt.show() else: - plt.savefig("tmp.png") - return wave, "tmp.png" + plt.savefig(f"{plot_name}.png") + return wave, f"{plot_name}.png" + if view_contours: + from Utility.utils import cumsum_durations + fig, ax = plt.subplots(figsize=(8,4)) + lbd.specshow(mel.cpu().numpy(), + ax=ax, + sr=16000, + cmap='GnBu', + y_axis='mel', + x_axis=None, + hop_length=256) + ax.yaxis.set_visible(False) + #ax.set_ylim(200, 4000) + duration_splits, label_positions = cumsum_durations(durations.cpu().numpy()) + ax.xaxis.grid(True, which='minor') + ax.set_xticks(label_positions, minor=False) + if input_is_phones: + phones = text.replace(" ", "|") + else: + phones = self.text2phone.get_phone_string(text, for_plot_labels=True) + ax.set_xticklabels(phones, fontsize=28) + word_boundaries = list() + for label_index, phone in enumerate(phones): + if phone == "|": + word_boundaries.append(label_positions[label_index]) + + try: + prev_word_boundary = 0 + word_label_positions = list() + for word_boundary in word_boundaries: + word_label_positions.append((word_boundary + prev_word_boundary) / 2) + prev_word_boundary = word_boundary + word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2) + + #secondary_ax = ax.secondary_xaxis('bottom') + #secondary_ax.tick_params(axis="x", direction="out", pad=24) + #secondary_ax.set_xticks(word_label_positions, minor=False) + #secondary_ax.set_xticklabels(text.split()) + #secondary_ax.tick_params(axis='x', colors='black', labelsize=16) + #secondary_ax.xaxis.label.set_color('black') + except ValueError: + ax.set_title(text) + except IndexError: + ax.set_title(text) + + #ax.vlines(x=duration_splits, colors="black", linestyles="dotted", ymin=0.0, ymax=8000, linewidth=1.0) + #ax.vlines(x=word_boundaries, colors="black", linestyles="solid", ymin=0.0, ymax=8000, linewidth=1.2) + pitch_array = pitch.cpu().numpy() + for pitch_index, xrange in enumerate(zip(duration_splits[:-1], duration_splits[1:])): + if pitch_array[pitch_index] != 0: + ax.hlines(pitch_array[pitch_index] * 1000, xmin=xrange[0], xmax=xrange[1], color="red", + linestyles="solid", linewidth=5) + #energy_array = energy.cpu().numpy() + #for energy_index, xrange in enumerate(zip(duration_splits[:-1], duration_splits[1:])): + # if energy_array[energy_index] != 0: + # ax.hlines(energy_array[energy_index] * 1000, xmin=xrange[0], xmax=xrange[1], color="orange", + # linestyles="solid", linewidth=2.5) + plt.subplots_adjust(left=0.05, bottom=0.12, right=0.95, top=.9, wspace=0.0, hspace=0.0) + plt.savefig(f"{plot_name}.pdf", format="pdf") + plt.close() return wave def read_to_file(self, @@ -238,7 +408,10 @@ def read_to_file(self, dur_list=None, pitch_list=None, energy_list=None, - increased_compatibility_mode=False): + increased_compatibility_mode=False, + view=False, + view_contours=False, + plot_name="tmp"): """ Args: increased_compatibility_mode: Whether to export audio as 16bit integer 48kHz audio for maximum compatibility across systems and devices @@ -276,7 +449,11 @@ def read_to_file(self, energy=energy.to(self.device) if energy is not None else None, duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, - energy_variance_scale=energy_variance_scale).cpu() + energy_variance_scale=energy_variance_scale, + silent=silent, + view=view, + view_contours=view_contours, + plot_name=plot_name).cpu() wav = torch.cat((wav, spoken_sentence, silence), 0) if increased_compatibility_mode: wav = [val for val in wav.numpy() for _ in (0, 1)] # doubling the sampling rate for better compatibility (24kHz is not as standard as 48kHz) diff --git a/Layers/Conformer.py b/Layers/Conformer.py index 33f707fd..1932721d 100644 --- a/Layers/Conformer.py +++ b/Layers/Conformer.py @@ -3,6 +3,8 @@ """ import torch +from torchvision.ops import SqueezeExcitation +from torch.nn.utils.rnn import pad_sequence from Layers.Attention import RelPositionMultiHeadedAttention from Layers.Convolution import ConvolutionModule @@ -45,14 +47,35 @@ class Conformer(torch.nn.Module): """ - def __init__(self, idim, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1, - attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1, - macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, lang_embs=None, use_output_norm=True): + def __init__(self, + idim, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + normalize_before=True, + concat_after=False, + positionwise_conv_kernel_size=1, + macaron_style=False, + use_cnn_module=False, + cnn_module_kernel=31, + zero_triu=False, + use_output_norm=True, + utt_embed=None, + lang_embs=None, + word_embed_dim=None, + conformer_encoder=False, + ): super(Conformer, self).__init__() activation = Swish() self.conv_subsampling_factor = 1 self.use_output_norm = use_output_norm + self.conformer_encoder = conformer_encoder if isinstance(input_layer, torch.nn.Module): self.embed = input_layer @@ -66,11 +89,21 @@ def __init__(self, idim, attention_dim=256, attention_heads=4, linear_units=2048 if self.use_output_norm: self.output_norm = LayerNorm(attention_dim) self.utt_embed = utt_embed - if utt_embed is not None: - self.hs_emb_projection = torch.nn.Linear(attention_dim + utt_embed, attention_dim) + self.word_embed_dim = word_embed_dim + + if self.utt_embed is not None: + self.hs_emb_projection = torch.nn.Linear(attention_dim + self.utt_embed, attention_dim) + if self.conformer_encoder: + self.encoder_projection = torch.nn.Linear(attention_dim + self.utt_embed, attention_dim) + if lang_embs is not None: self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=attention_dim) + if self.word_embed_dim is not None: + self.word_embed_adaptation = torch.nn.Linear(word_embed_dim, attention_dim) + self.word_phoneme_projection = torch.nn.Linear(attention_dim * 2, attention_dim) + self.squeeze_excitation_word = SqueezeExcitation(attention_dim * 2, attention_dim) + # self-attention module definition encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu) @@ -93,7 +126,9 @@ def forward(self, xs, masks, utterance_embedding=None, - lang_ids=None): + lang_ids=None, + word_embedding=None, + word_boundaries=None): """ Encode input sequence. Args: @@ -109,6 +144,20 @@ def forward(self, if self.embed is not None: xs = self.embed(xs) + if self.word_embed_dim is not None: + word_embedding = torch.nn.functional.normalize(word_embedding, dim=0) + word_embedding = self.word_embed_adaptation(word_embedding) + xs = self._integrate_with_word_embed(xs=xs, + word_boundaries=word_boundaries, + word_embedding=word_embedding, + word_phoneme_projection=self.word_phoneme_projection, + word_phoneme_squeeze_excitation=self.squeeze_excitation_word) + + if self.utt_embed is not None: + xs = self._integrate_with_utt_embed(hs=xs, + utt_embeddings=utterance_embedding, + projection=self.hs_emb_projection) + if lang_ids is not None: lang_embs = self.language_embedding(lang_ids) xs = xs + lang_embs # offset phoneme representation by language specific offset @@ -122,13 +171,51 @@ def forward(self, if self.use_output_norm: xs = self.output_norm(xs) - if self.utt_embed: - xs = self._integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding) + if self.utt_embed is not None and self.conformer_encoder: # only do this in the encoder + xs = self._integrate_with_utt_embed(hs=xs, + utt_embeddings=utterance_embedding, + projection=self.encoder_projection) return xs, masks - def _integrate_with_utt_embed(self, hs, utt_embeddings): + def _integrate_with_utt_embed(self, hs, utt_embeddings, projection): # concat hidden states with spk embeds and then apply projection embeddings_expanded = torch.nn.functional.normalize(utt_embeddings).unsqueeze(1).expand(-1, hs.size(1), -1) - hs = self.hs_emb_projection(torch.cat([hs, embeddings_expanded], dim=-1)) + hs = torch.cat([hs, embeddings_expanded], dim=-1) + hs = projection(hs) return hs + + def _integrate_with_word_embed(self, xs, word_boundaries, word_embedding, word_phoneme_projection, word_phoneme_squeeze_excitation): + xs_enhanced = [] + for batch_id, wbs in enumerate(word_boundaries): + w_start = 0 + phoneme_sequence = [] + for i, wb_id in enumerate(wbs): + # get phoneme embeddings corresponding to words according to word boundaries + phoneme_embeds = xs[batch_id, w_start:wb_id+1] + # get cooresponding word embedding + try: + word_embed = word_embedding[batch_id, i] + # if mismatch of words and phonemizer is not handled + except IndexError: + # take last word embedding again to avoid errors + word_embed = word_embedding[batch_id, -1] + # concatenate phoneme embeddings with word embedding + phoneme_embeds = self._cat_with_word_embed(phoneme_embeddings=phoneme_embeds, word_embedding=word_embed) + phoneme_sequence.append(phoneme_embeds) + w_start = wb_id + 1 + # put whole phoneme sequence back together + phoneme_sequence = torch.cat(phoneme_sequence, dim=0) + xs_enhanced.append(phoneme_sequence) + # pad phoneme sequences to get whole batch + xs_enhanced_padded = pad_sequence(xs_enhanced, batch_first=True) + # apply projection + xs = word_phoneme_squeeze_excitation(xs_enhanced_padded.transpose(0, 2)).transpose(0, 2) + xs = word_phoneme_projection(xs_enhanced_padded) + return xs + + def _cat_with_word_embed(self, phoneme_embeddings, word_embedding): + # concat phoneme embeddings with corresponding word embedding and then apply projection + word_embeddings_expanded = torch.nn.functional.normalize(word_embedding, dim=0).unsqueeze(0).expand(phoneme_embeddings.size(0), -1) + phoneme_embeddings = torch.cat([phoneme_embeddings, word_embeddings_expanded], dim=-1) + return phoneme_embeddings \ No newline at end of file diff --git a/Layers/EncoderLayer.py b/Layers/EncoderLayer.py index 6ae91c25..45d8bb6a 100644 --- a/Layers/EncoderLayer.py +++ b/Layers/EncoderLayer.py @@ -36,7 +36,7 @@ class EncoderLayer(nn.Module): """ - def __init__(self, size, self_attn, feed_forward, feed_forward_macaron, conv_module, dropout_rate, normalize_before=True, concat_after=False, ): + def __init__(self, size, self_attn, feed_forward, feed_forward_macaron, conv_module, dropout_rate, normalize_before=True, concat_after=False): super(EncoderLayer, self).__init__() self.self_attn = self_attn self.feed_forward = feed_forward diff --git a/Preprocessing/sentence_embeddings/BERTSentenceEmbeddingExtractor.py b/Preprocessing/sentence_embeddings/BERTSentenceEmbeddingExtractor.py new file mode 100644 index 00000000..44171302 --- /dev/null +++ b/Preprocessing/sentence_embeddings/BERTSentenceEmbeddingExtractor.py @@ -0,0 +1,28 @@ +import torch +from transformers import BertTokenizer, BertModel + +from Preprocessing.sentence_embeddings.SentenceEmbeddingExtractor import SentenceEmbeddingExtractor + +class BERTSentenceEmbeddingExtractor(SentenceEmbeddingExtractor): + def __init__(self, cache_dir:str="", pooling:str='second_to_last_mean', device=torch.device("cuda")): + super().__init__() + if cache_dir: + self.cache_dir = cache_dir + self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=self.cache_dir) + self.model = BertModel.from_pretrained("bert-base-uncased", cache_dir=self.cache_dir).to(device) + assert pooling in ["cls", "last_mean", "second_to_last_mean"] + self.pooling = pooling + self.device = device + + def encode(self, sentences: list[str]) -> torch.Tensor: + if self.pooling == "cls": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + return self.model(**encoded_input).last_hidden_state[:,0].detach().cpu() + if self.pooling == "last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(**encoded_input, output_hidden_states=True).last_hidden_state + return torch.mean(token_embeddings, dim=1).detach().cpu() + if self.pooling == "second_to_last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(**encoded_input, output_hidden_states=True).hidden_states[-2] + return torch.mean(token_embeddings, dim=1).detach().cpu() \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/BLOOMSentenceEmbeddingExtractor.py b/Preprocessing/sentence_embeddings/BLOOMSentenceEmbeddingExtractor.py new file mode 100644 index 00000000..faa0cb76 --- /dev/null +++ b/Preprocessing/sentence_embeddings/BLOOMSentenceEmbeddingExtractor.py @@ -0,0 +1,25 @@ +import torch +from transformers import AutoTokenizer, BloomModel + +from Preprocessing.sentence_embeddings.SentenceEmbeddingExtractor import SentenceEmbeddingExtractor + +class BLOOMSentenceEmbeddingExtractor(SentenceEmbeddingExtractor): + def __init__(self, cache_dir:str="", pooling:str='second_to_last_mean', device=torch.device("cuda")): + super().__init__() + if cache_dir: + self.cache_dir = cache_dir + self.tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m", cache_dir=self.cache_dir) + self.model = BloomModel.from_pretrained("bigscience/bloom-560m", cache_dir=self.cache_dir).to(device) + assert pooling in ["last_mean", "second_to_last_mean"] + self.pooling = pooling + self.device = device + + def encode(self, sentences: list[str]) -> torch.Tensor: + if self.pooling == "last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(**encoded_input, output_hidden_states=True).last_hidden_state + return torch.mean(token_embeddings, dim=1).detach().cpu() + if self.pooling == "second_to_last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(**encoded_input, output_hidden_states=True).hidden_states[-2] + return torch.mean(token_embeddings, dim=1).detach().cpu() \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/ByT5SentenceEmbeddingExtractor.py b/Preprocessing/sentence_embeddings/ByT5SentenceEmbeddingExtractor.py new file mode 100644 index 00000000..e91409b1 --- /dev/null +++ b/Preprocessing/sentence_embeddings/ByT5SentenceEmbeddingExtractor.py @@ -0,0 +1,25 @@ +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from Preprocessing.sentence_embeddings.SentenceEmbeddingExtractor import SentenceEmbeddingExtractor + +class ByT5SentenceEmbeddingExtractor(SentenceEmbeddingExtractor): + def __init__(self, cache_dir:str="", pooling:str='second_to_last_mean', device=torch.device("cuda")): + super().__init__() + if cache_dir: + self.cache_dir = cache_dir + self.tokenizer = AutoTokenizer.from_pretrained('google/byt5-base', cache_dir=self.cache_dir) + self.model = T5EncoderModel.from_pretrained('google/byt5-base', cache_dir=self.cache_dir).to(device) + assert pooling in ["last_mean", "second_to_last_mean"] + self.pooling = pooling + self.device = device + + def encode(self, sentences: list[str]) -> torch.Tensor: + if self.pooling == "last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(encoded_input.input_ids).last_hidden_state + return torch.mean(token_embeddings, dim=1).detach().cpu() + if self.pooling == "second_to_last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(encoded_input.input_ids, output_hidden_states=True).hidden_states[-2] + return torch.mean(token_embeddings, dim=1).detach().cpu() \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/CAMEMBERTSentenceEmbeddingExtractor.py b/Preprocessing/sentence_embeddings/CAMEMBERTSentenceEmbeddingExtractor.py new file mode 100644 index 00000000..6c37ccd1 --- /dev/null +++ b/Preprocessing/sentence_embeddings/CAMEMBERTSentenceEmbeddingExtractor.py @@ -0,0 +1,30 @@ +import torch +import os +from transformers import CamembertTokenizer, CamembertModel + +from Preprocessing.sentence_embeddings.SentenceEmbeddingExtractor import SentenceEmbeddingExtractor + +class CAMEMBERTSentenceEmbeddingExtractor(SentenceEmbeddingExtractor): + def __init__(self, cache_dir:str="", pooling:str='second_to_last_mean', device=torch.device("cuda")): + super().__init__() + if cache_dir: + self.cache_dir = cache_dir + os.makedirs(self.cache_dir, exist_ok=True) + self.tokenizer = CamembertTokenizer.from_pretrained('camembert-base', cache_dir=self.cache_dir) + self.model = CamembertModel.from_pretrained("camembert-base", cache_dir=self.cache_dir).to(device) + assert pooling in ["cls", "last_mean", "second_to_last_mean"] + self.pooling = pooling + self.device = device + + def encode(self, sentences: list[str]) -> torch.Tensor: + if self.pooling == "cls": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + return self.model(**encoded_input).last_hidden_state[:,0].detach().cpu() + if self.pooling == "last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(**encoded_input, output_hidden_states=True).last_hidden_state + return torch.mean(token_embeddings, dim=1).detach().cpu() + if self.pooling == "second_to_last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(**encoded_input, output_hidden_states=True).hidden_states[-2] + return torch.mean(token_embeddings, dim=1).detach().cpu() \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/EmotionRoBERTaSentenceEmbeddingExtractor.py b/Preprocessing/sentence_embeddings/EmotionRoBERTaSentenceEmbeddingExtractor.py new file mode 100644 index 00000000..f4729a4e --- /dev/null +++ b/Preprocessing/sentence_embeddings/EmotionRoBERTaSentenceEmbeddingExtractor.py @@ -0,0 +1,28 @@ +import torch +from transformers import RobertaTokenizerFast, RobertaModel + +from Preprocessing.sentence_embeddings.SentenceEmbeddingExtractor import SentenceEmbeddingExtractor + +class EmotionRoBERTaSentenceEmbeddingExtractor(SentenceEmbeddingExtractor): + def __init__(self, cache_dir:str="", pooling:str='second_to_last_mean', device=torch.device("cuda")): + super().__init__() + if cache_dir: + self.cache_dir = cache_dir + self.tokenizer = RobertaTokenizerFast.from_pretrained('j-hartmann/emotion-english-distilroberta-base', cache_dir=self.cache_dir) + self.model = RobertaModel.from_pretrained("j-hartmann/emotion-english-distilroberta-base", cache_dir=self.cache_dir).to(device) + assert pooling in ["cls", "last_mean", "second_to_last_mean"] + self.pooling = pooling + self.device = device + + def encode(self, sentences: list[str]) -> torch.Tensor: + if self.pooling == "cls": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + return self.model(**encoded_input).last_hidden_state[:,0].detach().cpu() + if self.pooling == "last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(**encoded_input, output_hidden_states=True).last_hidden_state + return torch.mean(token_embeddings, dim=1).detach().cpu() + if self.pooling == "second_to_last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(**encoded_input, output_hidden_states=True).hidden_states[-2] + return torch.mean(token_embeddings, dim=1).detach().cpu() \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/LABSESentenceEmbeddingExtractor.py b/Preprocessing/sentence_embeddings/LABSESentenceEmbeddingExtractor.py new file mode 100644 index 00000000..6f604655 --- /dev/null +++ b/Preprocessing/sentence_embeddings/LABSESentenceEmbeddingExtractor.py @@ -0,0 +1,19 @@ +import os +import tensorflow_hub as hub +import tensorflow_text as text # required even if import is not used here +import tensorflow as tf +import torch + +from Preprocessing.sentence_embeddings.SentenceEmbeddingExtractor import SentenceEmbeddingExtractor + +class LABSESentenceEmbeddingExtractor(SentenceEmbeddingExtractor): + def __init__(self, cache_dir:str=""): + super().__init__() + if cache_dir: + self.cache_dir = cache_dir + os.environ['TFHUB_CACHE_DIR']=self.cache_dir + self.tokenizer = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder-cmlm/multilingual-preprocess/2") + self.model = hub.KerasLayer("https://tfhub.dev/google/LaBSE/2") + + def encode(self, sentences: list[str]) -> torch.Tensor: + return torch.as_tensor(self.model(self.tokenizer(tf.constant(sentences)))['default'].numpy()) \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/LASERSentenceEmbeddingExtractor.py b/Preprocessing/sentence_embeddings/LASERSentenceEmbeddingExtractor.py new file mode 100644 index 00000000..d1962818 --- /dev/null +++ b/Preprocessing/sentence_embeddings/LASERSentenceEmbeddingExtractor.py @@ -0,0 +1,12 @@ +import torch + +from Preprocessing.sentence_embeddings.SentenceEmbeddingExtractor import SentenceEmbeddingExtractor +from Preprocessing.sentence_embeddings.laserembeddings.laser import Laser + +class LASERSentenceEmbeddingExtractor(SentenceEmbeddingExtractor): + def __init__(self): + super().__init__() + self.model = Laser(mode='spm') + + def encode(self, sentences: list[str]) -> torch.Tensor: + return torch.as_tensor(self.model.embed_sentences(sentences)) \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/LEALLASentenceEmbeddingExtractor.py b/Preprocessing/sentence_embeddings/LEALLASentenceEmbeddingExtractor.py new file mode 100644 index 00000000..af6a96a8 --- /dev/null +++ b/Preprocessing/sentence_embeddings/LEALLASentenceEmbeddingExtractor.py @@ -0,0 +1,18 @@ +import os +import tensorflow_hub as hub +import tensorflow_text as text # required even if import is not used here +import tensorflow as tf +import torch + +from Preprocessing.sentence_embeddings.SentenceEmbeddingExtractor import SentenceEmbeddingExtractor + +class LEALLASentenceEmbeddingExtractor(SentenceEmbeddingExtractor): + def __init__(self, cache_dir:str=""): + super().__init__() + if cache_dir: + self.cache_dir = cache_dir + os.environ['TFHUB_CACHE_DIR']=self.cache_dir + self.model = hub.KerasLayer("https://tfhub.dev/google/LEALLA/LEALLA-base/1") + + def encode(self, sentences: list[str]) -> torch.Tensor: + return torch.as_tensor(self.model(tf.constant(sentences)).numpy()) \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/MBERTSentenceEmbeddingExtractor.py b/Preprocessing/sentence_embeddings/MBERTSentenceEmbeddingExtractor.py new file mode 100644 index 00000000..d2c2bbbc --- /dev/null +++ b/Preprocessing/sentence_embeddings/MBERTSentenceEmbeddingExtractor.py @@ -0,0 +1,28 @@ +import torch +from transformers import BertTokenizer, BertModel + +from Preprocessing.sentence_embeddings.SentenceEmbeddingExtractor import SentenceEmbeddingExtractor + +class MBERTSentenceEmbeddingExtractor(SentenceEmbeddingExtractor): + def __init__(self, cache_dir:str="", pooling:str='second_to_last_mean', device=torch.device("cuda")): + super().__init__() + if cache_dir: + self.cache_dir = cache_dir + self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', cache_dir=self.cache_dir) + self.model = BertModel.from_pretrained("bert-base-multilingual-cased", cache_dir=self.cache_dir).to(device) + assert pooling in ["cls", "last_mean", "second_to_last_mean"] + self.pooling = pooling + self.device = device + + def encode(self, sentences: list[str]) -> torch.Tensor: + if self.pooling == "cls": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + return self.model(**encoded_input).last_hidden_state[:,0].detach().cpu() + if self.pooling == "last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(**encoded_input, output_hidden_states=True).last_hidden_state + return torch.mean(token_embeddings, dim=1).detach().cpu() + if self.pooling == "second_to_last_mean": + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + token_embeddings = self.model(**encoded_input, output_hidden_states=True).hidden_states[-2] + return torch.mean(token_embeddings, dim=1).detach().cpu() \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/STSentenceEmbeddingExtractor.py b/Preprocessing/sentence_embeddings/STSentenceEmbeddingExtractor.py new file mode 100644 index 00000000..1ca2b67e --- /dev/null +++ b/Preprocessing/sentence_embeddings/STSentenceEmbeddingExtractor.py @@ -0,0 +1,29 @@ +import os + +import torch +from sentence_transformers import SentenceTransformer + +from Preprocessing.sentence_embeddings.SentenceEmbeddingExtractor import SentenceEmbeddingExtractor + +class STSentenceEmbeddingExtractor(SentenceEmbeddingExtractor): + def __init__(self, model:str='para', cache_dir:str=""): + super().__init__() + assert model in ['para', 'para_mini', 'distil', 'bloom', 'camembert', 'mpnet'] + os.environ["TOKENIZERS_PARALLELISM"] = 'False' + if cache_dir: + self.cache_dir = cache_dir + if model == 'para_mini': + self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', cache_folder=self.cache_dir) + if model == 'para': + self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2', cache_folder=self.cache_dir) + if model == 'distil': + self.model = SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v2', cache_folder=self.cache_dir) + if model == 'bloom': + self.model = SentenceTransformer('bigscience-data/sgpt-bloom-1b7-nli', cache_folder=self.cache_dir) + if model == 'camembert': + self.model = SentenceTransformer('dangvantuan/sentence-camembert-base', cache_folder=self.cache_dir) + if model == 'mpnet': + self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', cache_folder=self.cache_dir) + + def encode(self, sentences: list[str]) -> torch.Tensor: + return torch.as_tensor(self.model.encode(sentences)) \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/SentenceEmbeddingExtractor.py b/Preprocessing/sentence_embeddings/SentenceEmbeddingExtractor.py new file mode 100644 index 00000000..3537c5bb --- /dev/null +++ b/Preprocessing/sentence_embeddings/SentenceEmbeddingExtractor.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod +import torch +import os + +from Utility.storage_config import MODELS_DIR + +class SentenceEmbeddingExtractor(ABC): + + def __init__(self): + self.cache_dir = os.path.join(MODELS_DIR, 'Language_Models') + pass + + @abstractmethod + def encode(self, sentences:list[str]) -> torch.Tensor: + pass \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/laserembeddings/embedding.py b/Preprocessing/sentence_embeddings/laserembeddings/embedding.py new file mode 100644 index 00000000..5b3264f5 --- /dev/null +++ b/Preprocessing/sentence_embeddings/laserembeddings/embedding.py @@ -0,0 +1,91 @@ +from typing import Optional, List, Union +from io import BufferedIOBase + +import numpy as np + +from Preprocessing.sentence_embeddings.laserembeddings.encoder import SentenceEncoder + +__all__ = ['BPESentenceEmbedding'] + + +class BPESentenceEmbedding: + """ + LASER embeddings computation from BPE-encoded sentences. + + Args: + encoder (str or BufferedIOBase): the path to LASER's encoder PyTorch model, + or a binary-mode file object. + max_sentences (int, optional): see ``.encoder.SentenceEncoder``. + max_tokens (int, optional): see ``.encoder.SentenceEncoder``. + stable (bool, optional): if True, mergesort sorting algorithm will be used, + otherwise quicksort will be used. Defaults to False. See ``.encoder.SentenceEncoder``. + cpu (bool, optional): if True, forces the use of the CPU even a GPU is available. Defaults to False. + """ + + def __init__(self, + encoder: Union[str, BufferedIOBase], + max_sentences: Optional[int] = None, + max_tokens: Optional[int] = 12000, + stable: bool = False, + cpu: bool = False): + + self.encoder = SentenceEncoder( + encoder, + max_sentences=max_sentences, + max_tokens=max_tokens, + sort_kind='mergesort' if stable else 'quicksort', + cpu=cpu) + + def embed_bpe_sentences(self, bpe_sentences: List[str]) -> np.ndarray: + """ + Computes the LASER embeddings of BPE-encoded sentences + + Args: + bpe_sentences (List[str]): The list of BPE-encoded sentences + + Returns: + np.ndarray: A N * 1024 NumPy array containing the embeddings, N being the number of sentences provided. + """ + return self.encoder.encode_sentences(bpe_sentences) + +class SPMSentenceEmbedding: + """ + LASER embeddings computation from SPM-encoded sentences. + + Args: + encoder (str or BufferedIOBase): the path to LASER's encoder PyTorch model, + or a binary-mode file object. + max_sentences (int, optional): see ``.encoder.SentenceEncoder``. + max_tokens (int, optional): see ``.encoder.SentenceEncoder``. + stable (bool, optional): if True, mergesort sorting algorithm will be used, + otherwise quicksort will be used. Defaults to False. See ``.encoder.SentenceEncoder``. + cpu (bool, optional): if True, forces the use of the CPU even a GPU is available. Defaults to False. + """ + + def __init__(self, + encoder: Union[str, BufferedIOBase], + spm_vocab: Union[str, BufferedIOBase], + max_sentences: Optional[int] = None, + max_tokens: Optional[int] = 12000, + stable: bool = False, + cpu: bool = False): + + self.encoder = SentenceEncoder( + model_path=encoder, + spm_vocab=spm_vocab, + max_sentences=max_sentences, + max_tokens=max_tokens, + sort_kind='mergesort' if stable else 'quicksort', + cpu=cpu) + + def embed_spm_sentences(self, spm_sentences: List[str]) -> np.ndarray: + """ + Computes the LASER embeddings of SPM-encoded sentences + + Args: + spm_sentences (List[str]): The list of SPM-encoded sentences + + Returns: + np.ndarray: A N * 1024 NumPy array containing the embeddings, N being the number of sentences provided. + """ + return self.encoder.encode_sentences(spm_sentences) \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/laserembeddings/encoder.py b/Preprocessing/sentence_embeddings/laserembeddings/encoder.py new file mode 100644 index 00000000..6dbaa77b --- /dev/null +++ b/Preprocessing/sentence_embeddings/laserembeddings/encoder.py @@ -0,0 +1,367 @@ +# The code contained in this file was copied/pasted from LASER's source code (source/embed.py; from the updated version) +# and nearly kept untouched + +import re +import os +import sys +import time +import numpy as np +import logging +from collections import namedtuple +from subprocess import run +from pathlib import Path +from typing import Union + +import torch +import torch.nn as nn + +from fairseq.models.transformer import ( + Embedding, + TransformerEncoder, +) +from fairseq.data.dictionary import Dictionary +from fairseq.modules import LayerNorm + +SPACE_NORMALIZER = re.compile(r"\s+") +Batch = namedtuple("Batch", "srcs tokens lengths") + +logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") +logger = logging.getLogger('embed') + +def buffered_read(fp, buffer_size): + buffer = [] + for src_str in fp: + buffer.append(src_str.strip()) + if len(buffer) >= buffer_size: + yield buffer + buffer = [] + + if len(buffer) > 0: + yield buffer + + +class SentenceEncoder: + def __init__( + self, + model_path, + max_sentences=None, + max_tokens=None, + spm_vocab=None, + cpu=False, + fp16=False, + verbose=False, + sort_kind="quicksort", + ): + if verbose: + logger.info(f"loading encoder: {model_path}") + self.use_cuda = torch.cuda.is_available() and not cpu + self.max_sentences = max_sentences + self.max_tokens = max_tokens + if self.max_tokens is None and self.max_sentences is None: + self.max_sentences = 1 + + state_dict = torch.load(model_path) + if "params" in state_dict: + self.encoder = LaserLstmEncoder(**state_dict["params"]) + self.encoder.load_state_dict(state_dict["model"]) + self.dictionary = state_dict["dictionary"] + self.prepend_bos = False + self.left_padding = False + else: + self.encoder = LaserTransformerEncoder(state_dict, spm_vocab) + self.dictionary = self.encoder.dictionary.indices + self.prepend_bos = state_dict["cfg"]["model"].prepend_bos + self.left_padding = state_dict["cfg"]["model"].left_pad_source + del state_dict + self.bos_index = self.dictionary[""] = 0 + self.pad_index = self.dictionary[""] = 1 + self.eos_index = self.dictionary[""] = 2 + self.unk_index = self.dictionary[""] = 3 + + if fp16: + self.encoder.half() + if self.use_cuda: + if verbose: + logger.info("transfer encoder to GPU") + self.encoder.cuda() + self.encoder.eval() + self.sort_kind = sort_kind + + def _process_batch(self, batch): + tokens = batch.tokens + lengths = batch.lengths + if self.use_cuda: + tokens = tokens.cuda() + lengths = lengths.cuda() + + with torch.no_grad(): + sentemb = self.encoder(tokens, lengths)["sentemb"] + embeddings = sentemb.detach().cpu().numpy() + return embeddings + + def _tokenize(self, line): + tokens = SPACE_NORMALIZER.sub(" ", line).strip().split() + ntokens = len(tokens) + if self.prepend_bos: + ids = torch.LongTensor(ntokens + 2) + ids[0] = self.bos_index + for i, token in enumerate(tokens): + ids[i + 1] = self.dictionary.get(token, self.unk_index) + ids[ntokens + 1] = self.eos_index + else: + ids = torch.LongTensor(ntokens + 1) + for i, token in enumerate(tokens): + ids[i] = self.dictionary.get(token, self.unk_index) + ids[ntokens] = self.eos_index + return ids + + def _make_batches(self, lines): + tokens = [self._tokenize(line) for line in lines] + lengths = np.array([t.numel() for t in tokens]) + indices = np.argsort(-lengths, kind=self.sort_kind) + + def batch(tokens, lengths, indices): + toks = tokens[0].new_full((len(tokens), tokens[0].shape[0]), self.pad_index) + if not self.left_padding: + for i in range(len(tokens)): + toks[i, : tokens[i].shape[0]] = tokens[i] + else: + for i in range(len(tokens)): + toks[i, -tokens[i].shape[0] :] = tokens[i] + return ( + Batch(srcs=None, tokens=toks, lengths=torch.LongTensor(lengths)), + indices, + ) + + batch_tokens, batch_lengths, batch_indices = [], [], [] + ntokens = nsentences = 0 + for i in indices: + if nsentences > 0 and ( + (self.max_tokens is not None and ntokens + lengths[i] > self.max_tokens) + or (self.max_sentences is not None and nsentences == self.max_sentences) + ): + yield batch(batch_tokens, batch_lengths, batch_indices) + ntokens = nsentences = 0 + batch_tokens, batch_lengths, batch_indices = [], [], [] + batch_tokens.append(tokens[i]) + batch_lengths.append(lengths[i]) + batch_indices.append(i) + ntokens += tokens[i].shape[0] + nsentences += 1 + if nsentences > 0: + yield batch(batch_tokens, batch_lengths, batch_indices) + + def encode_sentences(self, sentences): + indices = [] + results = [] + for batch, batch_indices in self._make_batches(sentences): + indices.extend(batch_indices) + results.append(self._process_batch(batch)) + return np.vstack(results)[np.argsort(indices, kind=self.sort_kind)] + + +class HuggingFaceEncoder(): + def __init__(self, encoder_name: str, verbose=False): + from sentence_transformers import SentenceTransformer + encoder = f"sentence-transformers/{encoder_name}" + if verbose: + logger.info(f"loading HuggingFace encoder: {encoder}") + self.encoder = SentenceTransformer(encoder) + + def encode_sentences(self, sentences): + return self.encoder.encode(sentences) + + +class LaserTransformerEncoder(TransformerEncoder): + def __init__(self, state_dict, vocab_path): + self.dictionary = Dictionary.load(vocab_path) + if any( + k in state_dict["model"] + for k in ["encoder.layer_norm.weight", "layer_norm.weight"] + ): + self.dictionary.add_symbol("") + cfg = state_dict["cfg"]["model"] + self.sentemb_criterion = cfg.sentemb_criterion + self.pad_idx = self.dictionary.pad_index + self.bos_idx = self.dictionary.bos_index + embed_tokens = Embedding( + len(self.dictionary), cfg.encoder_embed_dim, self.pad_idx, + ) + super().__init__(cfg, self.dictionary, embed_tokens) + if "decoder.version" in state_dict["model"]: + self._remove_decoder_layers(state_dict) + if "layer_norm.weight" in state_dict["model"]: + self.layer_norm = LayerNorm(cfg.encoder_embed_dim) + self.load_state_dict(state_dict["model"]) + + def _remove_decoder_layers(self, state_dict): + for key in list(state_dict["model"].keys()): + if not key.startswith( + ( + "encoder.layer_norm", + "encoder.layers", + "encoder.embed", + "encoder.version", + ) + ): + del state_dict["model"][key] + else: + renamed_key = key.replace("encoder.", "") + state_dict["model"][renamed_key] = state_dict["model"].pop(key) + + def forward(self, src_tokens, src_lengths): + encoder_out = super().forward(src_tokens, src_lengths) + if isinstance(encoder_out, dict): + x = encoder_out["encoder_out"][0] # T x B x C + else: + x = encoder_out[0] + if self.sentemb_criterion == "cls": + cls_indices = src_tokens.eq(self.bos_idx).t() + sentemb = x[cls_indices, :] + else: + padding_mask = src_tokens.eq(self.pad_idx).t().unsqueeze(-1) + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) + sentemb = x.max(dim=0)[0] + return {"sentemb": sentemb} + + +class LaserLstmEncoder(nn.Module): + def __init__( + self, + num_embeddings, + padding_idx, + embed_dim=320, + hidden_size=512, + num_layers=1, + bidirectional=False, + left_pad=True, + padding_value=0.0, + ): + super().__init__() + + self.num_layers = num_layers + self.bidirectional = bidirectional + self.hidden_size = hidden_size + + self.padding_idx = padding_idx + self.embed_tokens = nn.Embedding( + num_embeddings, embed_dim, padding_idx=self.padding_idx + ) + + self.lstm = nn.LSTM( + input_size=embed_dim, + hidden_size=hidden_size, + num_layers=num_layers, + bidirectional=bidirectional, + ) + self.left_pad = left_pad + self.padding_value = padding_value + + self.output_units = hidden_size + if bidirectional: + self.output_units *= 2 + + def forward(self, src_tokens, src_lengths): + bsz, seqlen = src_tokens.size() + + # embed tokens + x = self.embed_tokens(src_tokens) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # pack embedded source tokens into a PackedSequence + packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) + + # apply LSTM + if self.bidirectional: + state_size = 2 * self.num_layers, bsz, self.hidden_size + else: + state_size = self.num_layers, bsz, self.hidden_size + h0 = x.data.new(*state_size).zero_() + c0 = x.data.new(*state_size).zero_() + packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) + + # unpack outputs and apply dropout + x, _ = nn.utils.rnn.pad_packed_sequence( + packed_outs, padding_value=self.padding_value + ) + assert list(x.size()) == [seqlen, bsz, self.output_units] + + if self.bidirectional: + + def combine_bidir(outs): + return torch.cat( + [ + torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view( + 1, bsz, self.output_units + ) + for i in range(self.num_layers) + ], + dim=0, + ) + + final_hiddens = combine_bidir(final_hiddens) + final_cells = combine_bidir(final_cells) + + encoder_padding_mask = src_tokens.eq(self.padding_idx).t() + + # Set padded outputs to -inf so they are not selected by max-pooling + padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) + + # Build the sentence embedding by max-pooling over the encoder outputs + sentemb = x.max(dim=0)[0] + + return { + "sentemb": sentemb, + "encoder_out": (x, final_hiddens, final_cells), + "encoder_padding_mask": encoder_padding_mask + if encoder_padding_mask.any() + else None, + } + + +def EncodeLoad(args): + args.buffer_size = max(args.buffer_size, 1) + assert ( + not args.max_sentences or args.max_sentences <= args.buffer_size + ), "--max-sentences/--batch-size cannot be larger than --buffer-size" + + print(" - loading encoder", args.encoder) + return SentenceEncoder( + args.encoder, + max_sentences=args.max_sentences, + max_tokens=args.max_tokens, + cpu=args.cpu, + verbose=args.verbose, + ) + + +def EncodeTime(t): + t = int(time.time() - t) + if t < 1000: + return "{:d}s".format(t) + else: + return "{:d}m{:d}s".format(t // 60, t % 60) + +# Load existing embeddings +def EmbedLoad(fname, dim=1024, verbose=False): + x = np.fromfile(fname, dtype=np.float32, count=-1) + x.resize(x.shape[0] // dim, dim) + if verbose: + print(" - Embeddings: {:s}, {:d}x{:d}".format(fname, x.shape[0], dim)) + return x + +# Get memory mapped embeddings +def EmbedMmap(fname, dim=1024, dtype=np.float32, verbose=False): + nbex = int(os.path.getsize(fname) / dim / np.dtype(dtype).itemsize) + E = np.memmap(fname, mode="r", dtype=dtype, shape=(nbex, dim)) + if verbose: + print(" - embeddings on disk: {:s} {:d} x {:d}".format(fname, nbex, dim)) + return E \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/laserembeddings/laser.py b/Preprocessing/sentence_embeddings/laserembeddings/laser.py new file mode 100644 index 00000000..b5aebf30 --- /dev/null +++ b/Preprocessing/sentence_embeddings/laserembeddings/laser.py @@ -0,0 +1,162 @@ +from typing import Dict, Any, Union, List, Optional +from io import TextIOBase, BufferedIOBase +import os + +import numpy as np + +from Preprocessing.sentence_embeddings.laserembeddings.preprocessing import Tokenizer, BPE, SPM +from Preprocessing.sentence_embeddings.laserembeddings.embedding import BPESentenceEmbedding, SPMSentenceEmbedding +from Preprocessing.sentence_embeddings.laserembeddings.utils import sre_performance_patch, download_models +from Utility.storage_config import MODELS_DIR + +__all__ = ['Laser'] + + +class Laser: + """ + End-to-end LASER embedding. + + The pipeline is: ``Tokenizer.tokenize`` -> ``BPE.encode_tokens`` -> ``BPESentenceEmbedding.embed_bpe_sentences`` + Using spm model: ``Tokenizer.tokenize`` -> ``SPM.encode_sentence`` -> ``SPMSentenceEmbedding.embed_spm_sentences`` + + Args: + mode (str): spm or bpe + bpe_codes (str or TextIOBase, optional): the path to LASER's BPE codes (``93langs.fcodes``), + or a text-mode file object. If omitted, ``Laser.DEFAULT_BPE_CODES_FILE`` is used. + bpe_codes (str or TextIOBase, optional): the path to LASER's BPE vocabulary (``93langs.fvocab``), + or a text-mode file object. If omitted, ``Laser.DEFAULT_BPE_VOCAB_FILE`` is used. + encoder (str or BufferedIOBase, optional): the path to LASER's encoder PyToch model (``bilstm.93langs.2018-12-26.pt``), + or a binary-mode file object. If omitted, ``Laser.DEFAULT_ENCODER_FILE`` is used. + spm_model (str or BufferedIOBase, optional): the path to LASER's SPM model + spm_vocab (str or BufferedIOBase, optional): the path to LASER's SPM vocab + tokenizer_options (Dict[str, Any], optional): additional arguments to pass to the tokenizer. + See ``.preprocessing.Tokenizer``. + embedding_options (Dict[str, Any], optional): additional arguments to pass to the embedding layer. + See ``.embedding.BPESentenceEmbedding``. + + Class attributes: + DATA_DIR (str): the path to the directory of default LASER files. + DEFAULT_BPE_CODES_FILE: the path to default BPE codes file. + DEFAULT_BPE_VOCAB_FILE: the path to default BPE vocabulary file. + DEFAULT_ENCODER_FILE: the path to default LASER encoder PyTorch model file. + """ + + DATA_DIR = os.path.join(MODELS_DIR, 'Language_Models') + DEFAULT_BPE_CODES_FILE = os.path.join(DATA_DIR, '93langs.fcodes') + DEFAULT_BPE_VOCAB_FILE = os.path.join(DATA_DIR, '93langs.fvocab') + DEFAULT_ENCODER_LASER_FILE = os.path.join(DATA_DIR, + 'bilstm.93langs.2018-12-26.pt') + DEFAULT_ENCODER_LASER2_FILE = os.path.join(DATA_DIR, 'laser2.pt') + DEFAULT_SPM_MODEL_FILE = os.path.join(DATA_DIR, 'laser2.spm') + DEFAULT_SPM_VOCAB_FILE = os.path.join(DATA_DIR, 'laser2.cvocab') + + def __init__(self, + mode: str = 'spm', + bpe_codes: Optional[Union[str, TextIOBase]] = None, + bpe_vocab: Optional[Union[str, TextIOBase]] = None, + encoder: Optional[Union[str, BufferedIOBase]] = None, + spm_model: Optional[Union[str, BufferedIOBase]] = None, + spm_vocab: Optional[Union[str, BufferedIOBase]] = None, + tokenizer_options: Optional[Dict[str, Any]] = None, + embedding_options: Optional[Dict[str, Any]] = None): + + if tokenizer_options is None: + tokenizer_options = {} + if embedding_options is None: + embedding_options = {} + + self.bpe = None + self.spm = None + + if mode == 'bpe': + if bpe_codes is None: + if not os.path.isfile(self.DEFAULT_BPE_CODES_FILE): + download_models(self.DATA_DIR, version=1) + bpe_codes = self.DEFAULT_BPE_CODES_FILE + if bpe_vocab is None: + if not os.path.isfile(self.DEFAULT_BPE_VOCAB_FILE): + download_models(self.DATA_DIR, version=1) + bpe_vocab = self.DEFAULT_BPE_VOCAB_FILE + if encoder is None: + if not os.path.isfile(self.DEFAULT_ENCODER_LASER_FILE): + download_models(self.DATA_DIR, version=1) + encoder = self.DEFAULT_ENCODER_LASER_FILE + + print("Mode BPE") + print("Using encoder: {}".format(encoder)) + + self.tokenizer_options = tokenizer_options + self.tokenizers: Dict[str, Tokenizer] = {} + + self.bpe = BPE(bpe_codes, bpe_vocab) + self.bpeSentenceEmbedding = BPESentenceEmbedding( + encoder, **embedding_options) + + if mode == 'spm': + if spm_model is None: + if not os.path.isfile(self.DEFAULT_SPM_MODEL_FILE): + download_models(self.DATA_DIR, version=2) + spm_model = self.DEFAULT_SPM_MODEL_FILE + if spm_vocab is None: + if not os.path.isfile(self.DEFAULT_SPM_VOCAB_FILE): + download_models(self.DATA_DIR, version=2) + spm_vocab = self.DEFAULT_SPM_VOCAB_FILE + if encoder is None: + if not os.path.isfile(self.DEFAULT_ENCODER_LASER2_FILE): + download_models(self.DATA_DIR, version=2) + encoder = self.DEFAULT_ENCODER_LASER2_FILE + + print("Mode SPM") + print("Using encoder: {}".format(encoder)) + + self.tokenizer_options = tokenizer_options + self.tokenizers: Dict[str, Tokenizer] = {} + + self.spm = SPM(spm_model) + self.spmSentenceEmbedding = SPMSentenceEmbedding( + encoder, spm_vocab=spm_vocab, **embedding_options) + + def _get_tokenizer(self, lang: str) -> Tokenizer: + """Returns the Tokenizer instance for the specified language. The returned tokenizers are cached.""" + + if lang not in self.tokenizers: + self.tokenizers[lang] = Tokenizer(lang, **self.tokenizer_options) + + return self.tokenizers[lang] + + def embed_sentences(self, sentences: Union[List[str], str], + lang: Union[str, List[str]]="en") -> np.ndarray: + """ + Computes the LASER embeddings of provided sentences using the tokenizer for the specified language. + + Args: + sentences (str or List[str]): the sentences to compute the embeddings from. + lang (str or List[str]): the language code(s) (ISO 639-1) used to tokenize the sentences + (either as a string - same code for every sentence - or as a list of strings - one code per sentence). + + Returns: + np.ndarray: A N * 1024 NumPy array containing the embeddings, N being the number of sentences provided. + """ + sentences = [sentences] if isinstance(sentences, str) else sentences + lang = [lang] * len(sentences) if isinstance(lang, str) else lang + + if len(sentences) != len(lang): + raise ValueError( + 'lang: invalid length: the number of language codes does not match the number of sentences' + ) + + with sre_performance_patch(): # see https://bugs.python.org/issue37723 + if self.bpe: + sentence_tokens = [ + self._get_tokenizer(sentence_lang).tokenize(sentence) + for sentence, sentence_lang in zip(sentences, lang) + ] + bpe_encoded = [ + self.bpe.encode_tokens(tokens) for tokens in sentence_tokens + ] + return self.bpeSentenceEmbedding.embed_bpe_sentences(bpe_encoded) + if self.spm: + spm_encoded = [ + self.spm.encode_sentence(sentence) for sentence in sentences + ] + return self.spmSentenceEmbedding.embed_spm_sentences(spm_encoded) \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/laserembeddings/preprocessing.py b/Preprocessing/sentence_embeddings/laserembeddings/preprocessing.py new file mode 100644 index 00000000..59897c99 --- /dev/null +++ b/Preprocessing/sentence_embeddings/laserembeddings/preprocessing.py @@ -0,0 +1,186 @@ +from typing import Union, Optional +from io import TextIOBase + +from sacremoses import MosesPunctNormalizer, MosesTokenizer +from sacremoses.util import xml_unescape +from subword_nmt.apply_bpe import BPE as subword_nmt_bpe, read_vocabulary +from transliterate import translit +import sentencepiece + +from Preprocessing.sentence_embeddings.laserembeddings.utils import adapt_bpe_codes + +# Extras +try: + import jieba + jieba.setLogLevel(60) +except ImportError: + jieba = None + +try: + import MeCab + import ipadic +except ImportError: + MeCab = None + +__all__ = ['Tokenizer', 'BPE'] + +############################################################################### +# +# Tokenizer +# +############################################################################### + + +class Tokenizer: + """ + Tokenizer. + + Args: + lang (str): the language code (ISO 639-1) of the texts to tokenize + lower_case (bool, optional): if True, the texts are lower-cased before being tokenized. + Defaults to True. + romanize (bool or None, optional): if True, the texts are romanized. + Defaults to None (romanization enabled based on input language). + descape (bool, optional): if True, the XML-escaped symbols get de-escaped. + Default to False. + """ + + def __init__(self, + lang: str = 'en', + lower_case: bool = True, + romanize: Optional[bool] = None, + descape: bool = False): + assert lower_case, 'lower case is needed by all the models' + + if lang in ('cmn', 'wuu', 'yue'): + lang = 'zh' + if lang == 'jpn': + lang = 'ja' + + if lang == 'zh' and jieba is None: + raise ModuleNotFoundError( + '''No module named 'jieba'. Install laserembeddings with 'zh' extra to fix that: "pip install laserembeddings[zh]"''' + ) + if lang == 'ja' and MeCab is None: + raise ModuleNotFoundError( + '''No module named 'MeCab'. Install laserembeddings with 'ja' extra to fix that: "pip install laserembeddings[ja]"''' + ) + + self.lang = lang + self.lower_case = lower_case + self.romanize = romanize if romanize is not None else lang == 'el' + self.descape = descape + + self.normalizer = MosesPunctNormalizer(lang=lang) + self.tokenizer = MosesTokenizer(lang=lang) + self.mecab_tokenizer = MeCab.Tagger( + f"{ipadic.MECAB_ARGS} -Owakati -b 50000") if lang == 'ja' else None + + def tokenize(self, text: str) -> str: + """Tokenizes a text and returns the tokens as a string""" + + # REM_NON_PRINT_CHAR + # not implemented + + # NORM_PUNC + text = self.normalizer.normalize(text) + + # DESCAPE + if self.descape: + text = xml_unescape(text) + + # MOSES_TOKENIZER + # see: https://github.com/facebookresearch/LASER/issues/55#issuecomment-480881573 + text = self.tokenizer.tokenize(text, + return_str=True, + escape=False, + aggressive_dash_splits=False) + + # jieba + if self.lang == 'zh': + text = ' '.join(jieba.cut(text.rstrip('\r\n'))) + + # MECAB + if self.lang == 'ja': + text = self.mecab_tokenizer.parse(text).rstrip('\r\n') + + # ROMAN_LC + if self.romanize: + text = translit(text, self.lang, reversed=True) + + if self.lower_case: + text = text.lower() + + return text + + +############################################################################### +# +# Apply BPE +# +############################################################################### + + +class BPE: + """ + BPE encoder. + + Args: + bpe_codes (str or TextIOBase): the path to LASER's BPE codes (``93langs.fcodes``), + or a text-mode file object. + bpe_codes (str or TextIOBase): the path to LASER's BPE vocabulary (``93langs.fvocab``), + or a text-mode file object. + """ + + def __init__(self, bpe_codes: Union[str, TextIOBase], + bpe_vocab: Union[str, TextIOBase]): + + f_bpe_codes = None + f_bpe_vocab = None + + try: + if isinstance(bpe_codes, str): + f_bpe_codes = open(bpe_codes, 'r', encoding='utf-8') # pylint: disable=consider-using-with + if isinstance(bpe_vocab, str): + f_bpe_vocab = open(bpe_vocab, 'r', encoding='utf-8') # pylint: disable=consider-using-with + + self.bpe = subword_nmt_bpe(codes=adapt_bpe_codes(f_bpe_codes + or bpe_codes), + vocab=read_vocabulary(f_bpe_vocab + or bpe_vocab, + threshold=None)) + self.bpe.version = (0, 2) + + finally: + if f_bpe_codes: + f_bpe_codes.close() + if f_bpe_vocab: + f_bpe_vocab.close() + + def encode_tokens(self, sentence_tokens: str) -> str: + """Returns the BPE-encoded sentence from a tokenized sentence""" + return self.bpe.process_line(sentence_tokens) + +############################################################################### +# +# Apply SPM +# +############################################################################### + +class SPM: + def __init__(self, spm_model: Union[str, TextIOBase]): + self.spm = None + try: + if isinstance(spm_model, str): + self.spm = sentencepiece.SentencePieceProcessor(model_file=spm_model) + except FileNotFoundError: + pass + + def encode_sentence(self, sentence: str) -> str: + # NORM_PUNC + LC + normalizer = MosesPunctNormalizer(lang="en") + sentence = normalizer.normalize(sentence) + sentence = sentence.lower() + + pieces = self.spm.encode_as_pieces(sentence) + return ' '.join(pieces) \ No newline at end of file diff --git a/Preprocessing/sentence_embeddings/laserembeddings/utils.py b/Preprocessing/sentence_embeddings/laserembeddings/utils.py new file mode 100644 index 00000000..515e6d37 --- /dev/null +++ b/Preprocessing/sentence_embeddings/laserembeddings/utils.py @@ -0,0 +1,102 @@ +from io import TextIOBase, StringIO +import re +import sys +import os +import urllib.request + +def adapt_bpe_codes(bpe_codes_f: TextIOBase) -> TextIOBase: + """ + Converts fastBPE codes to subword_nmt BPE codes. + + Args: + bpe_codes_f (TextIOBase): the text-mode file-like object of fastBPE codes + Returns: + TextIOBase: subword_nmt-compatible BPE codes as a text-mode file-like object + """ + return StringIO( + re.sub(r'^([^ ]+) ([^ ]+) ([^ ]+)$', + r'\1 \2', + bpe_codes_f.read(), + flags=re.MULTILINE)) + + +class sre_performance_patch: + """ + Patch fixing https://bugs.python.org/issue37723 for Python 3.7 (<= 3.7.4) + and Python 3.8 (<= 3.8.0 beta 3) + """ + + def __init__(self): + self.sre_parse = None + self.original_sre_parse_uniq = None + + def __enter__(self): + #pylint: disable=import-outside-toplevel + import sys + + if self.original_sre_parse_uniq is None and ( + 0x03070000 <= sys.hexversion <= 0x030704f0 + or 0x03080000 <= sys.hexversion <= 0x030800b3): + try: + import sre_parse + self.sre_parse = sre_parse + #pylint: disable=protected-access + self.original_sre_parse_uniq = sre_parse._uniq + sre_parse._uniq = lambda x: list(dict.fromkeys(x)) + except (ImportError, AttributeError): + self.sre_parse = None + self.original_sre_parse_uniq = None + + def __exit__(self, type_, value, traceback): + if self.sre_parse and self.original_sre_parse_uniq: + #pylint: disable=protected-access + self.sre_parse._uniq = self.original_sre_parse_uniq + self.original_sre_parse_uniq = None + +# model downloads +IS_WIN = os.name == 'nt' + +def non_win_string(s): + return s if not IS_WIN else '' + +CONSOLE_CLEAR = non_win_string('\033[0;0m') +CONSOLE_BOLD = non_win_string('\033[0;1m') +CONSOLE_WAIT = non_win_string('⏳') +CONSOLE_DONE = non_win_string('✅') +CONSOLE_STARS = non_win_string('✨') +CONSOLE_ERROR = non_win_string('❌') + + +def download_file(url, dest): + print(f'{CONSOLE_WAIT} Downloading {url}...', end='') + sys.stdout.flush() + urllib.request.urlretrieve(url, dest) + print(f'\r{CONSOLE_DONE} Downloaded {url} ') + + +def download_models(output_dir, version=2): + assert version in [1, 2] + print(f'Downloading models into {output_dir}') + print('') + + if version == 1: + download_file('https://dl.fbaipublicfiles.com/laser/models/93langs.fcodes', + os.path.join(output_dir, '93langs.fcodes')) + download_file('https://dl.fbaipublicfiles.com/laser/models/93langs.fvocab', + os.path.join(output_dir, '93langs.fvocab')) + download_file( + 'https://dl.fbaipublicfiles.com/laser/models/bilstm.93langs.2018-12-26.pt', + os.path.join(output_dir, 'bilstm.93langs.2018-12-26.pt')) + if version == 2: + download_file( + 'https://dl.fbaipublicfiles.com/nllb/laser/laser2.pt', + os.path.join(output_dir, 'laser2.pt')) + download_file( + 'https://dl.fbaipublicfiles.com/nllb/laser/laser2.spm', + os.path.join(output_dir, 'laser2.spm')) + download_file( + 'https://dl.fbaipublicfiles.com/nllb/laser/laser2.cvocab', + os.path.join(output_dir, 'laser2.cvocab')) + + print('') + print(f'{CONSOLE_STARS} You\'re all set!') \ No newline at end of file diff --git a/Preprocessing/word_embeddings/BERTWordEmbeddingExtractor.py b/Preprocessing/word_embeddings/BERTWordEmbeddingExtractor.py new file mode 100644 index 00000000..dd8ce562 --- /dev/null +++ b/Preprocessing/word_embeddings/BERTWordEmbeddingExtractor.py @@ -0,0 +1,120 @@ +import os +import re + +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence +from transformers import BertTokenizerFast, BertModel + +from Preprocessing.word_embeddings.WordEmbeddingExtractor import WordEmbeddingExtractor +from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend +from Preprocessing.TextFrontend import english_text_expansion +from Utility.storage_config import MODELS_DIR + + +class BERTWordEmbeddingExtractor(WordEmbeddingExtractor): + def __init__(self, cache_dir:str ="", device=torch.device("cuda")): + super().__init__() + if cache_dir: + self.cache_dir = cache_dir + self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', cache_dir=self.cache_dir) + self.model = BertModel.from_pretrained("bert-base-uncased", cache_dir=self.cache_dir).to(device) + self.model.eval() + self.device = device + self.tf = ArticulatoryCombinedTextFrontend(language="en") + self.merge_tokens = set() + self.expand_tokens = set() + + def encode(self, sentences: list[str]) -> np.ndarray: + if type(sentences) == str: + sentences = [sentences] + # apply spacing + sentences = [english_text_expansion(sent) for sent in sentences] + # replace words + for sent in sentences: + phone_string = self.tf.get_phone_string(sent) + if len(phone_string.split()) != len(sent.split()): + #print("Warning: length mismatch in following sentence") + #print(sent) + #print(phone_string) + #print(len(phone_string.split())) + self.merge_tokens.update(self.get_merge_tokens(sent)) + self.expand_tokens.update(self.get_expand_tokens(sent)) + # tokenize and encode sentences + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + with torch.no_grad(): + # get all hidden states + hidden_states = self.model(**encoded_input, output_hidden_states=True).hidden_states + # stack and sum last 4 layers for each token + token_embeddings = torch.stack([hidden_states[-4], hidden_states[-3], hidden_states[-2], hidden_states[-1]]).sum(0).squeeze() + if len(sentences) == 1: + token_embeddings = token_embeddings.unsqueeze(0) + word_embeddings_list = [] + lens = [] + for batch_id in range(len(sentences)): + # get word ids corresponding to token embeddings + word_ids = encoded_input.word_ids(batch_id) + word_ids_set = set([word_id for word_id in word_ids if word_id is not None]) + # get ids of hidden states of sub tokens for each word + token_ids_words = [[t_id for t_id, word_id in enumerate(word_ids) if word_id == w_id] for w_id in word_ids_set] + # combine hidden states of sub tokens for each word + word_embeddings = torch.stack([token_embeddings[batch_id, token_ids_word].mean(dim=0) for token_ids_word in token_ids_words]) + # combine word embeddings tokens merged by the phonemizer + tokens = re.findall(r"[\w']+|[.,!?;]", sentences[batch_id]) + merged = False + for i in range(len(tokens)): + if merged: + merged = False + continue + t1 = tokens[i] + try: + t2 = tokens[i + 1] + except IndexError: + t2 = "###" + if (t1, t2) in self.merge_tokens: + if i == 0: + merged_embeddings = torch.stack([word_embeddings[i], word_embeddings[i + 1]]).mean(dim=0).unsqueeze(0) + else: + merged_embedding = torch.stack([word_embeddings[i], word_embeddings[i + 1]]).mean(dim=0).unsqueeze(0) + merged_embeddings = torch.cat([merged_embeddings, merged_embedding]) + merged = True + elif t1 in self.expand_tokens: + if i == 0: + merged_embeddings = torch.cat([word_embeddings[i].unsqueeze(0), word_embeddings[i].unsqueeze(0)]) + else: + merged_embeddings = torch.cat([merged_embeddings, word_embeddings[i].unsqueeze(0), word_embeddings[i].unsqueeze(0)]) + else: + if i == 0: + merged_embeddings = word_embeddings[i].unsqueeze(0) + else: + merged_embeddings = torch.cat([merged_embeddings, word_embeddings[i].unsqueeze(0)]) + word_embeddings = merged_embeddings + #print(self.tokenizer.tokenize(sentences[batch_id])) + word_embeddings_list.append(word_embeddings) + # save sentence lengths + lens.append(word_embeddings.shape[0]) + # pad tensors to max sentence length of batch + word_embeddings_batch = pad_sequence(word_embeddings_list, batch_first=True).detach() + # return word embeddings for each word in each sentence along with sentence lengths + return word_embeddings_batch, lens + + def get_merge_tokens(self, sentence:str): + w_list = sentence.split() + merge_tokens = [] + for (w1, w2) in zip(w_list, w_list[1:]): + phonemized = self.tf.get_phone_string(' '.join([w1, w2])) + if len(phonemized.split()) < 2: + merge_tokens.append((w1, w2)) + return merge_tokens + + def get_expand_tokens(self, sentence:str): + w_list = sentence.split() + expand_tokens = [] + for w in w_list: + phonemized = self.tf.get_phone_string(w) + if len(phonemized.split()) == 2: + expand_tokens.append(w) + if len(phonemized.split()) > 2: + print(w) + print(phonemized) + return expand_tokens diff --git a/Preprocessing/word_embeddings/EmotionRoBERTaWordEmbeddingExtractor.py b/Preprocessing/word_embeddings/EmotionRoBERTaWordEmbeddingExtractor.py new file mode 100644 index 00000000..f940d681 --- /dev/null +++ b/Preprocessing/word_embeddings/EmotionRoBERTaWordEmbeddingExtractor.py @@ -0,0 +1,120 @@ +import os +import re + +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence +from transformers import RobertaTokenizerFast, RobertaModel + +from Preprocessing.word_embeddings.WordEmbeddingExtractor import WordEmbeddingExtractor +from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend +from Preprocessing.TextFrontend import english_text_expansion +from Utility.storage_config import MODELS_DIR + + +class EmotionRoBERTaWordEmbeddingExtractor(WordEmbeddingExtractor): + def __init__(self, cache_dir:str ="", device=torch.device("cuda")): + super().__init__() + if cache_dir: + self.cache_dir = cache_dir + self.tokenizer = RobertaTokenizerFast.from_pretrained('j-hartmann/emotion-english-distilroberta-base', cache_dir=self.cache_dir) + self.model = RobertaModel.from_pretrained("j-hartmann/emotion-english-distilroberta-base", cache_dir=self.cache_dir).to(device) + self.model.eval() + self.device = device + self.tf = ArticulatoryCombinedTextFrontend(language="en") + self.merge_tokens = set() + self.expand_tokens = set() + + def encode(self, sentences: list[str]) -> np.ndarray: + if type(sentences) == str: + sentences = [sentences] + # apply spacing + sentences = [english_text_expansion(sent) for sent in sentences] + # replace words + for sent in sentences: + phone_string = self.tf.get_phone_string(sent) + if len(phone_string.split()) != len(sent.split()): + #print("Warning: length mismatch in following sentence") + #print(sent) + #print(phone_string) + #print(len(phone_string.split())) + self.merge_tokens.update(self.get_merge_tokens(sent)) + self.expand_tokens.update(self.get_expand_tokens(sent)) + # tokenize and encode sentences + encoded_input = self.tokenizer(sentences, padding=True, return_tensors='pt').to(self.device) + with torch.no_grad(): + # get all hidden states + hidden_states = self.model(**encoded_input, output_hidden_states=True).hidden_states + # stack and sum last 4 layers for each token + token_embeddings = torch.stack([hidden_states[-4], hidden_states[-3], hidden_states[-2], hidden_states[-1]]).sum(0).squeeze() + if len(sentences) == 1: + token_embeddings = token_embeddings.unsqueeze(0) + word_embeddings_list = [] + lens = [] + for batch_id in range(len(sentences)): + # get word ids corresponding to token embeddings + word_ids = encoded_input.word_ids(batch_id) + word_ids_set = set([word_id for word_id in word_ids if word_id is not None]) + # get ids of hidden states of sub tokens for each word + token_ids_words = [[t_id for t_id, word_id in enumerate(word_ids) if word_id == w_id] for w_id in word_ids_set] + # combine hidden states of sub tokens for each word + word_embeddings = torch.stack([token_embeddings[batch_id, token_ids_word].mean(dim=0) for token_ids_word in token_ids_words]) + # combine word embeddings tokens merged by the phonemizer + tokens = re.findall(r"[\w']+|[.,!?;]", sentences[batch_id]) + merged = False + for i in range(len(tokens)): + if merged: + merged = False + continue + t1 = tokens[i] + try: + t2 = tokens[i + 1] + except IndexError: + t2 = "###" + if (t1, t2) in self.merge_tokens: + if i == 0: + merged_embeddings = torch.stack([word_embeddings[i], word_embeddings[i + 1]]).mean(dim=0).unsqueeze(0) + else: + merged_embedding = torch.stack([word_embeddings[i], word_embeddings[i + 1]]).mean(dim=0).unsqueeze(0) + merged_embeddings = torch.cat([merged_embeddings, merged_embedding]) + merged = True + elif t1 in self.expand_tokens: + if i == 0: + merged_embeddings = torch.cat([word_embeddings[i].unsqueeze(0), word_embeddings[i].unsqueeze(0)]) + else: + merged_embeddings = torch.cat([merged_embeddings, word_embeddings[i].unsqueeze(0), word_embeddings[i].unsqueeze(0)]) + else: + if i == 0: + merged_embeddings = word_embeddings[i].unsqueeze(0) + else: + merged_embeddings = torch.cat([merged_embeddings, word_embeddings[i].unsqueeze(0)]) + word_embeddings = merged_embeddings + #print(self.tokenizer.tokenize(sentences[batch_id])) + word_embeddings_list.append(word_embeddings) + # save sentence lengths + lens.append(word_embeddings.shape[0]) + # pad tensors to max sentence length of batch + word_embeddings_batch = pad_sequence(word_embeddings_list, batch_first=True).detach() + # return word embeddings for each word in each sentence along with sentence lengths + return word_embeddings_batch, lens + + def get_merge_tokens(self, sentence:str): + w_list = sentence.split() + merge_tokens = [] + for (w1, w2) in zip(w_list, w_list[1:]): + phonemized = self.tf.get_phone_string(' '.join([w1, w2])) + if len(phonemized.split()) < 2: + merge_tokens.append((w1, w2)) + return merge_tokens + + def get_expand_tokens(self, sentence:str): + w_list = sentence.split() + expand_tokens = [] + for w in w_list: + phonemized = self.tf.get_phone_string(w) + if len(phonemized.split()) == 2: + expand_tokens.append(w) + if len(phonemized.split()) > 2: + print(w) + print(phonemized) + return expand_tokens diff --git a/Preprocessing/word_embeddings/WordEmbeddingExtractor.py b/Preprocessing/word_embeddings/WordEmbeddingExtractor.py new file mode 100644 index 00000000..a464a65e --- /dev/null +++ b/Preprocessing/word_embeddings/WordEmbeddingExtractor.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod +import numpy as np +import os + +from Utility.storage_config import MODELS_DIR + +class WordEmbeddingExtractor(ABC): + + def __init__(self): + self.cache_dir = os.path.join(MODELS_DIR, 'Language_Models') + pass + + @abstractmethod + def encode(self, sentences:list[str]) -> np.ndarray: + pass \ No newline at end of file diff --git a/TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeechDataset.py b/TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeechDataset.py index a9c91e2a..b1f92365 100644 --- a/TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeechDataset.py +++ b/TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeechDataset.py @@ -31,6 +31,7 @@ def __init__(self, ctc_selection=True, save_imgs=False): self.cache_dir = cache_dir + self.path_to_transcript_dict = path_to_transcript_dict os.makedirs(cache_dir, exist_ok=True) if not os.path.exists(os.path.join(cache_dir, "fast_train_cache.pt")) or rebuild_cache: if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache: @@ -176,14 +177,16 @@ def __init__(self, def __getitem__(self, index): return self.datapoints[index][0], \ - self.datapoints[index][1], \ - self.datapoints[index][2], \ - self.datapoints[index][3], \ - self.datapoints[index][4], \ - self.datapoints[index][5], \ - self.datapoints[index][6], \ - self.datapoints[index][7], \ - self.language_id + self.datapoints[index][1], \ + self.datapoints[index][2], \ + self.datapoints[index][3], \ + self.datapoints[index][4], \ + self.datapoints[index][5], \ + self.datapoints[index][6], \ + self.datapoints[index][7], \ + self.language_id, \ + self.path_to_transcript_dict[self.datapoints[index][8]], \ + self.datapoints[index][8] def __len__(self): return len(self.datapoints) diff --git a/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/ToucanTTS.py b/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/ToucanTTS.py index 7892db33..bc4e9769 100644 --- a/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/ToucanTTS.py +++ b/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/ToucanTTS.py @@ -1,4 +1,5 @@ import torch +from torchvision.ops import SqueezeExcitation from torch.nn import Linear from torch.nn import Sequential from torch.nn import Tanh @@ -96,7 +97,10 @@ def __init__(self, # additional features utt_embed_dim=64, - lang_embs=8000): + lang_embs=8000, + sent_embed_dim=None, + word_embed_dim=None, + static_speaker_embed=False): super().__init__() self.input_feature_dimensions = input_feature_dimensions @@ -105,6 +109,27 @@ def __init__(self, self.use_scaled_pos_enc = use_scaled_positional_encoding self.multilingual_model = lang_embs is not None self.multispeaker_model = utt_embed_dim is not None + self.use_sent_embed = sent_embed_dim is not None + self.use_word_embed = word_embed_dim is not None + self.static_speaker_embed = static_speaker_embed + + if self.static_speaker_embed: + # emovdb - 4, cremad - 91, esds - 10, ravdess - 24, ljspeech - 1, librittsr - 1230, tess - 2 + self.speaker_embedding = torch.nn.Embedding(10 + 24 + 1 + 1230 + 2, utt_embed_dim) + + if self.use_sent_embed: + self.sentence_embedding_adaptation = Linear(sent_embed_dim, 512) + sent_embed_dim = 512 + + self.squeeze_excitation = SqueezeExcitation(utt_embed_dim + sent_embed_dim, 192) + self.style_embedding_projection = Sequential(Linear(utt_embed_dim + sent_embed_dim, 512), + Tanh(), + Linear(512, 192)) + utt_embed_dim = 192 + else: + if utt_embed_dim is not None: + self.speaker_embedding_adaptation = Linear(utt_embed_dim, 192) + utt_embed_dim = 192 articulatory_feature_embedding = Sequential(Linear(input_feature_dimensions, 100), Tanh(), Linear(100, attention_dimension)) self.encoder = Conformer(idim=input_feature_dimensions, @@ -125,7 +150,9 @@ def __init__(self, zero_triu=False, utt_embed=utt_embed_dim, lang_embs=lang_embs, - use_output_norm=True) + word_embed_dim=word_embed_dim, + use_output_norm=True, + conformer_encoder=True) self.duration_predictor = DurationPredictor(idim=attention_dimension, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, @@ -172,6 +199,7 @@ def __init__(self, macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_decoder_kernel_size, + utt_embed=utt_embed_dim, use_output_norm=False) self.feat_out = Linear(attention_dimension, output_spectrogram_channels) @@ -215,7 +243,10 @@ def forward(self, gold_durations, gold_pitch, gold_energy, - utterance_embedding, + utterance_embedding=None, + speaker_id=None, + sentence_embedding=None, + word_embedding=None, return_mels=False, lang_ids=None, run_glow=True @@ -247,6 +278,9 @@ def forward(self, gold_pitch=gold_pitch, gold_energy=gold_energy, utterance_embedding=utterance_embedding, + speaker_id=speaker_id, + sentence_embedding=sentence_embedding, + word_embedding=word_embedding, is_inference=False, lang_ids=lang_ids, run_glow=run_glow) @@ -268,7 +302,7 @@ def forward(self, if return_mels: if after_outs is None: after_outs = before_outs - return l1_loss, duration_loss, pitch_loss, energy_loss, glow_loss, after_outs + return l1_loss, duration_loss, pitch_loss, energy_loss, glow_loss, after_outs, return l1_loss, duration_loss, pitch_loss, energy_loss, glow_loss def _forward(self, @@ -281,6 +315,9 @@ def _forward(self, gold_energy=None, is_inference=False, utterance_embedding=None, + speaker_id=None, + sentence_embedding=None, + word_embedding=None, lang_ids=None, run_glow=True): @@ -290,12 +327,44 @@ def _forward(self, if not self.multispeaker_model: utterance_embedding = None else: - utterance_embedding = torch.nn.functional.normalize(utterance_embedding) + if self.static_speaker_embed: + utterance_embedding = self.speaker_embedding(speaker_id) + else: + utterance_embedding = torch.nn.functional.normalize(utterance_embedding) + + if not self.use_sent_embed: + sentence_embedding = None + utterance_embedding = self.speaker_embedding_adaptation(utterance_embedding) + else: + sentence_embedding = torch.nn.functional.normalize(sentence_embedding) + sentence_embedding = self.sentence_embedding_adaptation(sentence_embedding) + utterance_embedding = torch.cat([utterance_embedding, sentence_embedding], dim=1) + utterance_embedding = self.squeeze_excitation(utterance_embedding.transpose(0, 1).unsqueeze(-1)).squeeze(-1).transpose(0, 1) + utterance_embedding = self.style_embedding_projection(utterance_embedding) + + if not self.use_word_embed: + word_embedding = None + word_boundaries_batch = None + else: + # get word boundaries + word_boundaries_batch = [] + for batch_id, batch in enumerate(text_tensors): + word_boundaries = [] + for phoneme_index, phoneme_vector in enumerate(batch): + if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1: + word_boundaries.append(phoneme_index) + word_boundaries.append(text_lengths[batch_id].cpu().numpy()-1) # marker for last word of sentence + word_boundaries_batch.append(torch.tensor(word_boundaries)) # encoding the texts text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2) padding_masks = make_pad_mask(text_lengths, device=text_lengths.device) - encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) + encoded_texts, _ = self.encoder(text_tensors, + text_masks, + utterance_embedding=utterance_embedding, + word_embedding=word_embedding, + word_boundaries=word_boundaries_batch, + lang_ids=lang_ids) # (B, Tmax, adim) if is_inference: # predicting pitch, energy and durations @@ -331,7 +400,9 @@ def _forward(self, # decoding spectrogram decoder_masks = make_non_pad_mask(speech_lengths, device=speech_lengths.device).unsqueeze(-2) if speech_lengths is not None and not is_inference else None - decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, decoder_masks) + decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, + decoder_masks, + utterance_embedding=utterance_embedding) decoded_spectrogram = self.feat_out(decoded_speech).view(decoded_speech.size(0), -1, self.output_spectrogram_channels) refined_spectrogram = decoded_spectrogram + self.conv_postnet(decoded_spectrogram.transpose(1, 2)).transpose(1, 2) @@ -351,6 +422,7 @@ def _forward(self, mel_out=refined_spectrogram.detach().clone(), encoded_texts=upsampled_enriched_encoded_texts.detach().clone(), tgt_nonpadding=decoder_masks) + if is_inference: return decoded_spectrogram.squeeze(), \ refined_spectrogram.squeeze(), \ @@ -370,6 +442,9 @@ def inference(self, text, speech=None, utterance_embedding=None, + speaker_id=None, + sentence_embedding=None, + word_embedding=None, return_duration_pitch_energy=False, lang_id=None, run_postflow=True): @@ -392,7 +467,10 @@ def inference(self, ys = y.unsqueeze(0) if lang_id is not None: lang_id = lang_id.unsqueeze(0) - utterance_embeddings = utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None + utterance_embeddings = utterance_embedding.unsqueeze(0).to(x.device) if utterance_embedding is not None else None + sentence_embeddings = sentence_embedding.unsqueeze(0).to(x.device) if sentence_embedding is not None else None + word_embeddings = word_embedding.unsqueeze(0).to(x.device) if word_embedding is not None else None + speaker_id = speaker_id.to(x.device) if speaker_id is not None else None before_outs, \ after_outs, \ @@ -403,6 +481,9 @@ def inference(self, ys, is_inference=True, utterance_embedding=utterance_embeddings, + speaker_id=speaker_id, + sentence_embedding=sentence_embeddings, + word_embedding=word_embeddings, lang_ids=lang_id, run_glow=run_postflow) # (1, L, odim) self.train() diff --git a/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_train_loop.py b/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_train_loop.py index 2cf4e3fc..6b2b461c 100644 --- a/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_train_loop.py +++ b/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_train_loop.py @@ -19,10 +19,11 @@ from run_weight_averaging import get_n_recent_checkpoints_paths from run_weight_averaging import load_net_toucan from run_weight_averaging import save_model_for_use +from Utility.utils import get_emotion_from_path, get_speakerid_from_path_all, get_speakerid_from_path, get_speakerid_from_path_all2 def collate_and_pad(batch): - # text, text_len, speech, speech_len, durations, energy, pitch, utterance condition, language_id + # text, text_len, speech, speech_len, durations, energy, pitch, utterance condition, language_id, sentence string, filepath return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True), torch.stack([datapoint[1] for datapoint in batch]).squeeze(1), pad_sequence([datapoint[2] for datapoint in batch], batch_first=True), @@ -31,7 +32,9 @@ def collate_and_pad(batch): pad_sequence([datapoint[5] for datapoint in batch], batch_first=True), pad_sequence([datapoint[6] for datapoint in batch], batch_first=True), None, - torch.stack([datapoint[8] for datapoint in batch])) + torch.stack([datapoint[8] for datapoint in batch]), + [datapoint[9] for datapoint in batch], + [datapoint[10] for datapoint in batch]) def train_loop(net, @@ -49,7 +52,12 @@ def train_loop(net, steps, use_wandb, postnet_start_steps, - use_discriminator + use_discriminator, + sent_embs=None, + emotion_sent_embs=None, + word_embedding_extractor=None, + path_to_xvect=None, + static_speaker_embed=False ): """ see train loop arbiter for explanations of the arguments @@ -58,11 +66,18 @@ def train_loop(net, if use_discriminator: discriminator = SpectrogramDiscriminator().to(device) - style_embedding_function = StyleEmbedding().to(device) - check_dict = torch.load(path_to_embed_model, map_location=device) - style_embedding_function.load_state_dict(check_dict["style_emb_func"]) - style_embedding_function.eval() - style_embedding_function.requires_grad_(False) + if path_to_embed_model is not None: + style_embedding_function = StyleEmbedding().to(device) + check_dict = torch.load(path_to_embed_model, map_location=device) + style_embedding_function.load_state_dict(check_dict["style_emb_func"]) + style_embedding_function.eval() + style_embedding_function.requires_grad_(False) + else: + style_embedding_function = None + + if static_speaker_embed: + with open("/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/librittsr/libri_speakers.txt") as f: + libri_speakers = sorted([int(line.rstrip()) for line in f]) torch.multiprocessing.set_sharing_strategy('file_system') train_loader = DataLoader(batch_size=batch_size, @@ -104,8 +119,45 @@ def train_loop(net, for batch in tqdm(train_loader): train_loss = 0.0 - style_embedding = style_embedding_function(batch_of_spectrograms=batch[2].to(device), - batch_of_spectrogram_lengths=batch[3].to(device)) + + if path_to_xvect is not None: + filepaths = batch[10] + embeddings = [] + for path in filepaths: + embeddings.append(path_to_xvect[path]) + style_embedding = torch.stack(embeddings).to(device) + elif style_embedding_function is not None: + style_embedding = style_embedding_function(batch_of_spectrograms=batch[2].to(device), + batch_of_spectrogram_lengths=batch[3].to(device)) + else: + style_embedding = None + + if sent_embs is not None or emotion_sent_embs is not None: + filepaths = batch[10] + sentences = batch[9] + sentence_embeddings = [] + for path, sentence in zip(filepaths, sentences): + if "LJSpeech" in path or "LibriTTS_R" in path: + sentence_embeddings.append(sent_embs[sentence]) + else: + emotion = get_emotion_from_path(path) + sentence_embeddings.append(random.choice(emotion_sent_embs[emotion])) + sentence_embedding = torch.stack(sentence_embeddings).to(device) + else: + sentence_embedding = None + + if static_speaker_embed: + filepaths = batch[10] + speaker_ids = torch.LongTensor([get_speakerid_from_path_all2(path, libri_speakers) for path in filepaths]).to(device) + else: + speaker_ids = None + + if word_embedding_extractor is not None: + word_embedding, sentence_lens = word_embedding_extractor.encode(sentences=batch[9]) + word_embedding = word_embedding.to(device) + else: + word_embedding = None + sentence_lens = None l1_loss, duration_loss, pitch_loss, energy_loss, glow_loss, generated_spectrograms = net( text_tensors=batch[0].to(device), @@ -116,6 +168,9 @@ def train_loop(net, gold_pitch=batch[6].to(device), # mind the switched order gold_energy=batch[5].to(device), # mind the switched order utterance_embedding=style_embedding, + speaker_id=speaker_ids, + sentence_embedding=sentence_embedding, + word_embedding=word_embedding, lang_ids=batch[8].to(device), return_mels=True, run_glow=step_counter > postnet_start_steps or fine_tune) @@ -161,10 +216,17 @@ def train_loop(net, # EPOCH IS OVER net.eval() - style_embedding_function.eval() - default_embedding = style_embedding_function( - batch_of_spectrograms=train_dataset[0][2].unsqueeze(0).to(device), - batch_of_spectrogram_lengths=train_dataset[0][3].unsqueeze(0).to(device)).squeeze() + + if style_embedding_function is not None: + style_embedding_function.eval() + if path_to_xvect is not None: + default_embedding = path_to_xvect[train_dataset[0][10]] + elif style_embedding_function is not None: + default_embedding = style_embedding_function(batch_of_spectrograms=train_dataset[0][2].unsqueeze(0).to(device), + batch_of_spectrogram_lengths=train_dataset[0][3].unsqueeze(0).to(device)).squeeze() + else: + default_embedding = None + torch.save({ "model" : net.state_dict(), "optimizer" : optimizer.state_dict(), @@ -200,6 +262,9 @@ def train_loop(net, step=step_counter, lang=lang, default_emb=default_embedding, + static_speaker_embed=static_speaker_embed, + sent_embs=sent_embs if sent_embs is not None else emotion_sent_embs, + word_embedding_extractor=word_embedding_extractor, run_postflow=step_counter - 5 > postnet_start_steps) if use_wandb: wandb.log({ diff --git a/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_train_loop_arbiter.py b/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_train_loop_arbiter.py index 0e020f55..6ed35a48 100644 --- a/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_train_loop_arbiter.py +++ b/TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_train_loop_arbiter.py @@ -23,7 +23,12 @@ def train_loop(net, # an already initialized ToucanTTS model that should be tra fine_tune=False, # whether to use the provided checkpoint as basis for fine-tuning. steps=80000, # how many updates to run until training is completed postnet_start_steps=9000, # how many warmup steps before the postnet starts training - use_discriminator=True # whether to use a discriminator as additional feedback signal for the TTS in the mono-lingual train loop + use_discriminator=True, # whether to use a discriminator as additional feedback signal for the TTS in the mono-lingual train loop + sent_embs=None, + emotion_sent_embs=None, + word_embedding_extractor=None, + path_to_xvect=None, + static_speaker_embed=False ): if type(datasets) != list: datasets = [datasets] @@ -60,4 +65,9 @@ def train_loop(net, # an already initialized ToucanTTS model that should be tra steps=steps, use_wandb=use_wandb, postnet_start_steps=postnet_start_steps, - use_discriminator=use_discriminator) + use_discriminator=use_discriminator, + sent_embs=sent_embs, + emotion_sent_embs=emotion_sent_embs, + word_embedding_extractor=word_embedding_extractor, + path_to_xvect=path_to_xvect, + static_speaker_embed=static_speaker_embed) diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_Baseline_Finetuning.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Baseline_Finetuning.py new file mode 100644 index 00000000..997bb270 --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Baseline_Finetuning.py @@ -0,0 +1,96 @@ +import time + +import torch +import wandb +from torch.utils.data import ConcatDataset +from tqdm import tqdm + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device("cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_Baseline_Finetuning_2" + print("base finetuning") + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + datasets = list() + + ''' + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_EmoV_DB_Speaker(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "emovdb_speaker"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_CREMA_D(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "cremad"), + lang="en", + save_imgs=False)) + ''' + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_RAVDESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "ravdess"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_ESDS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "esds"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_TESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "tess"), + lang="en", + save_imgs=False)) + + train_set = ConcatDataset(datasets) + + model = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + static_speaker_embed=True) + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=train_set, + device=device, + save_directory=save_dir, + batch_size=32, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=None, + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + path_to_xvect=None, + static_speaker_embed=True, + steps=200000) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_Baseline_Pretraining.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Baseline_Pretraining.py new file mode 100644 index 00000000..ce5a21e4 --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Baseline_Pretraining.py @@ -0,0 +1,116 @@ +import time + +import torch +import wandb +from torch.utils.data import ConcatDataset +from tqdm import tqdm + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device("cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_Baseline_Pretraining_2" + print("base") + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + datasets = list() + + try: + transcript_dict_ljspeech = torch.load(os.path.join(PREPROCESSING_DIR, "ljspeech", "path_to_transcript_dict.pt"), map_location='cpu') + except FileNotFoundError: + transcript_dict_ljspeech = build_path_to_transcript_dict_ljspeech() + + datasets.append(prepare_fastspeech_corpus(transcript_dict=transcript_dict_ljspeech, + corpus_dir=os.path.join(PREPROCESSING_DIR, "ljspeech"), + lang="en", + save_imgs=False)) + + try: + transcript_dict_librittsr = torch.load(os.path.join(PREPROCESSING_DIR, "librittsr", "path_to_transcript_dict.pt"), map_location='cpu') + except FileNotFoundError: + transcript_dict_librittsr = build_path_to_transcript_dict_libritts_all_clean() + + datasets.append(prepare_fastspeech_corpus(transcript_dict=transcript_dict_librittsr, + corpus_dir=os.path.join(PREPROCESSING_DIR, "librittsr"), + lang="en", + save_imgs=False)) + + ''' + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_EmoV_DB_Speaker(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "emovdb_speaker"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_CREMA_D(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "cremad"), + lang="en", + save_imgs=False)) + ''' + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_RAVDESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "ravdess"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_ESDS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "esds"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_TESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "tess"), + lang="en", + save_imgs=False)) + + train_set = ConcatDataset(datasets) + + model = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + static_speaker_embed=True) + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=train_set, + device=device, + save_directory=save_dir, + batch_size=32, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=None, + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + path_to_xvect=None, + static_speaker_embed=True, + steps=120000) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_ESDS.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_ESDS.py new file mode 100644 index 00000000..28df617a --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_ESDS.py @@ -0,0 +1,88 @@ +import time + +import torch +#import wandb + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device("cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_04_ESDS_static" + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + train_set = prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_ESDS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "esds"), + lang="en", + save_imgs=False) + + if "_xvect" in name: + print(f"Loading xvect embeddings from {os.path.join(PREPROCESSING_DIR, 'xvect_emomulti', 'xvect.pt')}") + path_to_xvect = torch.load(os.path.join(PREPROCESSING_DIR, "xvect_emomulti", "xvect.pt"), map_location='cpu') + else: + path_to_xvect = None + + if "_ecapa" in name: + print(f"Loading ecapa embeddings from {os.path.join(PREPROCESSING_DIR, 'ecapa_emomulti', 'ecapa.pt')}") + path_to_ecapa = torch.load(os.path.join(PREPROCESSING_DIR, "ecapa_emomulti", "ecapa.pt"), map_location='cpu') + else: + path_to_ecapa = None + if path_to_ecapa is not None: + path_to_xvect = path_to_ecapa + + if "_xvect" in name: + utt_embed_dim = 512 + elif "_ecapa" in name: + utt_embed_dim = 192 + else: + utt_embed_dim = 64 + + model = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + static_speaker_embed=True) + if use_wandb: + import wandb + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=16, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=None, + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + path_to_xvect=path_to_xvect, + static_speaker_embed=True) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_ESDS_sent_emb.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_ESDS_sent_emb.py new file mode 100644 index 00000000..66581875 --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_ESDS_sent_emb.py @@ -0,0 +1,232 @@ +import time + +import torch +import wandb +from torch.utils.data import ConcatDataset + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device(f"cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_06s_ESDS_sent_emb_a11_emoBERTcls_xvect" + """ + a01: integrate before encoder + a02: integrate before encoder and decoder + a03: integrate before encoder and decoder and postnet + a04: integrate before each encoder layer + a05: integrate before each encoder and decoder layer + a06: integrate before each encoder and decoder layer and postnet + a07: concatenate with style embedding and apply projection + a08: concatenate with style embedding + a09: a06 + a07 + a10: replace style embedding with sentence embedding (no style embedding, no language embedding, single speaker single language case) + a11: a01 + a07 + a12: integrate before encoder and use sentence embedding instead of style embedding (can be constrained with loss) + a13: use sentence embedding instead of style embedding (can be constrained with loss or adaptor) + loss: additionally use sentence style loss + """ + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + datasets = list() + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_ESDS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "esds"), + lang="en", + save_imgs=False)) + + train_set = ConcatDataset(datasets) + + if "_xvect" in name: + print(f"Loading xvect embeddings from {os.path.join(PREPROCESSING_DIR, 'xvect_emomulti', 'xvect.pt')}") + path_to_xvect = torch.load(os.path.join(PREPROCESSING_DIR, "xvect_emomulti", "xvect.pt"), map_location='cpu') + if "_static" in name: + import torchaudio + from speechbrain.pretrained import EncoderClassifier + classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb", savedir="./Models/Embedding/spkrec-xvect-voxceleb", run_opts={"device": device}) + xvect_list = [] + audio_paths = [] + for path in audio_paths: + wave, sr = torchaudio.load(path) + # mono + wave = torch.mean(wave, dim=0, keepdim=True) + # resampling + wave = torchaudio.functional.resample(wave, orig_freq=sr, new_freq=16000) + wave = wave.squeeze(0) + embedding = classifier.encode_batch(wave).squeeze(0).squeeze(0) + xvect_list.append(embedding) + else: + xvect_list = None + else: + path_to_xvect = None + xvect_list = None + + if "_ecapa" in name: + print(f"Loading ecapa embeddings from {os.path.join(PREPROCESSING_DIR, 'ecapa_emomulti', 'ecapa.pt')}") + path_to_ecapa = torch.load(os.path.join(PREPROCESSING_DIR, "ecapa_emomulti", "ecapa.pt"), map_location='cpu') + else: + path_to_ecapa = None + if path_to_ecapa is not None: + path_to_xvect = path_to_ecapa + + if "laser" in name: + embed_type = "laser" + sent_embed_dim = 1024 + if "lealla" in name: + embed_type = "lealla" + sent_embed_dim = 192 + if "para" in name: + embed_type = "para" + sent_embed_dim = 768 + if "mpnet" in name: + embed_type = "mpnet" + sent_embed_dim = 768 + if "bertcls" in name: + embed_type = "bertcls" + sent_embed_dim = 768 + if "bertlm" in name: + embed_type = "bertlm" + sent_embed_dim = 768 + if "emoBERTcls" in name: + embed_type = "emoBERTcls" + sent_embed_dim = 768 + + print(f'Loading sentence embeddings from {os.path.join(PREPROCESSING_DIR, "Yelp", f"emotion_prompts_large_sent_embs_{embed_type}.pt")}') + sent_embs = torch.load(os.path.join(PREPROCESSING_DIR, "Yelp", f"emotion_prompts_large_sent_embs_{embed_type}.pt"), map_location='cpu') + + sent_embed_encoder=False + sent_embed_decoder=False + sent_embed_each=False + sent_embed_postnet=False + concat_sent_style=False + use_concat_projection=False + replace_utt_sent_emb = False + style_sent = False + + lang_embs=None + if "_xvect" in name and "_adapted" not in name: + utt_embed_dim = 512 + elif "_ecapa" in name and "_adapted" not in name: + utt_embed_dim = 192 + else: + utt_embed_dim = 64 + + if "a01" in name: + sent_embed_encoder=True + if "a02" in name: + sent_embed_encoder=True + sent_embed_decoder=True + if "a03" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_postnet=True + if "a04" in name: + sent_embed_encoder=True + sent_embed_each=True + if "a05" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_each=True + if "a06" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_each=True + sent_embed_postnet=True + if "a07" in name: + concat_sent_style=True + use_concat_projection=True + if "a08" in name: + concat_sent_style=True + if "a09" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_each=True + sent_embed_postnet=True + concat_sent_style=True + use_concat_projection=True + if "a10" in name: + lang_embs = None + utt_embed_dim = 192 + sent_embed_dim = None + replace_utt_sent_emb = True + if "a11" in name: + sent_embed_encoder=True + concat_sent_style=True + use_concat_projection=True + if "a12" in name: + sent_embed_encoder=True + style_sent=True + if "noadapt" in name and "adapted" not in name: + utt_embed_dim = 768 + if "a13" in name: + style_sent=True + utt_embed_dim = sent_embed_dim + if "noadapt" in name and "adapted" not in name: + utt_embed_dim = 768 + + + model = ToucanTTS(lang_embs=lang_embs, + utt_embed_dim=utt_embed_dim, + sent_embed_dim=64 if "adapted" in name else sent_embed_dim, + sent_embed_adaptation="noadapt" not in name, + sent_embed_encoder=sent_embed_encoder, + sent_embed_decoder=sent_embed_decoder, + sent_embed_each=sent_embed_each, + sent_embed_postnet=sent_embed_postnet, + concat_sent_style=concat_sent_style, + use_concat_projection=use_concat_projection, + use_sent_style_loss="loss" in name, + pre_embed="_pre" in name, + style_sent=style_sent, + static_speaker_embed="_static" in name) + + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=16, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=os.path.join(MODELS_DIR, "EmoMulti_Embedding", "embedding_function.pt"), + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + sent_embs=sent_embs, + random_emb=True, + emovdb=True, + replace_utt_sent_emb=replace_utt_sent_emb, + use_adapted_embs="adapted" in name, + path_to_xvect=path_to_xvect, + static_speaker_embed="_static" in name) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_LJSpeech.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LJSpeech.py new file mode 100644 index 00000000..46040ee6 --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LJSpeech.py @@ -0,0 +1,62 @@ +import time + +import torch +import wandb + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device("cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_03_LJSpeech" + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + train_set = prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_ljspeech(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "ljspeech"), + lang="en", + save_imgs=False) + + model = ToucanTTS(lang_embs=None, utt_embed_dim=None) + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=8, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=os.path.join(MODELS_DIR, "Embedding", "embedding_function.pt"), + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_LJSpeech_sent_emb.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LJSpeech_sent_emb.py new file mode 100644 index 00000000..30e375be --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LJSpeech_sent_emb.py @@ -0,0 +1,102 @@ +import time + +import torch +import wandb +from torch.utils.data import ConcatDataset + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.sent_emb_extraction import extract_sent_embs +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR +from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device(f"cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_Sent_LJSpeech" + + ''' + concat speaker embedding and sentence embedding + input for encoder, pitch, energy, variance predictors and decoder + ''' + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + datasets = list() + + try: + transcript_dict_ljspeech = torch.load(os.path.join(PREPROCESSING_DIR, "ljspeech", "path_to_transcript_dict.pt"), map_location='cpu') + except FileNotFoundError: + transcript_dict_ljspeech = build_path_to_transcript_dict_ljspeech() + + datasets.append(prepare_fastspeech_corpus(transcript_dict=transcript_dict_ljspeech, + corpus_dir=os.path.join(PREPROCESSING_DIR, "ljspeech"), + lang="en", + save_imgs=False)) + + train_set = ConcatDataset(datasets) + + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "ljspeech", "sent_embs_emoBERTcls.pt")): + from Preprocessing.sentence_embeddings.EmotionRoBERTaSentenceEmbeddingExtractor import EmotionRoBERTaSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + sentence_embedding_extractor = SentenceEmbeddingExtractor(pooling="cls") + sent_embs = extract_sent_embs(train_set=train_set, sent_emb_extractor=sentence_embedding_extractor) + atf = ArticulatoryCombinedTextFrontend(language="en") + example_sentence = atf.get_example_sentence(lang="en") + sent_embs[example_sentence] = sentence_embedding_extractor.encode(sentences=[example_sentence]).squeeze() + torch.save(sent_embs, os.path.join(PREPROCESSING_DIR, "ljspeech", "sent_embs_emoBERTcls.pt")) + print(f'Saved sentence embeddings in {os.path.join(PREPROCESSING_DIR, "ljspeech", "sent_embs_emoBERTcls.pt")}') + del sentence_embedding_extractor + else: + print(f'Loading sentence embeddings from {os.path.join(PREPROCESSING_DIR, "ljspeech", "sent_embs_emoBERTcls.pt")}.') + sent_embs = torch.load(os.path.join(PREPROCESSING_DIR, "ljspeech", "sent_embs_emoBERTcls.pt"), map_location='cpu') + + return + + model = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + sent_embed_dim=768, + static_speaker_embed=True) + + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=32, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=None, + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + sent_embs=sent_embs, + path_to_xvect=None, + static_speaker_embed=True) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_LJSpeech_word_emb.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LJSpeech_word_emb.py new file mode 100644 index 00000000..4d5f3af9 --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LJSpeech_word_emb.py @@ -0,0 +1,76 @@ +import time + +import torch +import wandb + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device(f"cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_02_LJSpeech_word_emb_bert" + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + train_set = prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_ljspeech(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "ljspeech"), + lang="en", + save_imgs=False) + + word_embedding_extractor = None + word_embed_dim = None + + if "_bert" in name: + from Preprocessing.word_embeddings.BERTWordEmbeddingExtractor import BERTWordEmbeddingExtractor + word_embedding_extractor = BERTWordEmbeddingExtractor() + word_embed_dim = 768 + if "_emoBERT" in name: + from Preprocessing.word_embeddings.EmotionRoBERTaWordEmbeddingExtractor import EmotionRoBERTaWordEmbeddingExtractor + word_embedding_extractor = EmotionRoBERTaWordEmbeddingExtractor() + word_embed_dim = 768 + + model = ToucanTTS(lang_embs=None, utt_embed_dim=None, word_embed_dim=word_embed_dim) + + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=8, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=os.path.join(MODELS_DIR, "Embedding", "embedding_function.pt"), + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + word_embedding_extractor=word_embedding_extractor) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTS.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTS.py new file mode 100644 index 00000000..723f004a --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTS.py @@ -0,0 +1,65 @@ +import time + +import torch +import wandb + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + +import sys + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device("cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_01_LibriTTS" + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + train_set = prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_libritts_all_clean(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "libritts"), + lang="en", + save_imgs=False) + + model = ToucanTTS() + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=4, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=os.path.join(MODELS_DIR, "Embedding", "embedding_function.pt"), + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + steps=200000) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTSR.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTSR.py new file mode 100644 index 00000000..89cdf1fc --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTSR.py @@ -0,0 +1,102 @@ +import time + +import torch +import wandb +from torch.utils.data import ConcatDataset +from tqdm import tqdm + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device("cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_Baseline_LibriTTSR" + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + datasets = list() + + try: + transcript_dict_librittsr = torch.load(os.path.join(PREPROCESSING_DIR, "librittsr", "path_to_transcript_dict.pt"), map_location='cpu') + except FileNotFoundError: + transcript_dict_librittsr = build_path_to_transcript_dict_libritts_all_clean() + + datasets.append(prepare_fastspeech_corpus(transcript_dict=transcript_dict_librittsr, + corpus_dir=os.path.join(PREPROCESSING_DIR, "librittsr"), + lang="en", + save_imgs=False)) + + train_set = ConcatDataset(datasets) + + if "_xvect" in name: + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "xvect_emomulti", "xvect.pt")): + print("Extracting xvect from audio") + import torchaudio + from speechbrain.pretrained import EncoderClassifier + classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb", savedir="./Models/Embedding/spkrec-xvect-voxceleb", run_opts={"device": device}) + path_to_xvect = {} + for index in tqdm(range(len(train_set))): + path = train_set[index][10] + wave, sr = torchaudio.load(path) + # mono + wave = torch.mean(wave, dim=0, keepdim=True) + # resampling + wave = torchaudio.functional.resample(wave, orig_freq=sr, new_freq=16000) + wave = wave.squeeze(0) + embedding = classifier.encode_batch(wave).squeeze(0).squeeze(0) + path_to_xvect[path] = embedding + torch.save(path_to_xvect, os.path.join(PREPROCESSING_DIR, "xvect_emomulti", "xvect.pt")) + del classifier + else: + print(f"Loading xvect embeddings from {os.path.join(PREPROCESSING_DIR, 'xvect_emomulti', 'xvect.pt')}") + path_to_xvect = torch.load(os.path.join(PREPROCESSING_DIR, "xvect_emomulti", "xvect.pt"), map_location='cpu') + else: + path_to_xvect = None + + model = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + static_speaker_embed=True) + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=train_set, + device=device, + save_directory=save_dir, + batch_size=16, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=None, + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + path_to_xvect=None, + static_speaker_embed=True) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTSR_sent_emb.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTSR_sent_emb.py new file mode 100644 index 00000000..640d6e28 --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTSR_sent_emb.py @@ -0,0 +1,100 @@ +import time + +import torch +import wandb +from torch.utils.data import ConcatDataset + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.sent_emb_extraction import extract_sent_embs +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR +from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device(f"cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_Sent_LibriTTSR" + + ''' + concat speaker embedding and sentence embedding + input for encoder, pitch, energy, variance predictors and decoder + ''' + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + datasets = list() + + try: + transcript_dict_librittsr = torch.load(os.path.join(PREPROCESSING_DIR, "librittsr", "path_to_transcript_dict.pt"), map_location='cpu') + except FileNotFoundError: + transcript_dict_librittsr = build_path_to_transcript_dict_libritts_all_clean() + + datasets.append(prepare_fastspeech_corpus(transcript_dict=transcript_dict_librittsr, + corpus_dir=os.path.join(PREPROCESSING_DIR, "librittsr"), + lang="en", + save_imgs=False)) + + train_set = ConcatDataset(datasets) + + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "librittsr", "sent_embs_emoBERTcls.pt")): + from Preprocessing.sentence_embeddings.EmotionRoBERTaSentenceEmbeddingExtractor import EmotionRoBERTaSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + sentence_embedding_extractor = SentenceEmbeddingExtractor(pooling="cls") + sent_embs = extract_sent_embs(train_set=train_set, sent_emb_extractor=sentence_embedding_extractor) + atf = ArticulatoryCombinedTextFrontend(language="en") + example_sentence = atf.get_example_sentence(lang="en") + sent_embs[example_sentence] = sentence_embedding_extractor.encode(sentences=[example_sentence]).squeeze() + torch.save(sent_embs, os.path.join(PREPROCESSING_DIR, "librittsr", "sent_embs_emoBERTcls.pt")) + print(f'Saved sentence embeddings in {os.path.join(PREPROCESSING_DIR, "librittsr", "sent_embs_emoBERTcls.pt")}') + del sentence_embedding_extractor + else: + print(f'Loading sentence embeddings from {os.path.join(PREPROCESSING_DIR, "librittsr", "sent_embs_emoBERTcls.pt")}.') + sent_embs = torch.load(os.path.join(PREPROCESSING_DIR, "librittsr", "sent_embs_emoBERTcls.pt"), map_location='cpu') + + model = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + sent_embed_dim=768, + static_speaker_embed=True) + + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=16, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=None, + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + sent_embs=sent_embs, + path_to_xvect=None, + static_speaker_embed=True) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTS_sent_emb.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTS_sent_emb.py new file mode 100644 index 00000000..cfd75417 --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_LibriTTS_sent_emb.py @@ -0,0 +1,200 @@ +import time + +import torch +import wandb + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.sent_emb_extraction import extract_sent_embs +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR +from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device(f"cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_03_LibriTTS_sent_emb_a11_mpnet" + """ + a01: integrate before encoder + a02: integrate before encoder and decoder + a03: integrate before encoder and decoder and postnet + a04: integrate before each encoder layer + a05: integrate before each encoder and decoder layer + a06: integrate before each encoder and decoder layer and postnet + a07: concatenate with style embedding and apply projection + a08: concatenate with style embedding + a09: a06 + a07 + a10: replace style embedding with sentence embedding (no style embedding, no language embedding, single speaker single language case) + a11: a01 + a07 + loss: additionally use sentence style loss + """ + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + train_set = prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_libritts_all_clean(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "libritts"), + lang="en", + save_imgs=False) + + if "laser" in name: + embed_type = "laser" + sent_embed_dim = 1024 + if "lealla" in name: + embed_type = "lealla" + sent_embed_dim = 192 + if "para" in name: + embed_type = "para" + sent_embed_dim = 768 + if "mpnet" in name: + embed_type = "mpnet" + sent_embed_dim = 768 + if "bertcls" in name: + embed_type = "bertcls" + sent_embed_dim = 768 + + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "libritts", f"sent_emb_cache_{embed_type}.pt")): + if embed_type == "lealla": + import tensorflow as tf + gpus = tf.config.experimental.list_physical_devices('GPU') + tf.config.experimental.set_visible_devices(gpus[0], 'GPU') + from Preprocessing.sentence_embeddings.LEALLASentenceEmbeddingExtractor import LEALLASentenceEmbeddingExtractor as SentenceEmbeddingExtractor + sentence_embedding_extractor = SentenceEmbeddingExtractor() + if embed_type == "laser": + from Preprocessing.sentence_embeddings.LASERSentenceEmbeddingExtractor import LASERSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + sentence_embedding_extractor = SentenceEmbeddingExtractor() + if embed_type == "para": + from Preprocessing.sentence_embeddings.STSentenceEmbeddingExtractor import STSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + sentence_embedding_extractor = SentenceEmbeddingExtractor(model="para") + if embed_type == "mpnet": + from Preprocessing.sentence_embeddings.STSentenceEmbeddingExtractor import STSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + sentence_embedding_extractor = SentenceEmbeddingExtractor(model="mpnet") + if embed_type == "bertcls": + from Preprocessing.sentence_embeddings.BERTSentenceEmbeddingExtractor import BERTSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + sentence_embedding_extractor = SentenceEmbeddingExtractor(pooling="cls") + + sent_embs = extract_sent_embs(train_set=train_set, sent_emb_extractor=sentence_embedding_extractor) + atf = ArticulatoryCombinedTextFrontend(language="en") + example_sentence = atf.get_example_sentence(lang="en") + sent_embs[example_sentence] = sentence_embedding_extractor.encode(sentences=[example_sentence]).squeeze() + torch.save(sent_embs, os.path.join(PREPROCESSING_DIR, "libritts", f"sent_emb_cache_{embed_type}.pt")) + print(f'Saved sentence embeddings in {os.path.join(PREPROCESSING_DIR, "libritts", f"sent_emb_cache_{embed_type}.pt")}') + if embed_type == "lealla": + print("Please restart and use saved sentence embeddings because tensorflow won't release GPU memory for training.") + return + else: + del sentence_embedding_extractor + else: + print(f'Loading sentence embeddings from {os.path.join(PREPROCESSING_DIR, "libritts", f"sent_emb_cache_{embed_type}.pt")}.') + sent_embs = torch.load(os.path.join(PREPROCESSING_DIR, "libritts", f"sent_emb_cache_{embed_type}.pt"), map_location='cpu') + + if sent_embs is None: + raise TypeError("Sentence embeddings are None.") + + sent_embed_encoder=False + sent_embed_decoder=False + sent_embed_each=False + sent_embed_postnet=False + concat_sent_style=False + use_concat_projection=False + replace_utt_sent_emb = False + + lang_embs=8000 + utt_embed_dim=64 + + if "a01" in name: + sent_embed_encoder=True + if "a02" in name: + sent_embed_encoder=True + sent_embed_decoder=True + if "a03" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_postnet=True + if "a04" in name: + sent_embed_encoder=True + sent_embed_each=True + if "a05" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_each=True + if "a06" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_each=True + sent_embed_postnet=True + if "a07" in name: + concat_sent_style=True + use_concat_projection=True + if "a08" in name: + concat_sent_style=True + if "a09" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_each=True + sent_embed_postnet=True + concat_sent_style=True + use_concat_projection=True + if "a10" in name: + lang_embs = None + utt_embed_dim = 192 + sent_embed_dim = None + replace_utt_sent_emb = True + if "a11" in name: + sent_embed_encoder=True + concat_sent_style=True + use_concat_projection=True + + model = ToucanTTS(lang_embs=lang_embs, + utt_embed_dim=utt_embed_dim, + sent_embed_dim=sent_embed_dim, + sent_embed_adaptation="noadapt" not in name, + sent_embed_encoder=sent_embed_encoder, + sent_embed_decoder=sent_embed_decoder, + sent_embed_each=sent_embed_each, + sent_embed_postnet=sent_embed_postnet, + concat_sent_style=concat_sent_style, + use_concat_projection=use_concat_projection, + use_sent_style_loss="loss" in name) + + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=4, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=os.path.join(MODELS_DIR, "Embedding", "embedding_function.pt"), + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + sent_embs=sent_embs, + replace_utt_sent_emb=replace_utt_sent_emb, + steps=200000) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_Ravdess.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Ravdess.py new file mode 100644 index 00000000..c9ede9f8 --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Ravdess.py @@ -0,0 +1,65 @@ +import time + +import torch +import wandb + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device("cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_04_Ravdess_static" + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + train_set = prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_RAVDESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "ravdess"), + lang="en", + save_imgs=False) + + model = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + static_speaker_embed=True) + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=16, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=None, + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + static_speaker_embed=True) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_Ravdess_sent_emb.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Ravdess_sent_emb.py new file mode 100644 index 00000000..c45ef401 --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Ravdess_sent_emb.py @@ -0,0 +1,214 @@ +import time + +import torch +import wandb +from torch.utils.data import ConcatDataset + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device(f"cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_06_Ravdess_sent_emb_a11_emoBERTcls_static" + """ + a01: integrate before encoder + a02: integrate before encoder and decoder + a03: integrate before encoder and decoder and postnet + a04: integrate before each encoder layer + a05: integrate before each encoder and decoder layer + a06: integrate before each encoder and decoder layer and postnet + a07: concatenate with style embedding and apply projection + a08: concatenate with style embedding + a09: a06 + a07 + a10: replace style embedding with sentence embedding (no style embedding, no language embedding, single speaker single language case) + a11: a01 + a07 + a12: integrate before encoder and use sentence embedding instead of style embedding (can be constrained with loss) + a13: use sentence embedding instead of style embedding (can be constrained with loss or adaptor) + loss: additionally use sentence style loss + """ + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + datasets = list() + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_RAVDESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "ravdess"), + lang="en", + save_imgs=False)) + + train_set = ConcatDataset(datasets) + + if "_xvect" in name: + print(f"Loading xvect embeddings from {os.path.join(PREPROCESSING_DIR, 'xvect_emomulti', 'xvect.pt')}") + path_to_xvect = torch.load(os.path.join(PREPROCESSING_DIR, "xvect_emomulti", "xvect.pt"), map_location='cpu') + else: + path_to_xvect = None + + if "_ecapa" in name: + print(f"Loading ecapa embeddings from {os.path.join(PREPROCESSING_DIR, 'ecapa_emomulti', 'ecapa.pt')}") + path_to_ecapa = torch.load(os.path.join(PREPROCESSING_DIR, "ecapa_emomulti", "ecapa.pt"), map_location='cpu') + else: + path_to_ecapa = None + if path_to_ecapa is not None: + path_to_xvect = path_to_ecapa + + if "laser" in name: + embed_type = "laser" + sent_embed_dim = 1024 + if "lealla" in name: + embed_type = "lealla" + sent_embed_dim = 192 + if "para" in name: + embed_type = "para" + sent_embed_dim = 768 + if "mpnet" in name: + embed_type = "mpnet" + sent_embed_dim = 768 + if "bertcls" in name: + embed_type = "bertcls" + sent_embed_dim = 768 + if "bertlm" in name: + embed_type = "bertlm" + sent_embed_dim = 768 + if "emoBERTcls" in name: + embed_type = "emoBERTcls" + sent_embed_dim = 768 + + print(f'Loading sentence embeddings from {os.path.join(PREPROCESSING_DIR, "Yelp", f"emotion_prompts_large_sent_embs_{embed_type}.pt")}') + sent_embs = torch.load(os.path.join(PREPROCESSING_DIR, "Yelp", f"emotion_prompts_large_sent_embs_{embed_type}.pt"), map_location='cpu') + + sent_embed_encoder=False + sent_embed_decoder=False + sent_embed_each=False + sent_embed_postnet=False + concat_sent_style=False + use_concat_projection=False + replace_utt_sent_emb = False + style_sent = False + + lang_embs=None + if "_xvect" in name and "_adapted" not in name: + utt_embed_dim = 512 + elif "_ecapa" in name and "_adapted" not in name: + utt_embed_dim = 192 + else: + utt_embed_dim = 64 + + if "a01" in name: + sent_embed_encoder=True + if "a02" in name: + sent_embed_encoder=True + sent_embed_decoder=True + if "a03" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_postnet=True + if "a04" in name: + sent_embed_encoder=True + sent_embed_each=True + if "a05" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_each=True + if "a06" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_each=True + sent_embed_postnet=True + if "a07" in name: + concat_sent_style=True + use_concat_projection=True + if "a08" in name: + concat_sent_style=True + if "a09" in name: + sent_embed_encoder=True + sent_embed_decoder=True + sent_embed_each=True + sent_embed_postnet=True + concat_sent_style=True + use_concat_projection=True + if "a10" in name: + lang_embs = None + utt_embed_dim = 192 + sent_embed_dim = None + replace_utt_sent_emb = True + if "a11" in name: + sent_embed_encoder=True + concat_sent_style=True + use_concat_projection=True + if "a12" in name: + sent_embed_encoder=True + style_sent=True + if "noadapt" in name and "adapted" not in name: + utt_embed_dim = 768 + if "a13" in name: + style_sent=True + utt_embed_dim = sent_embed_dim + if "noadapt" in name and "adapted" not in name: + utt_embed_dim = 768 + + + model = ToucanTTS(lang_embs=lang_embs, + utt_embed_dim=utt_embed_dim, + sent_embed_dim=64 if "adapted" in name else sent_embed_dim, + sent_embed_adaptation="noadapt" not in name, + sent_embed_encoder=sent_embed_encoder, + sent_embed_decoder=sent_embed_decoder, + sent_embed_each=sent_embed_each, + sent_embed_postnet=sent_embed_postnet, + concat_sent_style=concat_sent_style, + use_concat_projection=use_concat_projection, + use_sent_style_loss="loss" in name, + pre_embed="_pre" in name, + style_sent=style_sent, + static_speaker_embed="_static" in name) + + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=16, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=os.path.join(MODELS_DIR, "EmoMulti_Embedding", "embedding_function.pt"), + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + sent_embs=sent_embs, + random_emb=True, + emovdb=True, + replace_utt_sent_emb=replace_utt_sent_emb, + use_adapted_embs="adapted" in name, + path_to_xvect=path_to_xvect, + static_speaker_embed="_static" in name) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_Sent_Finetuning.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Sent_Finetuning.py new file mode 100644 index 00000000..8d18a922 --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Sent_Finetuning.py @@ -0,0 +1,112 @@ +import time + +import torch +import wandb +from torch.utils.data import ConcatDataset + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device(f"cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_Sent_Finetuning_2" + print("sent finetuning") + + ''' + concat speaker embedding and sentence embedding + input for encoder, pitch, energy, variance predictors and decoder + ''' + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + datasets = list() + + ''' + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_EmoV_DB_Speaker(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "emovdb_speaker"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_CREMA_D(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "cremad"), + lang="en", + save_imgs=False)) + ''' + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_RAVDESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "ravdess"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_ESDS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "esds"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_TESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "tess"), + lang="en", + save_imgs=False)) + + train_set = ConcatDataset(datasets) + + print(f'Loading sentence embeddings from {os.path.join(PREPROCESSING_DIR, "Yelp", f"emotion_prompts_balanced_10000_sent_embs_emoBERTcls.pt")}') + emotion_sent_embs = torch.load(os.path.join(PREPROCESSING_DIR, "Yelp", f"emotion_prompts_balanced_10000_sent_embs_emoBERTcls.pt"), map_location='cpu') + + if "_xvect" in name: + print(f"Loading xvect embeddings from {os.path.join(PREPROCESSING_DIR, 'xvect_all2', 'xvect.pt')}") + path_to_xvect = torch.load(os.path.join(PREPROCESSING_DIR, "xvect_all2", "xvect.pt"), map_location='cpu') + else: + path_to_xvect = None + + model = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + sent_embed_dim=768, + static_speaker_embed=True) + + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=32, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=None, + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + emotion_sent_embs=emotion_sent_embs, + path_to_xvect=None, + static_speaker_embed=True, + steps=200000) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_Sent_Pretraining.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Sent_Pretraining.py new file mode 100644 index 00000000..574797de --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_Sent_Pretraining.py @@ -0,0 +1,160 @@ +import time + +import torch +import wandb +from torch.utils.data import ConcatDataset + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device(f"cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_Sent_Pretraining_2" + print("pretraining") + + ''' + concat speaker embedding and sentence embedding + input for encoder, pitch, energy, variance predictors and decoder + ''' + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + datasets = list() + + try: + transcript_dict_ljspeech = torch.load(os.path.join(PREPROCESSING_DIR, "ljspeech", "path_to_transcript_dict.pt"), map_location='cpu') + except FileNotFoundError: + transcript_dict_ljspeech = build_path_to_transcript_dict_ljspeech() + + datasets.append(prepare_fastspeech_corpus(transcript_dict=transcript_dict_ljspeech, + corpus_dir=os.path.join(PREPROCESSING_DIR, "ljspeech"), + lang="en", + save_imgs=False)) + + try: + transcript_dict_librittsr = torch.load(os.path.join(PREPROCESSING_DIR, "librittsr", "path_to_transcript_dict.pt"), map_location='cpu') + except FileNotFoundError: + transcript_dict_librittsr = build_path_to_transcript_dict_libritts_all_clean() + + datasets.append(prepare_fastspeech_corpus(transcript_dict=transcript_dict_librittsr, + corpus_dir=os.path.join(PREPROCESSING_DIR, "librittsr"), + lang="en", + save_imgs=False)) + ''' + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_EmoV_DB_Speaker(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "emovdb_speaker"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_CREMA_D(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "cremad"), + lang="en", + save_imgs=False)) + ''' + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_RAVDESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "ravdess"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_ESDS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "esds"), + lang="en", + save_imgs=False)) + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_TESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "tess"), + lang="en", + save_imgs=False)) + + train_set = ConcatDataset(datasets) + + print(f'Loading sentence embeddings from {os.path.join(PREPROCESSING_DIR, "ljspeech", "sent_embs_emoBERTcls.pt")}.') + sent_embs_lj = torch.load(os.path.join(PREPROCESSING_DIR, "ljspeech", "sent_embs_emoBERTcls.pt"), map_location='cpu') + + print(f'Loading sentence embeddings from {os.path.join(PREPROCESSING_DIR, "librittsr", "sent_embs_emoBERTcls.pt")}.') + sent_embs_libri = torch.load(os.path.join(PREPROCESSING_DIR, "librittsr", "sent_embs_emoBERTcls.pt"), map_location='cpu') + + sent_embs = sent_embs_libri | sent_embs_lj + + print(f'Loading sentence embeddings from {os.path.join(PREPROCESSING_DIR, "Yelp", f"emotion_prompts_balanced_10000_sent_embs_emoBERTcls.pt")}') + emotion_sent_embs = torch.load(os.path.join(PREPROCESSING_DIR, "Yelp", f"emotion_prompts_balanced_10000_sent_embs_emoBERTcls.pt"), map_location='cpu') + + if "_xvect" in name: + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "xvect_all2", "xvect.pt")): + print("Extracting xvect from audio") + os.makedirs(os.path.join(PREPROCESSING_DIR, "xvect_all2"), exist_ok=True) + import torchaudio + from speechbrain.pretrained import EncoderClassifier + classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb", savedir="./Models/Embedding/spkrec-xvect-voxceleb", run_opts={"device": device}) + path_to_xvect = {} + for index in tqdm(range(len(train_set))): + path = train_set[index][10] + wave, sr = torchaudio.load(path) + # mono + wave = torch.mean(wave, dim=0, keepdim=True) + # resampling + wave = torchaudio.functional.resample(wave, orig_freq=sr, new_freq=16000) + wave = wave.squeeze(0) + embedding = classifier.encode_batch(wave).squeeze(0).squeeze(0) + path_to_xvect[path] = embedding + torch.save(path_to_xvect, os.path.join(PREPROCESSING_DIR, "xvect_all2", "xvect.pt")) + del classifier + else: + print(f"Loading xvect embeddings from {os.path.join(PREPROCESSING_DIR, 'xvect_all2', 'xvect.pt')}") + path_to_xvect = torch.load(os.path.join(PREPROCESSING_DIR, "xvect_all2", "xvect.pt"), map_location='cpu') + else: + path_to_xvect = None + + model = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + sent_embed_dim=768, + static_speaker_embed=True) + + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=[train_set], + device=device, + save_directory=save_dir, + batch_size=32, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=None, + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + sent_embs=sent_embs, + emotion_sent_embs=emotion_sent_embs, + path_to_xvect=None, + static_speaker_embed=True, + steps=120000) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/TrainingInterfaces/TrainingPipelines/ToucanTTS_TESS.py b/TrainingInterfaces/TrainingPipelines/ToucanTTS_TESS.py new file mode 100644 index 00000000..084af90c --- /dev/null +++ b/TrainingInterfaces/TrainingPipelines/ToucanTTS_TESS.py @@ -0,0 +1,73 @@ +import time + +import torch +import wandb +from torch.utils.data import ConcatDataset +from tqdm import tqdm + +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS +from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.toucantts_train_loop_arbiter import train_loop +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * +from Utility.storage_config import MODELS_DIR +from Utility.storage_config import PREPROCESSING_DIR + + +def run(gpu_id, resume_checkpoint, finetune, model_dir, resume, use_wandb, wandb_resume_id): + if gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" + device = torch.device("cuda") + + torch.manual_seed(131714) + random.seed(131714) + torch.random.manual_seed(131714) + + print("Preparing") + + name = "ToucanTTS_Baseline_TESS" + print("tess") + + if model_dir is not None: + save_dir = model_dir + else: + save_dir = os.path.join(MODELS_DIR, name) + os.makedirs(save_dir, exist_ok=True) + + datasets = list() + + datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_TESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "tess"), + lang="en", + save_imgs=False)) + + train_set = ConcatDataset(datasets) + + model = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + static_speaker_embed=True) + if use_wandb: + wandb.init( + name=f"{name}_{time.strftime('%Y%m%d-%H%M%S')}" if wandb_resume_id is None else None, + id=wandb_resume_id, # this is None if not specified in the command line arguments. + resume="must" if wandb_resume_id is not None else None) + print("Training model") + train_loop(net=model, + datasets=train_set, + device=device, + save_directory=save_dir, + batch_size=8, + eval_lang="en", + path_to_checkpoint=resume_checkpoint, + path_to_embed_model=None, + fine_tune=finetune, + resume=resume, + use_wandb=use_wandb, + path_to_xvect=None, + static_speaker_embed=True) + if use_wandb: + wandb.finish() \ No newline at end of file diff --git a/Utility/Scorer.py b/Utility/Scorer.py index b7257a48..40359129 100644 --- a/Utility/Scorer.py +++ b/Utility/Scorer.py @@ -19,27 +19,39 @@ from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.ToucanTTS import ToucanTTS from Utility.corpus_preparation import prepare_fastspeech_corpus from Utility.storage_config import MODELS_DIR +from Utility.utils import get_speakerid_from_path_all, get_speakerid_from_path class AlignmentScorer: def __init__(self, path_to_aligner_model, device): self.path_to_score = dict() + self.path_to_id = dict() self.device = device self.nans = list() + self.nan_indexes = list() self.aligner = Aligner() self.aligner.load_state_dict(torch.load(path_to_aligner_model, map_location='cpu')["asr_model"]) self.aligner.to(self.device) + self.datapoints = None + self.path_to_aligner_dataset = None def score(self, path_to_aligner_dataset): """ call this to update the path_to_score dict with scores for this dataset """ datapoints = torch.load(path_to_aligner_dataset, map_location='cpu') + self.path_to_aligner_dataset = path_to_aligner_dataset + self.datapoints = datapoints[0] + self.norm_waves = datapoints[1] + self.speaker_embeddings = datapoints[2] + self.filepaths = datapoints[3] dataset = datapoints[0] filepaths = datapoints[3] self.nans = list() + self.nan_indexes = list() self.path_to_score = dict() + self.path_to_id = dict() for index in tqdm(range(len(dataset))): text = dataset[index][0] melspec = dataset[index][2] @@ -54,11 +66,14 @@ def score(self, path_to_aligner_dataset): return_ctc=True) if math.isnan(ctc_loss): self.nans.append(filepaths[index]) + self.nan_indexes.append(index) self.path_to_score[filepaths[index]] = ctc_loss + self.path_to_id[filepaths[index]] = index if len(self.nans) > 0: print("The following filepaths had an infinite loss:") for path in self.nans: print(path) + self.save_scores() def show_samples_with_highest_loss(self, n=-1): """ @@ -75,13 +90,50 @@ def show_samples_with_highest_loss(self, n=-1): if index < n or n == -1: print(f"Loss: {round(self.path_to_score[path], 3)} - Path: {path}") + def save_scores(self): + if self.path_to_score is None: + print("Please run the scoring first.") + else: + torch.save((self.path_to_score, self.path_to_id, self.nan_indexes), + os.path.join(os.path.dirname(self.path_to_aligner_dataset), 'alignment_scores.pt')) + + def remove_samples_with_highest_loss(self, path_to_aligner_dataset, n=10): + if self.datapoints is None: + self.path_to_aligner_dataset = path_to_aligner_dataset + datapoints = torch.load(self.path_to_aligner_dataset, map_location='cpu') + self.datapoints = datapoints[0] + self.norm_waves = datapoints[1] + self.speaker_embeddings = datapoints[2] + self.filepaths = datapoints[3] + try: + alignment_scores = torch.load(os.path.join(os.path.dirname(self.path_to_aligner_dataset), 'alignment_scores.pt'), map_location='cpu') + self.path_to_score = alignment_scores[0] + self.path_to_id = alignment_scores[1] + self.nan_indexes = alignment_scores[2] + except FileNotFoundError: + print("Please run the scoring first.") + return + remove_ids = list() + remove_ids.extend(self.nan_indexes) + for index, path in enumerate(sorted(self.path_to_score, key=self.path_to_score.get, reverse=True)): + if index < n: + remove_ids.append(self.path_to_id[path]) + for remove_id in sorted(remove_ids, reverse=True): + self.datapoints.pop(remove_id) + self.norm_waves.pop(remove_id) + self.speaker_embeddings.pop(remove_id) + self.filepaths.pop(remove_id) + torch.save((self.datapoints, self.norm_waves, self.speaker_embeddings, self.filepaths), + self.path_to_aligner_dataset) + print("Dataset updated!") class TTSScorer: def __init__(self, path_to_model, device, - path_to_embedding_checkpoint=os.path.join(MODELS_DIR, "Embedding", "embedding_function.pt") + path_to_embedding_checkpoint=None, + static_speaker_embed=False, ): self.device = device self.path_to_score = dict() @@ -98,13 +150,21 @@ def __init__(self, self.tts = ToucanTTS(lang_embs=None) self.tts.load_state_dict(weights) except RuntimeError: - self.tts = ToucanTTS(lang_embs=None, utt_embed_dim=None) - self.tts.load_state_dict(weights) - self.style_embedding_function = StyleEmbedding().to(device) - check_dict = torch.load(path_to_embedding_checkpoint, map_location=device) - self.style_embedding_function.load_state_dict(check_dict["style_emb_func"]) + try: + self.tts = ToucanTTS(lang_embs=None, utt_embed_dim=None) + self.tts.load_state_dict(weights) + except RuntimeError: + self.tts = ToucanTTS(lang_embs=None, utt_embed_dim=512, static_speaker_embed=True) + self.tts.load_state_dict(weights) + if path_to_embedding_checkpoint is not None: + self.style_embedding_function = StyleEmbedding().to(device) + check_dict = torch.load(path_to_embedding_checkpoint, map_location=device) + self.style_embedding_function.load_state_dict(check_dict["style_emb_func"]) + self.style_embedding_function.to(device) + else: + self.style_embedding_function = None self.tts.to(self.device) - self.style_embedding_function.to(device) + self.static_speaker_embed = static_speaker_embed self.nans_removed = False self.current_dset = None @@ -118,23 +178,38 @@ def score(self, path_to_toucantts_dataset, lang_id): self.nan_indexes = list() self.path_to_score = dict() self.path_to_id = dict() + if self.static_speaker_embed: + with open("/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/librittsr/libri_speakers.txt") as f: + libri_speakers = sorted([int(line.rstrip()) for line in f]) for index in tqdm(range(len(dataset.datapoints))): text, text_len, spec, spec_len, duration, energy, pitch, embed, filepath = dataset.datapoints[index] - style_embedding = self.style_embedding_function(batch_of_spectrograms=spec.unsqueeze(0).to(self.device), - batch_of_spectrogram_lengths=spec_len.unsqueeze(0).to(self.device)) + if self.style_embedding_function is not None: + style_embedding = self.style_embedding_function(batch_of_spectrograms=spec.unsqueeze(0).to(self.device), + batch_of_spectrogram_lengths=spec_len.unsqueeze(0).to(self.device)) + else: + style_embedding = None + if self.static_speaker_embed: + speaker_id = torch.LongTensor([get_speakerid_from_path(filepath, libri_speakers)]).to(self.device) + else: + speaker_id = None try: - l1_loss, duration_loss, pitch_loss, energy_loss, glow_loss = self.tts(text_tensors=text.unsqueeze(0).to(self.device), - text_lengths=text_len.to(self.device), - gold_speech=spec.unsqueeze(0).to(self.device), - speech_lengths=spec_len.to(self.device), - gold_durations=duration.unsqueeze(0).to(self.device), - gold_pitch=pitch.unsqueeze(0).to(self.device), - gold_energy=energy.unsqueeze(0).to(self.device), - utterance_embedding=style_embedding.to(self.device), - lang_ids=get_language_id(lang_id).unsqueeze(0).to(self.device), - return_mels=False, - run_glow=False) - loss = l1_loss + duration_loss + pitch_loss + energy_loss # we omit the glow loss + l1_loss, \ + duration_loss, \ + pitch_loss, \ + energy_loss, \ + glow_loss = self.tts(text_tensors=text.unsqueeze(0).to(self.device), + text_lengths=text_len.to(self.device), + gold_speech=spec.unsqueeze(0).to(self.device), + speech_lengths=spec_len.to(self.device), + gold_durations=duration.unsqueeze(0).to(self.device), + gold_pitch=pitch.unsqueeze(0).to(self.device), + gold_energy=energy.unsqueeze(0).to(self.device), + utterance_embedding=style_embedding.to(self.device) if style_embedding is not None else None, + speaker_id=speaker_id, + lang_ids=get_language_id(lang_id).unsqueeze(0).to(self.device), + return_mels=False, + run_glow=False) + loss = l1_loss + duration_loss + pitch_loss + energy_loss except TypeError: loss = torch.tensor(torch.nan) if torch.isnan(loss): diff --git a/Utility/path_to_transcript_dicts.py b/Utility/path_to_transcript_dicts.py index 35fd4904..37a29227 100644 --- a/Utility/path_to_transcript_dicts.py +++ b/Utility/path_to_transcript_dicts.py @@ -2,6 +2,10 @@ import os import random from pathlib import Path +from tqdm import tqdm + +import torch +from Utility.storage_config import PREPROCESSING_DIR def limit_to_n(path_to_transcript_dict, n=40000): @@ -221,9 +225,9 @@ def build_path_to_transcript_dict_libritts(): def build_path_to_transcript_dict_libritts_all_clean(): - path_train = "/mount/resources/speech/corpora/LibriTTS/all_clean" + path_train = "/mount/resources/speech/corpora/LibriTTS_R/" # using all files from the "clean" subsets from LibriTTS-R https://arxiv.org/abs/2305.18802 path_to_transcript = dict() - for speaker in os.listdir(path_train): + for speaker in tqdm(os.listdir(path_train)): for chapter in os.listdir(os.path.join(path_train, speaker)): for file in os.listdir(os.path.join(path_train, speaker, chapter)): if file.endswith("normalized.txt"): @@ -231,8 +235,87 @@ def build_path_to_transcript_dict_libritts_all_clean(): transcript = tf.read() wav_file = file.split(".")[0] + ".wav" path_to_transcript[os.path.join(path_train, speaker, chapter, wav_file)] = transcript + torch.save(path_to_transcript, os.path.join(PREPROCESSING_DIR, "librittsr", "path_to_transcript_dict.pt")) + return path_to_transcript + +def build_path_to_transcript_dict_promptspeech(): + # PromptSpeech uses a subset of LibriTTS + path_train = "/mount/resources/speech/corpora/LibriTTS/all_clean" + import pandas as pd + # TODO make promptspeech train file available in resources + promptspeech_train_df = pd.read_csv('/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/PromptSpeech_training.csv') + promptspeech_files = list(promptspeech_train_df['item_name']) + path_to_transcript = dict() + for speaker in tqdm(os.listdir(path_train)): + for chapter in os.listdir(os.path.join(path_train, speaker)): + for file in os.listdir(os.path.join(path_train, speaker, chapter)): + if file.endswith("normalized.txt"): + # only process files in filenames + if file.split(".")[0] in promptspeech_files: + with open(os.path.join(path_train, speaker, chapter, file), 'r', encoding='utf8') as tf: + transcript = tf.read() + wav_file = file.split(".")[0] + ".wav" + path_to_transcript[os.path.join(path_train, speaker, chapter, wav_file)] = transcript return path_to_transcript +def build_sent_to_prompt_dict_promptspeech(): + # PromptSpeech uses a subset of LibriTTS + path_train = "/mount/resources/speech/corpora/LibriTTS/all_clean" + import pandas as pd + # TODO make promptspeech train file available in resources + promptspeech_train_df = pd.read_csv('/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/PromptSpeech_training.csv') + promptspeech_files = list(promptspeech_train_df['item_name']) + promptspeech_prompts = list(promptspeech_train_df['style_prompt']) + sent_to_prompt = dict() + for speaker in tqdm(os.listdir(path_train)): + for chapter in os.listdir(os.path.join(path_train, speaker)): + for file in os.listdir(os.path.join(path_train, speaker, chapter)): + if file.endswith("normalized.txt"): + # only process files in filenames + if file.split(".")[0] in promptspeech_files: + with open(os.path.join(path_train, speaker, chapter, file), 'r', encoding='utf8') as tf: + transcript = tf.read() + prompt = promptspeech_prompts[promptspeech_files.index(file.split(".")[0])] + sent_to_prompt[transcript] = prompt + return sent_to_prompt + +def build_path_to_transcript_dict_emovdb_sam(): + import csv, glob + path_train = "/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/EmoVDB_Sam" + with open("/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/EmoVDB_Sam/transcripts.csv", 'r') as f: + reader = csv.reader(f) + id_to_transcript_dict = {rows[0]:rows[1] for rows in reader} + path_to_transcript = dict() + for file in glob.glob(os.path.join(path_train, "*.wav")): + sentence_id = os.path.splitext(os.path.basename(file))[0].split("-16bit")[0].split("_")[-1] + path_to_transcript[file] = id_to_transcript_dict[sentence_id] + return path_to_transcript + +def build_path_to_prompt_dict_emovdb_sam(): + import csv, glob, json + path_train = "/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/EmoVDB_Sam" + path_to_prompt = dict() + with open("/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/EmotionLines/Friends/friends_train.json") as f: + d = json.load(f) + emotion_prompts = {"amused": list(), "anger": list(), "disgust": list(), "neutral": list()} + for dialogue in d: + for utterance in dialogue: + prompt = utterance["utterance"] + emotion = utterance["emotion"] + if emotion == "joy": + emotion_prompts["amused"].append(prompt) + if emotion == "anger": + emotion_prompts["anger"].append(prompt) + if emotion == "disgust": + emotion_prompts["disgust"].append(prompt) + if emotion == "neutral": + emotion_prompts["neutral"].append(prompt) + for file in glob.glob(os.path.join(path_train, "*.wav")): + emotion = os.path.splitext(os.path.basename(file))[0].split("-16bit")[0].split("_")[0].lower() + prompt = random.choice(emotion_prompts[emotion]) + path_to_prompt[file] = prompt + return path_to_prompt + def build_path_to_transcript_dict_libritts_other500(): path_train = "/mount/resources/asr-data/LibriTTS/train-other-500" @@ -250,11 +333,22 @@ def build_path_to_transcript_dict_libritts_other500(): def build_path_to_transcript_dict_ljspeech(): path_to_transcript = dict() - for transcript_file in os.listdir("/mount/resources/speech/corpora/LJSpeech/16kHz/txt"): + for transcript_file in tqdm(os.listdir("/mount/resources/speech/corpora/LJSpeech/16kHz/txt")): with open("/mount/resources/speech/corpora/LJSpeech/16kHz/txt/" + transcript_file, 'r', encoding='utf8') as tf: transcript = tf.read() wav_path = "/mount/resources/speech/corpora/LJSpeech/16kHz/wav/" + transcript_file.rstrip(".txt") + ".wav" path_to_transcript[wav_path] = transcript + + torch.save(limit_to_n(path_to_transcript), os.path.join(PREPROCESSING_DIR, "ljspeech", "path_to_transcript_dict.pt")) + return limit_to_n(path_to_transcript) + +def build_path_to_transcript_dict_3xljspeech(): + path_to_transcript = dict() + for transcript_file in os.listdir("/mount/arbeitsdaten/synthesis/attention_projects/LJSpeech_3xlong_stripped/txt_long"): + with open("/mount/arbeitsdaten/synthesis/attention_projects/LJSpeech_3xlong_stripped/txt_long/" + transcript_file, 'r', encoding='utf8') as tf: + transcript = tf.read() + wav_path = "/mount/arbeitsdaten/synthesis/attention_projects/LJSpeech_3xlong_stripped/wav_long/" + transcript_file.rstrip(".txt") + ".wav" + path_to_transcript[wav_path] = transcript return limit_to_n(path_to_transcript) @@ -525,6 +619,68 @@ def build_path_to_transcript_dict_ESDS(): path_to_transcript_dict[f"{root}/{speaker_dir}/{emo_dir}/{filename}.wav"] = text return path_to_transcript_dict +def build_path_to_transcript_dict_CREMA_D(): + identifier_to_sent = {"IEO": "It's eleven o'clock.", + "TIE": "That is exactly what happened.", + "IOM": "I'm on my way to the meeting.", + "IWW": "I wonder what this is about.", + "TAI": "The airplane is almost full.", + "MTI": "Maybe tomorrow it will be cold.", + "IWL": "I would like a new alarm clock.", + "ITH": "I think, I have a doctor's appointment.", + "DFA": "Don't forget a jacket.", + "ITS": "I think, I've seen this before.", + "TSI": "The surface is slick.", + "WSI": "We'll stop in a couple of minutes."} + root = "/mount/resources/speech/corpora/CREMA_D/" + path_to_transcript = dict() + for file in os.listdir(root): + if file.endswith(".wav"): + path_to_transcript[root + file] = identifier_to_sent[file.split("_")[1]] + return path_to_transcript + +def build_path_to_transcript_dict_TESS(): + root = "/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/TESS" + path_to_transcript = dict() + for subdir in os.listdir(root): + for file in os.listdir(os.path.join(root, subdir)): + if file.endswith(".wav"): + word = file.split('_')[1] + transcript = f"Say the word {word}." + path_to_transcript[os.path.join(root, subdir, file)] = transcript + return path_to_transcript + +def build_path_to_transcript_dict_EmoV_DB(): + root = "/mount/resources/speech/corpora/EmoV_DB/" + path_to_transcript = dict() + with open(os.path.join(root, "labels.txt"), "r", encoding="utf8") as file: + lookup = file.read() + identifier_to_sent = dict() + for line in lookup.split("\n"): + if line.strip() != "": + identifier_to_sent[line.split()[0]] = " ".join(line.split()[1:]) + for file in os.listdir(root): + if file.endswith(".wav"): + path_to_transcript[root + file] = identifier_to_sent[file[-14:-10]] + return path_to_transcript + +def build_path_to_transcript_dict_EmoV_DB_Speaker(): + import csv, glob + root = "/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/EmoVDB" + with open(os.path.join(root, "labels.txt"), "r", encoding="utf8") as file: + lookup = file.read() + id_to_transcript_dict = dict() + for line in lookup.split("\n"): + if line.strip() != "": + id_to_transcript_dict[line.split()[0]] = " ".join(line.split()[1:]) + path_to_transcript = dict() + for speaker_dir in os.listdir(root): + if speaker_dir != "labels.txt": + for audio_file in os.listdir(os.path.join(root, speaker_dir)): + sentence_id = os.path.splitext(os.path.basename(audio_file))[0].split("-16bit")[0].split("_")[-1] + path_to_transcript[os.path.join(root, speaker_dir, audio_file)] = id_to_transcript_dict[sentence_id] + return path_to_transcript + def build_path_to_transcript_dict_blizzard_2013(): path_to_transcript = dict() diff --git a/Utility/sent_emb_extraction.py b/Utility/sent_emb_extraction.py new file mode 100644 index 00000000..3bb660a5 --- /dev/null +++ b/Utility/sent_emb_extraction.py @@ -0,0 +1,29 @@ +from tqdm import tqdm +import torch +import os +from Utility.storage_config import PREPROCESSING_DIR + +def extract_sent_embs(train_set, sent_emb_extractor, promptspeech=False, emovdb=False): + sent_embs = {} + print("Extracting sentence embeddings.") + if promptspeech: + from Utility.path_to_transcript_dicts import build_sent_to_prompt_dict_promptspeech + sent_to_prompt_dict = build_sent_to_prompt_dict_promptspeech() + if emovdb: + from Utility.path_to_transcript_dicts import build_path_to_prompt_dict_emovdb_sam + path_to_prompt_dict = build_path_to_prompt_dict_emovdb_sam() + for index in tqdm(range(len(train_set))): + sentence = train_set[index][9] + if promptspeech: + prompt = sent_to_prompt_dict[sentence] + sent_emb = sent_emb_extractor.encode(sentences=[prompt]).squeeze() + sent_embs[sentence] = sent_emb + elif emovdb: + filename = train_set[index][10] + prompt = path_to_prompt_dict[filename] + sent_emb = sent_emb_extractor.encode(sentences=[prompt]).squeeze() + sent_embs[filename] = sent_emb + else: + sent_emb = sent_emb_extractor.encode(sentences=[sentence]).squeeze() + sent_embs[sentence] = sent_emb + return sent_embs \ No newline at end of file diff --git a/Utility/utils.py b/Utility/utils.py index 748ba5ae..3f9efb01 100644 --- a/Utility/utils.py +++ b/Utility/utils.py @@ -200,23 +200,49 @@ def plot_progress_spec_toucantts(net, step, lang, default_emb, + static_speaker_embed=False, + sent_embs=None, + word_embedding_extractor=None, run_postflow=True): tf = ArticulatoryCombinedTextFrontend(language=lang) sentence = tf.get_example_sentence(lang=lang) if sentence is None: return None + if sent_embs is not None: + try: + sentence_embedding = sent_embs[sentence] + except KeyError: + sentence_embedding = sent_embs["neutral"][0] + else: + sentence_embedding = None + if static_speaker_embed: + speaker_id = torch.LongTensor([0]) + else: + speaker_id = None + if word_embedding_extractor is not None: + word_embedding, _ = word_embedding_extractor.encode([sentence]) + word_embedding = word_embedding.squeeze() + else: + word_embedding = None phoneme_vector = tf.string_to_tensor(sentence).squeeze(0).to(device) if run_postflow: spec_before, spec_after, durations, pitch, energy = net.inference(text=phoneme_vector, return_duration_pitch_energy=True, utterance_embedding=default_emb, + speaker_id=speaker_id, + sentence_embedding=sentence_embedding, + word_embedding=word_embedding, lang_id=get_language_id(lang).to(device), run_postflow=run_postflow) else: spec_before, spec_after, durations, pitch, energy = net.inference(text=phoneme_vector, return_duration_pitch_energy=True, utterance_embedding=default_emb, - lang_id=get_language_id(lang).to(device)) + speaker_id=speaker_id, + sentence_embedding=sentence_embedding, + word_embedding=word_embedding, + lang_id=get_language_id(lang).to(device), + run_postflow=False) spec = spec_before.transpose(0, 1).to("cpu").numpy() duration_splits, label_positions = cumsum_durations(durations.cpu().numpy()) os.makedirs(os.path.join(save_dir, "spec_before"), exist_ok=True) @@ -668,6 +694,287 @@ def curve_smoother(curve): new_curve.append(0) return new_curve +def get_speakerid_from_path(path): + speaker_id = None + if "EmoVDB" in path: + if "bea" in path: + speaker_id = 0 + if "jenie" in path: + speaker_id = 1 + if "josh" in path: + speaker_id = 2 + if "sam" in path: + speaker_id = 3 + if "Emotional_Speech_Dataset_Singapore" in path: + speaker = os.path.split(os.path.split(os.path.dirname(path))[0])[1] + if speaker == "0011": + speaker_id = 0 + if speaker == "0012": + speaker_id = 1 + if speaker == "0013": + speaker_id = 2 + if speaker == "0014": + speaker_id = 3 + if speaker == "0015": + speaker_id = 4 + if speaker == "0016": + speaker_id = 5 + if speaker == "0017": + speaker_id = 6 + if speaker == "0018": + speaker_id = 7 + if speaker == "0019": + speaker_id = 8 + if speaker == "0020": + speaker_id = 9 + if "CREMA_D" in path: + speaker = os.path.basename(path).split('_')[0] + for i, sp_id in enumerate(range(1001, 1092)): + if int(speaker) == sp_id: + speaker_id = i + if "RAVDESS" in path: + speaker = os.path.split(os.path.dirname(path))[1].split('_')[1] + speaker_id = int(speaker) - 1 + + if speaker_id is None: + raise TypeError('speaker id could not be extracted from filename') + + return speaker_id + +def get_emotion_from_path(path): + emotion = None + if "EmoV_DB" in path or "EmoVDB" in path or "EmoVDB_Sam" in path: + emotion = os.path.splitext(os.path.basename(path))[0].split("-16bit")[0].split("_")[0].lower() + if emotion == "amused": + emotion = "joy" + if emotion == "sleepiness": + raise NameError("emotion sleepiness should not be included") + if "CREMA_D" in path: + emotion = os.path.splitext(os.path.basename(path))[0].split('_')[2] + if emotion == "ANG": + emotion = "anger" + if emotion == "DIS": + emotion = "disgust" + if emotion == "FEA": + emotion = "fear" + if emotion == "HAP": + emotion = "joy" + if emotion == "NEU": + emotion = "neutral" + if emotion == "SAD": + emotion = "sadness" + if "Emotional_Speech_Dataset_Singapore" in path: + emotion = os.path.basename(os.path.dirname(path)).lower() + if emotion == "angry": + emotion = "anger" + if emotion == "happy": + emotion = "joy" + if emotion == "sad": + emotion = "sadness" + if "RAVDESS" in path: + emotion = os.path.splitext(os.path.basename(path))[0].split('-')[2] + if emotion == "01": + emotion = "neutral" + if emotion == "02": + raise NameError("emotion calm should not be included") + if emotion == "03": + emotion = "joy" + if emotion == "04": + emotion = "sadness" + if emotion == "05": + emotion = "anger" + if emotion == "06": + emotion = "fear" + if emotion == "07": + emotion = "disgust" + if emotion == "08": + emotion = "surprise" + if "TESS" in path: + emotion = os.path.split(os.path.dirname(path))[1].split('_')[1].lower() + if emotion == "angry": + emotion = "anger" + if emotion == "happy": + emotion = "joy" + if emotion == "sad": + emotion = "sadness" + if "LJSpeech" in path: + emotion = "neutral" + + if emotion is None: + raise TypeError('emotion could not be extracted from filename') + + return emotion + +def get_speakerid_from_path(path, libri_speakers): + speaker_id = None + if "LJSpeech" in path: + speaker_id = 0 + if "EmoVDB" in path: + if "bea" in path: + speaker_id = 0 + if "jenie" in path: + speaker_id = 1 + if "josh" in path: + speaker_id = 2 + if "sam" in path: + speaker_id = 3 + if "Emotional_Speech_Dataset_Singapore" in path: + speaker = os.path.split(os.path.split(os.path.dirname(path))[0])[1] + if speaker == "0011": + speaker_id = 0 + if speaker == "0012": + speaker_id = 1 + if speaker == "0013": + speaker_id = 2 + if speaker == "0014": + speaker_id = 3 + if speaker == "0015": + speaker_id = 4 + if speaker == "0016": + speaker_id = 5 + if speaker == "0017": + speaker_id = 6 + if speaker == "0018": + speaker_id = 7 + if speaker == "0019": + speaker_id = 8 + if speaker == "0020": + speaker_id = 9 + if "CREMA_D" in path: + speaker = os.path.basename(path).split('_')[0] + for i, sp_id in enumerate(range(1001, 1092)): + if int(speaker) == sp_id: + speaker_id = i + if "RAVDESS" in path: + speaker = os.path.split(os.path.dirname(path))[1].split('_')[1] + speaker_id = int(speaker) - 1 + if "LibriTTS_R" in path: + speaker = os.path.split(os.path.split(os.path.dirname(path))[0])[1] + speaker_id = libri_speakers.index(int(speaker)) + if "TESS" in path: + speaker = os.path.split(os.path.dirname(path))[1].split('_')[0] + if speaker == "OAF": + speaker_id = 0 + if speaker == "YAF": + speaker_id = 1 + + if speaker_id is None: + raise TypeError('speaker id could not be extracted from filename') + + return int(speaker_id) + +def get_speakerid_from_path_all(path, libri_speakers): + speaker_id = None + if "LJSpeech" in path: # 1 speaker + # 0 + speaker_id = 0 + if "CREMA_D" in path: # 91 speakers + # 1 - 91 + speaker = os.path.basename(path).split('_')[0] + speaker_id = int(speaker) - 1001 + 1 + if "RAVDESS" in path: # 24 speakers + # 92 - 115 + speaker = os.path.split(os.path.dirname(path))[1].split('_')[1] + speaker_id = int(speaker) - 1 + 1 + 91 + if "EmoVDB" in path: # 4 speakers + # 116 - 119 + if "bea" in path: + speaker_id = 0 + 1 + 91 + 24 + if "jenie" in path: + speaker_id = 1 + 1 + 91 + 24 + if "josh" in path: + speaker_id = 2 + 1 + 91 + 24 + if "sam" in path: + speaker_id = 3 + 1 + 91 + 24 + if "Emotional_Speech_Dataset_Singapore" in path: # 10 speakers + # 120 - 129 + speaker = os.path.split(os.path.split(os.path.dirname(path))[0])[1] + if speaker == "0011": + speaker_id = 0 + 1 + 91 + 24 + 4 + if speaker == "0012": + speaker_id = 1 + 1 + 91 + 24 + 4 + if speaker == "0013": + speaker_id = 2 + 1 + 91 + 24 + 4 + if speaker == "0014": + speaker_id = 3 + 1 + 91 + 24 + 4 + if speaker == "0015": + speaker_id = 4 + 1 + 91 + 24 + 4 + if speaker == "0016": + speaker_id = 5 + 1 + 91 + 24 + 4 + if speaker == "0017": + speaker_id = 6 + 1 + 91 + 24 + 4 + if speaker == "0018": + speaker_id = 7 + 1 + 91 + 24 + 4 + if speaker == "0019": + speaker_id = 8 + 1 + 91 + 24 + 4 + if speaker == "0020": + speaker_id = 9 + 1 + 91 + 24 + 4 + if "LibriTTS_R" in path: # 1230 speakers + # 130 - 1359 + speaker = os.path.split(os.path.split(os.path.dirname(path))[0])[1] + speaker_id = libri_speakers.index(int(speaker)) + 1 + 91 + 24 + 4 + 10 + if "TESS" in path: # 2 speakers + # 1360 - 1361 + speaker = os.path.split(os.path.dirname(path))[1].split('_')[0] + if speaker == "OAF": + speaker_id = 0 + 1 + 91 + 24 + 4 + 10 + 1230 + if speaker == "YAF": + speaker_id = 1 + 1 + 91 + 24 + 4 + 10 + 1230 + + if speaker_id is None: + raise TypeError('speaker id could not be extracted from filename') + + return speaker_id + +def get_speakerid_from_path_all2(path, libri_speakers): + speaker_id = None + if "LJSpeech" in path: # 1 speaker + # 0 + speaker_id = 0 + if "RAVDESS" in path: # 24 speakers + # 1 - 24 + speaker = os.path.split(os.path.dirname(path))[1].split('_')[1] + speaker_id = int(speaker) - 1 + 1 + if "Emotional_Speech_Dataset_Singapore" in path: # 10 speakers + # 25 - 34 + speaker = os.path.split(os.path.split(os.path.dirname(path))[0])[1] + if speaker == "0011": + speaker_id = 0 + 1 + 24 + if speaker == "0012": + speaker_id = 1 + 1 + 24 + if speaker == "0013": + speaker_id = 2 + 1 + 24 + if speaker == "0014": + speaker_id = 3 + 1 + 24 + if speaker == "0015": + speaker_id = 4 + 1 + 24 + if speaker == "0016": + speaker_id = 5 + 1 + 24 + if speaker == "0017": + speaker_id = 6 + 1 + 24 + if speaker == "0018": + speaker_id = 7 + 1 + 24 + if speaker == "0019": + speaker_id = 8 + 1 + 24 + if speaker == "0020": + speaker_id = 9 + 1 + 24 + if "LibriTTS_R" in path: # 1230 speakers + # 35 - 1264 + speaker = os.path.split(os.path.split(os.path.dirname(path))[0])[1] + speaker_id = libri_speakers.index(int(speaker)) + 1 + 24 + 10 + if "TESS" in path: # 2 speakers + # 1265, 1266 + speaker = os.path.split(os.path.dirname(path))[1].split('_')[0] + if speaker == "OAF": + speaker_id = 0 + 1 + 24 + 10 + 1230 + if speaker == "YAF": + speaker_id = 1 + 1 + 24 + 10 + 1230 + + if speaker_id is None: + raise TypeError('speaker id could not be extracted from filename') + + return speaker_id + if __name__ == '__main__': data = np.random.randn(50) diff --git a/extract_dailydialogues_sentences.py b/extract_dailydialogues_sentences.py new file mode 100644 index 00000000..a0e4b1dc --- /dev/null +++ b/extract_dailydialogues_sentences.py @@ -0,0 +1,17 @@ +from Utility.storage_config import PREPROCESSING_DIR +import torch +import os +from datasets import load_dataset + +if __name__ == '__main__': + device = 'cuda:5' + + dataset = load_dataset("daily_dialog", split="train", cache_dir=os.path.join(PREPROCESSING_DIR, 'DailyDialogues')) + id_to_emotion = {0: "neutral", 1: "anger", 2: "disgust", 3: "fear", 4: "joy", 5: "sadness", 6: "surprise"} + emotion_to_sents = emotion_to_sents = {"anger":[], "disgust":[], "fear":[], "joy":[], "neutral":[], "sadness":[], "surprise":[]} + + for dialog, emotions in zip(dataset["dialog"], dataset["emotion"]): + for sent, emotion in zip(dialog, emotions): + emotion_to_sents[id_to_emotion[emotion]].append(sent.strip()) + + torch.save(emotion_to_sents, os.path.join(PREPROCESSING_DIR, "DailyDialogues", "emotion_sentences.")) diff --git a/extract_tales_sentences.py b/extract_tales_sentences.py new file mode 100644 index 00000000..9954e416 --- /dev/null +++ b/extract_tales_sentences.py @@ -0,0 +1,40 @@ +from Utility.storage_config import PREPROCESSING_DIR +from transformers import pipeline +import torch +import os +from tqdm import tqdm + +def data(sentences): + for sent in sentences: + yield sent + +if __name__ == '__main__': + device = 'cuda:5' + + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "Tales", "emotion_sentences_full.pt")): + data_dir = "/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/Tales" + sentences = [] + for author in tqdm(os.listdir(data_dir)): + for file in os.listdir(os.path.join(data_dir, author, "sent")): + with open(os.path.join(data_dir, author, "sent", file)) as f: + sentences.extend([line.rstrip() for line in f]) + print(f"Extracted {len(sentences)} sentences.") + + classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=1, device=device) + emotion_to_sents = {"anger":set(), "disgust":set(), "fear":set(), "joy":set(), "neutral":set(), "sadness":set(), "surprise":set()} + + for i, result in tqdm(enumerate(classifier(data(sentences), truncation=True, max_length=512, padding=True, batch_size=64)), total=len(sentences)): + score = result[0]["score"] + if score > 0.9: + emotion = result[0]["label"] + emotion_to_sents[emotion].add((sentences[i], score)) + for emotion, sents in emotion_to_sents.items(): + emotion_to_sents[emotion] = sorted(list(sents), key=lambda x: x[1], reverse=True) + torch.save(emotion_to_sents, os.path.join(PREPROCESSING_DIR, "Tales", "emotion_sentences_full.pt")) + else: + emotion_to_sents = torch.load(os.path.join(PREPROCESSING_DIR, "Tales", "emotion_sentences_full.pt"), map_location='cpu') + + top_k = dict() + for emotion, sents in emotion_to_sents.items(): + top_k[emotion] = [sent[0] for sent in sents[:20]] + torch.save(top_k, os.path.join(PREPROCESSING_DIR, "Tales", "emotion_sentences_top20.pt")) diff --git a/extract_tales_sentences_annotated.py b/extract_tales_sentences_annotated.py new file mode 100644 index 00000000..e0626c49 --- /dev/null +++ b/extract_tales_sentences_annotated.py @@ -0,0 +1,23 @@ +from Utility.storage_config import PREPROCESSING_DIR +import torch +import os +from tqdm import tqdm +import pandas as pd +import csv + +if __name__ == '__main__': + device = 'cuda:5' + data_dir = "/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/Tales" + id_to_emotion = {"N": "neutral", "A": "anger", "D": "disgust", "F": "fear", "H": "joy", "Sa": "sadness", "Su+": "surprise", "Su-": "surprise"} + emotion_to_sents = emotion_to_sents = {"anger":[], "disgust":[], "fear":[], "joy":[], "neutral":[], "sadness":[], "surprise":[]} + + for author in tqdm(os.listdir(data_dir)): + if not author.endswith(".pt"): + for file in os.listdir(os.path.join(data_dir, author, "emmood")): + df = pd.read_csv(os.path.join(data_dir, author, "emmood", file), sep="\t", header=None, quoting=csv.QUOTE_NONE) + for index, (sent_id, emo, mood, sent) in df.iterrows(): + emotions = emo.split(":") + if emotions[0] == emotions[1]: + emotion_to_sents[id_to_emotion[emotions[0]]].append(sent) + + torch.save(emotion_to_sents, os.path.join(PREPROCESSING_DIR, "Tales", "emotion_sentences.pt")) diff --git a/extract_yelp_prompts.py b/extract_yelp_prompts.py new file mode 100644 index 00000000..1802ca19 --- /dev/null +++ b/extract_yelp_prompts.py @@ -0,0 +1,45 @@ +from Utility.storage_config import PREPROCESSING_DIR +from datasets import load_dataset +from transformers import pipeline +import random +import torch +import os +from tqdm import tqdm +import nltk +nltk.download('punkt') + +def data(sentences): + for sent in sentences: + yield sent + +if __name__ == '__main__': + device = 'cuda:5' + + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_prompts_full.pt")): + yelp = load_dataset("yelp_review_full", split="train", cache_dir=os.path.join(PREPROCESSING_DIR, 'Yelp')) + yelp_sents = [] + for review in tqdm(yelp[:]["text"]): + sentences = nltk.sent_tokenize(review) + for sent in sentences: + if len(sent.split()) < 50: + yelp_sents.append(sent) + + print(f"Extracted {len(yelp_sents)} sentences.") + + classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=1, device=device) + emotion_to_prompts = {"anger":[], "disgust":[], "fear":[], "joy":[], "neutral":[], "sadness":[], "surprise":[]} + + for i, result in tqdm(enumerate(classifier(data(yelp_sents), truncation=True, max_length=512, padding=True, batch_size=256)), total=len(yelp_sents)): + score = result[0]["score"] + if score > 0.8: + emotion = result[0]["label"] + emotion_to_prompts[emotion].append(yelp_sents[i]) + torch.save(emotion_to_prompts, os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_prompts_full.pt")) + else: + emotion_to_prompts = torch.load(os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_prompts_full.pt"), map_location='cpu') + + emotion_to_prompts_balanced = dict() + for emotion, prompts in tqdm(emotion_to_prompts.items()): + emotion_to_prompts_balanced[emotion] = random.sample(prompts, 10000) + + torch.save(emotion_to_prompts_balanced, os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_prompts_balanced_10000.pt")) diff --git a/extract_yelp_sent_embs.py b/extract_yelp_sent_embs.py new file mode 100644 index 00000000..fc214e8d --- /dev/null +++ b/extract_yelp_sent_embs.py @@ -0,0 +1,16 @@ +from Utility.storage_config import PREPROCESSING_DIR +from Preprocessing.sentence_embeddings.EmotionRoBERTaSentenceEmbeddingExtractor import EmotionRoBERTaSentenceEmbeddingExtractor as SentenceEmbeddingExtractor +import torch +import os +from tqdm import tqdm + +if __name__ == '__main__': + device = 'cuda:4' + emotion_prompts = torch.load(os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_prompts_balanced_10000.pt"), map_location='cpu') + sent_emb_extractor = SentenceEmbeddingExtractor(pooling='cls', device=device) + emotion_prompts_sent_embs = {"anger":[], "disgust":[], "fear":[], "joy":[], "neutral":[], "sadness":[], "surprise":[]} + for emotion in tqdm(list(emotion_prompts.keys())): + for prompt in tqdm(emotion_prompts[emotion]): + sent_emb = sent_emb_extractor.encode(sentences=[prompt]).squeeze() + emotion_prompts_sent_embs[emotion].append(sent_emb) + torch.save(emotion_prompts_sent_embs, os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_prompts_balanced_10000_sent_embs_emoBERTcls.pt")) diff --git a/extract_yelp_sent_embs_tsne_eval.py b/extract_yelp_sent_embs_tsne_eval.py new file mode 100644 index 00000000..b01e3cc3 --- /dev/null +++ b/extract_yelp_sent_embs_tsne_eval.py @@ -0,0 +1,18 @@ +from Utility.storage_config import PREPROCESSING_DIR +from Preprocessing.sentence_embeddings.EmotionRoBERTaSentenceEmbeddingExtractor import EmotionRoBERTaSentenceEmbeddingExtractor +from Preprocessing.sentence_embeddings.BERTSentenceEmbeddingExtractor import BERTSentenceEmbeddingExtractor +from Preprocessing.sentence_embeddings.STSentenceEmbeddingExtractor import STSentenceEmbeddingExtractor +import torch +import os +from tqdm import tqdm + +if __name__ == '__main__': + device = 'cuda:6' + emotion_prompts = torch.load(os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_prompts_balanced_10000.pt"), map_location='cpu') + sent_emb_extractor = STSentenceEmbeddingExtractor() + emotion_prompts_sent_embs = {"anger":[], "disgust":[], "fear":[], "joy":[], "neutral":[], "sadness":[], "surprise":[]} + for emotion in tqdm(list(emotion_prompts.keys())): + for prompt in tqdm(emotion_prompts[emotion]): + sent_emb = sent_emb_extractor.encode(sentences=[prompt]).squeeze() + emotion_prompts_sent_embs[emotion].append(sent_emb) + torch.save(emotion_prompts_sent_embs, os.path.join(PREPROCESSING_DIR, "Evaluation", "emotion_prompts_balanced_10000_sent_embs_stpara.pt")) diff --git a/extract_yelp_sentences.py b/extract_yelp_sentences.py new file mode 100644 index 00000000..6e75fe9d --- /dev/null +++ b/extract_yelp_sentences.py @@ -0,0 +1,45 @@ +from Utility.storage_config import PREPROCESSING_DIR +from datasets import load_dataset +from transformers import pipeline +import torch +import os +from tqdm import tqdm +import nltk +nltk.download('punkt') + +def data(sentences): + for sent in sentences: + yield sent + +if __name__ == '__main__': + device = 'cuda:5' + + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_sentences_full.pt")): + yelp = load_dataset("yelp_review_full", split="test", cache_dir=os.path.join(PREPROCESSING_DIR, 'Yelp')) + sentences = [] + for review in tqdm(yelp[:]["text"]): + sents = nltk.sent_tokenize(review) + for sent in sents: + if len(sent.split()) < 50: + sentences.append(sent) + + print(f"Extracted {len(sentences)} sentences.") + + classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=1, device=device) + emotion_to_sents = {"anger":set(), "disgust":set(), "fear":set(), "joy":set(), "neutral":set(), "sadness":set(), "surprise":set()} + + for i, result in tqdm(enumerate(classifier(data(sentences), truncation=True, max_length=512, padding=True, batch_size=256)), total=len(sentences)): + score = result[0]["score"] + if score > 0.9: + emotion = result[0]["label"] + emotion_to_sents[emotion].add((sentences[i], score)) + for emotion, sents in emotion_to_sents.items(): + emotion_to_sents[emotion] = sorted(list(sents), key=lambda x: x[1], reverse=True) + torch.save(emotion_to_sents, os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_sentences_full.pt")) + else: + emotion_to_sents = torch.load(os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_sentences_full.pt"), map_location='cpu') + + top_k = dict() + for emotion, sents in emotion_to_sents.items(): + top_k[emotion] = [sent[0] for sent in sents[:20]] + torch.save(top_k, os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_sentences_top20.pt")) diff --git a/plot_objective_evaluation.py b/plot_objective_evaluation.py new file mode 100644 index 00000000..555dc465 --- /dev/null +++ b/plot_objective_evaluation.py @@ -0,0 +1,423 @@ +import os +from statistics import median, mean +import numpy as np +import scipy.stats as stats + +import torch + +from Utility.storage_config import PREPROCESSING_DIR +from Evaluation.objective_evaluation import * +from Evaluation.plotting import * + +import sys + +EMOTIONS = ["anger", "joy", "neutral", "sadness", "surprise"] + +def get_ratings_per_speaker(data): + ratings = {} + for dataset, speakers in data.items(): + for speaker, emotions in speakers.items(): + if speaker not in ratings: + ratings[speaker] = [] + ratings[speaker].extend(list(emotions.values())) + return ratings + +def get_ratings_per_speaker_original(data): + ratings = {} + for speaker, emotions in data.items(): + ratings[speaker] = list(emotions.values()) + return ratings + +def get_single_rating_per_speaker(data): + rating = {} + for speaker, ratings in data.items(): + rating[speaker] = mean(ratings) + return rating + +def remove_outliers_per_speaker(data): + # data shape: {speaker: {ratings}} + cleaned_data = {} + for speaker, ratings_list in data.items(): + sorted_data = sorted(ratings_list) + q1, q3 = np.percentile(sorted_data, [25, 75]) + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + cleaned_data[speaker] = [x for x in sorted_data if lower_bound <= x <= upper_bound] + return cleaned_data + +def get_ratings_per_emotion(data): + ratings = {} + for dataset, speakers in data.items(): + for speaker, emotions in speakers.items(): + for emotion, preds in emotions.items(): + if emotion not in ratings: + ratings[emotion] = {} + for pred, freq in preds.items(): + if pred not in ratings[emotion]: + ratings[emotion][pred] = 0 + ratings[emotion][pred] += freq + return ratings + +def get_ratings_per_emotion_original(data): + ratings = {} + for speaker, emotions in data.items(): + for emotion, preds in emotions.items(): + if emotion not in ratings: + ratings[emotion] = {} + for pred, freq in preds.items(): + if pred not in ratings[emotion]: + ratings[emotion][pred] = 0 + ratings[emotion][pred] += freq + return ratings + +def get_ratings_per_speaker_emotion(data, speaker_ids): + ratings = {} + for dataset, speakers in data.items(): + for speaker, emotions in speakers.items(): + if int(speaker) in speaker_ids: + if speaker not in ratings: + ratings[speaker] = {} + for emotion, preds in emotions.items(): + if emotion not in ratings[speaker]: + ratings[speaker][emotion] = {} + for pred, freq in preds.items(): + if pred not in ratings[speaker][emotion]: + ratings[speaker][emotion][pred] = 0 + ratings[speaker][emotion][pred] += freq + return ratings + +def get_ratings_per_speaker_emotion_original(data, speaker_ids): + ratings = {} + for speaker, emotions in data.items(): + if int(speaker) in speaker_ids: + ratings[speaker] = {} + for emotion, preds in emotions.items(): + if emotion not in ratings[speaker]: + ratings[speaker][emotion] = {} + for pred, freq in preds.items(): + if pred not in ratings[speaker][emotion]: + ratings[speaker][emotion][pred] = 0 + ratings[speaker][emotion][pred] += freq + return ratings + +def total_accuracy(data): + count_correct = 0 + count_total = 0 + for emotion, preds in data.items(): + for pred, freq in preds.items(): + if pred == emotion: + count_correct += freq + count_total += freq + return count_correct / count_total + +def combine_sent_prompt(dict1, dict2): + combined_dict = {} + for key in dict1.keys() | dict2.keys(): + combined_dict[key] = dict1[key] + dict2[key] + return combined_dict + +def get_dict_with_rounded_values(dict, decimal_points=3): + rounded_dict = {key: round(value, decimal_points) for key, value in dict.items()} + return rounded_dict + +def cramers_v(data): + # Convert the data dictionary into a 2D array + counts = np.array([[data[emotion].get(label, 0) for emotion in EMOTIONS] for label in EMOTIONS]) + + # Compute the chi-squared statistic and p-value + chi2, p, _, _ = stats.chi2_contingency(counts) + + # Number of observations (total counts) + n = np.sum(counts) + + # Number of rows and columns in the contingency table + num_rows = len(EMOTIONS) + num_cols = len(EMOTIONS) + + # Compute Cramér's V + cramer_v = np.sqrt(chi2 / (n * (min(num_rows, num_cols) - 1))) + return p, cramer_v + +if __name__ == '__main__': + # load results + + # speaker similarity + # shape {dataset: {speaker: {emotion: speaker_similarity}}} + speaker_similarities_baseline = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_similarities_baseline.pt"), map_location='cpu') + speaker_similarities_sent = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_similarities_sent.pt"), map_location='cpu') + speaker_similarities_prompt = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_similarities_prompt.pt"), map_location='cpu') + + # wer + # shape {dataset: {speaker: {emotion: wer}}} + wers_original = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "wers_original.pt"), map_location='cpu') + wers_baseline = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "wers_baseline.pt"), map_location='cpu') + wers_sent = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "wers_sent.pt"), map_location='cpu') + wers_prompt = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "wers_prompt.pt"), map_location='cpu') + + # emotion recognition + # shape {dataset: {speaker: {emotion: {pred_emotion: count}}}} + freqs_original = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_original.pt"), map_location='cpu') + freqs_baseline = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_baseline.pt"), map_location='cpu') + freqs_sent = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_sent.pt"), map_location='cpu') + freqs_prompt = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_prompt.pt"), map_location='cpu') + + # extract ratings + + # speaker similarity + + # per speaker + speaker_similarities_baseline_speaker = get_ratings_per_speaker(speaker_similarities_baseline) + speaker_similarities_sent_speaker = get_ratings_per_speaker(speaker_similarities_sent) + speaker_similarities_prompt_speaker = get_ratings_per_speaker(speaker_similarities_prompt) + speaker_similarities_proposed_speaker = combine_sent_prompt(speaker_similarities_sent_speaker, speaker_similarities_prompt_speaker) + + speaker_similarities_baseline_speaker_cleaned = remove_outliers_per_speaker(speaker_similarities_baseline_speaker) + speaker_similarities_proposed_speaker_cleaned = remove_outliers_per_speaker(speaker_similarities_proposed_speaker) + + #mean + speaker_similarity_baseline_speaker = get_single_rating_per_speaker(speaker_similarities_baseline_speaker_cleaned) + speaker_similarity_proposed_speaker = get_single_rating_per_speaker(speaker_similarities_proposed_speaker_cleaned) + + print(dict(sorted(get_dict_with_rounded_values(speaker_similarity_baseline_speaker).items()))) + print() + print(dict(sorted(get_dict_with_rounded_values(speaker_similarity_proposed_speaker).items()))) + print() + + # total + speaker_similarity_baseline_total = mean(list(speaker_similarity_baseline_speaker.values())) + speaker_similarity_proposed_total = mean(list(speaker_similarity_proposed_speaker.values())) + + print("Speaker Similarity") + print(speaker_similarity_baseline_total) + print(speaker_similarity_proposed_total) + + # word error rate + + # per speaker + wers_original_speaker = get_ratings_per_speaker_original(wers_original) + wers_baseline_speaker = get_ratings_per_speaker(wers_baseline) + wers_sent_speaker = get_ratings_per_speaker(wers_sent) + wers_prompt_speaker = get_ratings_per_speaker(wers_prompt) + wers_proposed_speaker = combine_sent_prompt(wers_sent_speaker, wers_prompt_speaker) + + wers_original_speaker_cleaned = remove_outliers_per_speaker(wers_original_speaker) + wers_baseline_speaker_cleaned = remove_outliers_per_speaker(wers_baseline_speaker) + wers_proposed_speaker_cleaned = remove_outliers_per_speaker(wers_proposed_speaker) + + # mean + wer_original_speaker = get_single_rating_per_speaker(wers_original_speaker_cleaned) + wer_baseline_speaker = get_single_rating_per_speaker(wers_baseline_speaker_cleaned) + wer_proposed_speaker = get_single_rating_per_speaker(wers_proposed_speaker_cleaned) + + print(dict(sorted(get_dict_with_rounded_values(wer_original_speaker).items()))) + print() + print(dict(sorted(get_dict_with_rounded_values(wer_baseline_speaker).items()))) + print() + print(dict(sorted(get_dict_with_rounded_values(wer_proposed_speaker).items()))) + print() + + # total + wer_original_total = mean(list(wer_original_speaker.values())) + wer_baseline_total = mean(list(wer_baseline_speaker.values())) + wer_proposed_total = mean(list(wer_proposed_speaker.values())) + + print("Word Error Rate") + print(wer_original_total) + print(wer_baseline_total) + print(wer_proposed_total) + + # emotion recognition + + # per emotion + freqs_original_emotion = get_ratings_per_emotion_original(freqs_original) + freqs_baseline_emotion = get_ratings_per_emotion(freqs_baseline) + freqs_sent_emotion = get_ratings_per_emotion(freqs_sent) + freqs_prompt_emotion = get_ratings_per_emotion(freqs_prompt) + + # per speaker per emotion + freqs_original_speaker = get_ratings_per_speaker_emotion_original(freqs_original, [14, 15]) + freqs_baseline_speaker = get_ratings_per_speaker_emotion(freqs_baseline, [14, 15]) + freqs_sent_speaker = get_ratings_per_speaker_emotion(freqs_sent, [14, 15]) + freqs_prompt_speaker = get_ratings_per_speaker_emotion(freqs_prompt, [14, 15]) + + # total accuracy + accuracy_original = total_accuracy(freqs_original_emotion) + accuracy_baseline = total_accuracy(freqs_baseline_emotion) + accuracy_sent = total_accuracy(freqs_sent_emotion) + accuracy_prompt = total_accuracy(freqs_prompt_emotion) + + print("Emotion Recognition Accuracy") + print(accuracy_original) + print(accuracy_baseline) + print(accuracy_sent) + print(accuracy_prompt) + + + # plotting + os.makedirs(os.path.join(PREPROCESSING_DIR, "Evaluation", "plots"), exist_ok=True) + save_dir = os.path.join(PREPROCESSING_DIR, "Evaluation", "plots") + + boxplot_objective(speaker_similarities_baseline_speaker, os.path.join(save_dir, f"box_speaker_similarities_baseline.png")) + boxplot_objective(speaker_similarities_proposed_speaker, os.path.join(save_dir, f"box_speaker_similarities_proposed.png")) + + barplot_speaker_similarity([speaker_similarity_baseline_total, + speaker_similarity_proposed_total + ], + os.path.join(save_dir, f"speaker_similarity_total.png")) + + boxplot_objective2(wers_original_speaker, os.path.join(save_dir, f"box_wers_original_speaker.png")) + boxplot_objective2(wers_baseline_speaker, os.path.join(save_dir, f"box_wers_baseline_speaker.png")) + boxplot_objective2(wers_proposed_speaker, os.path.join(save_dir, f"box_wers_proposed_speaker.png")) + + barplot_wer([wer_original_total, + wer_baseline_total, + wer_proposed_total + ], + os.path.join(save_dir, f"wer_total.png")) + + heatmap_emotion(freqs_original_emotion, os.path.join(save_dir, f"emotion_objective_original.png")) + heatmap_emotion(freqs_baseline_emotion, os.path.join(save_dir, f"emotion_objective_baseline.png")) + heatmap_emotion(freqs_sent_emotion, os.path.join(save_dir, f"emotion_objective_sent.png")) + heatmap_emotion(freqs_prompt_emotion, os.path.join(save_dir, f"emotion_objective_prompt.png")) + + heatmap_emotion(freqs_original_speaker['15'], os.path.join(save_dir, f"emotion_objective_original_female.png")) + heatmap_emotion(freqs_baseline_speaker['15'], os.path.join(save_dir, f"emotion_objective_baseline_female.png")) + heatmap_emotion(freqs_sent_speaker['15'], os.path.join(save_dir, f"emotion_objective_sent_female.png")) + heatmap_emotion(freqs_prompt_speaker['15'], os.path.join(save_dir, f"emotion_objective_prompt_female.png")) + heatmap_emotion(freqs_original_speaker['14'], os.path.join(save_dir, f"emotion_objective_original_male.png")) + heatmap_emotion(freqs_baseline_speaker['14'], os.path.join(save_dir, f"emotion_objective_baseline_male.png")) + heatmap_emotion(freqs_sent_speaker['14'], os.path.join(save_dir, f"emotion_objective_sent_male.png")) + heatmap_emotion(freqs_prompt_speaker['14'], os.path.join(save_dir, f"emotion_objective_prompt_male.png")) + + barplot_emotion_recognition([accuracy_original, + accuracy_baseline, + accuracy_sent, + accuracy_prompt], + os.path.join(save_dir, f"emotion_accuracy.png")) + + print("Cramers V") + print(cramers_v(freqs_original_emotion)) + print(cramers_v(freqs_baseline_emotion)) + print(cramers_v(freqs_sent_emotion)) + print(cramers_v(freqs_prompt_emotion)) + + sys.exit() + accuracies_emotion_original = {} # per speaker per emotion + accuracies_speaker_original = {} # per speaker + for speaker, emotions in freqs_original.items(): + accuracies_emotion_original[speaker] = {} + accuracies_speaker_original[speaker] = sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds if pred == emo]) / sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds]) + for emotion, pred_emotions in emotions.items(): + accuracies_emotion_original[speaker][emotion] = pred_emotions[emotion] / sum(list(pred_emotions.values())) + + accuracy_original = sum([freqs_original[speaker][emotion][pred] + for speaker, emotions in freqs_original.items() + for emotion, preds in emotions.items() + for pred in preds if pred == emotion]) / sum([freqs_original[speaker][emotion][pred] + for speaker, emotions in freqs_original.items() + for emotion, preds in emotions.items() + for pred in preds]) + + accuracies_emotion_baseline = {} # per dataset per speaker per emotion + accuracies_speaker_baseline = {} # per speaker + count_correct = {} + count_total = {} + for dataset, speakers in freqs_baseline.items(): + accuracies_emotion_baseline[dataset] = {} + for speaker, emotions in speakers: + accuracies_emotion_baseline[dataset][speaker] = {} + if speaker not in count_correct: + count_correct[speaker] = 0 + if speaker not in count_total: + count_total[speaker] = 0 + count_correct[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds if pred == emo]) + count_total[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds]) + for emotion, pred_emotions in emotions.items(): + accuracies_emotion_baseline[dataset][speaker][emotion] = pred_emotions[emotion] / sum(list(pred_emotions.values())) + for speaker, freq in count_correct.items(): + accuracies_speaker_baseline[speaker] = freq / count_total[speaker] + + accuracy_baseline = sum([freqs_baseline[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_baseline.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds if pred == emotion]) / sum([freqs_baseline[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_baseline.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds]) + + accuracies_emotion_sent = {} # per dataset per speaker per emotion + accuracies_speaker_sent = {} # per speaker + count_correct = {} + count_total = {} + for dataset, speakers in freqs_sent.items(): + accuracies_emotion_sent[dataset] = {} + for speaker, emotions in speakers: + accuracies_emotion_sent[dataset][speaker] = {} + if speaker not in count_correct: + count_correct[speaker] = 0 + if speaker not in count_total: + count_total[speaker] = 0 + count_correct[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds if pred == emo]) + count_total[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds]) + for emotion, pred_emotions in emotions.items(): + accuracies_emotion_sent[dataset][speaker][emotion] = pred_emotions[emotion] / sum(list(pred_emotions.values())) + for speaker, freq in count_correct.items(): + accuracies_speaker_sent[speaker] = freq / count_total[speaker] + + accuracy_sent = sum([freqs_sent[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_sent.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds if pred == emotion]) / sum([freqs_sent[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_sent.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds]) + + accuracies_emotion_prompt = {} # per dataset per speaker per emotion + accuracies_speaker_prompt = {} # per speaker + count_correct = {} + count_total = {} + for dataset, speakers in freqs_prompt.items(): + accuracies_emotion_prompt[dataset] = {} + for speaker, emotions in speakers: + accuracies_emotion_prompt[dataset][speaker] = {} + if speaker not in count_correct: + count_correct[speaker] = 0 + if speaker not in count_total: + count_total[speaker] = 0 + count_correct[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds if pred == emo]) + count_total[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds]) + for emotion, pred_emotions in emotions.items(): + accuracies_emotion_prompt[dataset][speaker][emotion] = pred_emotions[emotion] / sum(list(pred_emotions.values())) + for speaker, freq in count_correct.items(): + accuracies_speaker_prompt[speaker] = freq / count_total[speaker] + + accuracy_prompt = sum([freqs_prompt[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_prompt.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds if pred == emotion]) / sum([freqs_prompt[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_prompt.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds]) \ No newline at end of file diff --git a/plot_objective_evaluation_paper.py b/plot_objective_evaluation_paper.py new file mode 100644 index 00000000..3b652fed --- /dev/null +++ b/plot_objective_evaluation_paper.py @@ -0,0 +1,334 @@ +import os +from statistics import median, mean +import numpy as np +import scipy.stats as stats + +import torch + +from Utility.storage_config import PREPROCESSING_DIR +from Evaluation.objective_evaluation import * +from Evaluation.plotting import * + +import sys + +EMOTIONS = ["anger", "joy", "neutral", "sadness", "surprise"] + +def get_ratings_per_speaker(data): + ratings = {} + for dataset, speakers in data.items(): + for speaker, emotions in speakers.items(): + if speaker not in ratings: + ratings[speaker] = [] + ratings[speaker].extend(list(emotions.values())) + return ratings + +def get_ratings_per_speaker_original(data): + ratings = {} + for speaker, emotions in data.items(): + ratings[speaker] = list(emotions.values()) + return ratings + +def get_single_rating_per_speaker(data): + rating = {} + for speaker, ratings in data.items(): + rating[speaker] = mean(ratings) + return rating + +def remove_outliers_per_speaker(data): + # data shape: {speaker: {ratings}} + cleaned_data = {} + for speaker, ratings_list in data.items(): + sorted_data = sorted(ratings_list) + q1, q3 = np.percentile(sorted_data, [25, 75]) + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + cleaned_data[speaker] = [x for x in sorted_data if lower_bound <= x <= upper_bound] + return cleaned_data + +def get_ratings_per_emotion(data): + ratings = {} + for dataset, speakers in data.items(): + for speaker, emotions in speakers.items(): + for emotion, preds in emotions.items(): + if emotion not in ratings: + ratings[emotion] = {} + for pred, freq in preds.items(): + if pred not in ratings[emotion]: + ratings[emotion][pred] = 0 + ratings[emotion][pred] += freq + return ratings + +def get_ratings_per_emotion_original(data): + ratings = {} + for speaker, emotions in data.items(): + for emotion, preds in emotions.items(): + if emotion not in ratings: + ratings[emotion] = {} + for pred, freq in preds.items(): + if pred not in ratings[emotion]: + ratings[emotion][pred] = 0 + ratings[emotion][pred] += freq + return ratings + +def get_ratings_per_speaker_emotion(data, speaker_ids): + ratings = {} + for dataset, speakers in data.items(): + for speaker, emotions in speakers.items(): + if int(speaker) in speaker_ids: + if speaker not in ratings: + ratings[speaker] = {} + for emotion, preds in emotions.items(): + if emotion not in ratings[speaker]: + ratings[speaker][emotion] = {} + for pred, freq in preds.items(): + if pred not in ratings[speaker][emotion]: + ratings[speaker][emotion][pred] = 0 + ratings[speaker][emotion][pred] += freq + return ratings + +def get_ratings_per_speaker_emotion_original(data, speaker_ids): + ratings = {} + for speaker, emotions in data.items(): + if int(speaker) in speaker_ids: + ratings[speaker] = {} + for emotion, preds in emotions.items(): + if emotion not in ratings[speaker]: + ratings[speaker][emotion] = {} + for pred, freq in preds.items(): + if pred not in ratings[speaker][emotion]: + ratings[speaker][emotion][pred] = 0 + ratings[speaker][emotion][pred] += freq + return ratings + +def total_accuracy(data): + count_correct = 0 + count_total = 0 + for emotion, preds in data.items(): + for pred, freq in preds.items(): + if pred == emotion: + count_correct += freq + count_total += freq + return count_correct / count_total + +def combine_sent_prompt(dict1, dict2): + combined_dict = {} + for key in dict1.keys() | dict2.keys(): + combined_dict[key] = dict1[key] + dict2[key] + return combined_dict + +def get_dict_with_rounded_values(dict, decimal_points=3): + rounded_dict = {key: round(value, decimal_points) for key, value in dict.items()} + return rounded_dict + +def cramers_v(data): + # Convert the data dictionary into a 2D array + counts = np.array([[data[emotion].get(label, 0) for emotion in EMOTIONS] for label in EMOTIONS]) + + # Compute the chi-squared statistic and p-value + chi2, p, _, _ = stats.chi2_contingency(counts) + + # Number of observations (total counts) + n = np.sum(counts) + + # Number of rows and columns in the contingency table + num_rows = len(EMOTIONS) + num_cols = len(EMOTIONS) + + # Compute Cramér's V + cramer_v = np.sqrt(chi2 / (n * (min(num_rows, num_cols) - 1))) + return p, cramer_v + +if __name__ == '__main__': + # load results + + # emotion recognition + # shape {dataset: {speaker: {emotion: {pred_emotion: count}}}} + freqs_original = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_original.pt"), map_location='cpu') + freqs_baseline = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_baseline.pt"), map_location='cpu') + freqs_sent = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_sent.pt"), map_location='cpu') + freqs_prompt = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_prompt.pt"), map_location='cpu') + freqs_emospeech = torch.load("/mount/arbeitsdaten/synthesis/bottts/emospeech/Evaluation/freqs.pt", map_location='cpu') + + speaker_similarities_emospeech = torch.load("/mount/arbeitsdaten/synthesis/bottts/emospeech/Evaluation/speaker_similarities.pt", map_location='cpu') + + # Extracting all the scores + all_scores = [] + for key in speaker_similarities_emospeech: + for sub_key in speaker_similarities_emospeech[key]: + all_scores.extend(list(speaker_similarities_emospeech[key][sub_key].values())) + + # Computing mean and standard deviation + mean_all = np.mean(all_scores) + std_dev_all = np.std(all_scores) + print(mean_all) + print(std_dev_all) + + + + # extract ratings + + # emotion recognition + + # per emotion + freqs_original_emotion = get_ratings_per_emotion_original(freqs_original) + freqs_baseline_emotion = get_ratings_per_emotion(freqs_baseline) + freqs_sent_emotion = get_ratings_per_emotion(freqs_sent) + freqs_prompt_emotion = get_ratings_per_emotion(freqs_prompt) + freqs_emospeech_emotion = get_ratings_per_emotion(freqs_emospeech) + + # plotting + os.makedirs(os.path.join(PREPROCESSING_DIR, "Evaluation", "plots_paper"), exist_ok=True) + save_dir = os.path.join(PREPROCESSING_DIR, "Evaluation", "plots_paper") + + #heatmap_emotion(freqs_original_emotion, os.path.join(save_dir, f"emotion_objective_original.png")) + #heatmap_emotion(freqs_baseline_emotion, os.path.join(save_dir, f"emotion_objective_baseline.png")) + #heatmap_emotion(freqs_sent_emotion, os.path.join(save_dir, f"emotion_objective_sent.png")) + #heatmap_emotion(freqs_prompt_emotion, os.path.join(save_dir, f"emotion_objective_prompt.png")) + titles = ["Ground Truth", "Baseline", "Prompt Conditioned Same", "Prompt Conditioned Other", "EmoSpeech"] + heatmap_emotion_multiple([freqs_original_emotion, freqs_baseline_emotion, freqs_sent_emotion, freqs_prompt_emotion, freqs_emospeech_emotion], titles, os.path.join(save_dir, f"emotion_objective_all2.pdf")) + + print("Cramers V") + data1 = freqs_original_emotion + data2 = freqs_emospeech_emotion + print(freqs_emospeech_emotion) + _, v1 = cramers_v(data1) + _, v2 = cramers_v(data2) + print(v1) + print(v2) + + # Compute the standard error of the difference + n1 = sum(sum(emotion.values()) for emotion in data1.values()) + n2 = sum(sum(emotion.values()) for emotion in data2.values()) + se_diff = np.sqrt((v1**2 / (n1 - 1)) + (v2**2 / (n2 - 1))) + + # Compute the t-statistic + t_statistic = (v1 - v2) / se_diff + + # Compute the degrees of freedom + df = min(len(data1) - 1, len(data2) - 1) + + # Compute the p-value + p_value = stats.t.cdf(t_statistic, df) + print(t_statistic) + print(p_value) + + sys.exit() + accuracies_emotion_original = {} # per speaker per emotion + accuracies_speaker_original = {} # per speaker + for speaker, emotions in freqs_original.items(): + accuracies_emotion_original[speaker] = {} + accuracies_speaker_original[speaker] = sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds if pred == emo]) / sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds]) + for emotion, pred_emotions in emotions.items(): + accuracies_emotion_original[speaker][emotion] = pred_emotions[emotion] / sum(list(pred_emotions.values())) + + accuracy_original = sum([freqs_original[speaker][emotion][pred] + for speaker, emotions in freqs_original.items() + for emotion, preds in emotions.items() + for pred in preds if pred == emotion]) / sum([freqs_original[speaker][emotion][pred] + for speaker, emotions in freqs_original.items() + for emotion, preds in emotions.items() + for pred in preds]) + + accuracies_emotion_baseline = {} # per dataset per speaker per emotion + accuracies_speaker_baseline = {} # per speaker + count_correct = {} + count_total = {} + for dataset, speakers in freqs_baseline.items(): + accuracies_emotion_baseline[dataset] = {} + for speaker, emotions in speakers: + accuracies_emotion_baseline[dataset][speaker] = {} + if speaker not in count_correct: + count_correct[speaker] = 0 + if speaker not in count_total: + count_total[speaker] = 0 + count_correct[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds if pred == emo]) + count_total[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds]) + for emotion, pred_emotions in emotions.items(): + accuracies_emotion_baseline[dataset][speaker][emotion] = pred_emotions[emotion] / sum(list(pred_emotions.values())) + for speaker, freq in count_correct.items(): + accuracies_speaker_baseline[speaker] = freq / count_total[speaker] + + accuracy_baseline = sum([freqs_baseline[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_baseline.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds if pred == emotion]) / sum([freqs_baseline[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_baseline.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds]) + + accuracies_emotion_sent = {} # per dataset per speaker per emotion + accuracies_speaker_sent = {} # per speaker + count_correct = {} + count_total = {} + for dataset, speakers in freqs_sent.items(): + accuracies_emotion_sent[dataset] = {} + for speaker, emotions in speakers: + accuracies_emotion_sent[dataset][speaker] = {} + if speaker not in count_correct: + count_correct[speaker] = 0 + if speaker not in count_total: + count_total[speaker] = 0 + count_correct[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds if pred == emo]) + count_total[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds]) + for emotion, pred_emotions in emotions.items(): + accuracies_emotion_sent[dataset][speaker][emotion] = pred_emotions[emotion] / sum(list(pred_emotions.values())) + for speaker, freq in count_correct.items(): + accuracies_speaker_sent[speaker] = freq / count_total[speaker] + + accuracy_sent = sum([freqs_sent[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_sent.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds if pred == emotion]) / sum([freqs_sent[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_sent.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds]) + + accuracies_emotion_prompt = {} # per dataset per speaker per emotion + accuracies_speaker_prompt = {} # per speaker + count_correct = {} + count_total = {} + for dataset, speakers in freqs_prompt.items(): + accuracies_emotion_prompt[dataset] = {} + for speaker, emotions in speakers: + accuracies_emotion_prompt[dataset][speaker] = {} + if speaker not in count_correct: + count_correct[speaker] = 0 + if speaker not in count_total: + count_total[speaker] = 0 + count_correct[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds if pred == emo]) + count_total[speaker] += sum([emotions[emo][pred] + for emo, preds in emotions.items() + for pred in preds]) + for emotion, pred_emotions in emotions.items(): + accuracies_emotion_prompt[dataset][speaker][emotion] = pred_emotions[emotion] / sum(list(pred_emotions.values())) + for speaker, freq in count_correct.items(): + accuracies_speaker_prompt[speaker] = freq / count_total[speaker] + + accuracy_prompt = sum([freqs_prompt[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_prompt.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds if pred == emotion]) / sum([freqs_prompt[dataset][speaker][emotion][pred] + for dataset, speakers in freqs_prompt.items() + for speaker, emotions in speakers + for emotion, preds in emotions.items() + for pred in preds]) \ No newline at end of file diff --git a/remove_emotion_files.py b/remove_emotion_files.py new file mode 100644 index 00000000..c2b4fc42 --- /dev/null +++ b/remove_emotion_files.py @@ -0,0 +1,70 @@ +from tqdm import tqdm + +from Utility.storage_config import PREPROCESSING_DIR +from Utility.corpus_preparation import prepare_fastspeech_corpus +from Utility.path_to_transcript_dicts import * + +def get_emotion_from_path(path): + if "EmoV_DB" in path: + emotion = os.path.splitext(os.path.basename(path))[0].split("-16bit")[0].split("_")[0].lower() + if emotion == "amused": + emotion = "joy" + if "CREMA_D" in path: + emotion = os.path.splitext(os.path.basename(path))[0].split('_')[2] + if emotion == "ANG": + emotion = "anger" + if emotion == "DIS": + emotion = "disgust" + if emotion == "FEA": + emotion = "fear" + if emotion == "HAP": + emotion = "joy" + if emotion == "NEU": + emotion = "neutral" + if emotion == "SAD": + emotion = "sadness" + if "Emotional_Speech_Dataset_Singapore" in path: + emotion = os.path.basename(os.path.dirname(path)).lower() + if emotion == "angry": + emotion = "anger" + if emotion == "happy": + emotion = "joy" + if emotion == "sad": + emotion = "sadness" + if "RAVDESS" in path: + emotion = os.path.splitext(os.path.basename(path))[0].split('-')[2] + if emotion == "01": + emotion = "neutral" + if emotion == "02": + emotion = "calm" + if emotion == "03": + emotion = "joy" + if emotion == "04": + emotion = "sadness" + if emotion == "05": + emotion = "anger" + if emotion == "06": + emotion = "fear" + if emotion == "07": + emotion = "disgust" + if emotion == "08": + emotion = "surprise" + return emotion + +if __name__ == '__main__': + train_set = prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_RAVDESS(), + corpus_dir=os.path.join(PREPROCESSING_DIR, "ravdess"), + lang="en", + save_imgs=False) + + remove_ids = [] + for index in tqdm(range(len(train_set))): + path = train_set[index][10] + emotion = get_emotion_from_path(path) + if emotion == "sleepiness" or emotion == "calm": + remove_ids.append(index) + + for remove_id in sorted(remove_ids, reverse=True): + print(train_set[remove_id][10]) + + #train_set.remove_samples(remove_ids) diff --git a/requirements.txt b/requirements.txt index cf056958..a621e947 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/run_gradio_demo.py b/run_gradio_demo.py new file mode 100644 index 00000000..6f929b02 --- /dev/null +++ b/run_gradio_demo.py @@ -0,0 +1,78 @@ +import gradio as gr +import numpy as np + +from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface +from Preprocessing.sentence_embeddings.EmotionRoBERTaSentenceEmbeddingExtractor import EmotionRoBERTaSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + + +def float2pcm(sig, dtype='int16'): + """ + https://gist.github.com/HudsonHuang/fbdf8e9af7993fe2a91620d3fb86a182 + """ + sig = np.asarray(sig) + if sig.dtype.kind != 'f': + raise TypeError("'sig' must be a float array") + dtype = np.dtype(dtype) + if dtype.kind not in 'iu': + raise TypeError("'dtype' must be an integer type") + i = np.iinfo(dtype) + abs_max = 2 ** (i.bits - 1) + offset = i.min + abs_max + return (sig * abs_max + offset).clip(i.min, i.max).astype(dtype) + + +class TTSWebUI: + + def __init__(self, gpu_id="cpu", title="Prompting ToucanTTS", article=""): + sent_emb_extractor = SentenceEmbeddingExtractor(pooling="cls", device=gpu_id) + self.speaker_to_id = {'Female 1': 29, + 'Female 2': 30, + 'Female 3': 31, + 'Female 4': 32, + 'Male 1': 25, + 'Male 2': 26, + 'Male 3': 27, + 'Male 4': 28, + 'Male 5': 33, + 'Male 6': 34} + self.tts_interface = ToucanTTSInterface(device=gpu_id, + tts_model_path='Proposed', + faster_vocoder=True, + sent_emb_extractor=sent_emb_extractor) + self.iface = gr.Interface(fn=self.read, + inputs=[gr.Textbox(lines=2, + placeholder="write what you want the synthesis to read here...", + value="Today is a beautiful day.", + label="Text input"), + gr.Textbox(lines=2, + placeholder="write a (emotional) prompt in order to control the speaking style...", + value="I am so angry!", + label="Prompt"), + gr.Dropdown(['Female 1', + 'Female 2', + 'Female 3', + 'Female 4', + 'Male 1', + 'Male 2', + 'Male 3', + 'Male 4', + 'Male 5', + 'Male 6'], type="value", + value='Female 1', label="Select a Speaker")], + outputs=[gr.Audio(type="numpy", label="Speech"), + gr.Image(label="Visualization")], + title=title, + theme="default", + allow_flagging="never", + article=article) + self.iface.launch(enable_queue=True) + + def read(self, input, prompt, speaker): + self.tts_interface.set_language("en") + self.tts_interface.set_speaker_id(self.speaker_to_id[speaker]) + self.tts_interface.set_sentence_embedding(prompt) + wav, fig = self.tts_interface(input, return_plot_as_filepath=True) + return (24000, float2pcm(wav.cpu().numpy())), fig + +if __name__ == '__main__': + TTSWebUI(gpu_id="cpu") \ No newline at end of file diff --git a/run_model_downloader.py b/run_model_downloader.py index 71ac8f50..d9fb82f4 100644 --- a/run_model_downloader.py +++ b/run_model_downloader.py @@ -62,6 +62,22 @@ def download_models(): url="https://github.com/DigitalPhonetics/IMS-Toucan/releases/download/v2.5/embedding_gan.pt", filename=os.path.abspath(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt")), reporthook=report) + + ############# + print("Downloading Baseline Model (Prompting)") + os.makedirs(os.path.join(MODELS_DIR, "ToucanTTS_Baseline"), exist_ok=True) + filename, headers = urllib.request.urlretrieve( + url="https://github.com/Thommy96/IMS-Toucan/releases/download/v1.0/baseline.pt", + filename=os.path.abspath(os.path.join(MODELS_DIR, "ToucanTTS_Baseline", "best.pt")), + reporthook=report) + + ############# + print("Downloading Proposed Model (Prompting)") + os.makedirs(os.path.join(MODELS_DIR, "ToucanTTS_Proposed"), exist_ok=True) + filename, headers = urllib.request.urlretrieve( + url="https://github.com/Thommy96/IMS-Toucan/releases/download/v1.0/proposed.pt", + filename=os.path.abspath(os.path.join(MODELS_DIR, "ToucanTTS_Proposed", "best.pt")), + reporthook=report) if __name__ == '__main__': diff --git a/run_objective_evaluation.py b/run_objective_evaluation.py new file mode 100644 index 00000000..13801219 --- /dev/null +++ b/run_objective_evaluation.py @@ -0,0 +1,220 @@ +import os +import argparse +from statistics import median, mean + +import torch +from transformers import pipeline +from speechbrain.pretrained import EncoderClassifier +from speechbrain.pretrained.interfaces import foreign_class +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + +from Utility.storage_config import PREPROCESSING_DIR, MODELS_DIR +from Evaluation.objective_evaluation import * +from Preprocessing.sentence_embeddings.EmotionRoBERTaSentenceEmbeddingExtractor import EmotionRoBERTaSentenceEmbeddingExtractor +from InferenceInterfaces.InferenceArchitectures.InferenceAvocodo import HiFiGANGenerator + +import sys + +NUM_TEST_SENTENCES = 50 # 50 sentences per emotion category +EMOTIONS = ["anger", "joy", "neutral", "sadness", "surprise"] + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Evaluation') + parser.add_argument('--gpu_id', + type=str, + help="Which GPU to run on. If not specified runs on CPU.", + default="cpu") + args = parser.parse_args() + if args.gpu_id == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + device = torch.device("cpu") + print(f"No GPU specified, using CPU.") + + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu_id}" + device = torch.device("cuda") + print(f"Making GPU {os.environ['CUDA_VISIBLE_DEVICES']} the only visible device.") + + # extract test sentences + # test sentences are a dict with shape {dataset: {emotion: [sentences]}} + print("Loading test senteces...") + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "Evaluation", "test_sentences.pt")): + os.makedirs(os.path.join(PREPROCESSING_DIR, "Evaluation"), exist_ok=True) + + emotion_to_sents_dialog = extract_dailydialogue_sentences() + + tales_data_dir = "/mount/arbeitsdaten/synthesis/bottts/IMS-Toucan/Corpora/Tales" + emotion_to_sents_tales = extract_tales_sentences(tales_data_dir) + + classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=1, device=device) + test_sentences = {} + test_sentences["dailydialogues"] = get_sorted_test_sentences(emotion_to_sents_dialog, classifier) + test_sentences["tales"] = get_sorted_test_sentences(emotion_to_sents_tales, classifier) + + torch.save(test_sentences, os.path.join(PREPROCESSING_DIR, "Evaluation", "test_sentences.pt")) + else: + test_sentences = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "test_sentences.pt"), map_location='cpu') + + for dataset, emotion_to_sents in test_sentences.items(): + for emotion, sentences in emotion_to_sents.items(): + test_sentences[dataset][emotion] = sentences[:NUM_TEST_SENTENCES] + + for dataset, emotion_to_sents in test_sentences.items(): + for emotion, sentences in emotion_to_sents.items(): + if len(sentences) != NUM_TEST_SENTENCES: + raise ValueError(f"Number of sentences is not {NUM_TEST_SENTENCES} for dataset {dataset} and emotion {emotion}.") + + # synthesize test sentences + if not os.path.exists(os.path.join("./audios/Evaluation")): + print("Synthesizing Baseline...") + synthesize_test_sentences(version="Baseline", + exec_device=device, + biggan=False, + sent_emb_extractor=None, + test_sentences=test_sentences, + silent=True) + print("Synthesizing Proposed...") + sent_emb_extractor = EmotionRoBERTaSentenceEmbeddingExtractor(pooling="cls") + synthesize_test_sentences(version="Sent", + exec_device=device, + biggan=False, + sent_emb_extractor=sent_emb_extractor, + test_sentences=test_sentences, + silent=True) + print("Synthesizing Prompt...") + synthesize_test_sentences(version="Prompt", + exec_device=device, + biggan=False, + sent_emb_extractor=sent_emb_extractor, + test_sentences=test_sentences, + silent=True) + + # get vocoded original sentences + if not os.path.exists(os.path.join("./audios/Evaluation/Original")): + print("Vocoding Original...") + os.makedirs(os.path.join("./audios/Evaluation/Original"), exist_ok=True) + vocoder_model_path = os.path.join(MODELS_DIR, "Avocodo", "best.pt") + mel2wav = HiFiGANGenerator(path_to_weights=vocoder_model_path).to(device) + mel2wav.remove_weight_norm() + mel2wav.eval() + vocode_original(mel2wav, num_sentences=NUM_TEST_SENTENCES, device=device) + + # extract speaker embeddings + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_embeddings_original.pt")): + print("Extracting speaker embeddings...") + classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb", + savedir="./Models/Embedding/spkrec-xvect-voxceleb", + run_opts={"device": device}) + # shape {speaker: {emotion: {file_id: embedding}}} + speaker_embeddings_original = extract_speaker_embeddings("./audios/Evaluation", classifier, version='Original') + torch.save(speaker_embeddings_original, os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_embeddings_original.pt")) + # shape {dataset: {speaker: {emotion: {file_id: embedding}}}} + speaker_embeddings_baseline = extract_speaker_embeddings("./audios/Evaluation", classifier, version='Baseline') + torch.save(speaker_embeddings_baseline, os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_embeddings_baseline.pt")) + # shape {dataset: {speaker: {emotion: {file_id: embedding}}}} + speaker_embeddings_sent = extract_speaker_embeddings("./audios/Evaluation", classifier, version='Sent') + torch.save(speaker_embeddings_sent, os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_embeddings_sent.pt")) + # shape {dataset: {speaker: {emotion: {prompt_emotion: {file_id: embedding}}}}} + speaker_embeddings_prompt = extract_speaker_embeddings("./audios/Evaluation", classifier, version='Prompt') + torch.save(speaker_embeddings_prompt, os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_embeddings_prompt.pt")) + else: + speaker_embeddings_original = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_embeddings_original.pt"), map_location='cpu') + speaker_embeddings_baseline = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_embeddings_baseline.pt"), map_location='cpu') + speaker_embeddings_sent = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_embeddings_sent.pt"), map_location='cpu') + speaker_embeddings_prompt = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_embeddings_prompt.pt"), map_location='cpu') + + # calculate speaker similarity + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_similarities_baseline.pt")): + print("Calculating speaker similarity...") + # shape {dataset: {speaker: {emotion: speaker_similarity}}} + speaker_similarities_baseline = compute_speaker_similarity(speaker_embeddings_original, speaker_embeddings_baseline, version='Baseline') + # shape {dataset: {speaker: {emotion: speaker_similarity}}} + speaker_similarities_sent = compute_speaker_similarity(speaker_embeddings_original, speaker_embeddings_sent, version='Sent') + # shape {dataset: {speaker: {prompt_emotion: speaker_similarity}}} + speaker_similarities_prompt = compute_speaker_similarity(speaker_embeddings_original, speaker_embeddings_prompt, version='Prompt') + + torch.save(speaker_similarities_baseline, os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_similarities_baseline.pt")) + torch.save(speaker_similarities_sent, os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_similarities_sent.pt")) + torch.save(speaker_similarities_prompt, os.path.join(PREPROCESSING_DIR, "Evaluation", "speaker_similarities_prompt.pt")) + + # calculate word error rate + print("Calculating word error rate...") + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "Evaluation", "transcriptions_original.pt")): + print("Transcribing...") + processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", cache_dir=os.path.join(MODELS_DIR, "ASR")) + model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", cache_dir=os.path.join(MODELS_DIR, "ASR")).to(device) + # shape {speaker: {emotion: {sentence_id: transcription}}} + transcriptions_original = asr_transcribe("./audios/Evaluation", processor, model, version='Original') + torch.save(transcriptions_original, os.path.join(PREPROCESSING_DIR, "Evaluation", "transcriptions_original.pt")) + # shape {dataset: {speaker: {emotion: {sentence_id: transcription}}}} + transcriptions_baseline = asr_transcribe("./audios/Evaluation", processor, model, version='Baseline') + torch.save(transcriptions_baseline, os.path.join(PREPROCESSING_DIR, "Evaluation", "transcriptions_baseline.pt")) + # shape {dataset: {speaker: {emotion: {sentence_id: transcription}}}} + transcriptions_sent = asr_transcribe("./audios/Evaluation", processor, model, version='Sent') + torch.save(transcriptions_sent, os.path.join(PREPROCESSING_DIR, "Evaluation", "transcriptions_sent.pt")) + # shape {dataset: {speaker: {emotion: {prompt_emotion: {file_id: embedding}}}}} + transcriptions_prompt = asr_transcribe("./audios/Evaluation", processor, model, version='Prompt') + torch.save(transcriptions_prompt, os.path.join(PREPROCESSING_DIR, "Evaluation", "transcriptions_prompt.pt")) + else: + transcriptions_original = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "transcriptions_original.pt"), map_location='cpu') + transcriptions_baseline = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "transcriptions_baseline.pt"), map_location='cpu') + transcriptions_sent = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "transcriptions_sent.pt"), map_location='cpu') + transcriptions_prompt = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "transcriptions_prompt.pt"), map_location='cpu') + + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "Evaluation", "wers_original.pt")): + # shape {speaker: {emotion: wer}} + wers_original = compute_word_error_rate(transcriptions_original, test_sentences, version='Original') + # shape {dataset: {speaker: {emotion: wer}}} + wers_baseline = compute_word_error_rate(transcriptions_baseline, test_sentences, version='Baseline') + # shape {dataset: {speaker: {emotion: wer}}} + wers_sent = compute_word_error_rate(transcriptions_sent, test_sentences, version='Sent') + # shape {dataset: {speaker: {prompt_emotion: wer}}} + wers_prompt = compute_word_error_rate(transcriptions_prompt, test_sentences, version='Prompt') + + torch.save(wers_original, os.path.join(PREPROCESSING_DIR, "Evaluation", "wers_original.pt")) + torch.save(wers_baseline, os.path.join(PREPROCESSING_DIR, "Evaluation", "wers_baseline.pt")) + torch.save(wers_sent, os.path.join(PREPROCESSING_DIR, "Evaluation", "wers_sent.pt")) + torch.save(wers_prompt, os.path.join(PREPROCESSING_DIR, "Evaluation", "wers_prompt.pt")) + + # speech emotion recognition + print("Calculating speech emotion recognition...") + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "Evaluation", "predicted_emotions_original.pt")): + print("Speech emotion recognition...") + classifier = foreign_class(source=os.path.join(MODELS_DIR, "Emotion_Recognition"), + pymodule_file="custom_interface.py", + classname="CustomEncoderWav2vec2Classifier", + savedir=os.path.join(MODELS_DIR, "Emotion_Recognition"), + run_opts={"device":device}) + + # shape {speaker: {emotion: {sentence_id: predicted emotion}}} + predicted_emotions_original = classify_speech_emotion("./audios/Evaluation", classifier, version='Original') + torch.save(predicted_emotions_original, os.path.join(PREPROCESSING_DIR, "Evaluation", "predicted_emotions_original.pt")) + # shape {dataset: {speaker: {emotion: {sentence_id: predicted emotion}}}} + predicted_emotions_baseline = classify_speech_emotion("./audios/Evaluation", classifier, version='Baseline') + torch.save(predicted_emotions_baseline, os.path.join(PREPROCESSING_DIR, "Evaluation", "predicted_emotions_baseline.pt")) + # shape {dataset: {speaker: {emotion: {sentence_id: predicted emotion}}}} + predicted_emotions_sent = classify_speech_emotion("./audios/Evaluation", classifier, version='Sent') + torch.save(predicted_emotions_sent, os.path.join(PREPROCESSING_DIR, "Evaluation", "predicted_emotions_sent.pt")) + # shape {dataset: {speaker: {emotion: {prompt_emotion: {sentence_id: predicted emotion}}}}} + predicted_emotions_prompt = classify_speech_emotion("./audios/Evaluation", classifier, version='Prompt') + torch.save(predicted_emotions_prompt, os.path.join(PREPROCESSING_DIR, "Evaluation", "predicted_emotions_prompt.pt")) + else: + predicted_emotions_original = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "predicted_emotions_original.pt"), map_location='cpu') + predicted_emotions_baseline = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "predicted_emotions_baseline.pt"), map_location='cpu') + predicted_emotions_sent = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "predicted_emotions_sent.pt"), map_location='cpu') + predicted_emotions_prompt = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "predicted_emotions_prompt.pt"), map_location='cpu') + + # shape {speaker: {emotion: {pred_emotion: count}}} + freqs_original = compute_predicted_emotions_frequencies(predicted_emotions_original, version='Original') + # shape {dataset: {speaker: {emotion: {pred_emotion: count}}}} + freqs_baseline = compute_predicted_emotions_frequencies(predicted_emotions_baseline, version='Baseline') + # shape {dataset: {speaker: {emotion: {pred_emotion: count}}}} + freqs_sent = compute_predicted_emotions_frequencies(predicted_emotions_sent, version='Sent') + # shape {dataset: {speaker: {prompt_emotion: {pred_emotion: count}}}} + freqs_prompt = compute_predicted_emotions_frequencies(predicted_emotions_prompt, version='Prompt') + + torch.save(freqs_original, os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_original.pt")) + torch.save(freqs_baseline, os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_baseline.pt")) + torch.save(freqs_sent, os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_sent.pt")) + torch.save(freqs_prompt, os.path.join(PREPROCESSING_DIR, "Evaluation", "freqs_prompt.pt")) diff --git a/run_sent_emb_test_suite.py b/run_sent_emb_test_suite.py new file mode 100644 index 00000000..17b124ac --- /dev/null +++ b/run_sent_emb_test_suite.py @@ -0,0 +1,601 @@ +import os +import time + +import torch + +from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface +from Utility.storage_config import PREPROCESSING_DIR + +def test_sentence(version, + model_id="Meta", + exec_device="cpu", + speaker_reference=None, + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + word_emb_extractor=None, + prompt:str=None, + xvect_model=None, + speaker_id=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_emb_extractor, + xvect_model=xvect_model) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if speaker_id is not None: + tts.set_speaker_id(speaker_id) + if prompt is not None: + tts.set_sentence_embedding(prompt) + + #sentence = "The football teams give a tea party." + sentence1 = "You can write an email." + sentence2 = "They will arrive tomorrow." + tts.read_to_file(text_list=[sentence1], + file_location=f"audios/{version}/paper1.wav", + increased_compatibility_mode=True, + view_contours=True, + plot_name="paper1") + start_time = time.time() + tts.read_to_file(text_list=[sentence2], + file_location=f"audios/{version}/paper2.wav", + increased_compatibility_mode=True, + view_contours=False, + plot_name="paper2") + end_time = time.time() + elapsed_time = end_time - start_time + print("Elapsed time:", elapsed_time, "seconds") + +def test_tales_emotion(version, model_id="Meta", + exec_device="cpu", + speaker_reference=None, + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + word_emb_extractor=None, + prompt:str=None, + xvect_model=None, + speaker_id=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + os.makedirs(f"audios/{version}/Tales", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_emb_extractor, + xvect_model=xvect_model) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if speaker_id is not None: + tts.set_speaker_id(speaker_id) + if prompt is not None: + tts.set_sentence_embedding(prompt) + + emotion_to_sents = torch.load(os.path.join(PREPROCESSING_DIR, "Tales", f"emotion_sentences_top20.pt"), map_location='cpu') + for emotion, sents in emotion_to_sents.items(): + for i, sent in enumerate(sents): + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Tales/{emotion}_{i}.wav", increased_compatibility_mode=True) + +def test_yelp_emotion(version, model_id="Meta", + exec_device="cpu", + speaker_reference=None, + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + word_emb_extractor=None, + prompt:str=None, + xvect_model=None, + speaker_id=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + os.makedirs(f"audios/{version}/Yelp", exist_ok=True) + os.makedirs(f"audios/{version}/Yelp_Prompt", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_emb_extractor, + xvect_model=xvect_model) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if speaker_id is not None: + tts.set_speaker_id(speaker_id) + if prompt is not None: + tts.set_sentence_embedding(prompt) + + emotion_to_sents = torch.load(os.path.join(PREPROCESSING_DIR, "Yelp", f"emotion_sentences_top20.pt"), map_location='cpu') + for emotion, sents in emotion_to_sents.items(): + for i, sent in enumerate(sents): + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Yelp/{emotion}_{i}.wav", increased_compatibility_mode=True) + +def test_gne_emotion(version, model_id="Meta", + exec_device="cpu", + speaker_reference=None, + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + word_emb_extractor=None, + prompt:str=None, + xvect_model=None, + speaker_id=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + os.makedirs(f"audios/{version}/Headlines", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_emb_extractor, + xvect_model=xvect_model) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if speaker_id is not None: + tts.set_speaker_id(speaker_id) + if prompt is not None: + tts.set_sentence_embedding(prompt) + + emotion_to_sents = torch.load(os.path.join(PREPROCESSING_DIR, "Headlines", f"emotion_sentences_top20.pt"), map_location='cpu') + for emotion, sents in emotion_to_sents.items(): + for i, sent in enumerate(sents): + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Headlines/{emotion}_{i}.wav", increased_compatibility_mode=True) + +def test_controllable(version, model_id="Meta", + exec_device="cpu", + speaker_reference=None, + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + word_emb_extractor=None, + prompt:str=None, + xvect_model=None, + speaker_id=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_emb_extractor, + xvect_model=xvect_model) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if speaker_id is not None: + tts.set_speaker_id(speaker_id) + if prompt is not None: + tts.set_sentence_embedding(prompt) + + for i, sentence in enumerate(['I am so happy to see you!', + 'Today is a beautiful day and the sun is shining.', + 'He seemed to be quite lucky as he was smiling at me.', + 'She laughed and said: This is so funny.', + 'No, this is horrible!', + 'I am so sad, why is this so depressing?', + 'Be careful, cried the woman.', + 'This makes me feel bad.', + 'Oh happy day!', + 'Well, this sucks.', + 'That smell is disgusting.', + 'I am so angry!', + 'What a surprise!', + 'I am so scared, I fear the worst.', + 'This is a neutral test sentence with medium length, which should have relatively neutral prosody, and can be used to test the controllability through textual prompts.']): + tts.read_to_file(text_list=[sentence], file_location=f"audios/{version}/Controllable_{i}.wav", increased_compatibility_mode=True) + +def test_study(version, model_id="Meta", + exec_device="cpu", + speaker_reference=None, + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + word_emb_extractor=None, + prompt:str=None, + xvect_model=None, + speaker_id=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + os.makedirs(f"audios/{version}/Study", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_emb_extractor, + xvect_model=xvect_model) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if speaker_id is not None: + tts.set_speaker_id(speaker_id) + if prompt is not None: + tts.set_sentence_embedding(prompt) + + emotion_to_sents = {"anger": ["You can't be serious, how dare you not tell me you were going to marry her?", + "I'm so angry, I feel like killing someone!", + "It’s infuriating, I have to be in Rome by five!", + "The bear, in great fury, ran after the carriage." + ], + "disgust": ["I can't help myself, it's just so disgusting in here!", + "What a stink, this place stinks like rotten eggs.", + "I hate to complain, but this soup is too salty.", + "The rabbits could not bear him, they could smell him half a mile off." + ], + "sadness": ["I know that , mom, but sometimes I'm just sad.", + "My uncle passed away last night.", + "Lily broke up with me last week, in fact, she dumped me.", + "Here he remained the whole night, feeling very tired and sorrowful." + ], + "joy": ["I am very happy to know my work could be recognized by you and our company.", + "I really enjoy the beach in the summer.", + "I had a wonderful time.", + "Then she saw that her deliverance was near, and her heart leapt with joy." + ], + "surprise":["I’m shocked he actually won.", + "Wow, why so much, I thought they were getting you an assistant.", + "Really, I can't believe it, it's like a dream come true, I never expected that I would win The Nobel Prize!", + "He was astonished when he saw them come alone, and asked what had happened to them." + ], + "fear": ["I feel very nervous about it.", + "I'm scared that she might not come back.", + "Well , I just saw a horror movie last night, it almost frightened me to death.", + "Peter sat down to rest, he was out of breath and trembling with fright, and he had not the least idea which way to go." + ], + "neutral": ["You must specify an address of the place where you will spend most of your time.", + "Just a second, I'll see if I can find them for you.", + "You can go to the Employment Development Office and pick it up.", + "So the queen gave him the letter, and said that he might see for himself what was written in it." + ] + } + for emotion, sents in emotion_to_sents.items(): + for i, sent in enumerate(sents): + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/{emotion}_{i}.wav", increased_compatibility_mode=True) + +def test_study2(version, model_id="Meta", + exec_device="cpu", + speaker_reference=None, + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + word_emb_extractor=None, + prompt:str=None, + xvect_model=None, + speaker_id=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + os.makedirs(f"audios/{version}/Study", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_emb_extractor, + xvect_model=xvect_model) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if speaker_id is not None: + tts.set_speaker_id(speaker_id) + if prompt is not None: + tts.set_sentence_embedding(prompt) + + emotion_to_sents = {"anger": ["You can't be serious, how dare you not tell me you were going to marry her?", + "The king grew angry, and cried: That is not allowed, he must appear before me and tell his name!" + ], + "disgust": ["What a stink, this place stinks like rotten eggs.", + "The rabbits could not bear him, they could smell him half a mile off." + ], + "sadness": ["Lily broke up with me last week, in fact, she dumped me.", + "The sisters mourned as young hearts can mourn, and were especially grieved at the sight of their parents' sorrow." + ], + "joy": ["I really enjoy the beach in the summer.", + "Then she saw that her deliverance was near, and her heart leapt with joy." + ], + "surprise":["Really? I can't believe it! It's like a dream come true, I never expected that I would win The Nobel Prize!", + "He was astonished when he saw them come alone, and asked what had happened to them." + ], + "fear": ["I'm scared that she might not come back.", + "Peter sat down to rest, he was out of breath and trembling with fright, and he had not the least idea which way to go." + ], + "neutral": ["You can go to the Employment Development Office and pick it up.", + "So the queen gave him the letter, and said that he might see for himself what was written in it." + ] + } + for emotion, sents in emotion_to_sents.items(): + for i, sent in enumerate(sents): + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/{emotion}_{i}.flac", increased_compatibility_mode=True) + +def test_study2_male(version, model_id="Meta", + exec_device="cpu", + speaker_reference=None, + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + word_emb_extractor=None, + prompt:str=None, + xvect_model=None, + speaker_id=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + os.makedirs(f"audios/{version}/Study", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_emb_extractor, + xvect_model=xvect_model) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if speaker_id is not None: + tts.set_speaker_id(speaker_id) + + emotion = "anger" + sent = "The king grew angry, and cried: That is not allowed, he must appear before me and tell his name!" + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_{emotion}_{1}.flac", increased_compatibility_mode=True) + + emotion = "joy" + sent = "Then she saw that her deliverance was near, and her heart leapt with joy." + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_{emotion}_{1}.flac", increased_compatibility_mode=True) + + emotion = "neutral" + sent = "So the queen gave him the letter, and said that he might see for himself what was written in it." + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_{emotion}_{1}.flac", increased_compatibility_mode=True) + + emotion = "sadness" + sent = "The sisters mourned as young hearts can mourn, and were especially grieved at the sight of their parents' sorrow." + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_{emotion}_{1}.flac", increased_compatibility_mode=True) + + emotion = "surprise" + sent = "Really? I can't believe it! It's like a dream come true, I never expected that I would win The Nobel Prize!" + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_{emotion}_{0}.flac", increased_compatibility_mode=True) + +def test_study2_male_prompt(version, model_id="Meta", + exec_device="cpu", + speaker_reference=None, + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + word_emb_extractor=None, + prompt:str=None, + xvect_model=None, + speaker_id=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + os.makedirs(f"audios/{version}/Study", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_emb_extractor, + xvect_model=xvect_model) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if speaker_id is not None: + tts.set_speaker_id(speaker_id) + + emotion = "anger" + sent = "The king grew angry, and cried: That is not allowed, he must appear before me and tell his name!" + prompt = "Really? I can't believe it! It's like a dream come true, I never expected that I would win The Nobel Prize!" + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_prompt_{emotion}_{1}.flac", increased_compatibility_mode=True) + + emotion = "joy" + sent = "Then she saw that her deliverance was near, and her heart leapt with joy." + prompt = "The sisters mourned as young hearts can mourn, and were especially grieved at the sight of their parents' sorrow." + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_prompt_{emotion}_{1}.flac", increased_compatibility_mode=True) + + emotion = "neutral" + sent = "So the queen gave him the letter, and said that he might see for himself what was written in it." + prompt = "The king grew angry, and cried: That is not allowed, he must appear before me and tell his name!" + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_prompt_{emotion}_{1}.flac", increased_compatibility_mode=True) + + emotion = "sadness" + sent = "The sisters mourned as young hearts can mourn, and were especially grieved at the sight of their parents' sorrow." + prompt = "Then she saw that her deliverance was near, and her heart leapt with joy." + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_prompt_{emotion}_{1}.flac", increased_compatibility_mode=True) + + emotion = "surprise" + sent = "Really? I can't believe it! It's like a dream come true, I never expected that I would win The Nobel Prize!" + prompt = "So the queen gave him the letter, and said that he might see for himself what was written in it." + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_prompt_{emotion}_{0}.flac", increased_compatibility_mode=True) + +def test_study2_female(version, model_id="Meta", + exec_device="cpu", + speaker_reference=None, + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + word_emb_extractor=None, + prompt:str=None, + xvect_model=None, + speaker_id=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + os.makedirs(f"audios/{version}/Study", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_emb_extractor, + xvect_model=xvect_model) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if speaker_id is not None: + tts.set_speaker_id(speaker_id) + + emotion = "anger" + sent = "You can't be serious, how dare you not tell me you were going to marry her?" + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_{emotion}_{0}.flac", increased_compatibility_mode=True) + + emotion = "joy" + sent = "I really enjoy the beach in the summer." + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_{emotion}_{0}.flac", increased_compatibility_mode=True) + + emotion = "neutral" + sent = "You can go to the Employment Development Office and pick it up." + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_{emotion}_{0}.flac", increased_compatibility_mode=True) + + emotion = "sadness" + sent = "Lily broke up with me last week, in fact, she dumped me." + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_{emotion}_{0}.flac", increased_compatibility_mode=True) + + emotion = "surprise" + sent = "He was astonished when he saw them come alone, and asked what had happened to them." + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_{emotion}_{1}.flac", increased_compatibility_mode=True) + +def test_study2_female_prompt(version, model_id="Meta", + exec_device="cpu", + speaker_reference=None, + vocoder_model_path=None, + biggan=False, + sent_emb_extractor=None, + word_emb_extractor=None, + prompt:str=None, + xvect_model=None, + speaker_id=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + os.makedirs(f"audios/{version}/Study", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, + tts_model_path=model_id, + vocoder_model_path=vocoder_model_path, + faster_vocoder=not biggan, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_emb_extractor, + xvect_model=xvect_model) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if speaker_id is not None: + tts.set_speaker_id(speaker_id) + + emotion = "anger" + sent = "You can't be serious, how dare you not tell me you were going to marry her?" + prompt = "You can go to the Employment Development Office and pick it up." + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_prompt_{emotion}_{0}.flac", increased_compatibility_mode=True) + + emotion = "joy" + sent = "I really enjoy the beach in the summer." + prompt = "He was astonished when he saw them come alone, and asked what had happened to them." + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_prompt_{emotion}_{0}.flac", increased_compatibility_mode=True) + + emotion = "neutral" + sent = "You can go to the Employment Development Office and pick it up." + prompt = "I really enjoy the beach in the summer." + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_prompt_{emotion}_{0}.flac", increased_compatibility_mode=True) + + emotion = "sadness" + sent = "Lily broke up with me last week, in fact, she dumped me." + prompt = "You can't be serious, how dare you not tell me you were going to marry her?" + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_prompt_{emotion}_{0}.flac", increased_compatibility_mode=True) + + emotion = "surprise" + sent = "He was astonished when he saw them come alone, and asked what had happened to them." + prompt = "Lily broke up with me last week, in fact, she dumped me." + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sent], file_location=f"audios/{version}/Study/sent_prompt_{emotion}_{1}.flac", increased_compatibility_mode=True) + + +if __name__ == '__main__': + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"0" + exec_device = "cuda:0" if torch.cuda.is_available() else "cpu" + print(f"running on {exec_device}") + + use_speaker_reference = False + use_sent_emb = True + use_word_emb = False + use_prompt = True + use_xvect = False + use_ecapa = False + use_speaker_id = True + + if use_speaker_id: + speaker_id = 4 + 1 + 24 + else: + speaker_id = None + + if use_sent_emb: + from Preprocessing.sentence_embeddings.EmotionRoBERTaSentenceEmbeddingExtractor import EmotionRoBERTaSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + sent_emb_extractor = SentenceEmbeddingExtractor(pooling="cls") + else: + sent_emb_extractor = None + + if use_word_emb: + from Preprocessing.word_embeddings.EmotionRoBERTaWordEmbeddingExtractor import EmotionRoBERTaWordEmbeddingExtractor + word_embedding_extractor = EmotionRoBERTaWordEmbeddingExtractor() + else: + word_embedding_extractor = None + + if use_speaker_reference: + speaker_reference = "/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore/0018/Surprise/0018_001431.wav" + else: + speaker_reference = None + + if use_prompt: + #prompt = "I am so angry!" + #prompt = "Roar with laughter, this is funny." + #prompt = "Ew, this is disgusting." + #prompt = "What a surprise!" + #prompt = "This is very sad." + #prompt = "I am so scared." + #prompt = "I love that." + #prompt = "He was furious." + #prompt = "She didn't expect that." + prompt = "That's ok." + #prompt = "Oh, really?" + else: + prompt = None + + if use_xvect: + from speechbrain.pretrained import EncoderClassifier + xvect_model = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb", savedir="./Models/Embedding/spkrec-xvect-voxceleb", run_opts={"device": exec_device}) + else: + xvect_model = None + + if use_ecapa: + from speechbrain.pretrained import EncoderClassifier + ecapa_model = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", savedir="./Models/Embedding/spkrec-ecapa-voxceleb", run_opts={"device": exec_device}) + else: + ecapa_model = None + if ecapa_model is not None: + xvect_model = ecapa_model + + test_sentence(version="ToucanTTS_Sent_Finetuning_2_80k", + model_id="Sent_Finetuning_2_80k", + exec_device=exec_device, + vocoder_model_path=None, + biggan=False, + speaker_reference=speaker_reference, + sent_emb_extractor=sent_emb_extractor, + word_emb_extractor=word_embedding_extractor, + prompt=prompt, + xvect_model=xvect_model, + speaker_id=speaker_id) diff --git a/run_sent_word_emb_test_suite.py b/run_sent_word_emb_test_suite.py new file mode 100644 index 00000000..bcb376c7 --- /dev/null +++ b/run_sent_word_emb_test_suite.py @@ -0,0 +1,170 @@ +import os + +import torch + +from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface + + +def the_raven_and_the_fox(version, model_id="Meta", exec_device="cpu", speaker_reference=None, vocoder_model_path=None, biggan=False, sent_emb_extractor=None, prompt:str=None, word_emb_extractor=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, tts_model_path=model_id, vocoder_model_path=vocoder_model_path, faster_vocoder=not biggan, sent_emb_extractor=sent_emb_extractor, word_emb_extractor=word_emb_extractor) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + + for i, sentence in enumerate(["Master Raven, on a perched tree, was holding a cheese in his beak.", + "Master Fox, attracted by the smell, told him more or less this language:", + "And hello, Master Raven, how pretty you are! How beautiful you seem to me!", + "No lie, if your bearing is anything like your plumage, you are the Phoenix of the hosts in these woods.", + "At these words the Raven does not feel happy, and to show his beautiful voice, he opens a wide beak, drops his prey.", + "The Fox seized it, and said: My good Sir, learn that every flatterer lives at the expense of the one who listens to him.", + "This lesson is worth a cheese without doubt.", + "The ashamed and confused Raven swore, but a little late, that he would not be caught again.", + "Master Raven, on a perched tree, was holding a cheese in his beak. Master Fox, attracted by the smell, told him more or less this language: And hello, Master Raven, how pretty you are! How beautiful you seem to me! No lie, if your bearing is anything like your plumage, you are the Phoenix of the hosts in these woods. At these words the Raven does not feel happy, and to show his beautiful voice, he opens a wide beak, drops his prey. The Fox seized it, and said: My good Sir, learn that every flatterer lives at the expense of the one who listens to him. This lesson is worth a cheese without doubt. The ashamed and confused Raven swore, but a little late, that he would not be caught again." + ]): + if prompt is not None: + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sentence], file_location=f"audios/{version}/The_raven_and_the_fox_{i}.wav") + +def poem(version, model_id="Meta", exec_device="cpu", speaker_reference=None, vocoder_model_path=None, biggan=False, sent_emb_extractor=None, prompt:str=None, word_emb_extractor=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, tts_model_path=model_id, vocoder_model_path=vocoder_model_path, faster_vocoder=not biggan, sent_emb_extractor=sent_emb_extractor, word_emb_extractor=word_emb_extractor) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + + for i, sentence in enumerate(['Once upon a midnight dreary, while I pondered, weak, and weary,', + 'Over many a quaint, and curious volume of forgotten lore,', + 'While I nodded, nearly napping, suddenly, there came a tapping,', + 'As of someone gently rapping, rapping at my chamber door.', + 'Tis some visitor, I muttered, tapping at my chamber door,', + 'Only this, and nothing more.', + 'Ah, distinctly, I remember, it was in the bleak December,', + 'And each separate dying ember, wrought its ghost upon the floor.', + 'Eagerly, I wished the morrow, vainly, I had sought to borrow', + 'From my books surcease of sorrow, sorrow, for the lost Lenore,', + 'For the rare and radiant maiden, whom the angels name Lenore,', + 'Nameless here, for evermore.', + 'And the silken, sad, uncertain, rustling of each purple curtain', + 'Thrilled me, filled me, with fantastic terrors, never felt before.']): + if prompt is not None: + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sentence], file_location=f"audios/{version}/Poem_{i}.wav") + +def test_sentence(version, model_id="Meta", exec_device="cpu", speaker_reference=None, vocoder_model_path=None, biggan=False, sent_emb_extractor=None, prompt:str=None, word_emb_extractor=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, tts_model_path=model_id, vocoder_model_path=vocoder_model_path, faster_vocoder=not biggan, sent_emb_extractor=sent_emb_extractor, word_emb_extractor=word_emb_extractor) + tts.set_language("en") + sentence = "Well, she said, if I had had your bringing up I might have had as good a temper as you, but now I don't believe I ever shall." + #sentence = "But yours, your regard was new compared with; Fanny, think of me!" + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + if prompt is not None: + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sentence], file_location=f"audios/{version}/test_sentence.wav") + +def test_controllable(version, model_id="Meta", exec_device="cpu", speaker_reference=None, vocoder_model_path=None, biggan=False, sent_emb_extractor=None, prompt:str=None, make_preprompt=False, word_emb_extractor=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, tts_model_path=model_id, vocoder_model_path=vocoder_model_path, faster_vocoder=not biggan, sent_emb_extractor=sent_emb_extractor, word_emb_extractor=word_emb_extractor) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + + for i, sentence in enumerate(['I am so happy to see you!', + 'Today is a beautiful day and the sun is shining.', + 'He seemed to be quite lucky as he was smiling at me.', + 'She laughed and said: This is so funny.', + 'No, this is horrible!', + 'I am so sad, why is this so depressing?', + 'Be careful!, Cried the woman', + 'This makes me feel bad.']): + if prompt is not None: + if make_preprompt: + prompt = prompt + ' ' + sentence + tts.set_sentence_embedding(prompt) + tts.read_to_file(text_list=[sentence], file_location=f"audios/{version}/Controllable_{i}.wav") + +def test_promptspeech(version, model_id="Meta", exec_device="cpu", speaker_reference=None, vocoder_model_path=None, biggan=False, sent_emb_extractor=None, prompt:str=None, make_preprompt=False): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, tts_model_path=model_id, vocoder_model_path=vocoder_model_path, faster_vocoder=not biggan, sent_emb_extractor=sent_emb_extractor) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + + prompts = ["Women's voice reminds whispers to be fast and high tone", + "Lady loud but bass tone saying", + "A girl shouted in a low speed", + "Seek a loud female voice with a treble, fast speed", + "A male bass voice said lowly and quietly", + "Minor men whisper and fast, talk with high tone", + "The male voice loud and high tone and the speed is slow", + "His speaking rate is rapidly and loudly"] + + for i, sentence in enumerate(['The next day I left Marsh End for Morton.', + 'It was not, she knew, that night had come, but something as dark as night had come.', + "He immediately answered in Hook's voice:", + 'Give me some brandy.', + "Perhaps I'd better quit talking.", + 'Up and down the street he went, and in and out the lanes, but no traces of the pig could he find anywhere.', + 'There were no windows whatever, and only one or two slight crevices through which the light came.', + 'When Wendy returned diffidently she found peter sitting on the bed post crowing gloriously, while Jane in her nighty was flying round the room in solemn ecstasy.']): + if make_preprompt: + prompt = prompts[i] + ' ' + sentence + tts.set_sentence_embedding(prompt) + else: + tts.set_sentence_embedding(prompts[i]) + tts.read_to_file(text_list=[sentence], file_location=f"audios/{version}/test_promptspeech_{i}.wav") + + +if __name__ == '__main__': + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"1,2" + exec_device = "cuda:0" if torch.cuda.is_available() else "cpu" + #exec_device = "cpu" + print(f"running on {exec_device}") + + use_speaker_reference = False + use_sent_emb = True + use_word_emb = True + use_prompt = False + + if use_sent_emb: + #import tensorflow + #gpus = tensorflow.config.experimental.list_physical_devices('GPU') + #tensorflow.config.experimental.set_visible_devices(gpus[1], 'GPU') + #from Preprocessing.sentence_embeddings.LEALLASentenceEmbeddingExtractor import LEALLASentenceEmbeddingExtractor as SentenceEmbeddingExtractor + + #from Preprocessing.sentence_embeddings.LASERSentenceEmbeddingExtractor import LASERSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + #from Preprocessing.sentence_embeddings.STSentenceEmbeddingExtractor import STSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + from Preprocessing.sentence_embeddings.BERTSentenceEmbeddingExtractor import BERTSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + + #sent_emb_extractor = SentenceEmbeddingExtractor(model="mpnet") + sent_emb_extractor = SentenceEmbeddingExtractor(pooling="cls") + #sent_emb_extractor = SentenceEmbeddingExtractor() + else: + sent_emb_extractor = None + + if use_word_emb: + from Preprocessing.word_embeddings.BERTWordEmbeddingExtractor import BERTWordEmbeddingExtractor + word_embedding_extractor = BERTWordEmbeddingExtractor() + else: + word_embedding_extractor = None + + if use_speaker_reference: + speaker_reference = "/mount/resources/speech/corpora/Blizzard2013/train/segmented/wavn/CA-BB-07-04.wav" + #speaker_reference = "/mount/resources/speech/corpora/LibriTTS/all_clean/1638/84448/1638_84448_000057_000006.wav" + else: + speaker_reference = None + + if use_prompt: + #prompt = "Well, she said, if I had had your bringing up I might have had as good a temper as you, but now I don't believe I ever shall." + prompt = "No, this is horrible!" + else: + prompt = None + + test_controllable(version="ToucanTTS_03_Blizzard2013_sent_word_emb_a12_loss_bertcls_keep", model_id="03_Blizzard2013_sent_word_emb_a12_loss_bertcls_keep", exec_device=exec_device, vocoder_model_path=None, biggan=True, speaker_reference=speaker_reference, sent_emb_extractor=sent_emb_extractor, word_emb_extractor=word_embedding_extractor, prompt=prompt) \ No newline at end of file diff --git a/run_subjective_evaluation.py b/run_subjective_evaluation.py new file mode 100644 index 00000000..79bf50a8 --- /dev/null +++ b/run_subjective_evaluation.py @@ -0,0 +1,223 @@ +import os +from statistics import mean + +from Evaluation.subjective_evaluation import * +from Evaluation.plotting import * +from Utility.storage_config import PREPROCESSING_DIR + +import sys + +if __name__ == '__main__': + os.makedirs(os.path.join(PREPROCESSING_DIR, "Evaluation", "plots"), exist_ok=True) + save_dir = os.path.join(PREPROCESSING_DIR, "Evaluation", "plots") + + # load data + data = read_data(os.path.join(PREPROCESSING_DIR, "Evaluation", "data_listeningtestmaster_2023-07-30_16-07.csv")) + + sd = sociodemographics(data) + + pref = preference(data) + + sim = similarity(data) + + mos_original = mean_opinion_score(data, "O") + mos_baseline = mean_opinion_score(data, "B") + mos_proposed = mean_opinion_score(data, "S") + mos_prompt = mean_opinion_score(data, "P") + + emotion_original = emotion(data, "O") + emotion_baseline = emotion(data, "B") + emotion_proposed = emotion(data, "S") + emotion_prompt = emotion(data, "P") + + valence_original = valence(data, "O") + valence_baseline = valence(data, "B") + valence_proposed = valence(data, "S") + valence_prompt = valence(data, "P") + + arousal_original = arousal(data, "O") + arousal_baseline = arousal(data, "B") + arousal_proposed = arousal(data, "S") + arousal_prompt = arousal(data, "P") + + # transform data + ################# + ################# + + # A/B pref + ############## + pref_female, pref_male = split_female_male(pref) + pref_female_total = collapse_subdicts(pref_female) + pref_male_total = collapse_subdicts(pref_male) + pref_total = {} + for key, value in pref_female_total.items(): + if key not in pref_total: + pref_total[key] = 0 + pref_total[key] += value + for key, value in pref_male_total.items(): + if key not in pref_total: + pref_total[key] = 0 + pref_total[key] += value + + # similarity + ############# + sim_female, sim_male = split_female_male(sim) + msim_female = get_mean_rating_nested(remove_outliers(sim_female)) + msim_male = get_mean_rating_nested(remove_outliers(sim_male)) + msim_total = {'female' : mean(list(msim_female.values())), + 'male' : mean(list(msim_male.values()))} + + # mos + ############## + mos_original_female, mos_original_male = split_female_male(mos_original) + mos_baseline_female, mos_baseline_male = split_female_male(mos_baseline) + mos_proposed_female, mos_proposed_male = split_female_male(mos_proposed) + mos_prompt_female, mos_prompt_male = split_female_male(mos_prompt) + # combine proposed and prompt since it is the same system and the distinction is not really needed for mos + mos_proposed_female = combine_dicts(mos_proposed_female, mos_prompt_female) + mos_proposed_male = combine_dicts(mos_proposed_male, mos_prompt_male) + + omos_original_female = get_mean_rating_nested(remove_outliers(mos_original_female)) + omos_original_male = get_mean_rating_nested(remove_outliers(mos_original_male)) + omos_baseline_female = get_mean_rating_nested(remove_outliers(mos_baseline_female)) + omos_baseline_male = get_mean_rating_nested(remove_outliers(mos_baseline_male)) + omos_proposed_female = get_mean_rating_nested(remove_outliers(mos_proposed_female)) + omos_proposed_male = get_mean_rating_nested(remove_outliers(mos_proposed_male)) + + omos_female = [mean(list(omos_original_female.values())), mean(list(omos_baseline_female.values())), mean(list(omos_proposed_female.values()))] + omos_male = [mean(list(omos_original_male.values())), mean(list(omos_baseline_male.values())), mean(list(omos_proposed_male.values()))] + omos_all = [mean([m1, m2]) for m1, m2 in zip(omos_female, omos_male)] + + print(omos_female) + + _, p_value_mos = independent_samples_t_test(mos_proposed_female, mos_proposed_male, mos_baseline_female, mos_baseline_male) + print(f'p value MOS proposed-baseline: {p_value_mos}') + + # emotion + ########### + emotion_original_female, emotion_original_male = split_female_male(emotion_original) + emotion_baseline_female, emotion_baseline_male = split_female_male(emotion_baseline) + emotion_proposed_female, emotion_proposed_male = split_female_male(emotion_proposed) + emotion_prompt_female, emotion_prompt_male = split_female_male(emotion_prompt) + emotion_prompt_female = make_emotion_prompts(emotion_prompt_female, "f") + emotion_prompt_male = make_emotion_prompts(emotion_prompt_male, "m") + + print(cramers_v(emotion_original_female)) + print(cramers_v(emotion_baseline_female)) + print(cramers_v(emotion_proposed_female)) + print(cramers_v(emotion_prompt_female)) + print(cramers_v(emotion_original_male)) + print(cramers_v(emotion_baseline_male)) + print(cramers_v(emotion_proposed_male)) + print(cramers_v(emotion_prompt_male)) + + # valence/arousal + ################### + valence_original_female, valence_original_male = split_female_male(valence_original) + arousal_original_female, arousal_original_male = split_female_male(arousal_original) + valence_baseline_female, valence_baseline_male = split_female_male(valence_baseline) + arousal_baseline_female, arousal_baseline_male = split_female_male(arousal_baseline) + valence_proposed_female, valence_proposed_male = split_female_male(valence_proposed) + arousal_proposed_female, arousal_proposed_male = split_female_male(arousal_proposed) + valence_prompt_female, valence_prompt_male = split_female_male(valence_prompt) + arousal_prompt_female, arousal_prompt_male = split_female_male(arousal_prompt) + valence_prompt_female = make_emotion_prompts(valence_prompt_female, "f") + valence_prompt_male = make_emotion_prompts(valence_prompt_male, "m") + arousal_prompt_female = make_emotion_prompts(arousal_prompt_female, "f") + arousal_prompt_male = make_emotion_prompts(arousal_prompt_male, "m") + + mvalence_original_female = get_mean_rating_nested(remove_outliers(valence_original_female)) + mvalence_original_male = get_mean_rating_nested(remove_outliers(valence_original_male)) + mvalence_baseline_female = get_mean_rating_nested(remove_outliers(valence_baseline_female)) + mvalence_baseline_male = get_mean_rating_nested(remove_outliers(valence_baseline_male)) + mvalence_proposed_female = get_mean_rating_nested(remove_outliers(valence_proposed_female)) + mvalence_proposed_male = get_mean_rating_nested(remove_outliers(valence_proposed_male)) + mvalence_prompt_female = get_mean_rating_nested(remove_outliers(valence_prompt_female)) + mvalence_prompt_male = get_mean_rating_nested(remove_outliers(valence_prompt_male)) + + marousal_original_female = get_mean_rating_nested(remove_outliers(arousal_original_female)) + marousal_original_male = get_mean_rating_nested(remove_outliers(arousal_original_male)) + marousal_baseline_female = get_mean_rating_nested(remove_outliers(arousal_baseline_female)) + marousal_baseline_male = get_mean_rating_nested(remove_outliers(arousal_baseline_male)) + marousal_proposed_female = get_mean_rating_nested(remove_outliers(arousal_proposed_female)) + marousal_proposed_male = get_mean_rating_nested(remove_outliers(arousal_proposed_male)) + marousal_prompt_female = get_mean_rating_nested(remove_outliers(arousal_prompt_female)) + marousal_prompt_male = get_mean_rating_nested(remove_outliers(arousal_prompt_male)) + + print(marousal_original_female) + print(marousal_original_male) + + print(marousal_baseline_female) + print(marousal_baseline_male) + + print(marousal_proposed_female) + print(marousal_proposed_male) + + print(marousal_prompt_female) + print(marousal_prompt_male) + + + # make plots + ################# + ################# + + for v, d in sd.items(): + pie_chart_counts(d, v, os.path.join(save_dir, f"{v}.png")) + + barplot_pref3(pref_female, os.path.join(save_dir, f"pref_female.png")) + barplot_pref3(pref_male, os.path.join(save_dir, f"pref_male.png")) + barplot_pref_total(pref_female_total, os.path.join(save_dir, f"pref_female_total.png")) + barplot_pref_total(pref_male_total, os.path.join(save_dir, f"pref_male_total.png")) + barplot_pref_total(pref_total, os.path.join(save_dir, f"pref_total.png")) + + boxplot_rating(sim_female, os.path.join(save_dir, f"box_sim_female.png")) + boxplot_rating(sim_male, os.path.join(save_dir, f"box_sim_male.png")) + barplot_sim(msim_female, os.path.join(save_dir, f"sim_female.png")) + barplot_sim(msim_male, os.path.join(save_dir, f"sim_male.png")) + barplot_sim_total(msim_total, os.path.join(save_dir, f"sim_total.png")) + + boxplot_rating(mos_original_female, os.path.join(save_dir, f"box_mos_original_female.png")) + boxplot_rating(mos_original_male, os.path.join(save_dir, f"box_mos_original_male.png")) + boxplot_rating(mos_baseline_female, os.path.join(save_dir, f"box_mos_baseline_female.png")) + boxplot_rating(mos_baseline_male, os.path.join(save_dir, f"box_mos_baseline_male.png")) + boxplot_rating(mos_proposed_female, os.path.join(save_dir, f"box_mos_proposed_female.png")) + boxplot_rating(mos_proposed_male, os.path.join(save_dir, f"box_mos_proposed_male.png")) + barplot_mos(omos_female, os.path.join(save_dir, f"mos_female.png")) + barplot_mos(omos_male, os.path.join(save_dir, f"mos_male.png")) + barplot_mos(omos_all, os.path.join(save_dir, f"mos.png")) + + heatmap_emotion(emotion_original_female, os.path.join(save_dir, f"emotion_original_female.png")) + heatmap_emotion(emotion_original_male, os.path.join(save_dir, f"emotion_original_male.png")) + heatmap_emotion(emotion_baseline_female, os.path.join(save_dir, f"emotion_baseline_female.png")) + heatmap_emotion(emotion_baseline_male, os.path.join(save_dir, f"emotion_baseline_male.png")) + heatmap_emotion(emotion_proposed_female, os.path.join(save_dir, f"emotion_proposed_female.png")) + heatmap_emotion(emotion_proposed_male, os.path.join(save_dir, f"emotion_proposed_male.png")) + heatmap_emotion(emotion_prompt_female, os.path.join(save_dir, f"emotion_prompt_female.png")) + heatmap_emotion(emotion_prompt_male, os.path.join(save_dir, f"emotion_prompt_male.png")) + + boxplot_rating(valence_original_female, os.path.join(save_dir, f"box_v_original_female.png")) + boxplot_rating(valence_original_male, os.path.join(save_dir, f"box_v_original_male.png")) + boxplot_rating(valence_baseline_female, os.path.join(save_dir, f"box_v_baseline_female.png")) + boxplot_rating(valence_baseline_male, os.path.join(save_dir, f"box_v_baseline_male.png")) + boxplot_rating(valence_proposed_female, os.path.join(save_dir, f"box_v_proposed_female.png")) + boxplot_rating(valence_proposed_male, os.path.join(save_dir, f"box_v_proposed_male.png")) + boxplot_rating(valence_prompt_female, os.path.join(save_dir, f"box_v_prompt_female.png")) + boxplot_rating(valence_prompt_male, os.path.join(save_dir, f"box_v_prompt_male.png")) + + boxplot_rating(arousal_original_female, os.path.join(save_dir, f"box_a_original_female.png")) + boxplot_rating(arousal_original_male, os.path.join(save_dir, f"box_a_original_male.png")) + boxplot_rating(arousal_baseline_female, os.path.join(save_dir, f"box_a_baseline_female.png")) + boxplot_rating(arousal_baseline_male, os.path.join(save_dir, f"box_a_baseline_male.png")) + boxplot_rating(arousal_proposed_female, os.path.join(save_dir, f"box_a_proposed_female.png")) + boxplot_rating(arousal_proposed_male, os.path.join(save_dir, f"box_a_proposed_male.png")) + boxplot_rating(arousal_prompt_female, os.path.join(save_dir, f"box_a_prompt_female.png")) + boxplot_rating(arousal_prompt_male, os.path.join(save_dir, f"box_a_prompt_male.png")) + + scatterplot_va(mvalence_original_female, marousal_original_female, os.path.join(save_dir, f"va_original_female.png")) + scatterplot_va(mvalence_original_male, marousal_original_male, os.path.join(save_dir, f"va_original_male.png")) + scatterplot_va(mvalence_baseline_female, marousal_baseline_female, os.path.join(save_dir, f"va_baseline_female.png")) + scatterplot_va(mvalence_baseline_male, marousal_baseline_male, os.path.join(save_dir, f"va_baseline_male.png")) + scatterplot_va(mvalence_proposed_female, marousal_proposed_female, os.path.join(save_dir, f"va_proposed_female.png")) + scatterplot_va(mvalence_proposed_male, marousal_proposed_male, os.path.join(save_dir, f"va_proposed_male.png")) + scatterplot_va(mvalence_prompt_female, marousal_prompt_female, os.path.join(save_dir, f"va_prompt_female.png")) + scatterplot_va(mvalence_prompt_male, marousal_prompt_male, os.path.join(save_dir, f"va_prompt_male.png")) diff --git a/run_subjective_evaluation_paper.py b/run_subjective_evaluation_paper.py new file mode 100644 index 00000000..8f40e570 --- /dev/null +++ b/run_subjective_evaluation_paper.py @@ -0,0 +1,218 @@ +import os +from statistics import mean +import math + +from Evaluation.subjective_evaluation import * +from Evaluation.plotting import * +from Utility.storage_config import PREPROCESSING_DIR + +import sys + +if __name__ == '__main__': + os.makedirs(os.path.join(PREPROCESSING_DIR, "Evaluation", "plots_paper"), exist_ok=True) + save_dir = os.path.join(PREPROCESSING_DIR, "Evaluation", "plots_paper") + + # load data + data = read_data(os.path.join(PREPROCESSING_DIR, "Evaluation", "data_listeningtestmaster_2023-07-30_16-07.csv")) + + sd = sociodemographics(data) + + pref = preference(data) + + sim = similarity(data) + + mos_original = mean_opinion_score(data, "O") + mos_baseline = mean_opinion_score(data, "B") + mos_proposed = mean_opinion_score(data, "S") + mos_prompt = mean_opinion_score(data, "P") + + #print(sum(sum(inner_dict.values()) for inner_dict in mos_original.values())) + #print(sum(sum(inner_dict.values()) for inner_dict in mos_baseline.values())) + #print(sum(sum(inner_dict.values()) for inner_dict in mos_proposed.values())) + #print(sum(sum(inner_dict.values()) for inner_dict in mos_prompt.values())) + + emotion_original = emotion(data, "O") + emotion_baseline = emotion(data, "B") + emotion_proposed = emotion(data, "S") + emotion_prompt = emotion(data, "P") + + valence_original = valence(data, "O") + valence_baseline = valence(data, "B") + valence_proposed = valence(data, "S") + valence_prompt = valence(data, "P") + + arousal_original = arousal(data, "O") + arousal_baseline = arousal(data, "B") + arousal_proposed = arousal(data, "S") + arousal_prompt = arousal(data, "P") + + # transform data + ################# + ################# + + # A/B pref + ############## + pref_female, pref_male = split_female_male(pref) + pref_female_total = collapse_subdicts(pref_female) + pref_male_total = collapse_subdicts(pref_male) + pref_total = {} + for key, value in pref_female_total.items(): + if key not in pref_total: + pref_total[key] = 0 + pref_total[key] += value + for key, value in pref_male_total.items(): + if key not in pref_total: + pref_total[key] = 0 + pref_total[key] += value + + #print(pref_total) + + # similarity + ############# + sim_female, sim_male = split_female_male(sim) + msim_female = get_mean_rating_nested(remove_outliers(sim_female)) + msim_male = get_mean_rating_nested(remove_outliers(sim_male)) + msim_total = {'female' : mean(list(msim_female.values())), + 'male' : mean(list(msim_male.values()))} + + sim_female = remove_outliers(sim_female) + sim_male = remove_outliers(sim_male) + print(sim_male) + + def calculate_stats(data): + emotions = {} + total_count = 0 + total_sum = 0 + for emotion, ratings in data.items(): + total_sum = 0 + squared_deviations = 0 + total_count_emotion = sum(ratings.values()) + for rating, count in ratings.items(): + total_sum += rating * count + squared_deviations += (rating - (total_sum / total_count_emotion))**2 * count + mean = total_sum / total_count_emotion + std_dev = math.sqrt(squared_deviations / (total_count_emotion - 1)) + emotions[emotion] = {"mean": mean, "std_dev": std_dev} + total_count += total_count_emotion + total_sum += total_sum + + overall_mean = total_sum / total_count + overall_squared_deviations = 0 + for emotion, stats in emotions.items(): + emotion_mean = stats["mean"] + for rating, count in data[emotion].items(): + overall_squared_deviations += (rating - overall_mean) ** 2 * count + overall_std_dev = math.sqrt(overall_squared_deviations / (total_count - 1)) + + return {"emotions": emotions, "overall_mean": overall_mean, "overall_std_dev": overall_std_dev} + + # Calculate the total count + total_count = 0 + for e, ratings in sim_female.items(): + total_count += sum(ratings.values()) + print(total_count) + # Calculate the total count + total_count = 0 + for e, ratings in sim_male.items(): + total_count += sum(ratings.values()) + print(total_count) + + print(calculate_stats(sim_female)) + + # mos + ############## + mos_original_female, mos_original_male = split_female_male(mos_original) + mos_baseline_female, mos_baseline_male = split_female_male(mos_baseline) + mos_proposed_female, mos_proposed_male = split_female_male(mos_proposed) + mos_prompt_female, mos_prompt_male = split_female_male(mos_prompt) + # combine proposed and prompt since it is the same system and the distinction is not really needed for mos + mos_proposed_female = combine_dicts(mos_proposed_female, mos_prompt_female) + mos_proposed_male = combine_dicts(mos_proposed_male, mos_prompt_male) + + omos_original_female = get_mean_rating_nested(remove_outliers(mos_original_female)) + omos_original_male = get_mean_rating_nested(remove_outliers(mos_original_male)) + omos_baseline_female = get_mean_rating_nested(remove_outliers(mos_baseline_female)) + omos_baseline_male = get_mean_rating_nested(remove_outliers(mos_baseline_male)) + omos_proposed_female = get_mean_rating_nested(remove_outliers(mos_proposed_female)) + omos_proposed_male = get_mean_rating_nested(remove_outliers(mos_proposed_male)) + + omos_female = [mean(list(omos_original_female.values())), mean(list(omos_baseline_female.values())), mean(list(omos_proposed_female.values()))] + omos_male = [mean(list(omos_original_male.values())), mean(list(omos_baseline_male.values())), mean(list(omos_proposed_male.values()))] + omos_all = [mean([m1, m2]) for m1, m2 in zip(omos_female, omos_male)] + + print(omos_female) + + _, p_value_mos = independent_samples_t_test(mos_proposed_female, mos_proposed_male, mos_baseline_female, mos_baseline_male) + print(f'p value MOS proposed-baseline: {p_value_mos}') + + # emotion + ########### + emotion_original_female, emotion_original_male = split_female_male(emotion_original) + emotion_baseline_female, emotion_baseline_male = split_female_male(emotion_baseline) + emotion_proposed_female, emotion_proposed_male = split_female_male(emotion_proposed) + emotion_prompt_female, emotion_prompt_male = split_female_male(emotion_prompt) + emotion_prompt_female = make_emotion_prompts(emotion_prompt_female, "f") + emotion_prompt_male = make_emotion_prompts(emotion_prompt_male, "m") + + print(cramers_v(emotion_original_female)) + print(cramers_v(emotion_baseline_female)) + print(cramers_v(emotion_proposed_female)) + print(cramers_v(emotion_prompt_female)) + print(cramers_v(emotion_original_male)) + print(cramers_v(emotion_baseline_male)) + print(cramers_v(emotion_proposed_male)) + print(cramers_v(emotion_prompt_male)) + + # valence/arousal + ################### + valence_original_female, valence_original_male = split_female_male(valence_original) + arousal_original_female, arousal_original_male = split_female_male(arousal_original) + valence_baseline_female, valence_baseline_male = split_female_male(valence_baseline) + arousal_baseline_female, arousal_baseline_male = split_female_male(arousal_baseline) + valence_proposed_female, valence_proposed_male = split_female_male(valence_proposed) + arousal_proposed_female, arousal_proposed_male = split_female_male(arousal_proposed) + valence_prompt_female, valence_prompt_male = split_female_male(valence_prompt) + arousal_prompt_female, arousal_prompt_male = split_female_male(arousal_prompt) + valence_prompt_female = make_emotion_prompts(valence_prompt_female, "f") + valence_prompt_male = make_emotion_prompts(valence_prompt_male, "m") + arousal_prompt_female = make_emotion_prompts(arousal_prompt_female, "f") + arousal_prompt_male = make_emotion_prompts(arousal_prompt_male, "m") + + mvalence_original_female = get_mean_rating_nested(remove_outliers(valence_original_female)) + mvalence_original_male = get_mean_rating_nested(remove_outliers(valence_original_male)) + mvalence_baseline_female = get_mean_rating_nested(remove_outliers(valence_baseline_female)) + mvalence_baseline_male = get_mean_rating_nested(remove_outliers(valence_baseline_male)) + mvalence_proposed_female = get_mean_rating_nested(remove_outliers(valence_proposed_female)) + mvalence_proposed_male = get_mean_rating_nested(remove_outliers(valence_proposed_male)) + mvalence_prompt_female = get_mean_rating_nested(remove_outliers(valence_prompt_female)) + mvalence_prompt_male = get_mean_rating_nested(remove_outliers(valence_prompt_male)) + + marousal_original_female = get_mean_rating_nested(remove_outliers(arousal_original_female)) + marousal_original_male = get_mean_rating_nested(remove_outliers(arousal_original_male)) + marousal_baseline_female = get_mean_rating_nested(remove_outliers(arousal_baseline_female)) + marousal_baseline_male = get_mean_rating_nested(remove_outliers(arousal_baseline_male)) + marousal_proposed_female = get_mean_rating_nested(remove_outliers(arousal_proposed_female)) + marousal_proposed_male = get_mean_rating_nested(remove_outliers(arousal_proposed_male)) + marousal_prompt_female = get_mean_rating_nested(remove_outliers(arousal_prompt_female)) + marousal_prompt_male = get_mean_rating_nested(remove_outliers(arousal_prompt_male)) + + print(marousal_original_female) + print(marousal_original_male) + + print(marousal_baseline_female) + print(marousal_baseline_male) + + print(marousal_proposed_female) + print(marousal_proposed_male) + + print(marousal_prompt_female) + print(marousal_prompt_male) + + + # make plots + ################# + ################# + + pie_barplot_pref_total(pref_total, os.path.join(save_dir, f"pref_total_pie.pdf")) + + pie_barplot_pref(pref_female, os.path.join(save_dir, f"pref_pie.pdf")) diff --git a/run_training_pipeline.py b/run_training_pipeline.py index 91dfa591..f2ba98b6 100644 --- a/run_training_pipeline.py +++ b/run_training_pipeline.py @@ -13,6 +13,21 @@ from TrainingInterfaces.TrainingPipelines.ToucanTTS_IntegrationTest import run as tt_integration_test from TrainingInterfaces.TrainingPipelines.ToucanTTS_MetaCheckpoint import run as meta from TrainingInterfaces.TrainingPipelines.ToucanTTS_Nancy import run as nancy +from TrainingInterfaces.TrainingPipelines.ToucanTTS_Ravdess import run as ravdess +from TrainingInterfaces.TrainingPipelines.ToucanTTS_Ravdess_sent_emb import run as ravdess_sent +from TrainingInterfaces.TrainingPipelines.ToucanTTS_ESDS import run as esds +from TrainingInterfaces.TrainingPipelines.ToucanTTS_ESDS_sent_emb import run as esds_sent +from TrainingInterfaces.TrainingPipelines.ToucanTTS_TESS import run as tess +from TrainingInterfaces.TrainingPipelines.ToucanTTS_LibriTTS import run as libri +from TrainingInterfaces.TrainingPipelines.ToucanTTS_LibriTTS_sent_emb import run as libri_sent +from TrainingInterfaces.TrainingPipelines.ToucanTTS_LibriTTSR import run as librir +from TrainingInterfaces.TrainingPipelines.ToucanTTS_LibriTTSR_sent_emb import run as librir_sent +from TrainingInterfaces.TrainingPipelines.ToucanTTS_LJSpeech import run as lj +from TrainingInterfaces.TrainingPipelines.ToucanTTS_LJSpeech_sent_emb import run as lj_sent +from TrainingInterfaces.TrainingPipelines.ToucanTTS_Sent_Pretraining import run as sent_pre +from TrainingInterfaces.TrainingPipelines.ToucanTTS_Baseline_Pretraining import run as base_pre +from TrainingInterfaces.TrainingPipelines.ToucanTTS_Sent_Finetuning import run as sent_fine +from TrainingInterfaces.TrainingPipelines.ToucanTTS_Baseline_Finetuning import run as base_fine from TrainingInterfaces.TrainingPipelines.finetuning_example import run as fine_tuning_example from TrainingInterfaces.TrainingPipelines.pretrain_aligner import run as aligner @@ -23,9 +38,22 @@ "fs_it" : fs_integration_test, "tt_it" : tt_integration_test, # regular ToucanTTS pipelines - "nancy" : nancy, - "nancystoch" : nancystoch, - "meta" : meta, + "nancy" : nancy, + "nancystoch": nancystoch, + "meta" : meta, + "libri" : libri, + "librir" : librir, + "libri_sent" : libri_sent, + "librir_sent" : librir_sent, + "ravdess" : ravdess, + "ravdess_sent" : ravdess_sent, + "esds" : esds, + "esds_sent" : esds_sent, + "tess" : tess, + "sent_pre" : sent_pre, + "base_pre" : base_pre, + "sent_fine" : sent_fine, + "base_fine" : base_fine, # training vocoders (not recommended, best to use provided checkpoint) "avocodo" : hifi_codo, "bigvgan" : bigvgan, diff --git a/run_weight_averaging.py b/run_weight_averaging.py index 14d111f8..5d51a5b2 100644 --- a/run_weight_averaging.py +++ b/run_weight_averaging.py @@ -24,8 +24,23 @@ def load_net_toucan(path): net = ToucanTTS(lang_embs=None) net.load_state_dict(check_dict["model"]) except RuntimeError: - net = ToucanTTS(lang_embs=None, utt_embed_dim=None) - net.load_state_dict(check_dict["model"]) + try: + net = ToucanTTS(lang_embs=None, utt_embed_dim=None) + net.load_state_dict(check_dict["model"]) + except RuntimeError: + try: + print("Loading baseline architecture") + net = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + static_speaker_embed=True) + net.load_state_dict(check_dict["model"]) + except RuntimeError: + print("Loading sent emb architecture") + net = ToucanTTS(lang_embs=None, + utt_embed_dim=512, + sent_embed_dim=768, + static_speaker_embed=True) + net.load_state_dict(check_dict["model"]) except RuntimeError: try: net = StochasticToucanTTS() @@ -132,7 +147,7 @@ def make_best_in_all(): continue averaged_model, _ = average_checkpoints(checkpoint_paths, load_func=load_net_bigvgan) save_model_for_use(model=averaged_model, name=os.path.join(MODELS_DIR, model_dir, "best.pt"), dict_name="generator") - elif "ToucanTTS" in model_dir: + elif "ToucanTTS_Sent" in model_dir: checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=os.path.join(MODELS_DIR, model_dir), n=3) if checkpoint_paths is None: continue diff --git a/run_word_emb_test_suite.py b/run_word_emb_test_suite.py new file mode 100644 index 00000000..d32b6c22 --- /dev/null +++ b/run_word_emb_test_suite.py @@ -0,0 +1,106 @@ +import os + +import torch + +from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface + + +def the_raven_and_the_fox(version, model_id="Meta", exec_device="cpu", speaker_reference=None, vocoder_model_path=None, biggan=False, word_emb_extractor=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, tts_model_path=model_id, vocoder_model_path=vocoder_model_path, faster_vocoder=not biggan, word_emb_extractor=word_emb_extractor) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + + for i, sentence in enumerate(["Master Raven, on a perched tree, was holding a cheese in his beak.", + "Master Fox, attracted by the smell, told him more or less this language:", + "And hello, Master Raven, how pretty you are! How beautiful you seem to me!", + "No lie, if your bearing is anything like your plumage, you are the Phoenix of the hosts in these woods.", + "At these words the Raven does not feel happy, and to show his beautiful voice, he opens a wide beak, drops his prey.", + "The Fox seized it, and said: My good Sir, learn that every flatterer lives at the expense of the one who listens to him.", + "This lesson is worth a cheese without doubt.", + "The ashamed and confused Raven swore, but a little late, that he would not be caught again.", + "Master Raven, on a perched tree, was holding a cheese in his beak. Master Fox, attracted by the smell, told him more or less this language: And hello, Master Raven, how pretty you are! How beautiful you seem to me! No lie, if your bearing is anything like your plumage, you are the Phoenix of the hosts in these woods. At these words the Raven does not feel happy, and to show his beautiful voice, he opens a wide beak, drops his prey. The Fox seized it, and said: My good Sir, learn that every flatterer lives at the expense of the one who listens to him. This lesson is worth a cheese without doubt. The ashamed and confused Raven swore, but a little late, that he would not be caught again." + ]): + tts.read_to_file(text_list=[sentence], file_location=f"audios/{version}/The_raven_and_the_fox_{i}.wav") + +def poem(version, model_id="Meta", exec_device="cpu", speaker_reference=None, vocoder_model_path=None, biggan=False, word_emb_extractor=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, tts_model_path=model_id, vocoder_model_path=vocoder_model_path, faster_vocoder=not biggan, word_emb_extractor=word_emb_extractor) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + + for i, sentence in enumerate(['Once upon a midnight dreary, while I pondered, weak, and weary,', + 'Over many a quaint, and curious volume of forgotten lore,', + 'While I nodded, nearly napping, suddenly, there came a tapping,', + 'As of someone gently rapping, rapping at my chamber door.', + 'Tis some visitor, I muttered, tapping at my chamber door,', + 'Only this, and nothing more.', + 'Ah, distinctly, I remember, it was in the bleak December,', + 'And each separate dying ember, wrought its ghost upon the floor.', + 'Eagerly, I wished the morrow, vainly, I had sought to borrow', + 'From my books surcease of sorrow, sorrow, for the lost Lenore,', + 'For the rare and radiant maiden, whom the angels name Lenore,', + 'Nameless here, for evermore.', + 'And the silken, sad, uncertain, rustling of each purple curtain', + 'Thrilled me, filled me, with fantastic terrors, never felt before.']): + tts.read_to_file(text_list=[sentence], file_location=f"audios/{version}/Poem_{i}.wav") + +def test_sentence(version, model_id="Meta", exec_device="cpu", speaker_reference=None, vocoder_model_path=None, biggan=False, word_emb_extractor=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, tts_model_path=model_id, vocoder_model_path=vocoder_model_path, faster_vocoder=not biggan, word_emb_extractor=word_emb_extractor) + tts.set_language("en") + #sentence = "Well, she said, if I had had your bringing up I might have had as good a temper as you, but now I don't believe I ever shall." + sentence = "result in some degree of interference with the personal liberty of those involved." + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + tts.read_to_file(text_list=[sentence], file_location=f"audios/{version}/test_sentence.wav") + +def test_controllable(version, model_id="Meta", exec_device="cpu", speaker_reference=None, vocoder_model_path=None, biggan=False, word_emb_extractor=None): + os.makedirs("audios", exist_ok=True) + os.makedirs(f"audios/{version}", exist_ok=True) + tts = ToucanTTSInterface(device=exec_device, tts_model_path=model_id, vocoder_model_path=vocoder_model_path, faster_vocoder=not biggan, word_emb_extractor=word_emb_extractor) + tts.set_language("en") + if speaker_reference is not None: + tts.set_utterance_embedding(speaker_reference) + + for i, sentence in enumerate(['I am so happy to see you!', + 'Today is a beautiful day and the sun is shining.', + 'He seemed to be quite lucky as he was smiling at me.', + 'She laughed and said: This is so funny.', + 'No, this is horrible!', + 'I am so sad, why is this so depressing?', + 'Be careful, cried the woman.', + 'This makes me feel bad.']): + tts.read_to_file(text_list=[sentence], file_location=f"audios/{version}/Controllable_{i}.wav") + + +if __name__ == '__main__': + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = f"5" + exec_device = "cuda:0" if torch.cuda.is_available() else "cpu" + #exec_device = "cpu" + print(f"running on {exec_device}") + + use_speaker_reference = False + use_word_emb = True + + if use_word_emb: + #from Preprocessing.word_embeddings.BERTWordEmbeddingExtractor import BERTWordEmbeddingExtractor + #word_embedding_extractor = BERTWordEmbeddingExtractor() + from Preprocessing.word_embeddings.EmotionRoBERTaWordEmbeddingExtractor import EmotionRoBERTaWordEmbeddingExtractor + word_embedding_extractor = EmotionRoBERTaWordEmbeddingExtractor() + else: + word_embedding_extractor = None + + if use_speaker_reference: + #speaker_reference = "/mount/resources/speech/corpora/Blizzard2013/train/segmented/wavn/CA-BB-05-19.wav" + speaker_reference = "/mount/resources/speech/corpora/LibriTTS/all_clean/1638/84448/1638_84448_000057_000006.wav" + else: + speaker_reference = None + + test_controllable(version="ToucanTTS_02_Blizzard2013_word_emb_emoBERT", model_id="02_Blizzard2013_word_emb_emoBERT", exec_device=exec_device, vocoder_model_path=None, biggan=True, speaker_reference=speaker_reference, word_emb_extractor=word_embedding_extractor) diff --git a/visualize_sent_embs_tsne.py b/visualize_sent_embs_tsne.py new file mode 100644 index 00000000..6a8dd74f --- /dev/null +++ b/visualize_sent_embs_tsne.py @@ -0,0 +1,87 @@ +import os + +from tqdm import tqdm +import torch +import numpy as np +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE + +from Utility.storage_config import PREPROCESSING_DIR +from Preprocessing.sentence_embeddings.EmotionRoBERTaSentenceEmbeddingExtractor import EmotionRoBERTaSentenceEmbeddingExtractor as SentenceEmbeddingExtractor + +def visualize_sent_embs(sent_embs, save_dir): + # Prepare the data for t-SNE + data_points = np.vstack([embedding.numpy() for embeddings in sent_embs.values() for embedding in embeddings]) + labels = np.concatenate([[i] * len(sent_embs[emotion]) for i, emotion in enumerate(sent_embs)]) + + # Apply t-SNE to reduce dimensionality to 2D + tsne = TSNE(n_components=2, random_state=42, init='pca', learning_rate='auto') + tsne_result = tsne.fit_transform(data_points) + + # Plot the t-SNE points with colors corresponding to emotions + color_mapping = { + "anger": "red", + "disgust": "purple", + "fear": "black", + "joy": "green", + "neutral": "blue", + "sadness": "gray", + "surprise": "orange" + } + plt.figure(figsize=(10, 8)) + for i, emotion in enumerate(sent_embs): + indices = np.where(labels == i)[0] + plt.scatter(tsne_result[indices, 0], tsne_result[indices, 1], label=emotion, color=color_mapping[emotion]) + + plt.legend() + # Save the figure + plt.savefig(save_dir, bbox_inches='tight') + plt.close() + +if __name__ == '__main__': + os.makedirs(os.path.join(PREPROCESSING_DIR, "Evaluation", "plots"), exist_ok=True) + save_dir = os.path.join(PREPROCESSING_DIR, "Evaluation", "plots") + + #train_sent_embs = torch.load(os.path.join(PREPROCESSING_DIR, "Yelp", "emotion_prompts_balanced_10000_sent_embs_emoBERTcls.pt"), map_location='cpu') + #visualize_sent_embs(train_sent_embs, os.path.join(save_dir, 'tsne_train_sent_embs.png')) + + #train_sent_embs_BERT = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "emotion_prompts_balanced_10000_sent_embs_BERT.pt"), map_location='cpu') + #visualize_sent_embs(train_sent_embs_BERT, os.path.join(save_dir, 'tsne_train_sent_embs_BERT.png')) + + train_sent_embs_BERT = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "emotion_prompts_balanced_10000_sent_embs_stpara.pt"), map_location='cpu') + visualize_sent_embs(train_sent_embs_BERT, os.path.join(save_dir, 'tsne_train_sent_embs_stpara.png')) + + if not os.path.exists(os.path.join(PREPROCESSING_DIR, "Evaluation", "test_dailydialogues_sent_embs_emoBERTcls.pt")): + print("Extracting test sent embs...") + sent_emb_extractor = SentenceEmbeddingExtractor(pooling='cls', device='cuda:5') + test_sentences = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "test_sentences.pt"), map_location='cpu') + for dataset, emotion_to_sents in test_sentences.items(): + for emotion, sentences in emotion_to_sents.items(): + test_sentences[dataset][emotion] = sentences[:50] + + test_dailydialogues_sent_embs = {"anger":[], "joy":[], "neutral":[], "sadness":[], "surprise":[]} + test_tales_sent_embs = {"anger":[], "joy":[], "neutral":[], "sadness":[], "surprise":[]} + + # dailydialogues + for emotion, sents in tqdm(list(test_sentences['dailydialogues'].items())): + for sent in sents: + sent_emb = sent_emb_extractor.encode(sentences=[sent]).squeeze() + test_dailydialogues_sent_embs[emotion].append(sent_emb) + torch.save(test_dailydialogues_sent_embs, os.path.join(PREPROCESSING_DIR, "Evaluation", "test_dailydialogues_sent_embs_emoBERTcls.pt")) + + # tales + for emotion, sents in tqdm(list(test_sentences['tales'].items())): + for sent in sents: + sent_emb = sent_emb_extractor.encode(sentences=[sent]).squeeze() + test_tales_sent_embs[emotion].append(sent_emb) + torch.save(test_tales_sent_embs, os.path.join(PREPROCESSING_DIR, "Evaluation", "test_tales_sent_embs_emoBERTcls.pt")) + else: + test_dailydialogues_sent_embs = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "test_dailydialogues_sent_embs_emoBERTcls.pt"), map_location='cpu') + test_tales_sent_embs = torch.load(os.path.join(PREPROCESSING_DIR, "Evaluation", "test_tales_sent_embs_emoBERTcls.pt"), map_location='cpu') + #visualize_sent_embs(test_dailydialogues_sent_embs, os.path.join(save_dir, 'tsne_test_dailydialogues_sent_embs.png')) + #visualize_sent_embs(test_tales_sent_embs, os.path.join(save_dir, 'tsne_test_tales_sent_embs.png')) + + test_combined_sent_embs = {} + for emotion in test_dailydialogues_sent_embs: + test_combined_sent_embs[emotion] = test_dailydialogues_sent_embs[emotion] + test_tales_sent_embs[emotion] + #visualize_sent_embs(test_combined_sent_embs, os.path.join(save_dir, 'tsne_test_combined_sent_embs.png')) diff --git a/vocode_original.py b/vocode_original.py new file mode 100644 index 00000000..1fcc632a --- /dev/null +++ b/vocode_original.py @@ -0,0 +1,47 @@ +from Preprocessing.AudioPreprocessor import AudioPreprocessor +import soundfile as sf +import torch +from numpy import trim_zeros +from InferenceInterfaces.InferenceArchitectures.InferenceBigVGAN import BigVGAN +from InferenceInterfaces.InferenceArchitectures.InferenceAvocodo import HiFiGANGenerator +import soundfile +from Utility.utils import float2pcm +import os +from Utility.storage_config import MODELS_DIR + +if __name__ == '__main__': + paths_female = ["/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore/0015/Angry/0015_000605.wav", + "/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore/0015/Happy/0015_000814.wav", + "/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore/0015/Neutral/0015_000148.wav", + "/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore/0015/Sad/0015_001088.wav", + "/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore/0015/Surprise/0015_001604.wav"] + ids_female = [0, 0, 0, 0, 1] + + paths_male = ["/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore/0014/Angry/0014_000479.wav", + "/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore/0014/Happy/0014_001048.wav", + "/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore/0014/Neutral/0014_000061.wav", + "/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore/0014/Sad/0014_001169.wav", + "/mount/resources/speech/corpora/Emotional_Speech_Dataset_Singapore/0014/Surprise/0014_001639.wav"] + ids_male = [1, 1, 1, 1, 0] + + emotions = ["anger", "joy", "neutral", "sadness", "surprise"] + + vocoder_model_path = os.path.join(MODELS_DIR, "Avocodo", "best.pt") + mel2wav = HiFiGANGenerator(path_to_weights=vocoder_model_path).to(torch.device('cpu')) + mel2wav.remove_weight_norm() + mel2wav.eval() + + for i, path in enumerate(paths_male): + wave, sr = sf.read(path) + ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=True, device='cpu') + norm_wave = ap.audio_to_wave_tensor(normalize=True, audio=wave) + norm_wave = torch.tensor(trim_zeros(norm_wave.numpy())) + spec = ap.audio_to_mel_spec_tensor(audio=norm_wave, normalize=False, explicit_sampling_rate=16000).cpu() + + wave = mel2wav(spec) + silence = torch.zeros([10600]) + wav = silence.clone() + wav = torch.cat((wav, wave, silence), 0) + + wav = [val for val in wav.detach().numpy() for _ in (0, 1)] # doubling the sampling rate for better compatibility (24kHz is not as standard as 48kHz) + soundfile.write(file=f"./audios/Original/male/orig_{emotions[i]}_{ids_male[i]}.flac", data=float2pcm(wav), samplerate=48000, subtype="PCM_16") \ No newline at end of file