Skip to content

Commit

Permalink
Generate: unset GenerationConfig parameters do not raise warning (#29119
Browse files Browse the repository at this point in the history
)
  • Loading branch information
gante authored Feb 20, 2024
1 parent 7d312ad commit a7755d2
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 24 deletions.
28 changes: 16 additions & 12 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ class GenerationConfig(PushToHubMixin):

def __init__(self, **kwargs):
# Parameters that control the length of the output
# if the default `max_length` is updated here, make sure to update the `generate` tests following https://github.com/huggingface/transformers/pull/25030
self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
self.min_length = kwargs.pop("min_length", 0)
Expand Down Expand Up @@ -407,32 +406,34 @@ def validate(self, is_init=False):
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
+ fix_location
)
if self.temperature != 1.0:
if self.temperature is not None and self.temperature != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
UserWarning,
)
if self.top_p != 1.0:
if self.top_p is not None and self.top_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
UserWarning,
)
if self.typical_p != 1.0:
if self.typical_p is not None and self.typical_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
UserWarning,
)
if self.top_k != 50 and self.penalty_alpha is None: # contrastive search uses top_k
if (
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
): # contrastive search uses top_k
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
UserWarning,
)
if self.epsilon_cutoff != 0.0:
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
UserWarning,
)
if self.eta_cutoff != 0.0:
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
UserWarning,
Expand All @@ -453,21 +454,21 @@ def validate(self, is_init=False):
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
UserWarning,
)
if self.num_beam_groups != 1:
if self.num_beam_groups is not None and self.num_beam_groups != 1:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
),
UserWarning,
)
if self.diversity_penalty != 0.0:
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
),
UserWarning,
)
if self.length_penalty != 1.0:
if self.length_penalty is not None and self.length_penalty != 1.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
UserWarning,
Expand All @@ -491,7 +492,7 @@ def validate(self, is_init=False):
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
if self.num_beam_groups != 1:
if self.num_beam_groups is not None and self.num_beam_groups != 1:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
Expand Down Expand Up @@ -1000,6 +1001,9 @@ def update(self, **kwargs):
setattr(self, key, value)
to_remove.append(key)

# remove all the attributes that were updated, without modifying the input dict
# Confirm that the updated instance is still valid
self.validate()

# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs
1 change: 0 additions & 1 deletion src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ def generate(

generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())

logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()
Expand Down
1 change: 0 additions & 1 deletion src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,6 @@ def generate(

generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())

# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
Expand Down
1 change: 0 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,6 @@ def generate(

generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())

# 2. Set generation parameters if not already defined
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def to_json_string(self, use_diff: bool = True) -> str:
config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

# Copied from transformers.generation.configuration_utils.GenerationConfig.update
def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
Expand All @@ -171,7 +170,7 @@ def update(self, **kwargs):
setattr(self, key, value)
to_remove.append(key)

# remove all the attributes that were updated, without modifying the input dict
# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs

Expand Down
32 changes: 25 additions & 7 deletions tests/generation/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,26 +124,44 @@ def test_validate(self):
"""
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
"""
# Case 1: A correct configuration will not throw any warning
# A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig()
self.assertEqual(len(captured_warnings), 0)

# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(temperature=0.5)
GenerationConfig(do_sample=False, temperature=0.5)
self.assertEqual(len(captured_warnings), 1)

# Case 3: Impossible sets of contraints/parameters will raise an exception
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
# that is done by unsetting the parameter (i.e. setting it to None)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# BAD - 0.9 means it is still set, we should warn
generation_config_bad_temperature.update(temperature=0.9)
self.assertEqual(len(captured_warnings), 1)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
generation_config_bad_temperature.update(temperature=1.0)
self.assertEqual(len(captured_warnings), 0)
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)
with warnings.catch_warnings(record=True) as captured_warnings:
# OK - None means it is unset, nothing to warn about
generation_config_bad_temperature.update(temperature=None)
self.assertEqual(len(captured_warnings), 0)

# Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(num_return_sequences=2)
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)

# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
# Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(logits_processor="foo")

# Case 5: Model-specific parameters will NOT raise an exception or a warning
# Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0)
Expand Down

0 comments on commit a7755d2

Please sign in to comment.