Skip to content

Commit

Permalink
Check for EOS and \n
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga committed Sep 3, 2024
1 parent af7b57c commit 0f62744
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def __init__(self, threshold: float, probability: float, filter_value: float = -
self.threshold = threshold
self.probability = probability
self.filter_value = filter_value
self.special_token_ids = [
shared.tokenizer.encode("\n")[-1],
]

if shared.tokenizer.eos_token_id is not None:
self.special_token_ids.append(shared.tokenizer.eos_token_id)

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# `random` returns values in the half-open range [0, 1), so setting `probability`
Expand All @@ -221,7 +227,14 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
# of all tokens that meet the threshold, *except* the least probable one.
sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold

# Convert sorted_indices_to_remove to the original indices
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

# If newline or EOS tokens would be removed, return the original scores
if indices_to_remove[:, self.special_token_ids].any()
return scores

# Otherwise, remove tokens with the mask
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores

Expand Down

0 comments on commit 0f62744

Please sign in to comment.