-
Notifications
You must be signed in to change notification settings - Fork 9.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BPE Tokenizer: Multiple newlines doesn't merge into a single token #6809
Comments
I'm also running into this. It seems to degrade performance for llama-3-instruct. (Hackily replacing two new line with the single token improves performance anecdotally) I'd imagine there are other cases where the tokenization is not as greedy as possible - unsure how this would affect model performance though. |
The See: #5613 |
I am encountering a similar issue. For me, the model likes to generate token Same thing with various other tokens like @Lyrcaxis I don't think your hack is sufficient. Due to the massive vocab size of llama 3 there are many combinations relating to newlines that the model picks and this bug affects, another one seems to be |
Does anyone know what regex is used by LLaMA 3 to preprocess the text? In Lines 12128 to 12135 in aa750c1
My guess is that we need to update it with whatever is used for LLaMA 3 |
I'm not sure how accurate this is, but here is a possible reference which appears to at least merge However, it's not just implemented as a clean regex but appears to have some additional processing too. |
I see, this is useful. We'll need to support that. There has been some work started in #6252 to improve BPE preprocessing. I guess we have to prioritize this, since this likely leads to poor generation quality |
Is this what you'd be looking for? |
I have Llama3 regex implementation. I did some tests, generating texts (randomly merging strings from tokenizer.json) and comparing encodings to tiktoken's encoding. The main indea is first annotate all matched character lengths in If this is useful, I can do a PR. std::vector<std::string> bpe_llama3_preprocess(const std::string & text) {
// LLAMA3 Regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
const auto cpts = unicode_cpts_from_utf8(text);
const int num_cpts = (int)cpts.size();
auto _tolower = [] (const int cpt) -> int {
if ('A' <= cpt && cpt <= 'Z')
return cpt + ('a'-'A');
return cpt;
};
auto _get_cpt = [&] (const int pos) -> uint32_t {
return (0 <= pos && pos < num_cpts) ? cpts[pos] : 0;
};
auto _get_cpt_type = [&] (const int pos) -> int {
return (0 <= pos && pos < num_cpts) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
};
std::vector<int> tokens_length;
tokens_length.reserve(cpts.size()/3+4);
int _prev_end = 0;
auto _add_token = [&] (const int end) -> int {
GGML_ASSERT(_prev_end <= end && end <= num_cpts);
int len = end - _prev_end;
if(len > 0)
tokens_length.push_back(len);
_prev_end = end;
//if( len && true ) {
// std::string s = "";
// for( int p = end-len; p < end; p++ )
// s += unicode_cpt_to_utf8(cpts[p]);
// printf( ">>> '%s'\n", s.c_str() );
//}
return len;
};
int pos = 0;
while (pos < num_cpts) {
const uint32_t cpt = _get_cpt(pos);
const int cpt_type = _get_cpt_type(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos+1 < num_cpts) {
uint32_t cpt_next = _tolower(_get_cpt(pos+1));
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
pos += _add_token(pos+2);
continue;
} else if (pos+2 < num_cpts) {
uint32_t cpt_next_next = _tolower(_get_cpt(pos+2));
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
(cpt_next == 'v' && cpt_next_next == 'e') ||
(cpt_next == 'l' && cpt_next_next == 'l')) {
pos += _add_token(pos+3);
continue;
}
}
}
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_DIGIT) {
if(cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters
pos++;
while(_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER)
pos++;
_add_token(pos);
continue;
}
}
// regex: \p{N}{1,3}
if (cpt_type == CODEPOINT_TYPE_DIGIT) {
int ini = pos;
while(_get_cpt_type(pos) == CODEPOINT_TYPE_DIGIT) {
if (++pos - ini >= 3 ) {
_add_token(pos);
ini = pos;
}
}
_add_token(pos);
continue;
}
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
uint32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
if (cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_DIGIT && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
pos += (cpt == ' ');
while(cpt2_type != CODEPOINT_TYPE_WHITESPACE && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_DIGIT && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED)
cpt2_type = _get_cpt_type(++pos);
cpt2 = _get_cpt(pos);
while(cpt2 == '\r' || cpt2 == '\n')
cpt2 = _get_cpt(++pos);
_add_token(pos);
continue;
}
int num_whitespaces = 0;
int last_pos_r_or_n = -1;
while (_get_cpt_type(pos+num_whitespaces) == CODEPOINT_TYPE_WHITESPACE) {
cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n')
last_pos_r_or_n = pos+num_whitespaces;
num_whitespaces++;
}
// regex: \s*[\r\n]+
if (last_pos_r_or_n >= 0) {
pos = last_pos_r_or_n + 1;
_add_token(pos);
continue;
}
// regex: \s+(?!\S)
if(num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
pos += num_whitespaces - 1;
_add_token(pos);
continue;
}
// regex: \s+
if(num_whitespaces > 0) {
pos += num_whitespaces;
_add_token(pos);
continue;
}
// no matches
_add_token(++pos);
}
GGML_ASSERT(pos == num_cpts);
_add_token(pos);
pos = 0;
std::vector<std::string> bpe_encoded_words(tokens_length.size());
for (int n = 0; n < (int)tokens_length.size(); n++) {
std::string &encoded_token = bpe_encoded_words[n];
const int length = tokens_length[n];
GGML_ASSERT(length > 0);
for (int i = 0; i < length; i++) {
std::string char_utf8 = unicode_cpt_to_utf8(cpts[pos++]);
for (char c : char_utf8) {
encoded_token += unicode_byte_to_utf8(c);
}
}
}
GGML_ASSERT(pos == num_cpts);
return bpe_encoded_words;
} |
The issue should be fixed with #6920 |
Awesome! Thanks. |
So, I found out that
\n\n
if appended by a character tokenizes as['\n',\n']
([198, 198]
) instead of['\n\n']
([271]
).(I'm using Llama3 for this example, but this extends to other models as well)
Here's an example prompt:
And the tokenized text:
If I switch the template to use
\n\n\n\n
(1038
) it tokenizes as['\n\n\n', '\n']
([1432, 198]
):(Note: I know there've been efforts in making special tokens render, but rn I understand they don't have a textual representation, so you can ignore tokens like 128000, 128006 and 128007 in the sequences above)
In C# I patch the issue like so:
(ignoring all \n merges except the
\n\n
which is common for the template)The text was updated successfully, but these errors were encountered: