Skip to content

Commit

Permalink
support for multiple codecs, context patterns and multilingual model …
Browse files Browse the repository at this point in the history
…in the olde codebase

Signed-off-by: Paarth Neekhara <[email protected]>
  • Loading branch information
paarthneekhara committed Dec 24, 2023
1 parent 8c30c24 commit 4296e43
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,13 @@ def _get_default_text_tokenizer_conf():
return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer))


def pad_text_to_speech_dims(text_tensor, pad_id):
token_len = text_tensor.shape[0]
empty_padding = torch.ones((7, token_len), dtype=text_tensor.dtype, device=text_tensor.device) * pad_id
return torch.cat((text_tensor.unsqueeze(0), empty_padding), dim=0)


tokenizer_config = _get_default_text_tokenizer_conf()
phoneme_tokenizer = instantiate(tokenizer_config).text_tokenizer


def pad_text_to_speech_dims(text_tensor, pad_id):
def pad_text_to_speech_dims(text_tensor, pad_id, pad_size=7):
token_len = text_tensor.shape[0]
empty_padding = torch.ones((7, token_len), dtype=text_tensor.dtype, device=text_tensor.device) * pad_id
empty_padding = torch.ones((pad_size, token_len), dtype=text_tensor.dtype, device=text_tensor.device) * pad_id
return torch.cat((text_tensor.unsqueeze(0), empty_padding), dim=0)


Expand Down Expand Up @@ -129,6 +123,12 @@ def __init__(
attention_prior_scaling_factor: Optional[float] = 1.0,
cross_attention_epsilon: Optional[float] = 0.0,
lm_vocab_size: Optional[int] = 30000,
num_speech_codebooks: Optional[int] = 8,
codebook_fps: Optional[int] = 75,
add_special_tokens_to_only_first_codebook: Optional[bool] = False,
context_pattern: Optional[str] = "parallel",
context_duration_min: Optional[float] = 3.0,
context_duration_max: Optional[float] = 5.0,
**kwargs,
):
"""
Expand Down Expand Up @@ -172,6 +172,13 @@ def __init__(
self.cross_attention_epsilon = cross_attention_epsilon # value of prior for context tokens (b/w 0 and 1)
assert self.cross_attention_epsilon >= 0.0 and self.cross_attention_epsilon <= 1.0
self.lm_vocab_size = lm_vocab_size
self.num_speech_codebooks = num_speech_codebooks
self.codebook_fps = codebook_fps
self.add_special_tokens_to_only_first_codebook = add_special_tokens_to_only_first_codebook
# context_pattern and duration arguments are supported only if context_type is REFSPEAKERCODEC in the manifest
self.context_pattern = context_pattern
self.context_duration_min = context_duration_min
self.context_duration_max = context_duration_max

# Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type
if sup_data_path is not None:
Expand Down Expand Up @@ -261,15 +268,15 @@ def load_data(self, dataset):

if doc["context_type"] == "SPEECH":
assert "context_duration" in doc, f"context_duration key not in document {doc}"
approx_context_len = min(doc["context_duration"] * 76 * 0.3, 400)
approx_context_len = self.context_duration_max * (self.codebook_fps + 1)
elif "Remove Noise" in question_in_manifest:
approx_context_len = doc["answer_duration"] * 76
approx_context_len = doc["answer_duration"] * (self.codebook_fps + 1)
elif "Extract Speaker Audio" in question_in_manifest:
approx_context_len = doc["answer_duration"] * 76 + 400 # 400 is the max ref speaker audio
approx_context_len = doc["answer_duration"] * (self.codebook_fps + 1) + 400 # 400 is the max ref speaker audio
elif ("Text to speech this" in question_in_manifest) or ('Phoneme TTS' in question_in_manifest):
approx_context_len = 400 # Max length of Ref TTS audio
elif "Edit Speech" in question_in_manifest:
approx_context_len = doc["answer_duration"] * 76
approx_context_len = doc["answer_duration"] * (self.codebook_fps + 1)
else:
raise NotImplementedError(f"Unknown context type {doc['context_type']}")

Expand All @@ -280,10 +287,10 @@ def load_data(self, dataset):

if doc["answer_type"] in ["SPEECH", "AUDIOCODEC"]:
assert "answer_duration" in doc, f"answer_duration key not in document {doc}"
approx_answer_len = doc["answer_duration"] * 76
approx_answer_len = doc["answer_duration"] * (self.codebook_fps + 1)
if self.seq_pattern == "delay_parallel":
# In delay parallel, there is padding so add 8 frames
approx_answer_len = approx_answer_len + 8
approx_answer_len = approx_answer_len + self.num_speech_codebooks
else:
approx_answer_len = len(doc["answer"].split(' ')) + 3

Expand Down Expand Up @@ -385,9 +392,14 @@ def __getitem__(self, idx):
question_tokens, question_tokens_len = self.list_to_tensor(question_tokens)

if doc["question_type"] == "TEXT" and doc["context_type"] != "TEXT":
question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id)
question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)
if doc["context_type"] == "TEXT" and doc["question_type"] != "TEXT":
context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id)
context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)
if doc["context_type"] == "TEXT" and doc["question_type"] == "TEXT":
# When context is text for speaker id conditioning.
context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)
question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)

context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1)

# get answer ids
Expand Down Expand Up @@ -458,7 +470,7 @@ def __getitem__(self, idx):
)
dec_input_new = []
dec_labels_new = []
for _c in range(8):
for _c in range(self.num_speech_codebooks):
st = num_codebooks - _c
et_decoder_input = dec_input_padded.shape[1] - _c
et_decoder_labels = dec_labels_padded.shape[1] - _c
Expand Down Expand Up @@ -527,8 +539,11 @@ def list_to_tensor(self, element, fill=False):
ret = []
for e in element:
if isinstance(e, int):
tmp = torch.full((8, 1), e if fill else -1)
tmp[7] = e
tmp = torch.full((self.num_speech_codebooks, 1), e if fill else -1)
tmp[self.num_speech_codebooks-1] = e
if self.add_special_tokens_to_only_first_codebook:
# Fill zeros in all other codebooks (to avoid out of range when getting embeddings)
tmp[1:] = 0
else:
tmp = e
ret.append(tmp)
Expand Down Expand Up @@ -695,8 +710,27 @@ def _get_tokens(self, doc, field, field_data):
reference_codec_path = rng.choice(reference_codec_paths)
field_tokens = torch.load(reference_codec_path).long()
field_tokens[0] = (field_tokens[0] + self.speech_offset).long()
reference_codec_len = rng.randint(240, 400)
field_tokens = [field_tokens[:, :reference_codec_len]]
_min_len = int(self.context_duration_min * self.codebook_fps)
_max_len = int(self.context_duration_max * self.codebook_fps)
reference_codec_len = rng.randint(_min_len, _max_len)
reference_codec_len = min(reference_codec_len, field_tokens.shape[1])
si = rng.randint(0, field_tokens.shape[1] - reference_codec_len)
field_tokens = field_tokens[:, si:si+reference_codec_len]
if self.context_pattern == "delay_parallel":
field_tokens = torch.cat([
torch.zeros(self.num_speech_codebooks, self.num_speech_codebooks).long(),
field_tokens,
torch.zeros(self.num_speech_codebooks, self.num_speech_codebooks).long()
], dim=1)
new_field_tokens = []
for _c in range(self.num_speech_codebooks):
st = self.num_speech_codebooks - _c
et = field_tokens.shape[1] - _c
new_field_tokens.append(field_tokens[_c, st:et])
field_tokens = torch.stack(new_field_tokens, dim=0)

field_tokens = [field_tokens]

elif doc[f"{field}_type"] == 'SEPARATIONCODECS':
mixed_codec_path, reference_codec_paths = field_data.split(",")
reference_codec_paths = reference_codec_paths.split(";")
Expand All @@ -706,7 +740,7 @@ def _get_tokens(self, doc, field, field_data):
reference_codec_len = rng.randint(240, 400)
reference_codec = reference_codec[:, :reference_codec_len]
# MIXED AUDIO AND REF AUDIO ARE SEPARATED BY 8 TIMESTEPS OF 1023 TOKENS IN ALL CODEBOOKS
mask_tokens = (torch.ones(8, 8) * 1023).long()
mask_tokens = (torch.ones(self.num_speech_codebooks, self.num_speech_codebooks) * 1023).long()
field_tokens = torch.cat([mixed_codec, mask_tokens, reference_codec], dim=1)
field_tokens[0] = (field_tokens[0] + self.speech_offset).long()
field_tokens = [field_tokens]
Expand All @@ -718,7 +752,7 @@ def _get_tokens(self, doc, field, field_data):
mask_len = min(mask_len, reference_codec.shape[1] - 80)
mask_start = rng.randint(0, reference_codec.shape[1] - mask_len)
mask_end = mask_start + mask_len
mask_tokens = (torch.ones(8, 8) * 1023).long()
mask_tokens = (torch.ones(self.num_speech_codebooks, self.num_speech_codebooks) * 1023).long()
seg1 = reference_codec[:, :mask_start]
seg2 = reference_codec[:, mask_end:]
field_tokens = torch.cat([seg1, mask_tokens, seg2], dim=1)
Expand Down
Loading

0 comments on commit 4296e43

Please sign in to comment.