Skip to content

Commit

Permalink
Exclude Top Choices (XTC): A sampler that boosts creativity, breaks w…
Browse files Browse the repository at this point in the history
…riting clichés, and inhibits non-verbatim repetition (oobabooga#6335)
  • Loading branch information
p-e-w authored and Olivier Gagnon committed Nov 2, 2024
1 parent b682b84 commit a64573f
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 76 deletions.
2 changes: 2 additions & 0 deletions extensions/openai/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class GenerationOptions(BaseModel):
dry_base: float = 1.75
dry_allowed_length: int = 2
dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"'
xtc_threshold: float = 0.1
xtc_probability: float = 0
truncation_length: int = 0
max_tokens_second: int = 0
prompt_lookup_num_tokens: int = 0
Expand Down
6 changes: 6 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def transformers_samplers():
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'xtc_threshold',
'xtc_probability',
'seed',
'do_sample',
'penalty_alpha',
Expand Down Expand Up @@ -242,6 +244,8 @@ def transformers_samplers():
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'xtc_threshold',
'xtc_probability',
'seed',
'do_sample',
'mirostat_mode',
Expand Down Expand Up @@ -304,6 +308,8 @@ def transformers_samplers():
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'xtc_threshold',
'xtc_probability',
'seed',
'do_sample',
'mirostat_mode',
Expand Down
4 changes: 3 additions & 1 deletion modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def default_preset():
'dry_base': 1.75,
'dry_allowed_length': 2,
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat'
'xtc_threshold': 0.1,
'xtc_probability': 0,
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram'
}


Expand Down
244 changes: 170 additions & 74 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import math
import pprint
import random

import torch
import transformers
Expand Down Expand Up @@ -191,6 +192,51 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


# Exclude Top Choices (XTC)
class XTCLogitsWarper(LogitsWarper):
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
self.threshold = threshold
self.probability = probability
self.filter_value = filter_value
self.special_token_ids = [
shared.tokenizer.encode("\n")[-1],
]

if shared.tokenizer.eos_token_id is not None:
self.special_token_ids.append(shared.tokenizer.eos_token_id)

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# `random` returns values in the half-open range [0, 1), so setting `probability`
# to 0 means the sampler never takes action, while setting it to 1 means the sampler
# always takes action.
#
# Note that while XTC is most intuitively described as "if multiple tokens meet
# the threshold, then with probability...", reversing the two conditions is logically
# equivalent, and improves performance because processing can immediately be stopped
# if the random check fails.
if random.random() >= self.probability:
return scores

sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)

sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool)

# This operation sets exactly those indices to `True` for which the next index has
# probability above the threshold. Since `probs` is sorted, those are the indices
# of all tokens that meet the threshold, *except* the least probable one.
sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold

# Convert sorted_indices_to_remove to the original indices
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

# If newline or EOS tokens would be removed, return the original scores
if indices_to_remove[:, self.special_token_ids].any():
return scores

# Otherwise, remove tokens with the mask
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class DRYLogitsProcessor(LogitsProcessor):
def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int):
self.multiplier = multiplier
Expand Down Expand Up @@ -323,62 +369,141 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to


class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
'''
Copied from the transformers library
'''

def __init__(self, penalty: float, presence_penalty: float, frequency_penalty: float, _range: int):
def __init__(self, penalty: float, _range: int):
if not (penalty > 0):
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")

self.penalty = penalty
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self._range = _range

def apply_repetition_penalty(self, input_ids_row, scores_row):
unique_ids = torch.unique(input_ids_row)
score = torch.gather(scores_row, 0, unique_ids)

# Apply multiplicative repetition penalty
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores_row.scatter_(0, unique_ids, score)
return scores_row

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores):
scores_row = self.apply_repetition_penalty(input_ids_row, scores_row)

return scores


class PresencePenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, presence_penalty: float, _range: int):
self.presence_penalty = presence_penalty
self._range = _range

def apply_presence_penalty(self, input_ids_row, scores_row):
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)

# Apply presence penalty
raw_presence_penalty = (counts > 0).to(scores_row.dtype)
presence_penalty = raw_presence_penalty * self.presence_penalty
scores_row.scatter_add_(0, unique_ids, -presence_penalty)
return scores_row

# We loop here because torch.unique() needs to process each row separately in the
# case that batch_size > 1.
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores):
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
score = torch.gather(scores_row, 0, unique_ids)
scores_row = self.apply_presence_penalty(input_ids_row, scores_row)
return scores

# multiplicative repetition penalty
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores_row.scatter_(0, unique_ids, score)

# presence_penalty and frequency_penalty
raw_presence_penalty = (counts > 0).to(scores.dtype)
raw_frequency_penalty = counts.to(scores.dtype)
additive_penalty = raw_presence_penalty * self.presence_penalty + raw_frequency_penalty * self.frequency_penalty
scores_row.scatter_add_(0, unique_ids, -additive_penalty)
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, frequency_penalty: float, _range: int):
self.frequency_penalty = frequency_penalty
self._range = _range

def apply_frequency_penalty(self, input_ids_row, scores_row):
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)

# Apply frequency penalty
raw_frequency_penalty = counts.to(scores_row.dtype)
frequency_penalty = raw_frequency_penalty * self.frequency_penalty
scores_row.scatter_add_(0, unique_ids, -frequency_penalty)
return scores_row

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self._range:]
for input_ids_row, scores_row in zip(input_ids, scores):
scores_row = self.apply_frequency_penalty(input_ids_row, scores_row)
return scores


def get_logits_warper_patch(self, generation_config, **kwargs):
def get_logits_processor_patch(self, **kwargs):
generation_config = kwargs['generation_config']

# Parameter sanitization
if isinstance(generation_config.temperature, int):
generation_config.temperature = float(generation_config.temperature) # Must be float

# Get the original warpers
warpers = self._get_logits_warper_old(generation_config, **kwargs)
warpers = self._get_logits_processor_old(**kwargs)

# Replace temperature with our modified class.
# Currently, it behaves identically to the original.
for i in range(len(warpers)):
for i in range(len(warpers) - 1, -1, -1):
# Replace temperature with our modified class.
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = TemperatureLogitsWarperCustom(
generation_config.temperature,
)

# Stuff we don't need
elif warpers[i].__class__.__name__ in ['SuppressTokensLogitsProcessor', 'RepetitionPenaltyLogitsProcessor']:
del warpers[i]

# Add custom warpers
warpers_to_add = LogitsProcessorList()
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1

if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
warpers_to_add.append(
RepetitionPenaltyLogitsProcessorWithRange(
penalty=generation_config.repetition_penalty,
_range=generation_config.repetition_penalty_range
)
)

if generation_config.presence_penalty is not None and generation_config.presence_penalty != 0.0:
warpers_to_add.append(
PresencePenaltyLogitsProcessor(
presence_penalty=generation_config.presence_penalty,
_range=generation_config.repetition_penalty_range
)
)

if generation_config.frequency_penalty is not None and generation_config.frequency_penalty != 0.0:
warpers_to_add.append(
FrequencyPenaltyLogitsProcessor(
frequency_penalty=generation_config.frequency_penalty,
_range=generation_config.repetition_penalty_range
)
)

if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
dry_sequence_breakers = generation_config.dry_sequence_breakers

# Support both JSON array notation and comma-separated strings.
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"

sequence_breaker_strings = json.loads(dry_sequence_breakers)
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}

warpers.append(
DRYLogitsProcessor(
multiplier=generation_config.dry_multiplier,
base=generation_config.dry_base,
allowed_length=generation_config.dry_allowed_length,
sequence_breakers=sequence_breakers,
_range=generation_config.repetition_penalty_range,
)
)

if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
warpers_to_add.append(
TailFreeLogitsWarper(
Expand All @@ -395,6 +520,14 @@ def get_logits_warper_patch(self, generation_config, **kwargs):
)
)

if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
warpers_to_add.append(
XTCLogitsWarper(
threshold=generation_config.xtc_threshold,
probability=generation_config.xtc_probability,
)
)

if generation_config.dynamic_temperature:
warpers_to_add.append(
DynamicTemperatureLogitsWarper(
Expand Down Expand Up @@ -454,7 +587,14 @@ def get_logits_warper_patch(self, generation_config, **kwargs):
'TopALogitsWarper': 'top_a',
'TopKLogitsWarper': 'top_k',
'TopPLogitsWarper': 'top_p',
'TypicalLogitsWarper': 'typical_p'
'TypicalLogitsWarper': 'typical_p',
'XTCLogitsWarper': 'xtc',
'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty',
'PresencePenaltyLogitsProcessor': 'presence_penalty',
'FrequencyPenaltyLogitsProcessor': 'frequency_penalty',
'DRYLogitsProcessor': 'dry',
'EncoderRepetitionPenaltyLogitsProcessor': 'encoder_repetition_penalty',
'NoRepeatNGramLogitsProcessor': 'no_repeat_ngram',
}

def custom_sort_key(obj):
Expand Down Expand Up @@ -482,49 +622,6 @@ def custom_sort_key(obj):
return warpers


def get_logits_processor_patch(self, **kwargs):
generation_config = kwargs['generation_config']

do_rep_pen_hijack = (generation_config.repetition_penalty > 1) or (generation_config.presence_penalty != 0) or (generation_config.frequency_penalty != 0)
if do_rep_pen_hijack:
generation_config.repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created

result = self._get_logits_processor_old(**kwargs)

if do_rep_pen_hijack:
for i in range(len(result)):
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
result[i] = RepetitionPenaltyLogitsProcessorWithRange(
generation_config.repetition_penalty,
generation_config.presence_penalty,
generation_config.frequency_penalty,
generation_config.repetition_penalty_range
)

if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
dry_sequence_breakers = generation_config.dry_sequence_breakers

# Support both JSON array notation and comma-separated strings.
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"

sequence_breaker_strings = json.loads(dry_sequence_breakers)
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}

result.append(
DRYLogitsProcessor(
multiplier=generation_config.dry_multiplier,
base=generation_config.dry_base,
allowed_length=generation_config.dry_allowed_length,
sequence_breakers=sequence_breakers,
_range=generation_config.repetition_penalty_range,
)
)

return result


def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs)
self.min_p = kwargs.pop("min_p", 0.0)
Expand All @@ -546,14 +643,13 @@ def generation_config_init_patch(self, **kwargs):
self.dry_base = kwargs.pop("dry_base", 1.75)
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
self.xtc_probability = kwargs.pop("xtc_probability", 0)
self.temperature_last = kwargs.pop("temperature_last", False)
self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat'])
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])


def hijack_samplers():
transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch

transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch

Expand Down
2 changes: 1 addition & 1 deletion modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):

def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers']:
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'dry_multiplier', 'dry_base', 'dry_allowed_length', 'dry_sequence_breakers', 'xtc_threshold', 'xtc_probability']:
if k in state:
generate_params[k] = state[k]

Expand Down
2 changes: 2 additions & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def list_interface_input_elements():
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'xtc_threshold',
'xtc_probability',
'do_sample',
'penalty_alpha',
'mirostat_mode',
Expand Down
Loading

0 comments on commit a64573f

Please sign in to comment.