diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6e6d5b8bdce7..53cd2df3a49c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1452,10 +1452,11 @@ def _prepare_generated_length( ): generation_config.max_length -= inputs_tensor.shape[1] elif has_default_max_length: # by default let's always generate 20 new tokens - generation_config.max_length = generation_config.max_length + input_ids_length - max_position_embeddings = getattr(self.config, "max_position_embeddings", None) - if max_position_embeddings is not None: - generation_config.max_length = min(generation_config.max_length, max_position_embeddings) + if generation_config.max_length == GenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) # same for min length if generation_config.min_new_tokens is not None: