From 90d7ebd32e6a8345ab23152c7e11df5fe67e4cca Mon Sep 17 00:00:00 2001 From: Bart Tadych Date: Fri, 12 Jul 2024 22:09:44 +0200 Subject: [PATCH] feat: chat-template argument. (#100) --- converter/convert-tokenizer-llama2.py | 42 ++++++++++++++++++++++ src/app.cpp | 12 +++++++ src/app.hpp | 1 + src/apps/dllama-api/dllama-api.cpp | 2 +- src/apps/dllama/dllama.cpp | 2 +- src/tokenizer-test.cpp | 6 ++-- src/tokenizer.cpp | 51 ++++++++++++++++++++------- src/tokenizer.hpp | 10 +++--- 8 files changed, 105 insertions(+), 21 deletions(-) create mode 100644 converter/convert-tokenizer-llama2.py diff --git a/converter/convert-tokenizer-llama2.py b/converter/convert-tokenizer-llama2.py new file mode 100644 index 0000000..9856dfb --- /dev/null +++ b/converter/convert-tokenizer-llama2.py @@ -0,0 +1,42 @@ +import sys +import os +from sentencepiece import SentencePieceProcessor +writer = __import__('tokenizer-writer') + +chatTemplate = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + +def printUsage(): + print('Usage: python convert-tokenizer-llama2.py ') + print() + print('Options:') + print(' The path to the folder with llama2 folder path') + +if __name__ == '__main__': + if (len(sys.argv) < 2): + printUsage() + exit(1) + + dirPath = sys.argv[1] + modelPath = os.path.join(dirPath, 'tokenizer.model') + processor = SentencePieceProcessor(model_file=modelPath) + + vocabSize = processor.vocab_size() + tokens = [] + scores = [] + for i in range(vocabSize): + t = processor.id_to_piece(i) + s = processor.get_score(i) + t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace + b = t.encode('utf-8') + tokens.append(b) + scores.append(s) + + outputFileName = 'dllama_tokenizer_llama2.t' + with open(outputFileName, 'wb') as outputFile: + writer.writeTokenizer(outputFile, { + 'bos_id': processor.bos_id(), + 'eos_id': processor.eos_id(), + 'chat_eos_id': processor.eos_id(), + }, tokens, scores, chatTemplate.encode('utf-8'), None) + + print(f'✅ Created {outputFileName}') diff --git a/src/app.cpp b/src/app.cpp index 517e325..018ca2b 100644 --- a/src/app.cpp +++ b/src/app.cpp @@ -16,6 +16,15 @@ FloatType parseFloatType(char* val) { exit(EXIT_FAILURE); } +ChatTemplateType parseChatTemplateType(char* val) { + if (strcmp(val, "llama2") == 0) return TEMPLATE_LLAMA2; + if (strcmp(val, "llama3") == 0) return TEMPLATE_LLAMA3; + if (strcmp(val, "zephyr") == 0) return TEMPLATE_ZEPHYR; + if (strcmp(val, "chatml") == 0) return TEMPLATE_CHATML; + throw std::runtime_error("Invalid chat template type"); + +} + AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) { AppArgs args; args.mode = NULL; @@ -31,6 +40,7 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) { args.topp = 0.9f; args.steps = 0; args.seed = (unsigned long long)time(NULL); + args.chatTemplateType = TEMPLATE_UNKNOWN; int i = 1; if (hasMode && argc > 1) { @@ -84,6 +94,8 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) { args.topp = atof(argv[i + 1]); } else if (strcmp(argv[i], "--seed") == 0) { args.seed = atoll(argv[i + 1]); + } else if (strcmp(argv[i], "--chat-template") == 0) { + args.chatTemplateType = parseChatTemplateType(argv[i + 1]); } else { printf("Unknown option %s\n", argv[i]); exit(EXIT_FAILURE); diff --git a/src/app.hpp b/src/app.hpp index af717d6..d53e9e7 100644 --- a/src/app.hpp +++ b/src/app.hpp @@ -32,6 +32,7 @@ class AppArgs { pos_t steps; bool benchmark; unsigned long long seed; + ChatTemplateType chatTemplateType; // worker int port; diff --git a/src/apps/dllama-api/dllama-api.cpp b/src/apps/dllama-api/dllama-api.cpp index 002206f..b13b6d6 100644 --- a/src/apps/dllama-api/dllama-api.cpp +++ b/src/apps/dllama-api/dllama-api.cpp @@ -396,7 +396,7 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, SocketServer* server = new SocketServer(args->port); TokenizerChatStops stops(tokenizer); - ChatTemplate chatTemplate(tokenizer->chatTemplate, stops.stops[0]); + ChatTemplate chatTemplate(args->chatTemplateType, tokenizer->chatTemplate, stops.stops[0]); EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength); ApiServer api(inference, tokenizer, sampler, args, spec, &eosDetector, &chatTemplate); diff --git a/src/apps/dllama/dllama.cpp b/src/apps/dllama/dllama.cpp index 69d7814..f28c123 100644 --- a/src/apps/dllama/dllama.cpp +++ b/src/apps/dllama/dllama.cpp @@ -195,7 +195,7 @@ class Chat { void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) { TokenizerChatStops stops(tokenizer); - ChatTemplate chatTemplate(tokenizer->chatTemplate, stops.stops[0]); + ChatTemplate chatTemplate(args->chatTemplateType, tokenizer->chatTemplate, stops.stops[0]); EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength); Chat chat(inference, tokenizer, sampler, args, spec, &eosDetector, &chatTemplate); diff --git a/src/tokenizer-test.cpp b/src/tokenizer-test.cpp index b4809d4..df56d92 100644 --- a/src/tokenizer-test.cpp +++ b/src/tokenizer-test.cpp @@ -12,13 +12,13 @@ #define EOS_ID 10000 void testChatTemplate() { - ChatTemplate t0("{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", ""); + ChatTemplate t0(TEMPLATE_UNKNOWN, "{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", ""); assert(t0.type == TEMPLATE_LLAMA3); - ChatTemplate t1("{{bos_token}}{\% for message in messages \%}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|im_start|>assistant\n' }}{\% endif \%}", ""); + ChatTemplate t1(TEMPLATE_UNKNOWN, "{{bos_token}}{\% for message in messages \%}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|im_start|>assistant\n' }}{\% endif \%}", ""); assert(t1.type == TEMPLATE_CHATML); - ChatTemplate t2("{\% for message in messages \%}\n{\% if message['role'] == 'user' \%}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'system' \%}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'assistant' \%}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{\% endif \%}\n{\% if loop.last and add_generation_prompt \%}\n{{ '<|assistant|>' }}\n{\% endif \%}\n{\% endfor \%}", ""); + ChatTemplate t2(TEMPLATE_UNKNOWN, "{\% for message in messages \%}\n{\% if message['role'] == 'user' \%}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'system' \%}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'assistant' \%}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{\% endif \%}\n{\% if loop.last and add_generation_prompt \%}\n{{ '<|assistant|>' }}\n{\% endif \%}\n{\% endfor \%}", ""); assert(t2.type == TEMPLATE_ZEPHYR); printf("✅ ChatTemplate\n"); diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index 74d8ab5..f81f1a7 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -433,27 +433,54 @@ TokenizerChatStops::~TokenizerChatStops() { delete[] stops; } -ChatTemplate::ChatTemplate(const char* chatTemplate, const char* eos) { - if (chatTemplate == NULL) - throw std::runtime_error("The tokenizer does not include chat template"); +ChatTemplate::ChatTemplate(const ChatTemplateType type, const char* chatTemplate, const char* eos) { + if (type == TEMPLATE_UNKNOWN) { + if (chatTemplate == NULL) + throw std::runtime_error("The tokenizer does not include chat template"); + if (strstr(chatTemplate, "[INST]") != NULL) { + this->type = TEMPLATE_LLAMA2; + } else if (strstr(chatTemplate, "<|start_header_id|>") != NULL) { + this->type = TEMPLATE_LLAMA3; + } else if (strstr(chatTemplate, "<|user|>") != NULL) { + this->type = TEMPLATE_ZEPHYR; + } else if (strstr(chatTemplate, "<|im_start|>") != NULL) { + this->type = TEMPLATE_CHATML; + } else { + throw new std::runtime_error("Not supported chat template"); + } + } else { + this->type = type; + } + this->eos = eos; printf("⭐ chat template: "); - if (strstr(chatTemplate, "<|start_header_id|>") != NULL) { - type = TEMPLATE_LLAMA3; + if (this->type == TEMPLATE_LLAMA2) { + printf("llama2\n"); + } else if (this->type == TEMPLATE_LLAMA3) { printf("llama3\n"); - } else if (strstr(chatTemplate, "<|user|>") != NULL) { - type = TEMPLATE_ZEPHYR; + } else if (this->type == TEMPLATE_ZEPHYR) { printf("zephyr\n"); - } else if (strstr(chatTemplate, "<|im_start|>") != NULL) { - type = TEMPLATE_CHATML; + } else if (this->type == TEMPLATE_CHATML) { printf("chatml\n"); - } else throw new std::runtime_error("Not supported chat template"); - this->eos = eos; + } } std::string ChatTemplate::generate(unsigned int nMessages, ChatItem* items, bool appendGenerationPrompt) { std::ostringstream buffer; - if (type == TEMPLATE_LLAMA3) { + if (type == TEMPLATE_LLAMA2) { + unsigned int i = 0; + if (nMessages >= 2 && items[0].role == "system" && items[1].role == "user") { + buffer << "[INST] <>\n" << items[0].message << "\n<>\n\n" << items[1].message << " [/INST]" << eos; + i += 2; + } + for (; i < nMessages; i++) { + if (items[i].role == "assistant") { + buffer << items[i].message << eos; + } else if (items[i].role == "user") { + buffer << "[INST] " << items[i].message << " [/INST]" << eos; + } + } + } else if (type == TEMPLATE_LLAMA3) { for (unsigned int i = 0; i < nMessages; i++) buffer << "<|start_header_id|>" << items[i].role << "<|end_header_id|>\n\n" << items[i].message << eos; if (appendGenerationPrompt) diff --git a/src/tokenizer.hpp b/src/tokenizer.hpp index 5ad7b11..5323eae 100644 --- a/src/tokenizer.hpp +++ b/src/tokenizer.hpp @@ -89,9 +89,11 @@ class TokenizerChatStops { }; enum ChatTemplateType { - TEMPLATE_LLAMA3 = 0, - TEMPLATE_ZEPHYR = 1, - TEMPLATE_CHATML = 2, + TEMPLATE_UNKNOWN = 0, + TEMPLATE_LLAMA2 = 1, + TEMPLATE_LLAMA3 = 2, + TEMPLATE_ZEPHYR = 3, + TEMPLATE_CHATML = 4, }; struct ChatItem { @@ -103,7 +105,7 @@ class ChatTemplate { public: const char* eos; ChatTemplateType type; - ChatTemplate(const char* chatTemplate, const char* eos); + ChatTemplate(const ChatTemplateType type, const char* chatTemplate, const char* eos); std::string generate(unsigned int nMessages, ChatItem* items, bool appendGenerationPrompt); };