Skip to content

Commit

Permalink
aligned sample_beam output selection with beam_search (huggingface#25375
Browse files Browse the repository at this point in the history
)

* aligned sample_beam specs with beam_search

* pull origin main

* Revert "pull origin main"

This reverts commit 06d356f.

* update test_utils.py

* fix format

* remove comment

---------

Co-authored-by: Shogo Fujita <[email protected]>
  • Loading branch information
2 people authored and EduardoPach committed Aug 9, 2023
1 parent c66415a commit 018061a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 28 deletions.
5 changes: 3 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,18 +1691,19 @@ def generate(

# 12. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size * generation_config.num_return_sequences,
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)

# 13. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams * generation_config.num_return_sequences,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
Expand Down
34 changes: 8 additions & 26 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,6 @@ def _beam_sample_generate(
input_ids,
attention_mask,
max_length,
num_return_sequences,
beam_scorer,
beam_kwargs,
logits_warper,
Expand All @@ -463,21 +462,21 @@ def _beam_sample_generate(
**logits_warper_kwargs,
**model_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
torch.manual_seed(0)
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=beam_scorer.num_beams * num_return_sequences,
num_interleave=beam_scorer.num_beams,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)

# prevent flaky generation test failures
logits_processor = LogitsProcessorList()
Expand All @@ -486,7 +485,7 @@ def _beam_sample_generate(
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_beam_sample = model.beam_sample(
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0),
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer,
max_length=max_length,
logits_warper=logits_warper,
Expand Down Expand Up @@ -891,13 +890,9 @@ def test_beam_search_generate(self):

self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())

# check `generate()` and `beam_search()` are equal for `num_return_sequences`
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
)
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)

output_generate, output_beam_search = self._beam_search_generate(
model=model,
Expand Down Expand Up @@ -1036,21 +1031,15 @@ def test_beam_sample_generate(self):
model = model_class(config).to(torch_device).eval()

# check `generate()` and `beam_search()` are equal
# change `num_return_sequences = 2` but not for `beam_scorer`
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0] * num_return_sequences, max_length
)
beam_kwargs["num_return_sequences"] = num_return_sequences
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)

output_generate, output_beam_sample = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=num_return_sequences,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_warper=logits_warper,
Expand All @@ -1074,20 +1063,15 @@ def test_beam_sample_generate_dict_output(self):
model = model_class(config).to(torch_device).eval()
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)

num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0] * num_return_sequences, max_length
)
beam_kwargs["num_return_sequences"] = num_return_sequences
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)

output_beam_sample, output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=num_return_sequences,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_warper=logits_warper,
Expand All @@ -1113,9 +1097,7 @@ def test_beam_sample_generate_dict_output(self):
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())

for output in (output_beam_sample, output_generate):
self._check_outputs(
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
)
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)

def test_generate_without_input_ids(self):
config, _, _, max_length = self._get_input_ids_and_config()
Expand Down

0 comments on commit 018061a

Please sign in to comment.