From 8b99de9b41deb6751300dd415731581cccd731ac Mon Sep 17 00:00:00 2001 From: SinanAkkoyun Date: Sat, 30 Sep 2023 20:59:11 +0200 Subject: [PATCH] Fixed speculative generator: Updated to new sampler usage. --- exllamav2/generator/speculative.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/exllamav2/generator/speculative.py b/exllamav2/generator/speculative.py index fc3ef54a..ada4b73c 100644 --- a/exllamav2/generator/speculative.py +++ b/exllamav2/generator/speculative.py @@ -79,7 +79,7 @@ def generate_simple(self, prompt: str, gen_settings: ExLlamaV2Sampler.Settings, logits = self.draft_model.forward(predict_ids[:, -1:], self.draft_cache).float().cpu() past = torch.cat((self.sequence_ids[:, :-1], predict_ids), dim = 1) - token, prob = ExLlamaV2Sampler.sample(logits, draft_settings, past, randoms[i]) + token, prob, _ = ExLlamaV2Sampler.sample(logits, draft_settings, past, randoms[i], self.tokenizer) predict_ids = torch.cat([predict_ids, token], dim = 1) used_predict_len += 1 if prob < self.prob_threshold: break @@ -91,7 +91,7 @@ def generate_simple(self, prompt: str, gen_settings: ExLlamaV2Sampler.Settings, tokens = 0 while True: - token, _ = ExLlamaV2Sampler.sample(logits[:, tokens : tokens + 1, :], gen_settings, self.sequence_ids, randoms[tokens]) + token, _, _ = ExLlamaV2Sampler.sample(logits[:, tokens : tokens + 1, :], gen_settings, self.sequence_ids, randoms[tokens], self.tokenizer) self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1) tokens += 1 if tokens == used_predict_len or token != predict_ids[:, tokens]: break @@ -101,7 +101,7 @@ def generate_simple(self, prompt: str, gen_settings: ExLlamaV2Sampler.Settings, if tokens == used_predict_len and token == predict_ids[:, tokens]: - token, _ = ExLlamaV2Sampler.sample(logits[:, tokens : tokens + 1, :], gen_settings, self.sequence_ids, randoms[tokens]) + token, _, _ = ExLlamaV2Sampler.sample(logits[:, tokens : tokens + 1, :], gen_settings, self.sequence_ids, randoms[tokens], self.tokenizer) self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1) tokens += 1 self.draft_model.forward(self.sequence_ids[:, -1:], self.draft_cache, preprocess_only = True)