From 4b423e607455a7aca1edc4beaa713da58e78ef0b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 17 Oct 2023 10:32:49 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=20Generate:=20change?= =?UTF-8?q?=20order=20of=20ops=20in=20beam=20sample=20to=20avoid=20nans=20?= =?UTF-8?q?(#26843)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/tf_utils.py | 23 ++++++++++++++--------- src/transformers/generation/utils.py | 18 ++++++++++++------ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 65906dc139cbf2..59848c3c85905d 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -1430,14 +1430,22 @@ def _get_logits_warper( # instantiate warpers list warpers = TFLogitsProcessorList() - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(generation_config.eos_token_id) + 1) + if generation_config.num_beams > 1: + if isinstance(generation_config.eos_token_id, list): + min_tokens_to_keep = len(generation_config.eos_token_id) + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + if generation_config.temperature is not None and generation_config.temperature != 1.0: warpers.append(TFTemperatureLogitsWarper(generation_config.temperature)) if generation_config.top_k is not None and generation_config.top_k != 0: - warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1)) + warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.top_p is not None and generation_config.top_p < 1.0: - warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1)) + warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) return warpers def _get_logits_processor( @@ -2366,14 +2374,11 @@ def beam_search_body_fn( log_probs = tf.nn.log_softmax(logits) log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len) log_probs = unflatten_beam_dim(log_probs, num_beams) - log_probs_processed = log_probs - log_probs = log_probs + tf.expand_dims(running_scores, axis=2) if do_sample: - # Note: logits warpers are intentionally applied after adding running beam scores. On some logits - # warpers (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, - # see https://github.com/huggingface/transformers/pull/5420#discussion_r449779867 log_probs = logits_warper(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len) log_probs = unflatten_beam_dim(log_probs, num_beams) + log_probs_processed = log_probs + log_probs = log_probs + tf.expand_dims(running_scores, axis=2) vocab_size = log_probs.shape[2] log_probs = tf.reshape(log_probs, (batch_size, num_beams * vocab_size)) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c9791653286bbb..606fbbe7060f93 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -820,11 +820,20 @@ def _get_logits_warper( # instantiate warpers list warpers = LogitsProcessorList() + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(list(generation_config.eos_token_id)) + 1) + if generation_config.num_beams > 1: + if isinstance(generation_config.eos_token_id, list): + min_tokens_to_keep = len(generation_config.eos_token_id) + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` if generation_config.temperature is not None and generation_config.temperature != 1.0: warpers.append(TemperatureLogitsWarper(generation_config.temperature)) - min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1 if generation_config.top_k is not None and generation_config.top_k != 0: warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.top_p is not None and generation_config.top_p < 1.0: @@ -3406,18 +3415,15 @@ def beam_sample( ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) - # Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers - # (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see - # https://github.com/huggingface/transformers/pull/5420#discussion_r449779867 - next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: - scores += (logits_warper(input_ids, next_token_scores_processed),) + scores += (next_token_scores_processed,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)