Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: chat command #1032

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions engine/commands/chat_cmd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include "chat_cmd.h"
#include "httplib.h"

#include "trantor/utils/Logger.h"

namespace commands {
namespace {
constexpr const char* kExitChat = "exit()";
constexpr const auto kMinDataChunkSize = 6u;
constexpr const char* kUser = "user";
constexpr const char* kAssistant = "assistant";

} // namespace

struct ChunkParser {
std::string content;
bool is_done = false;

ChunkParser(const char* data, size_t data_length) {
if (data && data_length > kMinDataChunkSize) {
std::string s(data + kMinDataChunkSize, data_length - kMinDataChunkSize);
if (s.find("[DONE]") != std::string::npos) {
is_done = true;
} else {
content = nlohmann::json::parse(s)["choices"][0]["delta"]["content"];
}
}
}
};

ChatCmd::ChatCmd(std::string host, int port, const config::ModelConfig& mc)
: host_(std::move(host)), port_(port), mc_(mc) {}

void ChatCmd::Exec(std::string msg) {
auto address = host_ + ":" + std::to_string(port_);
// Check if model is loaded
{
httplib::Client cli(address);
nlohmann::json json_data;
json_data["model"] = mc_.name;
json_data["engine"] = mc_.engine;

auto data_str = json_data.dump();

// TODO: move this to another message?
auto res = cli.Post("/inferences/server/modelstatus", httplib::Headers(),
data_str.data(), data_str.size(), "application/json");
if (res) {
if (res->status != httplib::StatusCode::OK_200) {
LOG_INFO << res->body;
return;
}
} else {
auto err = res.error();
LOG_WARN << "HTTP error: " << httplib::to_string(err);
return;
}
}
// Some instruction for user here
std::cout << "Inorder to exit, type exit()" << std::endl;
// Model is loaded, start to chat
{
while (true) {
std::string user_input = std::move(msg);
std::cout << "> ";
if (user_input.empty()) {
std::getline(std::cin, user_input);
}
if (user_input == kExitChat) {
break;
}

if (!user_input.empty()) {
httplib::Client cli(address);
nlohmann::json json_data;
nlohmann::json new_data;
new_data["role"] = kUser;
new_data["content"] = user_input;
histories_.push_back(std::move(new_data));
json_data["engine"] = mc_.engine;
json_data["messages"] = histories_;
json_data["model"] = mc_.name;
//TODO: support non-stream
json_data["stream"] = true;
json_data["stop"] = mc_.stop;
auto data_str = json_data.dump();
// std::cout << data_str << std::endl;
cli.set_read_timeout(std::chrono::seconds(60));
// std::cout << "> ";
httplib::Request req;
req.headers = httplib::Headers();
req.set_header("Content-Type", "application/json");
req.method = "POST";
req.path = "/v1/chat/completions";
req.body = data_str;
std::string ai_chat;
req.content_receiver = [&](const char* data, size_t data_length,
uint64_t offset, uint64_t total_length) {
ChunkParser cp(data, data_length);
if (cp.is_done) {
std::cout << std::endl;
return false;
}
std::cout << cp.content;
ai_chat += cp.content;
return true;
};
cli.send(req);

nlohmann::json ai_res;
ai_res["role"] = kAssistant;
ai_res["content"] = ai_chat;
histories_.push_back(std::move(ai_res));
}
// std::cout << "ok Done" << std::endl;
}
}
}

}; // namespace commands
19 changes: 19 additions & 0 deletions engine/commands/chat_cmd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once
#include <string>
#include <vector>
#include "config/model_config.h"
#include "nlohmann/json.hpp"

namespace commands {
class ChatCmd {
public:
ChatCmd(std::string host, int port, const config::ModelConfig& mc);
void Exec(std::string msg);

private:
std::string host_;
int port_;
const config::ModelConfig& mc_;
std::vector<nlohmann::json> histories_;
};
} // namespace commands
20 changes: 19 additions & 1 deletion engine/controllers/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "commands/start_model_cmd.h"
#include "commands/stop_model_cmd.h"
#include "commands/stop_server_cmd.h"
#include "commands/chat_cmd.h"
#include "config/yaml_config.h"
#include "utils/cortex_utils.h"

Expand Down Expand Up @@ -66,7 +67,24 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) {
models_cmd->add_subcommand("update", "Update configuration of a model");
}

auto chat_cmd = app_.add_subcommand("chat", "Send a chat request to a model");
{
auto chat_cmd =
app_.add_subcommand("chat", "Send a chat request to a model");
std::string model_id;
chat_cmd->add_option("model_id", model_id, "");
std::string msg;
chat_cmd->add_option("-m,--message", msg,
"Message to chat with model");

chat_cmd->callback([&model_id, &msg] {
// TODO(sang) switch to <model_id>.yaml when implement model manager
config::YamlHandler yaml_handler;
yaml_handler.ModelConfigFromFile(cortex_utils::GetCurrentPath() +
"/models/" + model_id + "/model.yml");
commands::ChatCmd cc("127.0.0.1", 3928, yaml_handler.GetModelConfig());
cc.Exec(msg);
});
}

auto ps_cmd =
app_.add_subcommand("ps", "Show running models and their status");
Expand Down
25 changes: 13 additions & 12 deletions engine/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,7 @@
#endif

int main(int argc, char* argv[]) {
// Create logs/ folder and setup log to file
std::filesystem::create_directory(cortex_utils::logs_folder);
trantor::AsyncFileLogger asyncFileLogger;
asyncFileLogger.setFileName(cortex_utils::logs_base_name);
asyncFileLogger.startLogging();
trantor::Logger::setOutputFunction(
[&](const char* msg, const uint64_t len) {
asyncFileLogger.output(msg, len);
},
[&]() { asyncFileLogger.flush(); });
asyncFileLogger.setFileSizeLimit(cortex_utils::log_file_size_limit);


// Check if this process is for python execution
if (argc > 1) {
if (strcmp(argv[1], "--run_python_file") == 0) {
Expand Down Expand Up @@ -61,6 +50,18 @@ int main(int argc, char* argv[]) {
return 0;
}

// Create logs/ folder and setup log to file
std::filesystem::create_directory(cortex_utils::logs_folder);
trantor::AsyncFileLogger asyncFileLogger;
asyncFileLogger.setFileName(cortex_utils::logs_base_name);
asyncFileLogger.startLogging();
trantor::Logger::setOutputFunction(
[&](const char* msg, const uint64_t len) {
asyncFileLogger.output(msg, len);
},
[&]() { asyncFileLogger.flush(); });
asyncFileLogger.setFileSizeLimit(cortex_utils::log_file_size_limit);

int thread_num = 1;
std::string host = "127.0.0.1";
int port = 3928;
Expand Down
Loading