Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about Encoder Logic #87

Open
JackxTong opened this issue Jul 20, 2024 · 5 comments
Open

Question about Encoder Logic #87

JackxTong opened this issue Jul 20, 2024 · 5 comments

Comments

@JackxTong
Copy link

JackxTong commented Jul 20, 2024

I noticed the encode() method has extra logic with a while loop to find the lowest merge index:

    def encode(self, text):
        text_bytes = text.encode("utf-8") # raw bytes
        ids = list(text_bytes) # list of integers in range 0..255
        while len(ids) >= 2:
            stats = get_stats(ids)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break # nothing else can be merged anymore
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        return ids

Can we simplify it like this:

    def encode(self, text):
        tokens = text.encode("utf-8")
        tokens = list(map(int, tokens))
        for pair, index in self.merges.items():
            tokens = merge(tokens, pair, index)
        return tokens

Since merge() merges all occurrences, it seems a simple for loop suffices. Is there a reason for the more complex logic?
I have trained my tokenizer vs the basictokenizer on some text data, and achieved the exact same vocab & encoder.
Maybe I missed something. Could you clarify?

Thanks!

Update:
I made a pytest from my forked repo just to show mine is also correct:
For anyone interested to try out

@202030481266
Copy link

I think it's a little bit different, but the effect should be the same (your python version should be higher than 3.7). Your implementation completely iterates over all merge items, but the original code can jump out. I think the reason the original code was written this way was to prevent the dictionary order might not be in the order it was added. Karpathy seems to mention this in the video, but the issue was fixed in py3.8.

@alexandermorgan
Copy link

As @202030481266 mentioned, your simpler version iterates over all of the merges made in the vocabulary. For a realistic tokenizer this is a lot of merges (~50k for GPT2, 200k+ for GPT4o) so at a practical scale your approach would require a lot more work and most of the merges applied would not even be in the chunk of text being processed.
But there is a little overly complex thing in Karpathy's code here. He calls get_stats inside of encode but only uses the keys from the get_stats dictionary. Since we're only using the keys here, there's no sense in going through the trouble of calculating the values (which is the point of get_stats). So instead of using get_stats(ids) it would be a lot less work to line up the consecutive pairs like this zip(ids, ids[1:]). Even if the ids list is only one element long that will still work correctly without throwing an out of range error.

@demouo
Copy link

demouo commented Oct 15, 2024

I prefer to do this using:

def encode2(text):
    tokens = list(text.encode('utf-8'))
    # make the sorted merges dict
    pairs = sorted(merges, key=merges.get)
    for pair in pairs:
        tokens = merge(tokens, pair, merges[pair])
    return tokens

Hope it will be helpful to you.

@satyagupte
Copy link

I think it's a little bit different, but the effect should be the same (your python version should be higher than 3.7). Your implementation completely iterates over all merge items, but the original code can jump out. I think the reason the original code was written this way was to prevent the dictionary order might not be in the order it was added. Karpathy seems to mention this in the video, but the issue was fixed in py3.8.

The code is logically equivalent but differs in efficiency. Both karpathy's code and encode2 depend on the dict order being the order of insertion

@satyagupte
Copy link

satyagupte commented Dec 10, 2024

I came up with the same encode2 func. Yes, you are right encode2 is simpler and is logically equivalent, but is slower( by ~ 77% on my randomly generated test cases)

import random
import time
encode_time = 0
encode2_time = 0

test_cases = []
for _ in range(500000):
    l = random.randint(0, 30)
    s = random.randint(0, len(text))
    e = min(len(text)-1, s+l)
    test = text[s:e]
    t1 = time.time()
    e1 = encode(test)
    t2 = time.time()
    encode_time += (t2-t1)
    t1 = time.time()
    e2 = encode2(test)
    #print(e1, e2)
    t2 = time.time() 
    encode2_time += (t2-t1)
    
    assert e1 == e2
    #test_cases.append(text[s:e])

print(f'encode time: {encode_time:5f}s')
print(f'encode2 time: {encode2_time:5f}s')
print(f'encode2/encode ratio: {encode2_time/encode_time:2f}')

encode time: 10.358455s
encode2 time: 18.415130s
encode2/encode ratio: 1.777787

def encode(text):
    """ encode text into tokens"""
    # first get the int repr of unicode bytes 
    ids = list(text.encode('utf-8'))
    while len(ids) >= 2: 
        stats = get_stats(ids)
        # the keys in merges are ordered by insertion order
        # get the pair from stats that was the earliest to be merged
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break
        ids = merge(ids, pair, merges[pair])
    return ids

def encode2(text):
    """ naive encode that is slower"""
    ids = list(text.encode('utf-8'))
    for pair, merged_id in merges.items():
        ids = merge(ids, pair, merged_id)
    return ids

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants