Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOCS] Add descriptive docstring to MinNewTokensLength #25196

41 changes: 39 additions & 2 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,52 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
Note that for decoder-only models, such as Llama2, `min_length` will compute the length of `prompt + newly
generated tokens` whereas for other models it will behave as `min_new_tokens`, that is, taking only into account the
newly generated ones.

Args:
prompt_length_to_skip (`int`):
The input tokens length.
The input tokens length. Not a valid argument when used with `generate` as it will automatically assign the
input length.
min_new_tokens (`int`):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.

Examples:

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> model.config.pad_token_id = model.config.eos_token_id
>>> model.generation_config.pad_token_id = model.config.eos_token_id
>>> input_context = "Hugging Face Company is"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use a different input string e.g. "Hugging Face is"?

The examples are a bit confusing because of "Company" being in the input and then "company" being set as the eos token

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, yes that makes sense, good catch ✨

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @amyeroberts I've been trying several prompts to no luck. I'm finding kind of hard getting two instances of the same token when one buries the first occurrence like in the 3rd example (blame my lack of experience 😅). My guess is that Company is promoting company somehow. The workaround I found is to use a couple of eos_token_ids which has some educational value on its own. This is what happens for the prompt Hugging Face Inc. is:

  • eos_token_ids=1664, min_new_tokens=1: Hugging Face Inc. is a company
  • eos_token_ids=[1664, 9856], min_new_tokens=2: Hugging Face Inc. is a non-profit organization that provides educational
  • eos_token_ids=[1664, 9856], min_new_tokens=10: Hugging Face Inc. is a non-profit organization that provides free, open source software to the

I'm happy to go with above examples, but I kind of like the existing ones as they only involve the change in one parameter, ie, the min_new_tokens so it's more like a RCT. In that case, what we can do is to clarfy this comment:

If `eos_token_id` is set to ` company` it will take into account how many `min_new_tokens` have 
been generated before stopping.

To something like:

If `eos_token_id` is set to ` company` it will take into account how many `min_new_tokens` have 
been generated before stopping. Note that ` Company` (5834) and ` company` (1664) are not 
actually the same token, and even if they were ` Company` would be ignored by `min_new_tokens`
as it excludes the prompt.

Let me know what you think 🙂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated comment clarifying what's happening sounds good to me!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for your patience 🙏

>>> input_ids = tokenizer.encode(input_context, return_tensors="pt")

>>> # Without `eos_token_id`, it will generate the default length, 20, ignoring `min_new_tokens`
>>> outputs = model.generate(input_ids=input_ids, min_new_tokens=30)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a company that has been working on a new product for the past year.

>>> # If `eos_token_id` is set to ` company` it will take into account how many `min_new_tokens` have been generated
>>> # before stopping.
>>> outputs = model.generate(input_ids=input_ids, min_new_tokens=1, eos_token_id=1664)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a company

>>> # Increasing `min_new_tokens` will bury the first occurrence of ` company` generating a different sequence.
>>> outputs = model.generate(input_ids=input_ids, min_new_tokens=2, eos_token_id=1664)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a new company

>>> # If no more occurrences of the `eos_token` happen after `min_new_tokens` it returns to the 20 default tokens.
>>> outputs = model.generate(input_ids=input_ids, min_new_tokens=10, eos_token_id=1664)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Hugging Face Company is a new and innovative brand of facial recognition technology that is designed to help you
```
"""

def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
Expand Down Expand Up @@ -194,7 +232,6 @@ class TemperatureLogitsWarper(LogitsWarper):
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM


>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> model.config.pad_token_id = model.config.eos_token_id
Expand Down