Skip to content

Commit

Permalink
🚨🚨 Generate: change order of ops in beam sample to avoid nans (#26843)
Browse files Browse the repository at this point in the history
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
gante and ArthurZucker authored Oct 17, 2023
1 parent 0b8604d commit 4b423e6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
23 changes: 14 additions & 9 deletions src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,14 +1430,22 @@ def _get_logits_warper(
# instantiate warpers list
warpers = TFLogitsProcessorList()

# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(generation_config.eos_token_id) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config.eos_token_id, list):
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1

if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TFTemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))
warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
return warpers

def _get_logits_processor(
Expand Down Expand Up @@ -2366,14 +2374,11 @@ def beam_search_body_fn(
log_probs = tf.nn.log_softmax(logits)
log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
log_probs = unflatten_beam_dim(log_probs, num_beams)
log_probs_processed = log_probs
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
if do_sample:
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits
# warpers (like top_p) this is indiferent, but on others (like temperature) it is not. For reference,
# see https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
log_probs = logits_warper(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
log_probs = unflatten_beam_dim(log_probs, num_beams)
log_probs_processed = log_probs
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
vocab_size = log_probs.shape[2]
log_probs = tf.reshape(log_probs, (batch_size, num_beams * vocab_size))

Expand Down
18 changes: 12 additions & 6 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,11 +820,20 @@ def _get_logits_warper(
# instantiate warpers list
warpers = LogitsProcessorList()

# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config.eos_token_id, list):
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1

# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
Expand Down Expand Up @@ -3406,18 +3415,15 @@ def beam_sample(
) # (batch_size * num_beams, vocab_size)

next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
next_token_scores = logits_warper(input_ids, next_token_scores)

# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (logits_warper(input_ids, next_token_scores_processed),)
scores += (next_token_scores_processed,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
Expand Down

0 comments on commit 4b423e6

Please sign in to comment.