diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index cf5e1f3aea24d2..f7cf2e89682a4d 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -133,14 +133,53 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class MinNewTokensLengthLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0. + Note that for decoder-only models, such as Llama2, `min_length` will compute the length of `prompt + newly + generated tokens` whereas for other models it will behave as `min_new_tokens`, that is, taking only into account + the newly generated ones. Args: prompt_length_to_skip (`int`): - The input tokens length. + The input tokens length. Not a valid argument when used with `generate` as it will automatically assign the + input length. min_new_tokens (`int`): The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`. eos_token_id (`Union[int, List[int]]`): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") + >>> model.config.pad_token_id = model.config.eos_token_id + >>> model.generation_config.pad_token_id = model.config.eos_token_id + >>> input_context = "Hugging Face Company is" + >>> input_ids = tokenizer.encode(input_context, return_tensors="pt") + + >>> # Without `eos_token_id`, it will generate the default length, 20, ignoring `min_new_tokens` + >>> outputs = model.generate(input_ids=input_ids, min_new_tokens=30) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + Hugging Face Company is a company that has been working on a new product for the past year. + + >>> # If `eos_token_id` is set to ` company` it will take into account how many `min_new_tokens` have been generated + >>> # before stopping. Note that ` Company` (5834) and ` company` (1664) are not actually the same token, and even + >>> # if they were ` Company` would be ignored by `min_new_tokens` as it excludes the prompt. + >>> outputs = model.generate(input_ids=input_ids, min_new_tokens=1, eos_token_id=1664) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + Hugging Face Company is a company + + >>> # Increasing `min_new_tokens` will bury the first occurrence of ` company` generating a different sequence. + >>> outputs = model.generate(input_ids=input_ids, min_new_tokens=2, eos_token_id=1664) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + Hugging Face Company is a new company + + >>> # If no more occurrences of the `eos_token` happen after `min_new_tokens` it returns to the 20 default tokens. + >>> outputs = model.generate(input_ids=input_ids, min_new_tokens=10, eos_token_id=1664) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + Hugging Face Company is a new and innovative brand of facial recognition technology that is designed to help you + ``` """ def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]): @@ -194,7 +233,6 @@ class TemperatureLogitsWarper(LogitsWarper): >>> import torch >>> from transformers import AutoTokenizer, AutoModelForCausalLM - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> model.config.pad_token_id = model.config.eos_token_id