Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOCS] Add example and modified docs of EtaLogitsWarper #25125

Merged
merged 8 commits into from
Aug 2, 2023
66 changes: 61 additions & 5 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,70 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class EtaLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs eta-sampling, i.e. calculates a dynamic cutoff `eta := min(epsilon, sqrt(epsilon,
e^-entropy(probabilities)))` and restricts to tokens with `prob >= eta`. Takes the largest min_tokens_to_keep
tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
[`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of
the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon, e^-entropy(probabilities)))`. Takes the largest
min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample`
must be set to `True` for this `LogitsWarper` to work.


Args:
epsilon (`float`):
A float value in the range (0, 1). Hyperparameter used to calculate the dynamic cutoff value, `eta`. The
suggested values from the paper ranges from 3e-4 to 4e-3 depending on the size of the model.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All values that are found to be below the dynamic cutoff value, `eta`, are set to this float value. This
parameter is useful when logits need to be modified for very low probability tokens that should be excluded
from generation entirely.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered."""
Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
even if all tokens have probabilities below the cutoff `eta`.

Raises:
ValueError: If `epsilon` is not within the range (0, 1) or if `min_tokens_to_keep` is not a positive
integer.
gante marked this conversation as resolved.
Show resolved Hide resolved

Examples:
```python
>>> # Import required libraries
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> import torch

>>> # Set the model name
>>> model_name = "gpt2"

>>> # Initialize the model and tokenizer
>>> model = AutoModelForCausalLM.from_pretrained(model_name)
>>> tokenizer = AutoTokenizer.from_pretrained(model_name)

>>> # Set the pad token to eos token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> model.generation_config.pad_token_id = model.config.eos_token_id

>>> # below sequence is intentionally of two subjects to show the difference between the two approaches
agno-nymous marked this conversation as resolved.
Show resolved Hide resolved
>>> sequence = "a quadcopter flight controller (RTFQ Flip MWC) that supports I2C sensors for adding thing like a barometer, magnetometer, and GPS system. The officially supported sensor block (BMP180, HMC5883L on one board) is discontinued, as far as I know, everyone involved lived to sing another day. . . disorder and an extreme state of dysmetabolism characterized by extensive erythema and a significant reduction in uncovered"
agno-nymous marked this conversation as resolved.
Show resolved Hide resolved

>>> # Tokenize the sequence
>>> inputs = tokenizer(sequence, return_tensors="pt")

>>> torch.manual_seed(0)
gante marked this conversation as resolved.
Show resolved Hide resolved

>>> # We can see that the model is generating repeating text and also is not able to continue the sequence properly
>>> outputs = model.generate(inputs["input_ids"], max_length=128)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
a quadcopter flight controller (RTFQ Flip MWC) that supports I2C sensors for adding thing like a barometer, magnetometer, and GPS system. The officially supported sensor block (BMP180, HMC5883L on one board) is discontinued, as far as I know, everyone involved lived to sing another day... disorder and an extreme state of dysmetabolism characterized by extensive erythema and a significant reduction in uncovered muscle mass. The patient was diagnosed with a severe erythema and a severe erythema of the right side of the body. The patient was
agno-nymous marked this conversation as resolved.
Show resolved Hide resolved

>>> # The result is much better and coherent when we use the `eta_cutoff` parameter
>>> outputs = model.generate(
... inputs["input_ids"], max_length=128, do_sample=True, eta_cutoff=2e-2
... ) # need to set do_sample=True for eta_cutoff to work
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
a quadcopter flight controller (RTFQ Flip MWC) that supports I2C sensors for adding thing like a barometer, magnetometer, and GPS system. The officially supported sensor block (BMP180, HMC5883L on one board) is discontinued, as far as I know, everyone involved lived to sing another day... disorder and an extreme state of dysmetabolism characterized by extensive erythema and a significant reduction in uncovered fatty acids. A significant loss of brain development. The individual also experienced high levels of a common psychiatric condition called schizophrenia, with an important and life threatening consequence.
```
"""

def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
epsilon = float(epsilon)
Expand Down