diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 2fbf6ba70..76cdcf303 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -70,7 +70,6 @@ endif() find_package(jsoncpp CONFIG REQUIRED) find_package(Drogon CONFIG REQUIRED) find_package(yaml-cpp CONFIG REQUIRED) -find_package(jinja2cpp CONFIG REQUIRED) find_package(httplib CONFIG REQUIRED) find_package(nlohmann_json CONFIG REQUIRED) find_package(CLI11 CONFIG REQUIRED) @@ -87,7 +86,6 @@ add_executable(${TARGET_NAME} main.cc target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib) target_link_libraries(${TARGET_NAME} PRIVATE nlohmann_json::nlohmann_json) -target_link_libraries(${TARGET_NAME} PRIVATE jinja2cpp) target_link_libraries(${TARGET_NAME} PRIVATE CLI11::CLI11) target_link_libraries(${TARGET_NAME} PRIVATE unofficial::minizip::minizip) target_link_libraries(${TARGET_NAME} PRIVATE LibArchive::LibArchive) diff --git a/engine/config/chat_template_renderer.h b/engine/config/chat_template_renderer.h new file mode 100644 index 000000000..f40894f7b --- /dev/null +++ b/engine/config/chat_template_renderer.h @@ -0,0 +1,431 @@ +/* + * This file contains code derived from the llama.cpp project. + * Original project: https://github.com/ggerganov/llama.cpp + * + * Original work Copyright (c) 2023 Georgi Gerganov + * Modified work Copyright (c) 2024 [Homebrew.ltd] + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * SPDX-License-Identifier: MIT + * + * This file incorporates work covered by the above copyright and permission notice. + * Any modifications made to this file are covered under the copyright of the modifying party. + * + * Modifications: + * [Brief description of modifications made to the original code, if any] + * + * For more information about the llama.cpp project and its license, please visit: + * https://github.com/ggerganov/llama.cpp/blob/master/LICENSE + */ + +// +// chat templates +// + +#include +#include +#include +#include +#include +#include +#include +namespace config { + +#if (defined(_MSC_VER) && _MSC_VER >= 1900 && defined(__cpp_char8_t)) || __cplusplus >= 202002L + #define LU8(x) reinterpret_cast(u8##x) +#else + #define LU8(x) u8##x +#endif + +typedef struct llama_chat_message { + const char* role; + const char* content; +} llama_chat_message; + +struct llama_chat_msg { + std::string role; + std::string content; +}; + +static std::string trim(const std::string& str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && isspace(str[start])) { + start += 1; + } + while (end > start && isspace(str[end - 1])) { + end -= 1; + } + return str.substr(start, end - start); +} +// Simple version of "llama_apply_chat_template" that only works with strings +// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. +static int32_t llama_chat_apply_template_internal( + const std::string& tmpl, const std::vector& chat, + std::string& dest, bool add_ass) { + // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + std::stringstream ss; + auto tmpl_contains = [&tmpl](std::string haystack) -> bool { + return tmpl.find(haystack) != std::string::npos; + }; + if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "\n" + << message->content << "<|im_end|>\n"; + } + if (add_ass) { + ss << "<|im_start|>assistant\n"; + } + } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) { + // llama2 template and its variants + // [variant] support system message + bool support_system_message = tmpl_contains("<>") || tmpl == "mistral"; + // [variant] space before + after response + bool space_around_response = tmpl_contains("' ' + eos_token"); + // [variant] add BOS inside history + bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); + // [variant] trim spaces from the input message + bool strip_message = tmpl_contains("content.strip()"); + // construct the prompt + bool is_inside_turn = true; // skip BOS at the beginning + ss << "[INST] "; + for (auto message : chat) { + std::string content = + strip_message ? trim(message->content) : message->content; + std::string role(message->role); + if (!is_inside_turn) { + is_inside_turn = true; + ss << (add_bos_inside_history ? "[INST] " : "[INST] "); + } + if (role == "system") { + if (support_system_message) { + ss << "<>\n" << content << "\n<>\n\n"; + } else { + // if the model does not support system message, we still include it in the first message, but without <> + ss << content << "\n"; + } + } else if (role == "user") { + ss << content << " [/INST]"; + } else { + ss << (space_around_response ? " " : "") << content + << (space_around_response ? " " : "") << ""; + is_inside_turn = false; + } + } + // llama2 templates seem to not care about "add_generation_prompt" + } else if (tmpl == "phi3" || + (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) { + // Phi 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << message->content << "<|end|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) { + // zephyr template + for (auto message : chat) { + ss << "<|" << message->role << "|>" << "\n" + << message->content << "<|endoftext|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == "monarch" || + tmpl_contains("bos_token + message['role']")) { + // mlabonne/AlphaMonarch-7B template (the is included inside history) + for (auto message : chat) { + std::string bos = + (message == chat.front()) ? "" : ""; // skip BOS for first message + ss << bos << message->role << "\n" << message->content << "\n"; + } + if (add_ass) { + ss << "assistant\n"; + } + } else if (tmpl == "gemma" || tmpl == "gemma2" || + tmpl_contains("")) { + // google/gemma-7b-it + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken + system_prompt = trim(message->content); + continue; + } + // in gemma, "assistant" is "model" + role = role == "assistant" ? "model" : message->role; + ss << "" << role << "\n"; + if (!system_prompt.empty() && role != "model") { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << trim(message->content) << "\n"; + } + if (add_ass) { + ss << "model\n"; + } + } else if (tmpl == "orion" || + tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + // OrionStarAI/Orion-14B-Chat + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message support, we will merge it with user prompt + system_prompt = message->content; + continue; + } else if (role == "user") { + ss << "Human: "; + if (!system_prompt.empty()) { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << message->content << "\n\nAssistant: "; + } else { + ss << message->content << ""; + } + } + } else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) { + // openchat/openchat-3.5-0106, + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "<|end_of_turn|>"; + } else { + role[0] = toupper(role[0]); + ss << "GPT4 Correct " << role << ": " << message->content + << "<|end_of_turn|>"; + } + } + if (add_ass) { + ss << "GPT4 Correct Assistant:"; + } + } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || + (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) { + // eachadea/vicuna-13b-1.1 (and Orca variant) + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // Orca-Vicuna variant uses a system prefix + if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) { + ss << "SYSTEM: " << message->content << "\n"; + } else { + ss << message->content << "\n\n"; + } + } else if (role == "user") { + ss << "USER: " << message->content << "\n"; + } else if (role == "assistant") { + ss << "ASSISTANT: " << message->content << "\n"; + } + } + if (add_ass) { + ss << "ASSISTANT:"; + } + } else if (tmpl == "deepseek" || + (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) { + // deepseek-ai/deepseek-coder-33b-instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content; + } else if (role == "user") { + ss << "### Instruction:\n" << message->content << "\n"; + } else if (role == "assistant") { + ss << "### Response:\n" << message->content << "\n<|EOT|>\n"; + } + } + if (add_ass) { + ss << "### Response:\n"; + } + } else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && + tmpl_contains("<|USER_TOKEN|>"))) { + // CohereForAI/c4ai-command-r-plus + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" + << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "user") { + ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) + << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "assistant") { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } + } + if (add_ass) { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; + } + } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && + tmpl_contains("<|end_header_id|>"))) { + // Llama 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" + << trim(message->content) << "<|eot_id|>"; + } + if (add_ass) { + ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; + } + } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) { + // chatglm3-6b + ss << "[gMASK]" << "sop"; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n " << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } + } else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]")) { + ss << "[gMASK]" << ""; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } + } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) { + // MiniCPM-3B-OpenHermes-2.5-v2-GGUF + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << LU8("<用户>"); + ss << trim(message->content); + ss << ""; + } else { + ss << trim(message->content); + } + } + } else if (tmpl == "deepseek2" || + tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { + // DeepSeek-V2 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << "User: " << message->content << "\n\n"; + } else if (role == "assistant") { + ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << "Assistant:"; + } + } else if (tmpl == "exaone3" || + (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && + tmpl_contains("[|endofturn|]"))) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n"; + } else if (role == "user") { + ss << "[|user|]" << trim(message->content) << "\n"; + } else if (role == "assistant") { + ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n"; + } + } + if (add_ass) { + ss << "[|assistant|]"; + } + } else { + // template not supported + return -1; + } + dest = ss.str(); + return dest.size(); +} + +int32_t llama_chat_apply_template(const char* tmpl, + const struct llama_chat_message* chat, + size_t n_msg, bool add_ass, char* buf, + int32_t length) { + std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); + + // format the chat to string + std::vector chat_vec; + chat_vec.resize(n_msg); + for (size_t i = 0; i < n_msg; i++) { + chat_vec[i] = &chat[i]; + } + + std::string formatted_chat; + int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, + formatted_chat, add_ass); + if (res < 0) { + return res; + } + if (buf && length > 0) { + strncpy(buf, formatted_chat.c_str(), length); + } + return res; +} + +std::string llama_chat_apply_template(const std::string& tmpl, + const std::vector& msgs, + bool add_ass) { + int alloc_size = 0; + bool fallback = false; // indicate if we must fallback to default chatml + std::vector chat; + for (auto& msg : msgs) { + chat.push_back({msg.role.c_str(), msg.content.c_str()}); + alloc_size += (msg.role.size() + msg.content.size()) * 1.25; + } + + const char* ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); + std::vector buf(alloc_size); + + // run the first time to get the total output length + int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), + add_ass, buf.data(), buf.size()); + + // error: chat template is not supported + if (res < 0) { + if (ptr_tmpl != nullptr) { + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported"); + } else { + // If the built-in template is not supported, we default to chatml + res = llama_chat_apply_template("chatml", chat.data(), chat.size(), + add_ass, buf.data(), buf.size()); + fallback = true; + } + } + + // if it turns out that our buffer is too small, we resize it + if ((size_t)res > buf.size()) { + buf.resize(res); + res = + llama_chat_apply_template(fallback ? "chatml" : ptr_tmpl, chat.data(), + chat.size(), add_ass, buf.data(), buf.size()); + } + + std::string formatted_chat(buf.data(), res); + return formatted_chat; +} +} // namespace config \ No newline at end of file diff --git a/engine/config/gguf_parser.cc b/engine/config/gguf_parser.cc index 160734468..178e4b652 100644 --- a/engine/config/gguf_parser.cc +++ b/engine/config/gguf_parser.cc @@ -19,7 +19,7 @@ #include // For file descriptors -#include +#include "chat_template_renderer.h" #include "gguf_parser.h" #include "trantor/utils/Logger.h" @@ -361,19 +361,10 @@ void GGUFHandler::PrintMetadata() { if (key.compare("tokenizer.chat_template") == 0) { LOG_INFO << key << ": " << "\n" << value << "\n"; - jinja2::Template chat_template; - chat_template.Load(value); - jinja2::ValuesMap params{ - {"add_generation_prompt", true}, - {"bos_token", "<|begin_of_text|>"}, - {"eos_token", "<|eot_id|>"}, - {"messages", - jinja2::ValuesList{ - jinja2::ValuesMap{{"role", "system"}, - {"content", "{system_message}"}}, - jinja2::ValuesMap{{"role", "user"}, {"content", "{prompt}"}}}}}; - std::string result = chat_template.RenderAsString(params).value(); - + std::vector messages{ + llama_chat_msg{"system", "{system_message}"}, + llama_chat_msg{"user", "{prompt}"}}; + std::string result = llama_chat_apply_template(value, messages, true); LOG_INFO << "result jinja render: " << result << "\n"; } else { LOG_INFO << key << ": " << value << "\n"; @@ -555,19 +546,10 @@ void GGUFHandler::ModelConfigFromMetadata() { ">\n\n"; } else { try { - jinja2::Template jinja2_chat_template; - jinja2_chat_template.Load(value); - jinja2::ValuesMap params{ - {"add_generation_prompt", true}, - {"bos_token", tokens[bos_token]}, - {"eos_token", tokens[eos_token]}, - {"messages", - jinja2::ValuesList{ - jinja2::ValuesMap{{"role", "system"}, - {"content", "{system_message}"}}, - jinja2::ValuesMap{{"role", "user"}, - {"content", "{prompt}"}}}}}; - chat_template = jinja2_chat_template.RenderAsString(params).value(); + std::vector messages{ + llama_chat_msg{"system", "{system_message}"}, + llama_chat_msg{"user", "{prompt}"}}; + chat_template = llama_chat_apply_template(value, messages, true); } catch (const std::exception& e) { std::cerr << "Error render chat template: " << e.what() << ". Using default template: \n[INST] " diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt index fa1c5477e..f89881118 100644 --- a/engine/test/components/CMakeLists.txt +++ b/engine/test/components/CMakeLists.txt @@ -8,10 +8,9 @@ add_executable(${PROJECT_NAME} ${SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/m find_package(Drogon CONFIG REQUIRED) find_package(GTest CONFIG REQUIRED) find_package(yaml-cpp CONFIG REQUIRED) -find_package(jinja2cpp CONFIG REQUIRED) find_package(httplib CONFIG REQUIRED) -target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp jinja2cpp +target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${PROJECT_NAME} PRIVATE httplib::httplib) diff --git a/engine/vcpkg.json b/engine/vcpkg.json index 8f7729524..40abc186e 100644 --- a/engine/vcpkg.json +++ b/engine/vcpkg.json @@ -10,7 +10,6 @@ ] }, "drogon", - "jinja2cpp", "jsoncpp", "minizip", "nlohmann-json",