From 3c0ec4037d63b2e5650a0c8e4f73a61fbdd5c912 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 20 Feb 2024 11:34:31 +0000 Subject: [PATCH] Generate: unset GenerationConfig parameters do not raise warning (#29119) --- .../generation/configuration_utils.py | 28 +++++++++------- src/transformers/generation/flax_utils.py | 1 - src/transformers/generation/tf_utils.py | 1 - src/transformers/generation/utils.py | 1 - src/transformers/utils/quantization_config.py | 3 +- tests/generation/test_configuration_utils.py | 32 +++++++++++++++---- 6 files changed, 42 insertions(+), 24 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 1d5d3b661e4050..87335b2667b23d 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -271,7 +271,6 @@ class GenerationConfig(PushToHubMixin): def __init__(self, **kwargs): # Parameters that control the length of the output - # if the default `max_length` is updated here, make sure to update the `generate` tests following https://github.com/huggingface/transformers/pull/25030 self.max_length = kwargs.pop("max_length", 20) self.max_new_tokens = kwargs.pop("max_new_tokens", None) self.min_length = kwargs.pop("min_length", 0) @@ -407,32 +406,34 @@ def validate(self, is_init=False): "used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`." + fix_location ) - if self.temperature != 1.0: + if self.temperature is not None and self.temperature != 1.0: warnings.warn( greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature), UserWarning, ) - if self.top_p != 1.0: + if self.top_p is not None and self.top_p != 1.0: warnings.warn( greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p), UserWarning, ) - if self.typical_p != 1.0: + if self.typical_p is not None and self.typical_p != 1.0: warnings.warn( greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p), UserWarning, ) - if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k + if ( + self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None + ): # contrastive search uses top_k warnings.warn( greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k), UserWarning, ) - if self.epsilon_cutoff != 0.0: + if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0: warnings.warn( greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff), UserWarning, ) - if self.eta_cutoff != 0.0: + if self.eta_cutoff is not None and self.eta_cutoff != 0.0: warnings.warn( greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff), UserWarning, @@ -453,21 +454,21 @@ def validate(self, is_init=False): single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping), UserWarning, ) - if self.num_beam_groups != 1: + if self.num_beam_groups is not None and self.num_beam_groups != 1: warnings.warn( single_beam_wrong_parameter_msg.format( flag_name="num_beam_groups", flag_value=self.num_beam_groups ), UserWarning, ) - if self.diversity_penalty != 0.0: + if self.diversity_penalty is not None and self.diversity_penalty != 0.0: warnings.warn( single_beam_wrong_parameter_msg.format( flag_name="diversity_penalty", flag_value=self.diversity_penalty ), UserWarning, ) - if self.length_penalty != 1.0: + if self.length_penalty is not None and self.length_penalty != 1.0: warnings.warn( single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty), UserWarning, @@ -491,7 +492,7 @@ def validate(self, is_init=False): raise ValueError( constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample) ) - if self.num_beam_groups != 1: + if self.num_beam_groups is not None and self.num_beam_groups != 1: raise ValueError( constrained_wrong_parameter_msg.format( flag_name="num_beam_groups", flag_value=self.num_beam_groups @@ -1000,6 +1001,9 @@ def update(self, **kwargs): setattr(self, key, value) to_remove.append(key) - # remove all the attributes that were updated, without modifying the input dict + # Confirm that the updated instance is still valid + self.validate() + + # Remove all the attributes that were updated, without modifying the input dict unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} return unused_kwargs diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 1e063be8638650..1bdf58691a80d7 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -330,7 +330,6 @@ def generate( generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList() diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 3021e1e55945f0..8c2d9fde6ae721 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -736,7 +736,6 @@ def generate( generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6fd2c752a0a40b..08fde585076877 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1347,7 +1347,6 @@ def generate( generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index d26cfca678c7b0..bcf31ebfaba0e4 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -152,7 +152,6 @@ def to_json_string(self, use_diff: bool = True) -> str: config_dict = self.to_dict() return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - # Copied from transformers.generation.configuration_utils.GenerationConfig.update def update(self, **kwargs): """ Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes, @@ -171,7 +170,7 @@ def update(self, **kwargs): setattr(self, key, value) to_remove.append(key) - # remove all the attributes that were updated, without modifying the input dict + # Remove all the attributes that were updated, without modifying the input dict unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} return unused_kwargs diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index 7aabee4b521552..4ff9d35aa0d2dc 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -124,26 +124,44 @@ def test_validate(self): """ Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time """ - # Case 1: A correct configuration will not throw any warning + # A correct configuration will not throw any warning with warnings.catch_warnings(record=True) as captured_warnings: GenerationConfig() self.assertEqual(len(captured_warnings), 0) - # Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling + # Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling # parameters with `do_sample=False`). May be escalated to an error in the future. with warnings.catch_warnings(record=True) as captured_warnings: - GenerationConfig(temperature=0.5) + GenerationConfig(do_sample=False, temperature=0.5) self.assertEqual(len(captured_warnings), 1) - # Case 3: Impossible sets of contraints/parameters will raise an exception + # Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally, + # that is done by unsetting the parameter (i.e. setting it to None) + generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) + with warnings.catch_warnings(record=True) as captured_warnings: + # BAD - 0.9 means it is still set, we should warn + generation_config_bad_temperature.update(temperature=0.9) + self.assertEqual(len(captured_warnings), 1) + generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) + with warnings.catch_warnings(record=True) as captured_warnings: + # CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn + generation_config_bad_temperature.update(temperature=1.0) + self.assertEqual(len(captured_warnings), 0) + generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) + with warnings.catch_warnings(record=True) as captured_warnings: + # OK - None means it is unset, nothing to warn about + generation_config_bad_temperature.update(temperature=None) + self.assertEqual(len(captured_warnings), 0) + + # Impossible sets of contraints/parameters will raise an exception with self.assertRaises(ValueError): - GenerationConfig(num_return_sequences=2) + GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2) - # Case 4: Passing `generate()`-only flags to `validate` will raise an exception + # Passing `generate()`-only flags to `validate` will raise an exception with self.assertRaises(ValueError): GenerationConfig(logits_processor="foo") - # Case 5: Model-specific parameters will NOT raise an exception or a warning + # Model-specific parameters will NOT raise an exception or a warning with warnings.catch_warnings(record=True) as captured_warnings: GenerationConfig(foo="bar") self.assertEqual(len(captured_warnings), 0)