Skip to content

Commit

Permalink
update text context pair script
Browse files Browse the repository at this point in the history
Signed-off-by: Shehzeen Hussain <[email protected]>
  • Loading branch information
shehzeen committed Jan 7, 2025
1 parent d848f02 commit 6e98969
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 314 deletions.
13 changes: 7 additions & 6 deletions examples/tts/t5tts_commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,13 @@ We pair a list of challenging texts with context audios from from Riva and Libri

```
python scripts/t5tts/dpo/create_text_contextpairs.py \
--challenging_texts /Data/challenging_texts_nemollm.txt \
--riva_manifest /Data/CodecDatasets/speechllm_codecdatasets_new/manifests/rivaLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM.json \
--libri_manifest /Data/CodecDatasets/speechllm_codecdatasets_new/manifests/libri360__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM.json \
--riva_textcontext_manifest /Data/CodecDatasets/speechllm_codecdatasets_new/manifests/rivaLindyRodneyTextContext__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM.json \
--output_manifest /Data/CodecDatasets/speechllm_codecdatasets_new/manifests/dpo_textcontext_pairs.json \
--nsamples_perpair 6 \
--challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \
--regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \
--regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \
--audio_contexts /Data/DPOPairsInputData/audio_context_list.json \
--text_contexts /Data/DPOPairsInputData/text_context_list.txt \
--output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json \
--nsamples_perpair 6 ;
```
Each pair is repeated `nsamples_perpair` times which specifies how many samples we want to generate for each pair. The output manifest serves as the input for the next step.

Expand Down
290 changes: 102 additions & 188 deletions scripts/t5tts/dpo/create_text_contextpairs.py
Original file line number Diff line number Diff line change
@@ -1,206 +1,120 @@
import argparse
import json
import random
import copy
import argparse
import os

def corrupt_text(question_text):
# randomly repeat word or delete a word from the question
question_words = question_text.split(" ")
if random.random() < 0.5:
# repeat a word
word_idx = random.randint(0, len(question_words) - 1)
word = question_words[word_idx]
# Repeat one occurence of the word
question_text = question_text.replace(word, word + " " + word, 1)
else:
# delete a word
word_idx = random.randint(0, len(question_words) - 1)
word = question_words[word_idx]
question_text = question_text.replace(word, "", 1)

return question_text

def read_records(manifest_path):
with open(manifest_path, 'r') as f:
lines = f.readlines()
records = []
for line in lines:
records.append(json.loads(line.strip()))
return records

def write_records(fp, records):
def write_manifest(fp, records):
with open(fp, "w") as f:
for record in records:
f.write(json.dumps(record) + "\n")
print("Wrote {} records to: {}".format(len(records), fp))


def get_audio_filepath_from_codecpath(codec_path):
# "target_codes_1638_84447_1638_84447_000105_000001.pt"
# 1638/84447/1638_84447_000105_000001.wav
codec_file_name = codec_path.split("/")[-1].split(".")[0]
if "Lindy" in codec_file_name or "Rodney" in codec_file_name:
# target_codes_Rodney_22khz_DROP_RODNEY_DROP_001185.pt
# Rodney/44khz/DROP/RODNEY_DROP_000953.wav
speaker = "RODNEY" if "Rodney" in codec_file_name else "LINDY"
speaker_lowercase = "Rodney" if "Rodney" in codec_file_name else "Lindy"
emotion_dir = codec_file_name.split("_22khz_")[1].split("_{}".format(speaker))[0] # DROP
remaining_file_name = codec_file_name.split("_22khz_{}_".format(emotion_dir))[1]
audio_file_path = "{}/22khz/{}/{}.wav".format(speaker_lowercase, emotion_dir, remaining_file_name)
audio_file_path = os.path.join("/Data/RivaData/riva", audio_file_path)
# import ipdb; ipdb.set_trace()
# assert os.path.exists(audio_file_path), "File does not exist: {}".format(audio_file_path)
return audio_file_path
else:
speaker_name = codec_file_name.split("target_codes_")[1].split("_")[0]
chapter_name = codec_file_name.split("target_codes_")[1].split("_")[1]
remaining_file_name = codec_file_name.split("target_codes_{}_{}_".format(speaker_name, chapter_name))[1]
audio_file_path = "{}/{}/{}.wav".format(speaker_name, chapter_name, remaining_file_name)
audio_file_path = os.path.join("/Data/LibriTTS/train-clean-360/", audio_file_path)
# import ipdb; ipdb.set_trace()
assert os.path.exists(audio_file_path), "File does not exist: {}".format(audio_file_path)
return audio_file_path


parser = argparse.ArgumentParser()
parser.add_argument("--challenging_texts", type=str, default="/Data/challenging_texts_nemollm.txt")
parser.add_argument("--riva_manifest", type=str, default="/Data/CodecDatasets/speechllm_codecdatasets_new/manifests/rivaLindyRodney__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM.json")
parser.add_argument("--libri_manifest", type=str, default="/Data/CodecDatasets/speechllm_codecdatasets_new/manifests/libri360__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM.json")
parser.add_argument("--riva_textcontext_manifest", type=str, default="/Data/CodecDatasets/speechllm_codecdatasets_new/manifests/rivaLindyRodneyTextContext__phoneme__nemo_audio_21fps_8codebooks_2kcodes_v2bWithWavLM.json")
parser.add_argument("--tts_records", type=bool, default=False)
parser.add_argument("--output_manifest", type=str, default="/Data/CodecDatasets/speechllm_codecdatasets_new/manifests/dpo_textcontext_pairs.json")
parser.add_argument("--nsamples_perpair", type=int, default=6)
args = parser.parse_args()

challenging_texts = args.challenging_texts
riva_manifest = args.riva_manifest
libri_manifest = args.libri_manifest
riva_textcontext_manifest = args.riva_textcontext_manifest
output_manifest = args.output_manifest

riva_records = read_records(riva_manifest)
libri_records = read_records(libri_manifest)
riva_textcontext_records = read_records(riva_textcontext_manifest)

# libri_records_longer_than_8 = [ record for record in libri_records if record['answer_duration'] > 8 ]
# riva_records_longer_than_8 = [ record for record in riva_records if record['answer_duration'] > 8 ]
libri_records_longer_than_2 = [ record for record in libri_records if record['answer_duration'] > 2 ]
riva_records_longer_than_2 = [ record for record in riva_records if record['answer_duration'] > 2 ]

with open(challenging_texts, 'r') as f:
challenging_texts = f.readlines()

challenging_records = []
num_contexts_per_sample = 12
for challenging_text in challenging_texts:
challenging_text = challenging_text.strip()
for ci in range(num_contexts_per_sample):
if ci >= num_contexts_per_sample - 2:
# For last 20% of the challenging texts, make it more challenging by corrupting the text
# Randomly drops a word or repeats a word
print("Corrupting text: {}".format(challenging_text))
challenging_text = corrupt_text(challenging_text)
print("Corrupted text: {}".format(challenging_text))

challenging_record_template = {
'text': challenging_text,
'target_audio_codes_path': None,
'duration': 6.0,
'context_audio_codes_path' : None,
'context_audio_duration': None,
'speaker': None
}
libri_challenging_record = copy.deepcopy(challenging_record_template)
riva_challenging_record = copy.deepcopy(challenging_record_template)
def main():
"""
python scripts/t5tts/dpo/create_text_contextpairs.py \
--challenging_texts /Data/DPOPairsInputData/challenging_texts_nemollm.txt \
--regular_texts_for_audiocontext /Data/DPOPairsInputData/regular_texts_for_audiocontext.txt \
--regular_texts_for_textcontext /Data/DPOPairsInputData/regular_texts_for_textcontext.txt \
--audio_contexts /Data/DPOPairsInputData/audio_context_list.json \
--text_contexts /Data/DPOPairsInputData/text_context_list.txt \
--output_manifest /Data/DPOPairsInputData/text_context_pairs_v2.json
"""
parser = argparse.ArgumentParser(description='Create text-context pairs for DPO')
parser.add_argument("--challenging_texts", type=str, help="Text file containing challenging texts")
parser.add_argument("--regular_texts_for_audiocontext", type=str, help="Text file containing regular texts")
parser.add_argument("--regular_texts_for_textcontext", type=str, help="Text file containing regular texts")
parser.add_argument("--audio_contexts", type=str, help="Manifest containing audio contexts") # This manifest should contain 'context_audio_filepath', 'context_audio_duration' and (optionally) 'context_audio_codes_path'
parser.add_argument("--text_contexts", type=str, help="Text file containing text contexts")
parser.add_argument("--n_audio_contexts_per_challenging_text", type=int, default=10)
parser.add_argument("--n_audio_contexts_per_regular_text", type=int, default=1)
parser.add_argument("--n_text_contexts_per_challenging_text", type=int, default=3)
parser.add_argument("--n_text_contexts_per_regular_text", type=int, default=1)
parser.add_argument("--nsamples_perpair", type=int, default=6)
parser.add_argument("--output_manifest", type=str)
args = parser.parse_args()

with open(args.challenging_texts, 'r') as f:
challenging_texts = f.readlines()
challenging_texts = [text.strip() for text in challenging_texts if text.strip() != '']

with open(args.regular_texts_for_audiocontext, 'r') as f:
regular_texts_for_audiocontext = f.readlines()
regular_texts_for_audiocontext = [text.strip() for text in regular_texts_for_audiocontext if text.strip() != '']

with open(args.regular_texts_for_textcontext, 'r') as f:
regular_texts_for_textcontext = f.readlines()
regular_texts_for_textcontext = [text.strip() for text in regular_texts_for_textcontext if text.strip() != '']

with open(args.audio_contexts, 'r') as f:
audio_contexts = f.readlines()
audio_contexts = [json.loads(context.strip()) for context in audio_contexts if context.strip() != '']

with open(args.text_contexts, 'r') as f:
text_contexts = f.readlines()
text_contexts = [text for text in text_contexts if text.strip() != '']

all_records = []
dummy_audio_filepath = audio_contexts[0]['context_audio_filepath']
dummy_target_audio_codes_path = audio_contexts[0].get('context_audio_codes_path', None)
for challenging_text in challenging_texts:
for _ in range(args.n_audio_contexts_per_challenging_text):
audio_context = random.choice(audio_contexts)
record = create_audio_context_record(challenging_text, audio_context, 'challenging')
all_records.append(record)


sample_libri_record = random.choice(libri_records)
libri_challenging_record['context_audio_codes_path'] = sample_libri_record['context']
libri_challenging_record['context_audio_filepath'] = get_audio_filepath_from_codecpath(sample_libri_record['context'])
libri_challenging_record['context_audio_duration'] = sample_libri_record['context_duration']
libri_challenging_record['speaker'] = sample_libri_record['speaker']
libri_challenging_record['target_audio_codes_path'] = libri_challenging_record['context_audio_codes_path']
libri_challenging_record['audio_filepath'] = libri_challenging_record['context_audio_filepath']

sample_riva_record = random.choice(riva_records)
riva_challenging_record['context_audio_codes_path'] = sample_riva_record['context']
riva_challenging_record['context_audio_filepath'] = get_audio_filepath_from_codecpath(sample_riva_record['context'])
riva_challenging_record['context_audio_duration'] = sample_riva_record['context_duration']
riva_challenging_record['speaker'] = sample_riva_record['speaker']
riva_challenging_record['target_audio_codes_path'] = riva_challenging_record['context_audio_codes_path']
riva_challenging_record['audio_filepath'] = riva_challenging_record['context_audio_filepath']

sample_riva_textcontext_record = random.choice(riva_textcontext_records)
riva_textcontext_challenging_record = copy.deepcopy(challenging_record_template)
del riva_textcontext_challenging_record['context_audio_codes_path']
del riva_textcontext_challenging_record['context_audio_duration']
riva_textcontext_challenging_record['text'] = challenging_text
riva_textcontext_challenging_record['target_audio_codes_path'] = sample_riva_textcontext_record['answer']
riva_textcontext_challenging_record['audio_filepath'] = get_audio_filepath_from_codecpath(sample_riva_textcontext_record['answer'])
riva_textcontext_challenging_record['context_text'] = sample_riva_textcontext_record['context'].replace("TEXT CONTEXT: ", "Speaker and Emotion: ")

challenging_records.append(libri_challenging_record)
challenging_records.append(riva_challenging_record)
if ci == 0:
# dont need too many text context examples
challenging_records.append(riva_textcontext_challenging_record)

# regular libri records 50% of the challenging records
libri_subset_records = random.sample(libri_records_longer_than_2, int(len(challenging_records)/2.0) )
libri_regular_records = []
for libri_subset_record in libri_subset_records:
context_record = random.choice(libri_records)
record = {
'text': libri_subset_record['text'],
'target_audio_codes_path': context_record['context'],
'audio_filepath': get_audio_filepath_from_codecpath(context_record['context']),
'duration': 6.0,
'context_audio_codes_path' : context_record['context'],
'context_audio_filepath': get_audio_filepath_from_codecpath(context_record['context']),
'context_audio_duration': context_record['context_duration'],
}
libri_regular_records.append(record)
for _ in range(args.n_text_contexts_per_challenging_text):
text_context = random.choice(text_contexts)
record = create_text_context_record(challenging_text, text_context, dummy_audio_filepath, 'challenging', dummy_target_audio_codes_path)
all_records.append(record)


# regular riva records 20% of the challenging records
riva_subset_records = random.sample(riva_records_longer_than_2, int(len(challenging_records)/5.0))
riva_regular_records = []
for riva_subset_record in riva_subset_records:
context_record = random.choice(riva_records)
for regular_text in regular_texts_for_audiocontext:
for _ in range(args.n_audio_contexts_per_regular_text):
audio_context = random.choice(audio_contexts)
record = create_audio_context_record(regular_text, audio_context, 'regular')
all_records.append(record)

for regular_text in regular_texts_for_textcontext:
for _ in range(args.n_text_contexts_per_regular_text):
text_context = random.choice(text_contexts)
record = create_text_context_record(regular_text, text_context, dummy_audio_filepath, 'regular', dummy_target_audio_codes_path)
all_records.append(record)

random.shuffle(all_records)
repeated_records = []
for record in all_records:
for _ in range(args.nsamples_perpair):
repeated_records.append(record)

write_manifest(args.output_manifest, repeated_records)
write_manifest(args.output_manifest.replace(".json", "_tinysubset.json"), repeated_records[:100*args.nsamples_perpair])

def create_audio_context_record(text, audio_context, record_type):
record = {
'text': riva_subset_record['text'],
'target_audio_codes_path': context_record['context'],
'audio_filepath': get_audio_filepath_from_codecpath(context_record['context']),
'duration': 6.0,
'context_audio_codes_path' : context_record['context'],
'context_audio_filepath': get_audio_filepath_from_codecpath(context_record['context']),
'context_audio_duration': context_record['context_duration'],
'text' : text,
'duration' : 6.0, # Does not matter, avoids filtering out in DPO,
'audio_filepath': audio_context['context_audio_filepath'],
'context_audio_filepath': audio_context['context_audio_filepath'],
'context_audio_duration' : audio_context['context_audio_duration'],
'record_type' : record_type # challenging or regular
}
riva_regular_records.append(record)
if 'context_audio_codes_path' in audio_context:
record['context_audio_codes_path'] = audio_context['context_audio_codes_path']
record['target_audio_codes_path'] = audio_context['context_audio_codes_path']

return record

# riva textcontext records 5% of the challenging records
riva_textcontext_subset_records = random.sample(riva_textcontext_records, int(len(challenging_records)/20.0))
riva_textcontext_regular_records = []
for riva_textcontext_subset_record in riva_textcontext_subset_records:
context_record = random.choice(riva_textcontext_records)
def create_text_context_record(text, text_context, dummy_audio_filepath, record_type, target_audio_codes_path=None):
record = {
'text': riva_textcontext_subset_record['text'],
'target_audio_codes_path': context_record['answer'],
'audio_filepath': get_audio_filepath_from_codecpath(context_record['answer']),
'duration': 6.0,
'context_text' : context_record['context'].replace("TEXT CONTEXT: ", "Speaker and Emotion: "),
'text' : text,
'duration' : 6.0, # Does not matter, avoids filtering out in DPO,
'audio_filepath': dummy_audio_filepath,
'context_text' : text_context,
'record_type' : record_type # challenging or regular
}
riva_textcontext_regular_records.append(record)

all_records = challenging_records + libri_regular_records + riva_regular_records + riva_textcontext_regular_records
random.shuffle(all_records)
if target_audio_codes_path is not None:
record['target_audio_codes_path'] = target_audio_codes_path
return record

# Repeate each record nsamples_perpair times
repeated_records = []
for record in all_records:
for i in range(args.nsamples_perpair):
repeated_records.append(record)

write_records(output_manifest, repeated_records)
write_records(output_manifest.replace(".json", "_240subset.json"), repeated_records[:240])
if __name__ == '__main__':
main()
Loading

0 comments on commit 6e98969

Please sign in to comment.