Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cortex models start has no output if variant not given #1531

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions engine/cli/commands/model_start_cmd.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
#include "model_start_cmd.h"
#include "config/yaml_config.h"
#include "cortex_upd_cmd.h"
#include "database/models.h"
#include "httplib.h"
#include "run_cmd.h"
#include "server_start_cmd.h"
#include "utils/cli_selection_utils.h"
#include "utils/logging_utils.h"

namespace commands {
bool ModelStartCmd::Exec(const std::string& host, int port,
const std::string& model_handle) {
std::optional<std::string> model_id =
SelectLocalModel(model_service_, model_handle);

if(!model_id.has_value()) {
return false;
}

// Start server if server is not started yet
if (!commands::IsServerAlive(host, port)) {
CLI_LOG("Starting server ...");
Expand All @@ -17,14 +29,16 @@ bool ModelStartCmd::Exec(const std::string& host, int port,
// Call API to start model
httplib::Client cli(host + ":" + std::to_string(port));
Json::Value json_data;
json_data["model"] = model_handle;
json_data["model"] = model_id.value();
auto data_str = json_data.toStyledString();
cli.set_read_timeout(std::chrono::seconds(60));
auto res = cli.Post("/v1/models/start", httplib::Headers(), data_str.data(),
data_str.size(), "application/json");
if (res) {
if (res->status == httplib::StatusCode::OK_200) {
CLI_LOG("Model loaded!");
CLI_LOG(model_id.value() << " model started successfully. Use `"
<< commands::GetCortexBinary() << " run "
<< *model_id << "` for interactive chat shell");
return true;
} else {
CTL_ERR("Model failed to load with status code: " << res->status);
Expand Down
114 changes: 60 additions & 54 deletions engine/cli/commands/run_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,59 @@

namespace commands {

std::optional<std::string> SelectLocalModel(ModelService& model_service,
const std::string& model_handle) {
std::optional<std::string> model_id = model_handle;
cortex::db::Models modellist_handler;

if (model_handle.empty()) {
auto all_local_models = modellist_handler.LoadModelList();
if (all_local_models.has_error() || all_local_models.value().empty()) {
CLI_LOG("No local models available!");
return std::nullopt;
}

if (all_local_models.value().size() == 1) {
model_id = all_local_models.value().front().model;
} else {
std::vector<std::string> model_id_list{};
for (const auto& model : all_local_models.value()) {
model_id_list.push_back(model.model);
}

auto selection = cli_selection_utils::PrintSelection(
model_id_list, "Please select an option");
if (!selection.has_value()) {
return std::nullopt;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());
}
} else {
auto related_models_ids = modellist_handler.FindRelatedModel(model_handle);
if (related_models_ids.has_error() || related_models_ids.value().empty()) {
auto result = model_service.DownloadModel(model_handle);
if (result.has_error()) {
CLI_LOG("Model " << model_handle << " not found!");
return std::nullopt;
}
model_id = result.value();
CTL_INF("model_id: " << model_id.value());
} else if (related_models_ids.value().size() == 1) {
model_id = related_models_ids.value().front();
} else { // multiple models with nearly same name found
auto selection = cli_selection_utils::PrintSelection(
related_models_ids.value(), "Local Models: (press enter to select)");
if (!selection.has_value()) {
return std::nullopt;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());
}
}
return model_id;
}

namespace {
std::string Repo2Engine(const std::string& r) {
if (r == kLlamaRepo) {
Expand All @@ -24,63 +77,16 @@ std::string Repo2Engine(const std::string& r) {
} // namespace

void RunCmd::Exec(bool run_detach) {
std::optional<std::string> model_id = model_handle_;

std::optional<std::string> model_id =
SelectLocalModel(model_service_, model_handle_);
if (!model_id.has_value()) {
return;
}

cortex::db::Models modellist_handler;
config::YamlHandler yaml_handler;
auto address = host_ + ":" + std::to_string(port_);

{
if (model_handle_.empty()) {
auto all_local_models = modellist_handler.LoadModelList();
if (all_local_models.has_error() || all_local_models.value().empty()) {
CLI_LOG("No local models available!");
return;
}

if (all_local_models.value().size() == 1) {
model_id = all_local_models.value().front().model;
} else {
std::vector<std::string> model_id_list{};
for (const auto& model : all_local_models.value()) {
model_id_list.push_back(model.model);
}

auto selection = cli_selection_utils::PrintSelection(
model_id_list, "Please select an option");
if (!selection.has_value()) {
return;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());
}
} else {
auto related_models_ids =
modellist_handler.FindRelatedModel(model_handle_);
if (related_models_ids.has_error() ||
related_models_ids.value().empty()) {
auto result = model_service_.DownloadModel(model_handle_);
if (result.has_error()) {
CLI_LOG("Model " << model_handle_ << " not found!");
return;
}
model_id = result.value();
CTL_INF("model_id: " << model_id.value());
} else if (related_models_ids.value().size() == 1) {
model_id = related_models_ids.value().front();
} else { // multiple models with nearly same name found
auto selection = cli_selection_utils::PrintSelection(
related_models_ids.value(),
"Local Models: (press enter to select)");
if (!selection.has_value()) {
return;
}
model_id = selection.value();
CLI_LOG("Selected: " << selection.value());
}
}
}

try {
namespace fs = std::filesystem;
namespace fmu = file_manager_utils;
Expand Down Expand Up @@ -148,7 +154,7 @@ void RunCmd::Exec(bool run_detach) {
// Chat
if (run_detach) {
CLI_LOG(*model_id << " model started successfully. Use `"
<< commands::GetCortexBinary() << " chat " << *model_id
<< commands::GetCortexBinary() << " run " << *model_id
<< "` for interactive chat shell");
} else {
ChatCompletionCmd(model_service_).Exec(host_, port_, *model_id, mc, "");
Expand Down
4 changes: 4 additions & 0 deletions engine/cli/commands/run_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include "services/model_service.h"

namespace commands {

std::optional<std::string> SelectLocalModel(ModelService& model_service,
const std::string& model_handle);

class RunCmd {
public:
explicit RunCmd(std::string host, int port, std::string model_handle,
Expand Down
Loading