Skip to content

Commit

Permalink
consistent shape between logits and generated_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Oct 4, 2024
1 parent 4aa4ebd commit 212d4cd
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion exllamav2/generator/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,9 @@ def sample(
# Apply logits processor

if settings.logits_processor:
generated_ids = sequence_ids[:, input_ids.shape[1]:]
generated_ids = sequence_ids[:, input_ids.shape[1]:].view(
logits.shape[:-1] + sequence_ids.shape[-1:] # ensure consistent batch dimensions
)
logits = settings.logits_processor(generated_ids, logits)

# Prepare filter
Expand Down

0 comments on commit 212d4cd

Please sign in to comment.