Skip to content

Commit

Permalink
add: cli API
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai committed Nov 6, 2024
1 parent eac2ac4 commit 3922612
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 14 deletions.
64 changes: 63 additions & 1 deletion engine/cli/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include "commands/config_get_cmd.h"
#include "commands/config_upd_cmd.h"
#include "commands/cortex_upd_cmd.h"
#include "commands/engine_get_cmd.h"
#include "commands/engine_install_cmd.h"
Expand Down Expand Up @@ -31,6 +34,7 @@ constexpr const auto kInferenceGroup = "Inference";
constexpr const auto kModelsGroup = "Models";
constexpr const auto kEngineGroup = "Engines";
constexpr const auto kSystemGroup = "Server";
constexpr const auto kConfigGroup = "Configurations";
constexpr const auto kSubcommands = "Subcommands";
} // namespace

Expand All @@ -57,6 +61,8 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) {

SetupSystemCommands();

SetupConfigsCommands();

app_.add_flag("--verbose", log_verbose, "Get verbose logs");

// Logic is handled in main.cc, just for cli helper command
Expand Down Expand Up @@ -301,6 +307,62 @@ void CommandLineParser::SetupModelCommands() {
});
}

void CommandLineParser::SetupConfigsCommands() {
auto config_cmd =
app_.add_subcommand("config", "Subcommands for managing configurations");
config_cmd->usage(
"Usage:\n" + commands::GetCortexBinary() +
" config status for listing all API server configuration.\n" +
commands::GetCortexBinary() +
" config --cors [on/off] to toggle CORS.\n" +
commands::GetCortexBinary() +
" config --allowed_origins [comma separated origin] to set a list of "
"allowed origin");
config_cmd->group(kConfigGroup);
auto config_status_cmd =
config_cmd->add_subcommand("status", "Print all configurations");
config_status_cmd->callback([this] {
if (std::exchange(executed_, true))
return;
commands::ConfigGetCmd().Exec(cml_data_.config.apiServerHost,
std::stoi(cml_data_.config.apiServerPort));
});

// TODO: this can be improved
std::vector<std::string> avai_opts{"cors", "allowed_origins"};
std::unordered_map<std::string, std::string> description{
{"cors", "[on/off] Toggling CORS."},
{"allowed_origins",
"Allowed origins for CORS. Comma separated. E.g. "
"http://localhost,https://cortex.so"}};
for (const auto& opt : avai_opts) {
std::string option = "--" + opt;
config_cmd->add_option(option, config_update_opts_[opt], description[opt])
->expected(0, 1)
->default_str("*");
}

config_cmd->callback([this, config_cmd] {
if (std::exchange(executed_, true))
return;

auto is_empty = true;
for (const auto& [key, value] : config_update_opts_) {
if (!value.empty()) {
is_empty = false;
break;
}
}
if (is_empty) {
CLI_LOG(config_cmd->help());
return;
}
commands::ConfigUpdCmd().Exec(cml_data_.config.apiServerHost,
std::stoi(cml_data_.config.apiServerPort),
config_update_opts_);
});
}

void CommandLineParser::SetupEngineCommands() {
auto engines_cmd =
app_.add_subcommand("engines", "Subcommands for managing engines");
Expand Down Expand Up @@ -339,7 +401,7 @@ void CommandLineParser::SetupEngineCommands() {
CLI_LOG(install_cmd->help());
}
});
for (auto& engine : engine_service_.kSupportEngines) {
for (const auto& engine : engine_service_.kSupportEngines) {
std::string engine_name{engine};
EngineInstall(install_cmd, engine_name, cml_data_.engine_version,
cml_data_.engine_src);
Expand Down
4 changes: 4 additions & 0 deletions engine/cli/command_line_parser.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <memory>
#include <unordered_map>
#include "CLI/CLI.hpp"
#include "services/engine_service.h"
#include "services/model_service.h"
Expand All @@ -22,6 +23,8 @@ class CommandLineParser {

void SetupSystemCommands();

void SetupConfigsCommands();

void EngineInstall(CLI::App* parent, const std::string& engine_name,
std::string& version, std::string& src);

Expand Down Expand Up @@ -62,5 +65,6 @@ class CommandLineParser {
std::unordered_map<std::string, std::string> model_update_options;
};
CmlData cml_data_;
std::unordered_map<std::string, std::string> config_update_opts_;
bool executed_ = false;
};
46 changes: 46 additions & 0 deletions engine/cli/commands/config_get_cmd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "config_get_cmd.h"
#include <tabulate/table.hpp>
#include "commands/server_start_cmd.h"
#include "utils/curl_utils.h"
#include "utils/logging_utils.h"
#include "utils/url_parser.h"

void commands::ConfigGetCmd::Exec(const std::string& host, int port) {
// Start server if server is not started yet
if (!commands::IsServerAlive(host, port)) {
CLI_LOG("Starting server ...");
commands::ServerStartCmd ssc;
if (!ssc.Exec(host, port)) {
return;
}
}
auto url = url_parser::Url{
.protocol = "http",
.host = host + ":" + std::to_string(port),
.pathParams = {"v1", "configs"},
};

auto get_config_result = curl_utils::SimpleGetJson(url.ToFullPath());
if (get_config_result.has_error()) {
CLI_LOG_ERROR(
"Failed to get configurations: " << get_config_result.error());
return;
}

auto json_value = get_config_result.value();
tabulate::Table table;
table.add_row({"Config name", "Value"});

for (const auto& key : json_value.getMemberNames()) {
if (json_value[key].isArray()) {
for (const auto& value : json_value[key]) {
table.add_row({key, value.asString()});
}
} else {
table.add_row({key, json_value[key].asString()});
}
}

std::cout << table << std::endl;
return;
}
10 changes: 10 additions & 0 deletions engine/cli/commands/config_get_cmd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once

#include <string>

namespace commands {
class ConfigGetCmd {
public:
void Exec(const std::string& host, int port);
};
} // namespace commands
73 changes: 73 additions & 0 deletions engine/cli/commands/config_upd_cmd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include "config_upd_cmd.h"
#include "commands/server_start_cmd.h"
#include "utils/curl_utils.h"
#include "utils/logging_utils.h"
#include "utils/string_utils.h"
#include "utils/url_parser.h"

namespace {
const std::vector<std::string> config_keys{"cors", "allowed_origins"};

inline Json::Value NormalizeJson(
const std::unordered_map<std::string, std::string> options) {
Json::Value root;
for (const auto& [key, value] : options) {
if (std::find(config_keys.begin(), config_keys.end(), key) ==
config_keys.end()) {
continue;
}

if (key == "cors") {
if (string_utils::EqualsIgnoreCase("on", value)) {
root["cors"] = true;
} else if (string_utils::EqualsIgnoreCase("off", value)) {
root["cors"] = false;
}
} else if (key == "allowed_origins") {
auto origins = string_utils::SplitBy(value, ",");
Json::Value origin_array(Json::arrayValue);
for (const auto& origin : origins) {
origin_array.append(origin);
}
root[key] = origin_array;
}
}

CTL_DBG("Normalized config update request: " << root.toStyledString());

return root;
}
}; // namespace

void commands::ConfigUpdCmd::Exec(
const std::string& host, int port,
const std::unordered_map<std::string, std::string>& options) {
if (!commands::IsServerAlive(host, port)) {
CLI_LOG("Starting server ...");
commands::ServerStartCmd ssc;
if (!ssc.Exec(host, port)) {
return;
}
}

auto url = url_parser::Url{
.protocol = "http",
.host = host + ":" + std::to_string(port),
.pathParams = {"v1", "configs"},
};

auto json = NormalizeJson(options);
if (json.empty()) {
CLI_LOG_ERROR("Invalid configuration options provided");
return;
}

auto update_cnf_result =
curl_utils::SimplePatch(url.ToFullPath(), json.toStyledString());
if (update_cnf_result.has_error()) {
CLI_LOG_ERROR(update_cnf_result.error());
return;
}

CLI_LOG("Configuration updated successfully!");
}
12 changes: 12 additions & 0 deletions engine/cli/commands/config_upd_cmd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <string>
#include <unordered_map>

namespace commands {
class ConfigUpdCmd {
public:
void Exec(const std::string& host, int port,
const std::unordered_map<std::string, std::string>& options);
};
} // namespace commands
8 changes: 4 additions & 4 deletions engine/controllers/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ void Configs::GetConfigurations(
auto get_config_result = config_service_->GetApiServerConfiguration();
if (get_config_result.has_error()) {
Json::Value error_json;
error_json["error"] = get_config_result.error();
error_json["message"] = get_config_result.error();
auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
Expand All @@ -24,9 +24,9 @@ void Configs::UpdateConfigurations(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) {
auto json_body = req->getJsonObject();
if (!json_body) {
if (json_body == nullptr) {
Json::Value error_json;
error_json["error"] = "Configuration must be provided via JSON body";
error_json["message"] = "Configuration must be provided via JSON body";
auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
Expand All @@ -36,7 +36,7 @@ void Configs::UpdateConfigurations(
config_service_->UpdateApiServerConfiguration(*json_body);
if (update_config_result.has_error()) {
Json::Value error_json;
error_json["error"] = update_config_result.error();
error_json["message"] = update_config_result.error();
auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
Expand Down
20 changes: 12 additions & 8 deletions engine/test/components/test_string_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ TEST_F(StringUtilsTestSuite, ParsePrompt) {

TEST_F(StringUtilsTestSuite, TestSplitBy) {
auto input = "this is a test";
std::string delimiter{' '};
auto result = SplitBy(input, delimiter);
auto result = SplitBy(input, " ");

EXPECT_EQ(result.size(), 4);
EXPECT_EQ(result[0], "this");
Expand All @@ -35,16 +34,14 @@ TEST_F(StringUtilsTestSuite, TestSplitBy) {

TEST_F(StringUtilsTestSuite, TestSplitByWithEmptyString) {
auto input = "";
std::string delimiter{' '};
auto result = SplitBy(input, delimiter);
auto result = SplitBy(input, " ");

EXPECT_EQ(result.size(), 0);
}

TEST_F(StringUtilsTestSuite, TestSplitModelHandle) {
auto input = "cortexso/tinyllama";
std::string delimiter{'/'};
auto result = SplitBy(input, delimiter);
auto result = SplitBy(input, "/");

EXPECT_EQ(result.size(), 2);
EXPECT_EQ(result[0], "cortexso");
Expand All @@ -53,13 +50,20 @@ TEST_F(StringUtilsTestSuite, TestSplitModelHandle) {

TEST_F(StringUtilsTestSuite, TestSplitModelHandleWithEmptyModelName) {
auto input = "cortexso/";
std::string delimiter{'/'};
auto result = SplitBy(input, delimiter);
auto result = SplitBy(input, "/");

EXPECT_EQ(result.size(), 1);
EXPECT_EQ(result[0], "cortexso");
}

TEST_F(StringUtilsTestSuite, TestSplitIfNotContainDelimiter) {
auto input = "https://cortex.so";
auto result = SplitBy(input, ",");

EXPECT_EQ(result.size(), 1);
EXPECT_EQ(result[0], "https://cortex.so");
}

TEST_F(StringUtilsTestSuite, TestStartsWith) {
auto input = "this is a test";
auto prefix = "this";
Expand Down
Loading

0 comments on commit 3922612

Please sign in to comment.