Skip to content

Commit

Permalink
allow delay pattern in context, allow context to by just text (for sp…
Browse files Browse the repository at this point in the history
…eaker id conditioning)

Signed-off-by: Paarth Neekhara <[email protected]>
  • Loading branch information
paarthneekhara committed Dec 21, 2023
1 parent d673a27 commit 007e8e4
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __init__(
num_speech_codebooks: Optional[int] = 8,
codebook_fps: Optional[int] = 75,
add_special_tokens_to_only_first_codebook: Optional[bool] = False,
context_pattern: Optional[str] = "parallel",
context_duration_min: Optional[float] = 3.0,
context_duration_max: Optional[float] = 5.0,
**kwargs,
):
"""
Expand Down Expand Up @@ -177,6 +180,11 @@ def __init__(
self.num_speech_codebooks = num_speech_codebooks
self.codebook_fps = codebook_fps
self.add_special_tokens_to_only_first_codebook = add_special_tokens_to_only_first_codebook

# context_pattern and duration arguments are supported only if context_type is REFSPEAKERCODEC in the manifest
self.context_pattern = context_pattern
self.context_duration_min = context_duration_min
self.context_duration_max = context_duration_max

# Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type
if sup_data_path is not None:
Expand Down Expand Up @@ -295,7 +303,7 @@ def load_data(self, dataset):
approx_answer_len = doc["answer_duration"] * (self.codebook_fps + 1)
if self.seq_pattern == "delay_parallel":
# In delay parallel, there is padding so add 8 frames
approx_answer_len = approx_answer_len + 8
approx_answer_len = approx_answer_len + self.num_speech_codebooks
else:
approx_answer_len = len(doc["answer"].split(' ')) + 3

Expand Down Expand Up @@ -400,6 +408,9 @@ def __getitem__(self, idx):
question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)
if doc["context_type"] == "TEXT" and doc["question_type"] != "TEXT":
context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)
if doc["context_type"] == "TEXT" and doc["question_type"] == "TEXT":
context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)
question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)
context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1)

# get answer ids
Expand Down Expand Up @@ -710,8 +721,27 @@ def _get_tokens(self, doc, field, field_data):
reference_codec_path = rng.choice(reference_codec_paths)
field_tokens = torch.load(reference_codec_path).long()
field_tokens[0] = (field_tokens[0] + self.speech_offset).long()
reference_codec_len = rng.randint(240, 400)
field_tokens = [field_tokens[:, :reference_codec_len]]
_min_len = int(self.context_duration_min * self.codebook_fps)
_max_len = int(self.context_duration_max * self.codebook_fps)
reference_codec_len = rng.randint(_min_len, _max_len)
reference_codec_len = min(reference_codec_len, field_tokens.shape[1])
si = rng.randint(0, field_tokens.shape[1] - reference_codec_len)
field_tokens = field_tokens[:, si:si+reference_codec_len]
if self.context_pattern == "delay_parallel":
field_tokens = torch.cat([
torch.zeros(self.num_speech_codebooks, self.num_speech_codebooks).long(),
field_tokens,
torch.zeros(self.num_speech_codebooks, self.num_speech_codebooks).long()
], dim=1)
new_field_tokens = []
for _c in range(self.num_speech_codebooks):
st = self.num_speech_codebooks - _c
et = field_tokens.shape[1] - _c
new_field_tokens.append(field_tokens[_c, st:et])
field_tokens = torch.stack(new_field_tokens, dim=0)

field_tokens = [field_tokens]

elif doc[f"{field}_type"] == 'SEPARATIONCODECS':
mixed_codec_path, reference_codec_paths = field_data.split(",")
reference_codec_paths = reference_codec_paths.split(";")
Expand Down Expand Up @@ -1062,6 +1092,9 @@ def __getitem__(self, idx):
question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)
if doc["context_type"] == "TEXT" and doc["question_type"] != "TEXT":
context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)
if doc["context_type"] == "TEXT" and doc["question_type"] == "TEXT":
context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)
question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks-1)
# context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1)

# get answer ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
return_all_crossattention_probs = cfg.get('return_all_crossattention_probs', False)
num_cross_attention_heads = cfg.get('num_cross_attention_heads', 12)
self.lm_vocab_size = cfg.get('lm_vocab_size', 30000)
self.context_pattern = cfg.data.get('context_pattern', 'parallel')

self.speech_offset = speech_offset
self.speech_codebook_size = speech_codebook_size
Expand Down Expand Up @@ -381,13 +382,15 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):

return loss_mean

def convert_tokens_to_range(self, tokens, apply_offset_correction=True):
def convert_tokens_to_range(self, tokens, apply_offset_correction=True, pattern=None):
# convert tokens to range [0, 1024]
output_tokens = tokens.clone()
if apply_offset_correction:
output_tokens[0] = output_tokens[0] - self.speech_offset
output_tokens = torch.clamp(output_tokens, min=0, max=1023)
if self.cfg.get('seq_pattern', 'delay_parallel') == "delay_parallel":
if pattern is None:
pattern = self.cfg.get('seq_pattern', 'delay_parallel')
if pattern == "delay_parallel":
output_tokens_new = []
for _c in range(output_tokens.shape[0]):
si = _c
Expand Down Expand Up @@ -475,11 +478,11 @@ def fwd_output_and_loss_func(dataloader_iter, model):
(ti, t) for ti, t in enumerate(input_token_list) if t != 0 and t < self.speech_offset
]
context_end_step = input_token_list[0][0]
_context_tokens = context_and_question_tokens[0][:, :context_end_step].clone()
_context_tokens[0] = _context_tokens[0] - self.speech_offset
_context_tokens = torch.clamp(_context_tokens, min=0, max=1023)
_context_wav = self.decode_wav_from_codec_model(_context_tokens)
self.logger.experiment.add_audio("train_context_wav", _context_wav, self.global_step, self.sample_rate)
if context_end_step > self.num_speech_codebooks:
_context_tokens = context_and_question_tokens[0][:, :context_end_step].clone()
_context_tokens = self.convert_tokens_to_range(_context_tokens, pattern=self.context_pattern)
_context_wav = self.decode_wav_from_codec_model(_context_tokens)
self.logger.experiment.add_audio("train_context_wav", _context_wav, self.global_step, self.sample_rate)

question_si = (
input_token_list[0][0] + virtual_tokens.shape[1]
Expand Down Expand Up @@ -768,11 +771,11 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
(ti, t) for ti, t in enumerate(input_token_list) if t != 0 and t < self.speech_offset
]
context_end_step = input_token_list[0][0]
_context_tokens = context_and_question_tokens[0][:, :context_end_step].clone()
_context_tokens[0] = _context_tokens[0] - self.speech_offset
_context_tokens = torch.clamp(_context_tokens, min=0, max=1023)
_context_wav = self.decode_wav_from_codec_model(_context_tokens)
self.logger.experiment.add_audio("val_context_wav", _context_wav, self.global_step, self.sample_rate)
if context_end_step > self.num_speech_codebooks:
_context_tokens = context_and_question_tokens[0][:, :context_end_step].clone()
_context_tokens = self.convert_tokens_to_range(_context_tokens, pattern=self.context_pattern)
_context_wav = self.decode_wav_from_codec_model(_context_tokens)
self.logger.experiment.add_audio("val_context_wav", _context_wav, self.global_step, self.sample_rate)

question_si = (
input_token_list[0][0] + virtual_tokens.shape[1]
Expand Down Expand Up @@ -1117,6 +1120,9 @@ def build_virtual_prompt_dataset(
num_speech_codebooks=self.num_speech_codebooks,
codebook_fps=self.cfg.data.get('codebook_fps', 75),
add_special_tokens_to_only_first_codebook=self.cfg.data.get('add_special_tokens_to_only_first_codebook', False),
context_pattern=self.cfg.data.get('context_pattern', 'parallel'),
context_duration_min=self.cfg.data.get('context_duration_min', 3.0),
context_duration_max=self.cfg.data.get('context_duration_max', 5.0),
)

rank = parallel_state.get_data_parallel_rank()
Expand Down Expand Up @@ -1399,11 +1405,11 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
(ti, t) for ti, t in enumerate(input_token_list) if t != 0 and t < self.speech_offset
]
context_end_step = input_token_list[0][0]
_context_tokens = context_and_question_tokens[i][:, :context_end_step].clone()
_context_tokens[0] = _context_tokens[0] - self.speech_offset
_context_tokens = torch.clamp(_context_tokens, min=0, max=1023)
_context_wav = self.decode_wav_from_codec_model(_context_tokens)
self.logger.experiment.add_audio("Context Wav", _context_wav, step, self.sample_rate)
if context_end_step > self.num_speech_codebooks:
_context_tokens = context_and_question_tokens[i][:, :context_end_step].clone()
_context_tokens = self.convert_tokens_to_range(_context_tokens, pattern=self.context_pattern)
_context_wav = self.decode_wav_from_codec_model(_context_tokens)
self.logger.experiment.add_audio("Context Wav", _context_wav, step, self.sample_rate)

task_question = self.frozen_model.tokenizer.ids_to_text(
[v[1] for v in input_token_list if v[1] < self.lm_vocab_size]
Expand Down

0 comments on commit 007e8e4

Please sign in to comment.