Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude Top Choices (XTC): A sampler that boosts creativity, breaks writing clichés, and inhibits non-verbatim repetition #6335

Merged
merged 7 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions extensions/openai/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class GenerationOptions(BaseModel):
dry_base: float = 1.75
dry_allowed_length: int = 2
dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"'
xtc_threshold: float = 0.1
xtc_probability: float = 0
truncation_length: int = 0
max_tokens_second: int = 0
prompt_lookup_num_tokens: int = 0
Expand Down
6 changes: 6 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def transformers_samplers():
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'xtc_threshold',
'xtc_probability',
'seed',
'do_sample',
'penalty_alpha',
Expand Down Expand Up @@ -242,6 +244,8 @@ def transformers_samplers():
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'xtc_threshold',
'xtc_probability',
'seed',
'do_sample',
'mirostat_mode',
Expand Down Expand Up @@ -304,6 +308,8 @@ def transformers_samplers():
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'xtc_threshold',
'xtc_probability',
'seed',
'do_sample',
'mirostat_mode',
Expand Down
4 changes: 3 additions & 1 deletion modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def default_preset():
'dry_base': 1.75,
'dry_allowed_length': 2,
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nencoder_repetition_penalty\nno_repeat_ngram'
'xtc_threshold': 0.1,
'xtc_probability': 0,
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram'
}


Expand Down
61 changes: 60 additions & 1 deletion modules/sampler_hijack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import math
import pprint
import random

import torch
import transformers
Expand Down Expand Up @@ -191,6 +192,53 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


# Exclude Top Choices (XTC)
class XTCLogitsWarper(LogitsWarper):
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
self.threshold = threshold
self.probability = probability
self.filter_value = filter_value
self.special_token_ids = [
shared.tokenizer.encode("\n")[-1],
]

if shared.tokenizer.eos_token_id is not None:
self.special_token_ids.append(shared.tokenizer.eos_token_id)

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# `random` returns values in the half-open range [0, 1), so setting `probability`
# to 0 means the sampler never takes action, while setting it to 1 means the sampler
# always takes action.
#
# Note that while XTC is most intuitively described as "if multiple tokens meet
# the threshold, then with probability...", reversing the two conditions is logically
# equivalent, and improves performance because processing can immediately be stopped
# if the random check fails.
if random.random() >= self.probability:
return scores

sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)

sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool)

# This operation sets exactly those indices to `True` for which the next index has
# probability above the threshold. Since `probs` is sorted, those are the indices
# of all tokens that meet the threshold, *except* the least probable one.
sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold

# Convert sorted_indices_to_remove to the original indices
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

# If newline or EOS tokens would be removed, return the original scores
if indices_to_remove[:, self.special_token_ids].any():
return scores

# Otherwise, remove tokens with the mask
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores


class DRYLogitsProcessor(LogitsProcessor):
def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int):
self.multiplier = multiplier
Expand Down Expand Up @@ -474,6 +522,14 @@ def get_logits_processor_patch(self, **kwargs):
)
)

if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
warpers_to_add.append(
XTCLogitsWarper(
threshold=generation_config.xtc_threshold,
probability=generation_config.xtc_probability,
)
)

if generation_config.dynamic_temperature:
warpers_to_add.append(
DynamicTemperatureLogitsWarper(
Expand Down Expand Up @@ -534,6 +590,7 @@ def get_logits_processor_patch(self, **kwargs):
'TopKLogitsWarper': 'top_k',
'TopPLogitsWarper': 'top_p',
'TypicalLogitsWarper': 'typical_p',
'XTCLogitsWarper': 'xtc',
'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty',
'PresencePenaltyLogitsProcessor': 'presence_penalty',
'FrequencyPenaltyLogitsProcessor': 'frequency_penalty',
Expand Down Expand Up @@ -588,8 +645,10 @@ def generation_config_init_patch(self, **kwargs):
self.dry_base = kwargs.pop("dry_base", 1.75)
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
self.xtc_probability = kwargs.pop("xtc_probability", 0)
self.temperature_last = kwargs.pop("temperature_last", False)
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'encoder_repetition_penalty', 'no_repeat_ngram'])
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])


def hijack_samplers():
Expand Down
2 changes: 1 addition & 1 deletion modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):

def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers']:
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']:
if k in state:
generate_params[k] = state[k]

Expand Down
2 changes: 2 additions & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def list_interface_input_elements():
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'xtc_threshold',
'xtc_probability',
'do_sample',
'penalty_alpha',
'mirostat_mode',
Expand Down
4 changes: 4 additions & 0 deletions modules/ui_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def create_ui(default_preset):
shared.gradio['dry_base'] = gr.Slider(1, 4, value=generate_params['dry_base'], step=0.01, label='dry_base', info='Controls how fast the penalty grows with increasing sequence length.')
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')

with gr.Blocks():
shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=generate_params['xtc_threshold'], step=0.01, label='xtc_threshold', info='If 2 or more tokens have probability above this threshold, consider removing all but the last one.')
oobabooga marked this conversation as resolved.
Show resolved Hide resolved
shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=generate_params['xtc_probability'], step=0.01, label='xtc_probability', info='Probability that the removal will actually happen. 0 disables the sampler. 1 makes it always happen.')

gr.Markdown("[Learn more](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab)")

with gr.Column():
Expand Down