From 212d4cdc40467d6e2ab8bcf66db2ada2254a0992 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 4 Oct 2024 16:55:34 -0400 Subject: [PATCH] consistent shape between logits and generated_ids --- exllamav2/generator/sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index a6bc113e..ef3bc311 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -364,7 +364,9 @@ def sample( # Apply logits processor if settings.logits_processor: - generated_ids = sequence_ids[:, input_ids.shape[1]:] + generated_ids = sequence_ids[:, input_ids.shape[1]:].view( + logits.shape[:-1] + sequence_ids.shape[-1:] # ensure consistent batch dimensions + ) logits = settings.logits_processor(generated_ids, logits) # Prepare filter