From 9c4a7c1553f0a01d71306b69af4b8df5ef511ea7 Mon Sep 17 00:00:00 2001 From: Vakarva <151787968+Vakarva@users.noreply.github.com> Date: Sun, 7 Apr 2024 13:46:39 -0700 Subject: [PATCH] Updated decode() method in GPT4Tokenizer so that it handles special tokens. --- minbpe/gpt4.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/minbpe/gpt4.py b/minbpe/gpt4.py index fcc65500..8c2dbba1 100644 --- a/minbpe/gpt4.py +++ b/minbpe/gpt4.py @@ -85,9 +85,18 @@ def _encode_chunk(self, text_bytes): return ids def decode(self, ids): - # we have to un-permute the bytes before we decode - text_bytes = b"".join(self.vocab[idx] for idx in ids) - text_bytes = bytes(self.inverse_byte_shuffle[b] for b in text_bytes) + # given ids (list of integers), return Python string + part_bytes = [] + for idx in ids: + if idx in self.vocab: + shuffled_bytes = self.vocab[idx] + unshuffled_bytes = bytes([self.inverse_byte_shuffle[b] for b in shuffled_bytes]) + part_bytes.append(unshuffled_bytes) + elif idx in self.inverse_special_tokens: + part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8")) + else: + raise ValueError(f"invalid token id: {idx}") + text_bytes = b"".join(part_bytes) text = text_bytes.decode("utf-8", errors="replace") return text