Skip to content

Commit

Permalink
Merge branch 'speechllm_2309' into speechllm_2310_rebased
Browse files Browse the repository at this point in the history
  • Loading branch information
blisc authored Nov 21, 2023
2 parents cfd03b6 + 08f4029 commit adec6f7
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def __init__(
seq_pattern: Optional[str] = "parallel",
use_attention_prior: Optional[bool] = False,
attention_prior_scaling_factor: Optional[float] = 1.0,
spec_aug = False,
spec_aug_time_width = 0.2,
spec_aug_time_masks = 2,
# cross_attention_epsilon: Optional[float] = 0.0,
# attention_prior_strength: Optional[float] = 0.5,
**kwargs,
Expand All @@ -135,6 +138,10 @@ def __init__(
**kwargs,
"""
# These two variables need to be set before calling super().__init__() because the parent class calls `load_data()` which requires these attributes.
self._rng = random.Random()
self.spec_aug = spec_aug if for_train else False
self.time_width = spec_aug_time_width
self.time_masks = spec_aug_time_masks
self.decoder_starts_with_pad = decoder_starts_with_pad
self.add_eos_to_decoder_output = add_eos_to_decoder_output
self.add_sentinel_to_input = add_sentinel_to_input
Expand Down Expand Up @@ -1075,6 +1082,7 @@ def collate_fn(self, batch):
# decoder_input_len = torch.stack(decoder_input_len)

decoder_mask = get_mask_from_lengths(decoder_input_len-1)
speech_mask = get_mask_from_lengths(decoder_input_len-1)
(
decoder_input_list,
decoder_labels_list,
Expand Down Expand Up @@ -1127,6 +1135,20 @@ def collate_fn(self, batch):
decoder_labels_list.append(decoder_labels)

decoder_mask[i, :context_tokens_len+question_tokens_len-1] = 0 # Mask out context and question
speech_mask[i, :context_tokens_len+question_tokens_len] = 0 # Mask out context and question

if self.spec_aug:
# Derive time width, sometimes based percentage of input length.
time_max_width = max(1, int(input_ids_len.item() * self.time_width))
time_start_upper_bound = max(1, input_ids_len.item() - time_max_width)
time_start = context_tokens_len.item() + question_tokens_len.item()
time_start_upper_bound += time_start

# Set time masking
for _ in range(self.time_masks):
start = self._rng.randint(time_start, time_start_upper_bound)
width = self._rng.randint(0, time_max_width)
speech_mask[i, start : start + width] = 0

if self.use_attention_prior:
cross_attention_question_prior = torch.from_numpy(
Expand All @@ -1148,14 +1170,21 @@ def collate_fn(self, batch):
cross_attention_prior[
i, context_tokens_len + question_tokens_len:context_tokens_len + question_tokens_len+input_ids_len-1, context_tokens_len + start_of_question_offset : context_tokens_len + question_tokens_len - end_of_question_offset
] = cross_attention_question_prior
# cross_attention_prior[
# i, context_tokens_len + start_of_question_offset : context_tokens_len + question_tokens_len - end_of_question_offset, context_tokens_len + question_tokens_len:context_tokens_len + question_tokens_len+input_ids_len-1
# ] = cross_attention_question_prior.T
# Using causal attention mask for whole input
batch_size = len(decoder_input_list)
attention_mask = torch.tril(torch.ones((batch_size, max_decoder_input_len_1, max_decoder_input_len_1))).view(
batch_size, 1, max_decoder_input_len_1, max_decoder_input_len_1
)

# Convert attention mask from float to bool
attention_mask = attention_mask < 0.5
attention_mask = attention_mask < 0.5 # Currently not used, not sure if correct either
# print(attention_mask)
# print(torch.max(torch.sum(cross_attention_prior, 2)))
# print(torch.max(torch.sum(attention_mask[:,0,:,:] * cross_attention_prior, 2)))
# import ipdb; ipdb.set_trace()


decoder_input = torch.stack(decoder_input_list)
Expand All @@ -1165,7 +1194,7 @@ def collate_fn(self, batch):
"position_ids": position_ids,
"attention_mask": attention_mask,
"labels": torch.stack(decoder_labels_list),
"speech_mask": decoder_mask, # For TTS, can just be loss_mask since answer will always be speech
"speech_mask": speech_mask, # For TTS, can just be loss_mask since answer will always be speech
"loss_mask": decoder_mask, # Mask out context and question and padding
"attention_prior": cross_attention_prior,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def forward(
if return_logits is None:
return_logits = encoder_input is not None

lm_output, attention_probs_list = self.language_model(
lm_output, attention_probs_list, prior = self.language_model(
input_ids,
position_ids,
attention_mask,
Expand Down Expand Up @@ -424,7 +424,7 @@ def forward(
raise NotImplementedError("No implementation for speechllm")
return res if logits is None else res, logits
else:
return post_process_result, attention_probs_list
return post_process_result, attention_probs_list, prior
else:
if attention_probs_list is not None:
raise NotImplementedError("No implementation for speechllm")
Expand Down
68 changes: 37 additions & 31 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,8 @@ def training_step(self, dataloader_iter, batch_idx):
self.log('global_batch_size', current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1)
self.if_first_step = 1

# print(f"loss: {loss_mean}")
# import ipdb; ipdb.set_trace()
return loss_mean

def backward(self, *args, **kwargs):
Expand Down Expand Up @@ -858,8 +860,8 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
required_keys.update(('tokens', 'position_ids'))
if parallel_state.is_pipeline_last_stage():
required_keys.update(('labels', 'loss_mask'))
if self.get_attention_mask_from_fusion:
required_keys.remove('attention_mask')
# if self.get_attention_mask_from_fusion:
# required_keys.remove('attention_mask')
batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch_cpu.items()}

# Model forward pass
Expand All @@ -886,9 +888,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
else:
# TODO: @eharper can we add this to mcore?
forward_args.pop('loss_mask')

# import ipdb; ipdb.set_trace()
(output_tensor, logits), attention_probs_list = model(**forward_args)
(output_tensor, logits), attention_probs_list, prior = model(**forward_args)

if self.trainer.global_step % self.train_check_interval == 0 and batch['speech_mask'][0].sum() != 0 and self.should_log and (not validation_step):
# Logs every if the first item in the batch is speech
Expand Down Expand Up @@ -998,6 +998,7 @@ def loss_func(output_tensor):
return fwd_output_and_loss_func

def get_forward_output_only_func(self):
""" Used in inference / generate """
def fwd_output_only_func(dataloader_iter, model):
batch = next(dataloader_iter)
extra_arg = {}
Expand Down Expand Up @@ -1054,7 +1055,8 @@ def fwd_output_only_func(dataloader_iter, model):
extra_arg['set_inference_key_value_memory'] = set_inference_key_value_memory[0].item()
extra_arg['inference_max_sequence_len'] = inference_max_sequence_len[0].item()
extra_arg['speech_mask'] = speech_mask
output_tensor, attention_probs_list = model(tokens, position_ids, attention_mask, **extra_arg)
# extra_arg['return_all_selfattention_probs'] = True
output_tensor, attention_, prior = model(tokens, position_ids, attention_mask, **extra_arg)

# Advance inference sequence offset.
if self.inference_params:
Expand Down Expand Up @@ -2089,7 +2091,7 @@ def validation_step(self, dataloader_iter, batch_idx):
# TODO: @eharper can we add this to mcore?
forward_args.pop('loss_mask')

(_, logits), attention_probs_list = self.model(**forward_args)
(_, logits), attention_probs_list, prior = self.model(**forward_args)
layerwise_metrics = {}
loss_total = 0.0
all_preds = []
Expand Down Expand Up @@ -2173,37 +2175,38 @@ def validation_step(self, dataloader_iter, batch_idx):
attention_sliced = torch.stack(attention_sliced_list)
attention_sliced = torch.mean(attention_sliced, 0)
alignment_image_sliced = plot_alignment_to_numpy(
attention_sliced.cpu().float().numpy(), phoneme_seq=phoneme_seq, phoneme_ver=2, vmin=0., vmax=1.
attention_sliced.cpu().float().numpy(), phoneme_seq=phoneme_seq, phoneme_ver=2, vmin=0.
)
self.logger.experiment.add_image(
f"Val Attention Probs Average Sliced TF",
alignment_image_sliced,
self.global_step,
dataformats="HWC",
)

# phoneme_seq = [question_start, start]
# prior = batch['attention_prior'][0,:,:].T
# prior_data = plot_alignment_to_numpy(
# prior.cpu().float().numpy(), phoneme_seq=phoneme_seq, phoneme_ver=1, vmin=0., vmax=1.
# )
# self.logger.experiment.add_image(
# f"Attention Prior",
# prior_data,
# self.global_step,
# dataformats="HWC",
# )
# phoneme_seq += question_ids
# prior = prior[question_start:start, start:start+length_of_speech]
# prior_data = plot_alignment_to_numpy(
# prior.cpu().float().numpy(), phoneme_seq=phoneme_seq, phoneme_ver=2, vmin=0., vmax=1.
# )
# self.logger.experiment.add_image(
# f"Attention Prior Sliced",
# prior_data,
# self.global_step,
# dataformats="HWC",
# )
if prior is not None:
phoneme_seq = [question_start, start]
# prior = batch['attention_prior'][0,:,:].T
prior = torch.exp(prior[0,0,:,:].T)
prior_data = plot_alignment_to_numpy(
prior.cpu().float().numpy(), phoneme_seq=phoneme_seq, phoneme_ver=1, vmin=0., vmax=1.
)
self.logger.experiment.add_image(
f"Attention Prior",
prior_data,
self.global_step,
dataformats="HWC",
)
# phoneme_seq += question_ids
# prior = prior[question_start:start, start:start+length_of_speech]
# prior_data = plot_alignment_to_numpy(
# prior.cpu().float().numpy(), phoneme_seq=phoneme_seq, phoneme_ver=2, vmin=0., vmax=1.
# )
# self.logger.experiment.add_image(
# f"Attention Prior Sliced",
# prior_data,
# self.global_step,
# dataformats="HWC",
# )

# Only for the first batch, log TF and autoregressive inference

Expand Down Expand Up @@ -2549,6 +2552,9 @@ def build_virtual_prompt_dataset(
context_length=self.cfg.data.get('context_length', None),
use_attention_prior=self.cfg.data.get('use_attention_prior', True),
attention_prior_scaling_factor=self.cfg.data.get('attention_prior_scaling_factor', 1.),
spec_aug = self.cfg.data.get('spec_aug', False),
spec_aug_time_width = self.cfg.data.get('spec_aug_time_width', 0.2),
spec_aug_time_masks = self.cfg.data.get('spec_aug_time_masks', 2),
# cross_attention_epsilon=self.cfg.data.get('cross_attention_epsilon', 1e-8),
)

Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,8 @@ def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, a
attention_scores = matmul_result.view(b, np, sq, sk)

if attention_bias is not None:
attention_scores = torch.log_softmax(attention_scores, dim=2) + attention_bias
# saved = torch.log_softmax(attention_scores, dim=-1)
attention_scores = torch.log_softmax(attention_scores, dim=-1) + attention_bias
# # attention_bias is not None only for cross attention layers right now
# # TODO: make attention_bias type configurable: additive or multiplicative (log additive)
# eps = 1e-8
Expand All @@ -950,6 +951,11 @@ def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, a
# # attention_scores += attention_bias

_attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# print(f"a: {torch.max(torch.exp(torch.logsumexp(saved, -1)))}")
# print(f"b: {torch.max(torch.exp(torch.logsumexp(attention_bias, -1)))}")
# # print(f"c: {torch.max(torch.exp(torch.logsumexp(attention_scores, -1)))}")
# print(f"d: {torch.max(torch.sum(_attention_probs, -1))}")
# # import ipdb; ipdb.set_trace()
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.

Expand Down
16 changes: 11 additions & 5 deletions nemo/collections/nlp/modules/common/megatron/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,17 +838,23 @@ def forward(
encoder_self_attention_relative_position_bias = None
else: # still using prior
logging.debug("Using prior")
prior_strength = 0.5
prior_strength = self.attn_prior_starting_strength
if global_step > self.attn_prior_scaledown_start_step: # In scaledown region
logging.debug("Scaling down prior")
total_annealing_steps = self.attn_prior_end_step - self.attn_prior_scaledown_start_step
curr_annealing_step = global_step - self.attn_prior_scaledown_start_step
prior_strength = (1. - curr_annealing_step / total_annealing_steps) * prior_strength
attention_prior = attention_prior * prior_strength + 1 - prior_strength
logging.debug(f"Modifying setup with strength: {prior_strength}")
modifier = (1-prior_strength)
# attn_len = attention_prior.shape[-1]
# modifier = (attn_len ** modifier - 1) / (attn_len - 1)
attention_prior = attention_prior + (1-attention_prior) * modifier
logging.debug(f"Modifying setup with strength: {prior_strength} and modifier: {modifier}")
# attention_prior = torch.log_softmax(attention_prior+1e-8, -2)
encoder_self_attention_relative_position_bias = attention_prior.unsqueeze(1).repeat(
1, num_attention_heads, 1, 1
)
encoder_self_attention_relative_position_bias = torch.log(encoder_self_attention_relative_position_bias + 1e-8)
# encoder_self_attention_relative_position_bias = torch.log_softmax(encoder_self_attention_relative_position_bias, dim=-1)

# import ipdb; ipdb.set_trace()
# encoder.
Expand Down Expand Up @@ -882,9 +888,9 @@ def forward(
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden_only:
if self.add_pooler and self.post_process:
return (encoder_output, pooled_output), attention_probs_list
return (encoder_output, pooled_output), attention_probs_list, encoder_self_attention_relative_position_bias
else:
return (encoder_output), attention_probs_list
return (encoder_output), attention_probs_list, encoder_self_attention_relative_position_bias

# Decoder Embedding
dec_embedding_output = self.embedding(dec_input_ids, dec_position_ids)
Expand Down
5 changes: 3 additions & 2 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,12 +638,13 @@ def check_resume(
trainer.ckpt_path = str(checkpoint)
logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}')

trainer.strategy.barrier()
if is_global_rank_zero():
# Check to see if any files exist that need to be moved
files_to_move = []
if Path(log_dir).exists():
for child in Path(log_dir).iterdir():
if child.is_file():
if child.is_file() and not child.name.startswith("events.out.tfevents"):
files_to_move.append(child)

if len(files_to_move) > 0:
Expand Down Expand Up @@ -764,7 +765,7 @@ def get_log_dir(
os.environ[NEMO_ENV_VARNAME_VERSION] = "" if version is None else version

log_dir = Path(_exp_dir) / Path(str(name)) / Path("" if version is None else str(version))
return log_dir, str(_exp_dir), name, version
return log_dir, str(_exp_dir), name, "" if version is None else str(version)


def get_git_hash():
Expand Down

0 comments on commit adec6f7

Please sign in to comment.