diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index a958228d9be..52371d94dc5 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -150,6 +150,12 @@ def __init__( self.generation_config.return_dict_in_generate = True self.generation_config.output_scores = True + # Disable sampling -- this implementation of assisted generation/speculative decoding uses the assistant + # greedily to maximize matches. Disables sampling-related flags to prevent warnings + self.generation_config.do_sample = False + for attr in ("temperature", "top_p", "min_p", "typical_p", "top_k", "epsilon_cutoff", "eta_cutoff"): + setattr(self.generation_config, attr, None) + # avoid unnecessary warnings that min_length is larger than max_new_tokens # remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`) self.main_model_min_length = self.generation_config.min_length diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 85fcc055948..2bdf20c6861 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -496,6 +496,11 @@ def validate(self, is_init=False): greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p), UserWarning, ) + if self.min_p is not None: + warnings.warn( + greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p), + UserWarning, + ) if self.typical_p is not None and self.typical_p != 1.0: warnings.warn( greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),