Skip to content

Commit

Permalink
Merge pull request #75 from SinanAkkoyun/fix-speculative-gen
Browse files Browse the repository at this point in the history
Fixed Speculative Generator
  • Loading branch information
turboderp authored Sep 30, 2023
2 parents 500d0c6 + 4838ef5 commit 90d29b0
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions exllamav2/generator/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def generate_simple(self, prompt: str, gen_settings: ExLlamaV2Sampler.Settings,

logits = self.draft_model.forward(predict_ids[:, -1:], self.draft_cache).float().cpu()
past = torch.cat((self.sequence_ids[:, :-1], predict_ids), dim = 1)
token, prob = ExLlamaV2Sampler.sample(logits, draft_settings, past, randoms[i])
token, prob, _ = ExLlamaV2Sampler.sample(logits, draft_settings, past, randoms[i], self.tokenizer)
predict_ids = torch.cat([predict_ids, token], dim = 1)
used_predict_len += 1
if prob < self.prob_threshold: break
Expand All @@ -91,7 +91,7 @@ def generate_simple(self, prompt: str, gen_settings: ExLlamaV2Sampler.Settings,
tokens = 0
while True:

token, _ = ExLlamaV2Sampler.sample(logits[:, tokens : tokens + 1, :], gen_settings, self.sequence_ids, randoms[tokens])
token, _, _ = ExLlamaV2Sampler.sample(logits[:, tokens : tokens + 1, :], gen_settings, self.sequence_ids, randoms[tokens], self.tokenizer)
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
tokens += 1
if tokens == used_predict_len or token != predict_ids[:, tokens]: break
Expand All @@ -101,7 +101,7 @@ def generate_simple(self, prompt: str, gen_settings: ExLlamaV2Sampler.Settings,

if tokens == used_predict_len and token == predict_ids[:, tokens]:

token, _ = ExLlamaV2Sampler.sample(logits[:, tokens : tokens + 1, :], gen_settings, self.sequence_ids, randoms[tokens])
token, _, _ = ExLlamaV2Sampler.sample(logits[:, tokens : tokens + 1, :], gen_settings, self.sequence_ids, randoms[tokens], self.tokenizer)
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
tokens += 1
self.draft_model.forward(self.sequence_ids[:, -1:], self.draft_cache, preprocess_only = True)
Expand Down

0 comments on commit 90d29b0

Please sign in to comment.