From f1232b1851966781901fe86322b52db97ccdf459 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Sun, 18 Aug 2024 15:59:44 +0530 Subject: [PATCH 1/5] Add Exclude Top Choices (XTC) sampler --- extensions/openai/typing.py | 2 ++ modules/loaders.py | 6 +++++ modules/presets.py | 4 ++- modules/sampler_hijack.py | 50 +++++++++++++++++++++++++++++++++++-- modules/text_generation.py | 2 +- modules/ui.py | 2 ++ modules/ui_parameters.py | 4 +++ 7 files changed, 66 insertions(+), 4 deletions(-) diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 4015f6a1ce..f63c1f3911 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -37,6 +37,8 @@ class GenerationOptions(BaseModel): dry_base: float = 1.75 dry_allowed_length: int = 2 dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"' + xtc_threshold: float = 0.1 + xtc_probability: float = 0 truncation_length: int = 0 max_tokens_second: int = 0 prompt_lookup_num_tokens: int = 0 diff --git a/modules/loaders.py b/modules/loaders.py index 549de5fb02..16a3106e94 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -168,6 +168,8 @@ def transformers_samplers(): 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', + 'xtc_threshold', + 'xtc_probability', 'seed', 'do_sample', 'penalty_alpha', @@ -242,6 +244,8 @@ def transformers_samplers(): 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', + 'xtc_threshold', + 'xtc_probability', 'seed', 'do_sample', 'mirostat_mode', @@ -304,6 +308,8 @@ def transformers_samplers(): 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', + 'xtc_threshold', + 'xtc_probability', 'seed', 'do_sample', 'mirostat_mode', diff --git a/modules/presets.py b/modules/presets.py index b00e829eb1..651d863433 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -44,7 +44,9 @@ def default_preset(): 'dry_base': 1.75, 'dry_allowed_length': 2, 'dry_sequence_breakers': '"\\n", ":", "\\"", "*"', - 'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat' + 'xtc_threshold': 0.1, + 'xtc_probability': 0, + 'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc' } diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 9fb661aecc..5a82e7182f 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -1,6 +1,7 @@ import json import math import pprint +import random import torch import transformers @@ -191,6 +192,40 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores +# Exclude Top Choices (XTC) +class XTCLogitsWarper(LogitsWarper): + def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")): + self.threshold = threshold + self.probability = probability + self.filter_value = filter_value + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # `random` returns values in the half-open range [0, 1), so setting `probability` + # to 0 means the sampler never takes action, while setting it to 1 means the sampler + # always takes action. + # + # Note that while XTC is most intuitively described as "if multiple tokens meet + # the threshold, then with probability...", reversing the two conditions is logically + # equivalent, and improves performance because processing can immediately be stopped + # if the random check fails. + if random.random() >= self.probability: + return scores + + sorted_logits, sorted_indices = torch.sort(scores, descending=True) + probs = sorted_logits.softmax(dim=-1) + + sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool) + + # This operation sets exactly those indices to `True` for which the next index has + # probability above the threshold. Since `probs` is sorted, those are the indices + # of all tokens that meet the threshold, *except* the least probable one. + sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + class DRYLogitsProcessor(LogitsProcessor): def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int): self.multiplier = multiplier @@ -395,6 +430,14 @@ def get_logits_warper_patch(self, generation_config, **kwargs): ) ) + if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0: + warpers_to_add.append( + XTCLogitsWarper( + threshold=generation_config.xtc_threshold, + probability=generation_config.xtc_probability, + ) + ) + if generation_config.dynamic_temperature: warpers_to_add.append( DynamicTemperatureLogitsWarper( @@ -454,7 +497,8 @@ def get_logits_warper_patch(self, generation_config, **kwargs): 'TopALogitsWarper': 'top_a', 'TopKLogitsWarper': 'top_k', 'TopPLogitsWarper': 'top_p', - 'TypicalLogitsWarper': 'typical_p' + 'TypicalLogitsWarper': 'typical_p', + 'XTCLogitsWarper': 'xtc', } def custom_sort_key(obj): @@ -546,8 +590,10 @@ def generation_config_init_patch(self, **kwargs): self.dry_base = kwargs.pop("dry_base", 1.75) self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2) self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"') + self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1) + self.xtc_probability = kwargs.pop("xtc_probability", 0) self.temperature_last = kwargs.pop("temperature_last", False) - self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat']) + self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc']) def hijack_samplers(): diff --git a/modules/text_generation.py b/modules/text_generation.py index 75e5ef36ae..7aa89674d5 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -284,7 +284,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers']: + for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']: if k in state: generate_params[k] = state[k] diff --git a/modules/ui.py b/modules/ui.py index 47f92cf0f9..76b1c009c4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -158,6 +158,8 @@ def list_interface_input_elements(): 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', + 'xtc_threshold', + 'xtc_probability', 'do_sample', 'penalty_alpha', 'mirostat_mode', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 234e1af2a9..020d6edfa8 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -45,6 +45,10 @@ def create_ui(default_preset): shared.gradio['dry_allowed_length'] = gr.Slider(1, 20, value=generate_params['dry_allowed_length'], step=1, label='dry_allowed_length', info='Longest sequence that can be repeated without being penalized.') shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.') + with gr.Blocks(): + shared.gradio['xtc_threshold'] = gr.Slider(0, 1, value=generate_params['xtc_threshold'], step=0.01, label='xtc_threshold', info='If there are multiple tokens with predicted probability at least xtc_threshold...') + shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=generate_params['xtc_probability'], step=0.01, label='xtc_probability', info='...remove all except the least probable one from sampling, with probability xtc_probability.') + gr.Markdown("[Learn more](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab)") with gr.Column(): From 5176eaaaf7c823898ee4ab8aa6db63c1f821cc91 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Tue, 20 Aug 2024 07:56:02 +0530 Subject: [PATCH 2/5] Change maximum of `xtc_threshold` slider to 0.5 XTC only takes effect if at least *two* tokens are above the threshold, so values larger than 0.5 do not make sense --- modules/ui_parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 020d6edfa8..05b13bfac8 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -46,7 +46,7 @@ def create_ui(default_preset): shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.') with gr.Blocks(): - shared.gradio['xtc_threshold'] = gr.Slider(0, 1, value=generate_params['xtc_threshold'], step=0.01, label='xtc_threshold', info='If there are multiple tokens with predicted probability at least xtc_threshold...') + shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=generate_params['xtc_threshold'], step=0.01, label='xtc_threshold', info='If there are multiple tokens with predicted probability at least xtc_threshold...') shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=generate_params['xtc_probability'], step=0.01, label='xtc_probability', info='...remove all except the least probable one from sampling, with probability xtc_probability.') gr.Markdown("[Learn more](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab)") From af7b57cce01ba4549f6509e3210de5cf1580caef Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 21 Aug 2024 11:03:30 -0700 Subject: [PATCH 3/5] Update the descriptions --- modules/ui_parameters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 28a8d611a4..a2665e0d21 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -46,8 +46,8 @@ def create_ui(default_preset): shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.') with gr.Blocks(): - shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=generate_params['xtc_threshold'], step=0.01, label='xtc_threshold', info='If there are multiple tokens with predicted probability at least xtc_threshold...') - shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=generate_params['xtc_probability'], step=0.01, label='xtc_probability', info='...remove all except the least probable one from sampling, with probability xtc_probability.') + shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=generate_params['xtc_threshold'], step=0.01, label='xtc_threshold', info='If 2 or more tokens have probability above this threshold, consider removing all but the last one.') + shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=generate_params['xtc_probability'], step=0.01, label='xtc_probability', info='Probability that the removal will actually happen. 0 disables the sampler. 1 makes it always happen.') gr.Markdown("[Learn more](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab)") From 0f62744df1d1a8c9cba83539c6d9869326864ad9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 2 Sep 2024 19:54:47 -0700 Subject: [PATCH 4/5] Check for EOS and \n --- modules/sampler_hijack.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 5a82e7182f..e67b46a706 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -198,6 +198,12 @@ def __init__(self, threshold: float, probability: float, filter_value: float = - self.threshold = threshold self.probability = probability self.filter_value = filter_value + self.special_token_ids = [ + shared.tokenizer.encode("\n")[-1], + ] + + if shared.tokenizer.eos_token_id is not None: + self.special_token_ids.append(shared.tokenizer.eos_token_id) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # `random` returns values in the half-open range [0, 1), so setting `probability` @@ -221,7 +227,14 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to # of all tokens that meet the threshold, *except* the least probable one. sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold + # Convert sorted_indices_to_remove to the original indices indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + # If newline or EOS tokens would be removed, return the original scores + if indices_to_remove[:, self.special_token_ids].any() + return scores + + # Otherwise, remove tokens with the mask scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores From 29d38a74ba8d47206f2371b44a746111fac80f4c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 27 Sep 2024 18:34:07 -0700 Subject: [PATCH 5/5] Add missing : --- modules/sampler_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index bfe97e963f..6d92978e6e 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -231,7 +231,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) # If newline or EOS tokens would be removed, return the original scores - if indices_to_remove[:, self.special_token_ids].any() + if indices_to_remove[:, self.special_token_ids].any(): return scores # Otherwise, remove tokens with the mask