Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct model_id in chat #1379

Merged
merged 3 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions engine/commands/chat_completion_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
#include "cortex_upd_cmd.h"
#include "database/models.h"
#include "model_status_cmd.h"
#include "run_cmd.h"
#include "server_start_cmd.h"
#include "trantor/utils/Logger.h"
#include "utils/logging_utils.h"
#include "run_cmd.h"
#include "config/yaml_config.h"

namespace commands {
namespace {
Expand Down Expand Up @@ -39,7 +40,7 @@ struct ChunkParser {
};

void ChatCompletionCmd::Exec(const std::string& host, int port,
const std::string& model_handle, std::string msg) {
const std::string& model_handle, std::string msg) {
cortex::db::Models modellist_handler;
config::YamlHandler yaml_handler;
try {
Expand All @@ -50,15 +51,16 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
}
yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml);
auto mc = yaml_handler.GetModelConfig();
Exec(host, port, mc, std::move(msg));
Exec(host, port, model_handle, mc, std::move(msg));
} catch (const std::exception& e) {
CLI_LOG("Fail to start model information with ID '" + model_handle +
"': " + e.what());
}
}

void ChatCompletionCmd::Exec(const std::string& host, int port,
const config::ModelConfig& mc, std::string msg) {
const std::string& model_handle,
const config::ModelConfig& mc, std::string msg) {
auto address = host + ":" + std::to_string(port);
// Check if server is started
{
Expand All @@ -71,7 +73,7 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,

// Only check if llamacpp engine
if ((mc.engine.find("llamacpp") != std::string::npos) &&
!commands::ModelStatusCmd().IsLoaded(host, port, mc)) {
!commands::ModelStatusCmd().IsLoaded(host, port, model_handle)) {
CLI_LOG("Model is not loaded yet!");
return;
}
Expand Down Expand Up @@ -104,7 +106,7 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
histories_.push_back(std::move(new_data));
json_data["engine"] = mc.engine;
json_data["messages"] = histories_;
json_data["model"] = mc.name;
json_data["model"] = model_handle;
//TODO: support non-stream
json_data["stream"] = true;
json_data["stop"] = mc.stop;
Expand Down
4 changes: 2 additions & 2 deletions engine/commands/chat_completion_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class ChatCompletionCmd {
public:
void Exec(const std::string& host, int port, const std::string& model_handle,
std::string msg);
void Exec(const std::string& host, int port, const config::ModelConfig& mc,
std::string msg);
void Exec(const std::string& host, int port, const std::string& model_handle,
const config::ModelConfig& mc, std::string msg);
vansangpfiev marked this conversation as resolved.
Show resolved Hide resolved

private:
std::vector<nlohmann::json> histories_;
Expand Down
75 changes: 7 additions & 68 deletions engine/commands/model_start_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,84 +5,23 @@
#include "model_status_cmd.h"
#include "nlohmann/json.hpp"
#include "server_start_cmd.h"
#include "services/model_service.h"
#include "trantor/utils/Logger.h"
#include "utils/file_manager_utils.h"
#include "utils/logging_utils.h"

namespace commands {
bool ModelStartCmd::Exec(const std::string& host, int port,
const std::string& model_handle) {
ModelService ms;
auto res = ms.StartModel(host, port, model_handle);

cortex::db::Models modellist_handler;
config::YamlHandler yaml_handler;
try {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
if (model_entry.has_error()) {
CLI_LOG("Error: " + model_entry.error());
return false;
}
yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml);
auto mc = yaml_handler.GetModelConfig();
return Exec(host, port, mc);
} catch (const std::exception& e) {
CLI_LOG("Fail to start model information with ID '" + model_handle +
"': " + e.what());
if (res.has_error()) {
CLI_LOG("Error: " + res.error());
return false;
}
}

bool ModelStartCmd::Exec(const std::string& host, int port,
const config::ModelConfig& mc) {
// Check if server is started
if (!commands::IsServerAlive(host, port)) {
CLI_LOG("Server is not started yet, please run `"
<< commands::GetCortexBinary() << " start` to start server!");
return false;
}

// Only check for llamacpp for now
if ((mc.engine.find("llamacpp") != std::string::npos) &&
commands::ModelStatusCmd().IsLoaded(host, port, mc)) {
CLI_LOG("Model has already been started!");
return true;
}

httplib::Client cli(host + ":" + std::to_string(port));

nlohmann::json json_data;
if (mc.files.size() > 0) {
// TODO(sang) support multiple files
json_data["model_path"] = mc.files[0];
} else {
LOG_WARN << "model_path is empty";
return false;
}
json_data["model"] = mc.name;
json_data["system_prompt"] = mc.system_template;
json_data["user_prompt"] = mc.user_template;
json_data["ai_prompt"] = mc.ai_template;
json_data["ctx_len"] = mc.ctx_len;
json_data["stop"] = mc.stop;
json_data["engine"] = mc.engine;

auto data_str = json_data.dump();
cli.set_read_timeout(std::chrono::seconds(60));
auto res = cli.Post("/inferences/server/loadmodel", httplib::Headers(),
data_str.data(), data_str.size(), "application/json");
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
CLI_LOG("Model loaded!");
return true;
} else {
CTL_ERR("Model failed to load with status code: " << res->status);
return false;
}
} else {
auto err = res.error();
CTL_ERR("HTTP error: " << httplib::to_string(err));
return false;
}
return false;
CLI_LOG("Model loaded!");
return true;
}

}; // namespace commands
2 changes: 0 additions & 2 deletions engine/commands/model_start_cmd.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
#pragma once
#include <string>
#include "config/model_config.h"

namespace commands {

class ModelStartCmd {
public:
bool Exec(const std::string& host, int port, const std::string& model_handle);

bool Exec(const std::string& host, int port, const config::ModelConfig& mc);
};
} // namespace commands
43 changes: 6 additions & 37 deletions engine/commands/model_status_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,18 @@
#include "httplib.h"
#include "nlohmann/json.hpp"
#include "utils/logging_utils.h"
#include "services/model_service.h"

namespace commands {
bool ModelStatusCmd::IsLoaded(const std::string& host, int port,
const std::string& model_handle) {
cortex::db::Models modellist_handler;
config::YamlHandler yaml_handler;
try {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
if (model_entry.has_error()) {
CLI_LOG("Error: " + model_entry.error());
return false;
}
yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml);
auto mc = yaml_handler.GetModelConfig();
return IsLoaded(host, port, mc);
} catch (const std::exception& e) {
CLI_LOG("Fail to get model status with ID '" + model_handle +
"': " + e.what());
return false;
}
}
ModelService ms;
auto res = ms.GetModelStatus(host, port, model_handle);

bool ModelStatusCmd::IsLoaded(const std::string& host, int port,
const config::ModelConfig& mc) {
httplib::Client cli(host + ":" + std::to_string(port));
nlohmann::json json_data;
json_data["model"] = mc.name;
json_data["engine"] = mc.engine;

auto data_str = json_data.dump();

auto res = cli.Post("/inferences/server/modelstatus", httplib::Headers(),
data_str.data(), data_str.size(), "application/json");
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
return true;
}
} else {
auto err = res.error();
CTL_WRN("HTTP error: " << httplib::to_string(err));
if (res.has_error()) {
// CLI_LOG("Error: " + res.error());
return false;
}

return false;
return true;
}
} // namespace commands
3 changes: 0 additions & 3 deletions engine/commands/model_status_cmd.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
#pragma once
#include <string>
#include "config/yaml_config.h"

namespace commands {

class ModelStatusCmd {
public:
bool IsLoaded(const std::string& host, int port,
const std::string& model_handle);
bool IsLoaded(const std::string& host, int port,
const config::ModelConfig& mc);
};
} // namespace commands
39 changes: 7 additions & 32 deletions engine/commands/model_stop_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,20 @@
#include "nlohmann/json.hpp"
#include "utils/file_manager_utils.h"
#include "utils/logging_utils.h"
#include "services/model_service.h"

namespace commands {

void ModelStopCmd::Exec(const std::string& host, int port,
const std::string& model_handle) {
cortex::db::Models modellist_handler;
config::YamlHandler yaml_handler;
try {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
if (model_entry.has_error()) {
CLI_LOG("Error: " + model_entry.error());
return;
}
yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml);
auto mc = yaml_handler.GetModelConfig();
httplib::Client cli(host + ":" + std::to_string(port));
nlohmann::json json_data;
json_data["model"] = mc.name;
json_data["engine"] = mc.engine;
ModelService ms;
auto res = ms.StopModel(host, port, model_handle);

auto data_str = json_data.dump();

auto res = cli.Post("/inferences/server/unloadmodel", httplib::Headers(),
data_str.data(), data_str.size(), "application/json");
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
// LOG_INFO << res->body;
CLI_LOG("Model unloaded!");
} else {
CLI_LOG("Error: could not unload model - " << res->status);
}
} else {
auto err = res.error();
CTL_ERR("HTTP error: " << httplib::to_string(err));
}
} catch (const std::exception& e) {
CLI_LOG("Fail to stop model information with ID '" + model_handle +
"': " + e.what());
if (res.has_error()) {
CLI_LOG("Error: " + res.error());
return;
}
CLI_LOG("Model unloaded!");
}

}; // namespace commands
6 changes: 3 additions & 3 deletions engine/commands/run_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@ void RunCmd::Exec(bool chat_flag) {
// If it is llamacpp, then check model status first
{
if ((mc.engine.find("llamacpp") == std::string::npos) ||
!commands::ModelStatusCmd().IsLoaded(host_, port_, mc)) {
if (!ModelStartCmd().Exec(host_, port_, mc)) {
!commands::ModelStatusCmd().IsLoaded(host_, port_, model_handle_)) {
if (!ModelStartCmd().Exec(host_, port_, model_handle_)) {
return;
}
}
}

// Chat
if (chat_flag) {
ChatCompletionCmd().Exec(host_, port_, mc, "");
ChatCompletionCmd().Exec(host_, port_, model_handle_, mc, "");
} else {
CLI_LOG(*model_id << " model started successfully. Use `"
<< commands::GetCortexBinary() << " chat " << *model_id
Expand Down
52 changes: 49 additions & 3 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,8 @@ void Models::ImportModel(
std::filesystem::path("imported") /
std::filesystem::path(modelHandle + ".yml"))
.string();
cortex::db::ModelEntry model_entry{
modelHandle, "local", "imported",
model_yaml_path, modelHandle};
cortex::db::ModelEntry model_entry{modelHandle, "local", "imported",
model_yaml_path, modelHandle};
try {
std::filesystem::create_directories(
std::filesystem::path(model_yaml_path).parent_path());
Expand Down Expand Up @@ -331,3 +330,50 @@ void Models::SetModelAlias(
callback(resp);
}
}

void Models::StartModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
if (!http_util::HasFieldInReq(req, callback, "model"))
return;
auto config = file_manager_utils::GetCortexConfig();
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
auto result = model_service_.StartModel(
config.apiServerHost, std::stoi(config.apiServerPort), model_handle);
if (result.has_error()) {
Json::Value ret;
ret["message"] = result.error();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
} else {
Json::Value ret;
ret["message"] = "Started successfully!";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k200OK);
callback(resp);
}
}

void Models::StopModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
if (!http_util::HasFieldInReq(req, callback, "model"))
return;
auto config = file_manager_utils::GetCortexConfig();
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
auto result = model_service_.StopModel(
config.apiServerHost, std::stoi(config.apiServerPort), model_handle);
if (result.has_error()) {
Json::Value ret;
ret["message"] = result.error();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
} else {
Json::Value ret;
ret["message"] = "Started successfully!";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k200OK);
callback(resp);
}
}
Loading
Loading