diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 9f130dfa2e9c..6b1093761fb9 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -485,14 +485,65 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to class EtaLogitsWarper(LogitsWarper): r""" - [`LogitsWarper`] that performs eta-sampling, i.e. calculates a dynamic cutoff `eta := min(epsilon, sqrt(epsilon, - e^-entropy(probabilities)))` and restricts to tokens with `prob >= eta`. Takes the largest min_tokens_to_keep - tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model - Desmoothing](https://arxiv.org/abs/2210.15191) for more information. + [`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic + cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of + the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon, e^-entropy(probabilities)))`. Takes the largest + min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long + samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation + Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample` + must be set to `True` for this `LogitsWarper` to work. + Args: + epsilon (`float`): + A float value in the range (0, 1). Hyperparameter used to calculate the dynamic cutoff value, `eta`. The + suggested values from the paper ranges from 3e-4 to 4e-3 depending on the size of the model. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All values that are found to be below the dynamic cutoff value, `eta`, are set to this float value. This + parameter is useful when logits need to be modified for very low probability tokens that should be excluded + from generation entirely. min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered.""" + Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities. + For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation, + even if all tokens have probabilities below the cutoff `eta`. + + Examples: + ```python + >>> # Import required libraries + >>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed + + >>> # Set the model name + >>> model_name = "gpt2" + + >>> # Initialize the model and tokenizer + >>> model = AutoModelForCausalLM.from_pretrained(model_name) + >>> tokenizer = AutoTokenizer.from_pretrained(model_name) + + >>> # Set the pad token to eos token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> model.generation_config.pad_token_id = model.config.eos_token_id + + >>> # The below sequence intentionally contains two subjects to show the difference between the two approaches + >>> sequence = "a quadcopter flight controller (RTFQ Flip MWC) that supports I2C sensors for adding things like a barometer, magnetometer, and GPS system. The officially supported sensor block (BMP180, HMC5883L on one board) is discontinued, as far as I know, everyone involved lived to sing another day. . . disorder and an extreme state of dysmetabolism characterized by extensive erythema and a significant reduction in uncovered" + + >>> # Tokenize the sequence + >>> inputs = tokenizer(sequence, return_tensors="pt") + + >>> set_seed(0) + + >>> # We can see that the model is generating repeating text and also is not able to continue the sequence properly + >>> outputs = model.generate(inputs["input_ids"], max_length=128) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + a quadcopter flight controller (RTFQ Flip MWC) that supports I2C sensors for adding things like a barometer, magnetometer, and GPS system. The officially supported sensor block (BMP180, HMC5883L on one board) is discontinued, as far as I know, everyone involved lived to sing another day... disorder and an extreme state of dysmetabolism characterized by extensive erythema and a significant reduction in uncovered muscle mass. The patient was diagnosed with a severe erythema and a severe erythema-like condition. The patient was treated with a combination + + >>> # The result is much better and coherent when we use the `eta_cutoff` parameter + >>> outputs = model.generate( + ... inputs["input_ids"], max_length=128, do_sample=True, eta_cutoff=2e-2 + ... ) # need to set do_sample=True for eta_cutoff to work + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + a quadcopter flight controller (RTFQ Flip MWC) that supports I2C sensors for adding things like a barometer, magnetometer, and GPS system. The officially supported sensor block (BMP180, HMC5883L on one board) is discontinued, as far as I know, everyone involved lived to sing another day... disorder and an extreme state of dysmetabolism characterized by extensive erythema and a significant reduction in uncovered fatty acids. A significant loss of brain development. The individual also experienced high levels of a common psychiatric condition called schizophrenia, with an important and life threatening consequence. + ``` + """ def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): epsilon = float(epsilon)