-
Notifications
You must be signed in to change notification settings - Fork 871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about Encoder Logic #87
Comments
I think it's a little bit different, but the effect should be the same (your python version should be higher than 3.7). Your implementation completely iterates over all merge items, but the original code can jump out. I think the reason the original code was written this way was to prevent the dictionary order might not be in the order it was added. Karpathy seems to mention this in the video, but the issue was fixed in py3.8. |
As @202030481266 mentioned, your simpler version iterates over all of the merges made in the vocabulary. For a realistic tokenizer this is a lot of merges (~50k for GPT2, 200k+ for GPT4o) so at a practical scale your approach would require a lot more work and most of the merges applied would not even be in the chunk of text being processed. |
I prefer to do this using: def encode2(text):
tokens = list(text.encode('utf-8'))
# make the sorted merges dict
pairs = sorted(merges, key=merges.get)
for pair in pairs:
tokens = merge(tokens, pair, merges[pair])
return tokens Hope it will be helpful to you. |
The code is logically equivalent but differs in efficiency. Both karpathy's code and encode2 depend on the dict order being the order of insertion |
I came up with the same import random
import time
encode_time = 0
encode2_time = 0
test_cases = []
for _ in range(500000):
l = random.randint(0, 30)
s = random.randint(0, len(text))
e = min(len(text)-1, s+l)
test = text[s:e]
t1 = time.time()
e1 = encode(test)
t2 = time.time()
encode_time += (t2-t1)
t1 = time.time()
e2 = encode2(test)
#print(e1, e2)
t2 = time.time()
encode2_time += (t2-t1)
assert e1 == e2
#test_cases.append(text[s:e])
print(f'encode time: {encode_time:5f}s')
print(f'encode2 time: {encode2_time:5f}s')
print(f'encode2/encode ratio: {encode2_time/encode_time:2f}') encode time: 10.358455s def encode(text):
""" encode text into tokens"""
# first get the int repr of unicode bytes
ids = list(text.encode('utf-8'))
while len(ids) >= 2:
stats = get_stats(ids)
# the keys in merges are ordered by insertion order
# get the pair from stats that was the earliest to be merged
pair = min(stats, key=lambda p: merges.get(p, float("inf")))
if pair not in merges:
break
ids = merge(ids, pair, merges[pair])
return ids
def encode2(text):
""" naive encode that is slower"""
ids = list(text.encode('utf-8'))
for pair, merged_id in merges.items():
ids = merge(ids, pair, merged_id)
return ids |
I noticed the encode() method has extra logic with a while loop to find the lowest merge index:
Can we simplify it like this:
Since merge() merges all occurrences, it seems a simple for loop suffices. Is there a reason for the more complex logic?
I have trained my tokenizer vs the basictokenizer on some text data, and achieved the exact same vocab & encoder.
Maybe I missed something. Could you clarify?
Thanks!
Update:
I made a pytest from my forked repo just to show mine is also correct:
For anyone interested to try out
The text was updated successfully, but these errors were encountered: