Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BPE Tokenizer: Multiple newlines doesn't merge into a single token #6809

Closed
Lyrcaxis opened this issue Apr 21, 2024 · 10 comments
Closed

BPE Tokenizer: Multiple newlines doesn't merge into a single token #6809

Lyrcaxis opened this issue Apr 21, 2024 · 10 comments

Comments

@Lyrcaxis
Copy link

Lyrcaxis commented Apr 21, 2024

So, I found out that \n\n if appended by a character tokenizes as ['\n',\n'] ([198, 198]) instead of ['\n\n'] ([271]).
(I'm using Llama3 for this example, but this extends to other models as well)

Here's an example prompt:

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You're Psy, user's assistant, and a master of concise replies.<|eot_id|><|start_header_id|>user<|end_header_id|>

Write a short poem<|eot_id|><|start_header_id|>assistant<|end_header_id|>


And the tokenized text:
image

If I switch the template to use \n\n\n\n (1038) it tokenizes as ['\n\n\n', '\n'] ([1432, 198]):
image

(Note: I know there've been efforts in making special tokens render, but rn I understand they don't have a textual representation, so you can ignore tokens like 128000, 128006 and 128007 in the sequences above)

In C# I patch the issue like so:

var tokensCount = NativeApi.llama_tokenize(model, bytesPtr, bytes.Length, tokensPtr, tokenBuffer.Length, add_bos, special);
var list = new List<LLamaToken>();
for (int i = 0; i < tokensCount; i++) { // Hack: ['\n','\n'] --> ['\n\n']
    if (tokenBuffer[i] == 198 && tokenBuffer[i + 1] == 198) { list.Add(271); i++; }
    else { list.Add(tokenBuffer[i]); }
}
return list.ToArray();

(ignoring all \n merges except the \n\n which is common for the template)

@MarcusDunn
Copy link
Contributor

MarcusDunn commented Apr 22, 2024

I'm also running into this. It seems to degrade performance for llama-3-instruct. (Hackily replacing two new line with the single token improves performance anecdotally)

I'd imagine there are other cases where the tokenization is not as greedy as possible - unsure how this would affect model performance though.

@bullno1
Copy link
Contributor

bullno1 commented Apr 23, 2024

The bpe_gpt2_preprocess split the string \n\nword in a bit of a strange way: \n, \nword.

See: #5613

@LostRuins
Copy link
Collaborator

LostRuins commented Apr 24, 2024

I am encountering a similar issue. For me, the model likes to generate token .\n (id=627) at the end of the sentence. However, when retokenizing the string subsequently I instead get two disjoint tokens . (id=13), and \n (id=198)

Same thing with various other tokens like .\n\n (id=382)
Something is really broken with the merging behavior related to newlines.

@Lyrcaxis I don't think your hack is sufficient. Due to the massive vocab size of llama 3 there are many combinations relating to newlines that the model picks and this bug affects, another one seems to be ---\n (id=11192)

@ggerganov
Copy link
Owner

Does anyone know what regex is used by LLaMA 3 to preprocess the text?

In llama.cpp we currently implement just this:

llama.cpp/llama.cpp

Lines 12128 to 12135 in aa750c1

std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
std::vector<std::string> bpe_words;
std::vector<std::string> bpe_encoded_words;
std::string token = "";
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
bool collecting_numeric = false;
bool collecting_letter = false;

My guess is that we need to update it with whatever is used for LLaMA 3

@LostRuins
Copy link
Collaborator

I'm not sure how accurate this is, but here is a possible reference which appears to at least merge \n\n correctly: https://raw.githubusercontent.com/belladoreai/llama3-tokenizer-js/master/llama-tokenizer.js

However, it's not just implemented as a clean regex but appears to have some additional processing too.

@ggerganov
Copy link
Owner

I see, this is useful. We'll need to support that. There has been some work started in #6252 to improve BPE preprocessing. I guess we have to prioritize this, since this likely leads to poor generation quality

@MarcusDunn
Copy link
Contributor

Does anyone know what regex is used by LLaMA 3 to preprocess the text?

Is this what you'd be looking for?

https://github.com/meta-llama/llama3/blob/af6eedf7042fb51d00b2b26d8ef1ceaab73e1670/llama/tokenizer.py#L47

@jaime-m-p
Copy link
Collaborator

I have Llama3 regex implementation.

I did some tests, generating texts (randomly merging strings from tokenizer.json) and comparing encodings to tiktoken's encoding.

The main indea is first annotate all matched character lengths in tokens_length, then build the bpe_encoded_words.

If this is useful, I can do a PR.

    std::vector<std::string> bpe_llama3_preprocess(const std::string & text) {
        // LLAMA3 Regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
        const auto cpts = unicode_cpts_from_utf8(text);
        const int num_cpts = (int)cpts.size();

        auto _tolower = [] (const int cpt) -> int {
            if ('A' <= cpt && cpt <= 'Z')
                return cpt + ('a'-'A');
            return cpt;
        };

        auto _get_cpt = [&] (const int pos) -> uint32_t {
            return (0 <= pos && pos < num_cpts) ? cpts[pos] : 0;
        };

        auto _get_cpt_type = [&] (const int pos) -> int {
            return (0 <= pos && pos < num_cpts) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
        };

        std::vector<int> tokens_length;
        tokens_length.reserve(cpts.size()/3+4);
        int _prev_end = 0;
        auto _add_token = [&] (const int end) -> int {
            GGML_ASSERT(_prev_end <= end && end <= num_cpts);
            int len = end - _prev_end;
            if(len > 0)
                tokens_length.push_back(len);
            _prev_end = end;
            //if( len && true ) {
            //    std::string s = "";
            //    for( int p = end-len; p < end; p++ )
            //        s += unicode_cpt_to_utf8(cpts[p]);
            //    printf( ">>> '%s'\n", s.c_str() );
            //}
            return len;
        };

        int pos = 0;
        while (pos < num_cpts) {

            const uint32_t cpt = _get_cpt(pos);
            const int cpt_type = _get_cpt_type(pos);

            // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
            if (cpt == '\'' && pos+1 < num_cpts) {
                uint32_t cpt_next = _tolower(_get_cpt(pos+1));
                if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
                    pos += _add_token(pos+2);
                    continue;
                } else if (pos+2 < num_cpts) {
                    uint32_t cpt_next_next = _tolower(_get_cpt(pos+2));
                    if ((cpt_next == 'r' && cpt_next_next == 'e') ||
                        (cpt_next == 'v' && cpt_next_next == 'e') ||
                        (cpt_next == 'l' && cpt_next_next == 'l')) {
                        pos += _add_token(pos+3);
                        continue;
                    }
                }
            }

            // regex: [^\r\n\p{L}\p{N}]?\p{L}+  //####FIXME: the first \p{L} is correct?
            if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_DIGIT) {
                if(cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) {  // one or more letters
                    pos++;
                    while(_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER)
                        pos++;
                    _add_token(pos);
                    continue;
                }
            }

            // regex: \p{N}{1,3}
            if (cpt_type == CODEPOINT_TYPE_DIGIT) {
                int ini = pos;
                while(_get_cpt_type(pos) == CODEPOINT_TYPE_DIGIT) {
                    if (++pos - ini >= 3 ) {
                        _add_token(pos);
                        ini = pos;
                    }
                }
                _add_token(pos);
                continue;
            }

            // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
            uint32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
            int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
            if (cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_DIGIT && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
                pos += (cpt == ' ');
                while(cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_DIGIT && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED)
                    cpt2_type = _get_cpt_type(++pos);
                cpt2 = _get_cpt(pos);
                while(cpt2 == '\r' || cpt2 == '\n')
                    cpt2 = _get_cpt(++pos);
                _add_token(pos);
                continue;
            }

            int num_whitespaces = 0;
            int last_pos_r_or_n = -1;
            while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_WHITESPACE) {
                cpt2 = _get_cpt(pos+num_whitespaces);
                if (cpt2 == '\r' || cpt2 == '\n')
                    last_pos_r_or_n = pos+num_whitespaces;
                num_whitespaces++;
            }

            // regex: \s*[\r\n]+
            if (last_pos_r_or_n >= 0) {
                pos = last_pos_r_or_n + 1;
                _add_token(pos);
                continue;
            }

            // regex: \s+(?!\S)
            if(num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
                pos += num_whitespaces - 1;
                _add_token(pos);
                continue;
            }

            // regex: \s+
            if(num_whitespaces > 0) {
                pos += num_whitespaces;
                _add_token(pos);
                continue;
            }

            // no matches
            _add_token(++pos);
        }

        GGML_ASSERT(pos == num_cpts);
        _add_token(pos);

        pos = 0;
        std::vector<std::string> bpe_encoded_words(tokens_length.size());
        for (int n = 0; n < (int)tokens_length.size(); n++) {
            std::string &encoded_token = bpe_encoded_words[n];
            const int length = tokens_length[n];
            GGML_ASSERT(length > 0);
            for (int i = 0; i < length; i++) {
                std::string char_utf8 = unicode_cpt_to_utf8(cpts[pos++]);
                for (char c : char_utf8) {
                    encoded_token += unicode_byte_to_utf8(c);
                }
            }
        }

        GGML_ASSERT(pos == num_cpts);
        return bpe_encoded_words;
    }

@ggerganov
Copy link
Owner

The issue should be fixed with #6920

@Lyrcaxis
Copy link
Author

Lyrcaxis commented Apr 29, 2024

Awesome! Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants