Skip to content

Commit

Permalink
feat: chat-template argument. (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Jul 12, 2024
1 parent 4c38a2e commit 90d7ebd
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 21 deletions.
42 changes: 42 additions & 0 deletions converter/convert-tokenizer-llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import sys
import os
from sentencepiece import SentencePieceProcessor
writer = __import__('tokenizer-writer')

chatTemplate = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"

def printUsage():
print('Usage: python convert-tokenizer-llama2.py <llama2FolderPath>')
print()
print('Options:')
print(' <llama2FolderPath> The path to the folder with llama2 folder path')

if __name__ == '__main__':
if (len(sys.argv) < 2):
printUsage()
exit(1)

dirPath = sys.argv[1]
modelPath = os.path.join(dirPath, 'tokenizer.model')
processor = SentencePieceProcessor(model_file=modelPath)

vocabSize = processor.vocab_size()
tokens = []
scores = []
for i in range(vocabSize):
t = processor.id_to_piece(i)
s = processor.get_score(i)
t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
b = t.encode('utf-8')
tokens.append(b)
scores.append(s)

outputFileName = 'dllama_tokenizer_llama2.t'
with open(outputFileName, 'wb') as outputFile:
writer.writeTokenizer(outputFile, {
'bos_id': processor.bos_id(),
'eos_id': processor.eos_id(),
'chat_eos_id': processor.eos_id(),
}, tokens, scores, chatTemplate.encode('utf-8'), None)

print(f'✅ Created {outputFileName}')
12 changes: 12 additions & 0 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ FloatType parseFloatType(char* val) {
exit(EXIT_FAILURE);
}

ChatTemplateType parseChatTemplateType(char* val) {
if (strcmp(val, "llama2") == 0) return TEMPLATE_LLAMA2;
if (strcmp(val, "llama3") == 0) return TEMPLATE_LLAMA3;
if (strcmp(val, "zephyr") == 0) return TEMPLATE_ZEPHYR;
if (strcmp(val, "chatml") == 0) return TEMPLATE_CHATML;
throw std::runtime_error("Invalid chat template type");

}

AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
AppArgs args;
args.mode = NULL;
Expand All @@ -31,6 +40,7 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.topp = 0.9f;
args.steps = 0;
args.seed = (unsigned long long)time(NULL);
args.chatTemplateType = TEMPLATE_UNKNOWN;

int i = 1;
if (hasMode && argc > 1) {
Expand Down Expand Up @@ -84,6 +94,8 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.topp = atof(argv[i + 1]);
} else if (strcmp(argv[i], "--seed") == 0) {
args.seed = atoll(argv[i + 1]);
} else if (strcmp(argv[i], "--chat-template") == 0) {
args.chatTemplateType = parseChatTemplateType(argv[i + 1]);
} else {
printf("Unknown option %s\n", argv[i]);
exit(EXIT_FAILURE);
Expand Down
1 change: 1 addition & 0 deletions src/app.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class AppArgs {
pos_t steps;
bool benchmark;
unsigned long long seed;
ChatTemplateType chatTemplateType;

// worker
int port;
Expand Down
2 changes: 1 addition & 1 deletion src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,
SocketServer* server = new SocketServer(args->port);

TokenizerChatStops stops(tokenizer);
ChatTemplate chatTemplate(tokenizer->chatTemplate, stops.stops[0]);
ChatTemplate chatTemplate(args->chatTemplateType, tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength);
ApiServer api(inference, tokenizer, sampler, args, spec, &eosDetector, &chatTemplate);

Expand Down
2 changes: 1 addition & 1 deletion src/apps/dllama/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class Chat {

void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
TokenizerChatStops stops(tokenizer);
ChatTemplate chatTemplate(tokenizer->chatTemplate, stops.stops[0]);
ChatTemplate chatTemplate(args->chatTemplateType, tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength);

Chat chat(inference, tokenizer, sampler, args, spec, &eosDetector, &chatTemplate);
Expand Down
6 changes: 3 additions & 3 deletions src/tokenizer-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
#define EOS_ID 10000

void testChatTemplate() {
ChatTemplate t0("{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", "<eos>");
ChatTemplate t0(TEMPLATE_UNKNOWN, "{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", "<eos>");
assert(t0.type == TEMPLATE_LLAMA3);

ChatTemplate t1("{{bos_token}}{\% for message in messages \%}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|im_start|>assistant\n' }}{\% endif \%}", "<eos>");
ChatTemplate t1(TEMPLATE_UNKNOWN, "{{bos_token}}{\% for message in messages \%}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|im_start|>assistant\n' }}{\% endif \%}", "<eos>");
assert(t1.type == TEMPLATE_CHATML);

ChatTemplate t2("{\% for message in messages \%}\n{\% if message['role'] == 'user' \%}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'system' \%}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'assistant' \%}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{\% endif \%}\n{\% if loop.last and add_generation_prompt \%}\n{{ '<|assistant|>' }}\n{\% endif \%}\n{\% endfor \%}", "<eos>");
ChatTemplate t2(TEMPLATE_UNKNOWN, "{\% for message in messages \%}\n{\% if message['role'] == 'user' \%}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'system' \%}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'assistant' \%}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{\% endif \%}\n{\% if loop.last and add_generation_prompt \%}\n{{ '<|assistant|>' }}\n{\% endif \%}\n{\% endfor \%}", "<eos>");
assert(t2.type == TEMPLATE_ZEPHYR);

printf("✅ ChatTemplate\n");
Expand Down
51 changes: 39 additions & 12 deletions src/tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,27 +433,54 @@ TokenizerChatStops::~TokenizerChatStops() {
delete[] stops;
}

ChatTemplate::ChatTemplate(const char* chatTemplate, const char* eos) {
if (chatTemplate == NULL)
throw std::runtime_error("The tokenizer does not include chat template");
ChatTemplate::ChatTemplate(const ChatTemplateType type, const char* chatTemplate, const char* eos) {
if (type == TEMPLATE_UNKNOWN) {
if (chatTemplate == NULL)
throw std::runtime_error("The tokenizer does not include chat template");
if (strstr(chatTemplate, "[INST]") != NULL) {
this->type = TEMPLATE_LLAMA2;
} else if (strstr(chatTemplate, "<|start_header_id|>") != NULL) {
this->type = TEMPLATE_LLAMA3;
} else if (strstr(chatTemplate, "<|user|>") != NULL) {
this->type = TEMPLATE_ZEPHYR;
} else if (strstr(chatTemplate, "<|im_start|>") != NULL) {
this->type = TEMPLATE_CHATML;
} else {
throw new std::runtime_error("Not supported chat template");
}
} else {
this->type = type;
}
this->eos = eos;

printf("⭐ chat template: ");
if (strstr(chatTemplate, "<|start_header_id|>") != NULL) {
type = TEMPLATE_LLAMA3;
if (this->type == TEMPLATE_LLAMA2) {
printf("llama2\n");
} else if (this->type == TEMPLATE_LLAMA3) {
printf("llama3\n");
} else if (strstr(chatTemplate, "<|user|>") != NULL) {
type = TEMPLATE_ZEPHYR;
} else if (this->type == TEMPLATE_ZEPHYR) {
printf("zephyr\n");
} else if (strstr(chatTemplate, "<|im_start|>") != NULL) {
type = TEMPLATE_CHATML;
} else if (this->type == TEMPLATE_CHATML) {
printf("chatml\n");
} else throw new std::runtime_error("Not supported chat template");
this->eos = eos;
}
}

std::string ChatTemplate::generate(unsigned int nMessages, ChatItem* items, bool appendGenerationPrompt) {
std::ostringstream buffer;
if (type == TEMPLATE_LLAMA3) {
if (type == TEMPLATE_LLAMA2) {
unsigned int i = 0;
if (nMessages >= 2 && items[0].role == "system" && items[1].role == "user") {
buffer << "[INST] <<SYS>>\n" << items[0].message << "\n<</SYS>>\n\n" << items[1].message << " [/INST]" << eos;
i += 2;
}
for (; i < nMessages; i++) {
if (items[i].role == "assistant") {
buffer << items[i].message << eos;
} else if (items[i].role == "user") {
buffer << "[INST] " << items[i].message << " [/INST]" << eos;
}
}
} else if (type == TEMPLATE_LLAMA3) {
for (unsigned int i = 0; i < nMessages; i++)
buffer << "<|start_header_id|>" << items[i].role << "<|end_header_id|>\n\n" << items[i].message << eos;
if (appendGenerationPrompt)
Expand Down
10 changes: 6 additions & 4 deletions src/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ class TokenizerChatStops {
};

enum ChatTemplateType {
TEMPLATE_LLAMA3 = 0,
TEMPLATE_ZEPHYR = 1,
TEMPLATE_CHATML = 2,
TEMPLATE_UNKNOWN = 0,
TEMPLATE_LLAMA2 = 1,
TEMPLATE_LLAMA3 = 2,
TEMPLATE_ZEPHYR = 3,
TEMPLATE_CHATML = 4,
};

struct ChatItem {
Expand All @@ -103,7 +105,7 @@ class ChatTemplate {
public:
const char* eos;
ChatTemplateType type;
ChatTemplate(const char* chatTemplate, const char* eos);
ChatTemplate(const ChatTemplateType type, const char* chatTemplate, const char* eos);
std::string generate(unsigned int nMessages, ChatItem* items, bool appendGenerationPrompt);
};

Expand Down

0 comments on commit 90d7ebd

Please sign in to comment.