diff --git a/engine/commands/chat_cmd.cc b/engine/commands/chat_cmd.cc new file mode 100644 index 000000000..185dd60fe --- /dev/null +++ b/engine/commands/chat_cmd.cc @@ -0,0 +1,120 @@ +#include "chat_cmd.h" +#include "httplib.h" + +#include "trantor/utils/Logger.h" + +namespace commands { +namespace { +constexpr const char* kExitChat = "exit()"; +constexpr const auto kMinDataChunkSize = 6u; +constexpr const char* kUser = "user"; +constexpr const char* kAssistant = "assistant"; + +} // namespace + +struct ChunkParser { + std::string content; + bool is_done = false; + + ChunkParser(const char* data, size_t data_length) { + if (data && data_length > kMinDataChunkSize) { + std::string s(data + kMinDataChunkSize, data_length - kMinDataChunkSize); + if (s.find("[DONE]") != std::string::npos) { + is_done = true; + } else { + content = nlohmann::json::parse(s)["choices"][0]["delta"]["content"]; + } + } + } +}; + +ChatCmd::ChatCmd(std::string host, int port, const config::ModelConfig& mc) + : host_(std::move(host)), port_(port), mc_(mc) {} + +void ChatCmd::Exec(std::string msg) { + auto address = host_ + ":" + std::to_string(port_); + // Check if model is loaded + { + httplib::Client cli(address); + nlohmann::json json_data; + json_data["model"] = mc_.name; + json_data["engine"] = mc_.engine; + + auto data_str = json_data.dump(); + + // TODO: move this to another message? + auto res = cli.Post("/inferences/server/modelstatus", httplib::Headers(), + data_str.data(), data_str.size(), "application/json"); + if (res) { + if (res->status != httplib::StatusCode::OK_200) { + LOG_INFO << res->body; + return; + } + } else { + auto err = res.error(); + LOG_WARN << "HTTP error: " << httplib::to_string(err); + return; + } + } + // Some instruction for user here + std::cout << "Inorder to exit, type exit()" << std::endl; + // Model is loaded, start to chat + { + while (true) { + std::string user_input = std::move(msg); + std::cout << "> "; + if (user_input.empty()) { + std::getline(std::cin, user_input); + } + if (user_input == kExitChat) { + break; + } + + if (!user_input.empty()) { + httplib::Client cli(address); + nlohmann::json json_data; + nlohmann::json new_data; + new_data["role"] = kUser; + new_data["content"] = user_input; + histories_.push_back(std::move(new_data)); + json_data["engine"] = mc_.engine; + json_data["messages"] = histories_; + json_data["model"] = mc_.name; + //TODO: support non-stream + json_data["stream"] = true; + json_data["stop"] = mc_.stop; + auto data_str = json_data.dump(); + // std::cout << data_str << std::endl; + cli.set_read_timeout(std::chrono::seconds(60)); + // std::cout << "> "; + httplib::Request req; + req.headers = httplib::Headers(); + req.set_header("Content-Type", "application/json"); + req.method = "POST"; + req.path = "/v1/chat/completions"; + req.body = data_str; + std::string ai_chat; + req.content_receiver = [&](const char* data, size_t data_length, + uint64_t offset, uint64_t total_length) { + ChunkParser cp(data, data_length); + if (cp.is_done) { + std::cout << std::endl; + return false; + } + std::cout << cp.content; + ai_chat += cp.content; + return true; + }; + cli.send(req); + + nlohmann::json ai_res; + ai_res["role"] = kAssistant; + ai_res["content"] = ai_chat; + histories_.push_back(std::move(ai_res)); + } + // std::cout << "ok Done" << std::endl; + } + } +} + +}; // namespace commands \ No newline at end of file diff --git a/engine/commands/chat_cmd.h b/engine/commands/chat_cmd.h new file mode 100644 index 000000000..d5b48927c --- /dev/null +++ b/engine/commands/chat_cmd.h @@ -0,0 +1,19 @@ +#pragma once +#include +#include +#include "config/model_config.h" +#include "nlohmann/json.hpp" + +namespace commands { +class ChatCmd { + public: + ChatCmd(std::string host, int port, const config::ModelConfig& mc); + void Exec(std::string msg); + + private: + std::string host_; + int port_; + const config::ModelConfig& mc_; + std::vector histories_; +}; +} // namespace commands \ No newline at end of file diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc index d58760433..48c63611d 100644 --- a/engine/controllers/command_line_parser.cc +++ b/engine/controllers/command_line_parser.cc @@ -5,6 +5,7 @@ #include "commands/start_model_cmd.h" #include "commands/stop_model_cmd.h" #include "commands/stop_server_cmd.h" +#include "commands/chat_cmd.h" #include "config/yaml_config.h" #include "utils/cortex_utils.h" @@ -66,7 +67,24 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) { models_cmd->add_subcommand("update", "Update configuration of a model"); } - auto chat_cmd = app_.add_subcommand("chat", "Send a chat request to a model"); + { + auto chat_cmd = + app_.add_subcommand("chat", "Send a chat request to a model"); + std::string model_id; + chat_cmd->add_option("model_id", model_id, ""); + std::string msg; + chat_cmd->add_option("-m,--message", msg, + "Message to chat with model"); + + chat_cmd->callback([&model_id, &msg] { + // TODO(sang) switch to .yaml when implement model manager + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile(cortex_utils::GetCurrentPath() + + "/models/" + model_id + "/model.yml"); + commands::ChatCmd cc("127.0.0.1", 3928, yaml_handler.GetModelConfig()); + cc.Exec(msg); + }); + } auto ps_cmd = app_.add_subcommand("ps", "Show running models and their status"); diff --git a/engine/main.cc b/engine/main.cc index 27591d48a..a92e114fb 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -22,18 +22,7 @@ #endif int main(int argc, char* argv[]) { - // Create logs/ folder and setup log to file - std::filesystem::create_directory(cortex_utils::logs_folder); - trantor::AsyncFileLogger asyncFileLogger; - asyncFileLogger.setFileName(cortex_utils::logs_base_name); - asyncFileLogger.startLogging(); - trantor::Logger::setOutputFunction( - [&](const char* msg, const uint64_t len) { - asyncFileLogger.output(msg, len); - }, - [&]() { asyncFileLogger.flush(); }); - asyncFileLogger.setFileSizeLimit(cortex_utils::log_file_size_limit); - + // Check if this process is for python execution if (argc > 1) { if (strcmp(argv[1], "--run_python_file") == 0) { @@ -61,6 +50,18 @@ int main(int argc, char* argv[]) { return 0; } + // Create logs/ folder and setup log to file + std::filesystem::create_directory(cortex_utils::logs_folder); + trantor::AsyncFileLogger asyncFileLogger; + asyncFileLogger.setFileName(cortex_utils::logs_base_name); + asyncFileLogger.startLogging(); + trantor::Logger::setOutputFunction( + [&](const char* msg, const uint64_t len) { + asyncFileLogger.output(msg, len); + }, + [&]() { asyncFileLogger.flush(); }); + asyncFileLogger.setFileSizeLimit(cortex_utils::log_file_size_limit); + int thread_num = 1; std::string host = "127.0.0.1"; int port = 3928;