From 018061a11647d523447cb52ffaebfd0643b55488 Mon Sep 17 00:00:00 2001 From: hukuda222 Date: Thu, 10 Aug 2023 01:28:57 +0900 Subject: [PATCH] aligned sample_beam output selection with beam_search (#25375) * aligned sample_beam specs with beam_search * pull origin main * Revert "pull origin main" This reverts commit 06d356f1137bb52272e120a03636598c44449cf3. * update test_utils.py * fix format * remove comment --------- Co-authored-by: Shogo Fujita --- src/transformers/generation/utils.py | 5 ++-- tests/generation/test_utils.py | 34 +++++++--------------------- 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index eaee4d029d42..0a32785ef66e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1691,18 +1691,19 @@ def generate( # 12. prepare beam search scorer beam_scorer = BeamSearchScorer( - batch_size=batch_size * generation_config.num_return_sequences, + batch_size=batch_size, num_beams=generation_config.num_beams, device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, max_length=generation_config.max_length, ) # 13. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=generation_config.num_beams * generation_config.num_return_sequences, + expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0cdc92398ba8..e6faf5babd1d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -438,7 +438,6 @@ def _beam_sample_generate( input_ids, attention_mask, max_length, - num_return_sequences, beam_scorer, beam_kwargs, logits_warper, @@ -463,7 +462,7 @@ def _beam_sample_generate( **logits_warper_kwargs, **model_kwargs, ) - # beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences` + # beam_search does not automatically interleave `batch_size` dim for `num_beams` torch.manual_seed(0) kwargs = {} if model.config.is_encoder_decoder: @@ -471,13 +470,13 @@ def _beam_sample_generate( model, input_ids, attention_mask, - num_interleave=beam_scorer.num_beams * num_return_sequences, + num_interleave=beam_scorer.num_beams, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) kwargs["encoder_outputs"] = encoder_outputs elif attention_mask is not None: - attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) + attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) # prevent flaky generation test failures logits_processor = LogitsProcessorList() @@ -486,7 +485,7 @@ def _beam_sample_generate( with torch.no_grad(): model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_beam_sample = model.beam_sample( - input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0), + input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), beam_scorer, max_length=max_length, logits_warper=logits_warper, @@ -891,13 +890,9 @@ def test_beam_search_generate(self): self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) - # check `generate()` and `beam_search()` are equal for `num_return_sequences` - num_return_sequences = 2 if model.config.is_encoder_decoder: max_length = 4 - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( - input_ids.shape[0], max_length, num_return_sequences=num_return_sequences - ) + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) output_generate, output_beam_search = self._beam_search_generate( model=model, @@ -1036,21 +1031,15 @@ def test_beam_sample_generate(self): model = model_class(config).to(torch_device).eval() # check `generate()` and `beam_search()` are equal - # change `num_return_sequences = 2` but not for `beam_scorer` - num_return_sequences = 2 if model.config.is_encoder_decoder: max_length = 4 - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( - input_ids.shape[0] * num_return_sequences, max_length - ) - beam_kwargs["num_return_sequences"] = num_return_sequences + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) output_generate, output_beam_sample = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - num_return_sequences=num_return_sequences, beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, logits_warper=logits_warper, @@ -1074,20 +1063,15 @@ def test_beam_sample_generate_dict_output(self): model = model_class(config).to(torch_device).eval() logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - num_return_sequences = 2 if model.config.is_encoder_decoder: max_length = 4 - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( - input_ids.shape[0] * num_return_sequences, max_length - ) - beam_kwargs["num_return_sequences"] = num_return_sequences + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) output_beam_sample, output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, - num_return_sequences=num_return_sequences, beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, logits_warper=logits_warper, @@ -1113,9 +1097,7 @@ def test_beam_sample_generate_dict_output(self): self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) for output in (output_beam_sample, output_generate): - self._check_outputs( - output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams - ) + self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) def test_generate_without_input_ids(self): config, _, _, max_length = self._get_input_ids_and_config()