diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 625750248..5ce54c532 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -2,6 +2,9 @@ #include #include #include +#include +#include "commands/config_get_cmd.h" +#include "commands/config_upd_cmd.h" #include "commands/cortex_upd_cmd.h" #include "commands/engine_get_cmd.h" #include "commands/engine_install_cmd.h" @@ -31,6 +34,7 @@ constexpr const auto kInferenceGroup = "Inference"; constexpr const auto kModelsGroup = "Models"; constexpr const auto kEngineGroup = "Engines"; constexpr const auto kSystemGroup = "Server"; +constexpr const auto kConfigGroup = "Configurations"; constexpr const auto kSubcommands = "Subcommands"; } // namespace @@ -57,6 +61,8 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) { SetupSystemCommands(); + SetupConfigsCommands(); + app_.add_flag("--verbose", log_verbose, "Get verbose logs"); // Logic is handled in main.cc, just for cli helper command @@ -301,6 +307,62 @@ void CommandLineParser::SetupModelCommands() { }); } +void CommandLineParser::SetupConfigsCommands() { + auto config_cmd = + app_.add_subcommand("config", "Subcommands for managing configurations"); + config_cmd->usage( + "Usage:\n" + commands::GetCortexBinary() + + " config status for listing all API server configuration.\n" + + commands::GetCortexBinary() + + " config --cors [on/off] to toggle CORS.\n" + + commands::GetCortexBinary() + + " config --allowed_origins [comma separated origin] to set a list of " + "allowed origin"); + config_cmd->group(kConfigGroup); + auto config_status_cmd = + config_cmd->add_subcommand("status", "Print all configurations"); + config_status_cmd->callback([this] { + if (std::exchange(executed_, true)) + return; + commands::ConfigGetCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort)); + }); + + // TODO: this can be improved + std::vector avai_opts{"cors", "allowed_origins"}; + std::unordered_map description{ + {"cors", "[on/off] Toggling CORS."}, + {"allowed_origins", + "Allowed origins for CORS. Comma separated. E.g. " + "http://localhost,https://cortex.so"}}; + for (const auto& opt : avai_opts) { + std::string option = "--" + opt; + config_cmd->add_option(option, config_update_opts_[opt], description[opt]) + ->expected(0, 1) + ->default_str("*"); + } + + config_cmd->callback([this, config_cmd] { + if (std::exchange(executed_, true)) + return; + + auto is_empty = true; + for (const auto& [key, value] : config_update_opts_) { + if (!value.empty()) { + is_empty = false; + break; + } + } + if (is_empty) { + CLI_LOG(config_cmd->help()); + return; + } + commands::ConfigUpdCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), + config_update_opts_); + }); +} + void CommandLineParser::SetupEngineCommands() { auto engines_cmd = app_.add_subcommand("engines", "Subcommands for managing engines"); @@ -339,7 +401,7 @@ void CommandLineParser::SetupEngineCommands() { CLI_LOG(install_cmd->help()); } }); - for (auto& engine : engine_service_.kSupportEngines) { + for (const auto& engine : engine_service_.kSupportEngines) { std::string engine_name{engine}; EngineInstall(install_cmd, engine_name, cml_data_.engine_version, cml_data_.engine_src); diff --git a/engine/cli/command_line_parser.h b/engine/cli/command_line_parser.h index 9f3cdda12..de51ba212 100644 --- a/engine/cli/command_line_parser.h +++ b/engine/cli/command_line_parser.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include "CLI/CLI.hpp" #include "services/engine_service.h" #include "services/model_service.h" @@ -22,6 +23,8 @@ class CommandLineParser { void SetupSystemCommands(); + void SetupConfigsCommands(); + void EngineInstall(CLI::App* parent, const std::string& engine_name, std::string& version, std::string& src); @@ -62,5 +65,6 @@ class CommandLineParser { std::unordered_map model_update_options; }; CmlData cml_data_; + std::unordered_map config_update_opts_; bool executed_ = false; }; diff --git a/engine/cli/commands/config_get_cmd.cc b/engine/cli/commands/config_get_cmd.cc new file mode 100644 index 000000000..aa2e059d7 --- /dev/null +++ b/engine/cli/commands/config_get_cmd.cc @@ -0,0 +1,46 @@ +#include "config_get_cmd.h" +#include +#include "commands/server_start_cmd.h" +#include "utils/curl_utils.h" +#include "utils/logging_utils.h" +#include "utils/url_parser.h" + +void commands::ConfigGetCmd::Exec(const std::string& host, int port) { + // Start server if server is not started yet + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return; + } + } + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "configs"}, + }; + + auto get_config_result = curl_utils::SimpleGetJson(url.ToFullPath()); + if (get_config_result.has_error()) { + CLI_LOG_ERROR( + "Failed to get configurations: " << get_config_result.error()); + return; + } + + auto json_value = get_config_result.value(); + tabulate::Table table; + table.add_row({"Config name", "Value"}); + + for (const auto& key : json_value.getMemberNames()) { + if (json_value[key].isArray()) { + for (const auto& value : json_value[key]) { + table.add_row({key, value.asString()}); + } + } else { + table.add_row({key, json_value[key].asString()}); + } + } + + std::cout << table << std::endl; + return; +} diff --git a/engine/cli/commands/config_get_cmd.h b/engine/cli/commands/config_get_cmd.h new file mode 100644 index 000000000..9431dc100 --- /dev/null +++ b/engine/cli/commands/config_get_cmd.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace commands { +class ConfigGetCmd { + public: + void Exec(const std::string& host, int port); +}; +} // namespace commands diff --git a/engine/cli/commands/config_upd_cmd.cc b/engine/cli/commands/config_upd_cmd.cc new file mode 100644 index 000000000..ebf86fcea --- /dev/null +++ b/engine/cli/commands/config_upd_cmd.cc @@ -0,0 +1,73 @@ +#include "config_upd_cmd.h" +#include "commands/server_start_cmd.h" +#include "utils/curl_utils.h" +#include "utils/logging_utils.h" +#include "utils/string_utils.h" +#include "utils/url_parser.h" + +namespace { +const std::vector config_keys{"cors", "allowed_origins"}; + +inline Json::Value NormalizeJson( + const std::unordered_map options) { + Json::Value root; + for (const auto& [key, value] : options) { + if (std::find(config_keys.begin(), config_keys.end(), key) == + config_keys.end()) { + continue; + } + + if (key == "cors") { + if (string_utils::EqualsIgnoreCase("on", value)) { + root["cors"] = true; + } else if (string_utils::EqualsIgnoreCase("off", value)) { + root["cors"] = false; + } + } else if (key == "allowed_origins") { + auto origins = string_utils::SplitBy(value, ","); + Json::Value origin_array(Json::arrayValue); + for (const auto& origin : origins) { + origin_array.append(origin); + } + root[key] = origin_array; + } + } + + CTL_DBG("Normalized config update request: " << root.toStyledString()); + + return root; +} +}; // namespace + +void commands::ConfigUpdCmd::Exec( + const std::string& host, int port, + const std::unordered_map& options) { + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return; + } + } + + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "configs"}, + }; + + auto json = NormalizeJson(options); + if (json.empty()) { + CLI_LOG_ERROR("Invalid configuration options provided"); + return; + } + + auto update_cnf_result = + curl_utils::SimplePatch(url.ToFullPath(), json.toStyledString()); + if (update_cnf_result.has_error()) { + CLI_LOG_ERROR(update_cnf_result.error()); + return; + } + + CLI_LOG("Configuration updated successfully!"); +} diff --git a/engine/cli/commands/config_upd_cmd.h b/engine/cli/commands/config_upd_cmd.h new file mode 100644 index 000000000..55375b3b7 --- /dev/null +++ b/engine/cli/commands/config_upd_cmd.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace commands { +class ConfigUpdCmd { + public: + void Exec(const std::string& host, int port, + const std::unordered_map& options); +}; +} // namespace commands diff --git a/engine/controllers/configs.cc b/engine/controllers/configs.cc index 630e6e65e..41b08cf45 100644 --- a/engine/controllers/configs.cc +++ b/engine/controllers/configs.cc @@ -6,7 +6,7 @@ void Configs::GetConfigurations( auto get_config_result = config_service_->GetApiServerConfiguration(); if (get_config_result.has_error()) { Json::Value error_json; - error_json["error"] = get_config_result.error(); + error_json["message"] = get_config_result.error(); auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json); resp->setStatusCode(drogon::k400BadRequest); callback(resp); @@ -24,9 +24,9 @@ void Configs::UpdateConfigurations( const HttpRequestPtr& req, std::function&& callback) { auto json_body = req->getJsonObject(); - if (!json_body) { + if (json_body == nullptr) { Json::Value error_json; - error_json["error"] = "Configuration must be provided via JSON body"; + error_json["message"] = "Configuration must be provided via JSON body"; auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json); resp->setStatusCode(drogon::k400BadRequest); callback(resp); @@ -36,7 +36,7 @@ void Configs::UpdateConfigurations( config_service_->UpdateApiServerConfiguration(*json_body); if (update_config_result.has_error()) { Json::Value error_json; - error_json["error"] = update_config_result.error(); + error_json["message"] = update_config_result.error(); auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json); resp->setStatusCode(drogon::k400BadRequest); callback(resp); diff --git a/engine/test/components/test_string_utils.cc b/engine/test/components/test_string_utils.cc index 71ab78a78..c412c5ec4 100644 --- a/engine/test/components/test_string_utils.cc +++ b/engine/test/components/test_string_utils.cc @@ -23,8 +23,7 @@ TEST_F(StringUtilsTestSuite, ParsePrompt) { TEST_F(StringUtilsTestSuite, TestSplitBy) { auto input = "this is a test"; - std::string delimiter{' '}; - auto result = SplitBy(input, delimiter); + auto result = SplitBy(input, " "); EXPECT_EQ(result.size(), 4); EXPECT_EQ(result[0], "this"); @@ -35,16 +34,14 @@ TEST_F(StringUtilsTestSuite, TestSplitBy) { TEST_F(StringUtilsTestSuite, TestSplitByWithEmptyString) { auto input = ""; - std::string delimiter{' '}; - auto result = SplitBy(input, delimiter); + auto result = SplitBy(input, " "); EXPECT_EQ(result.size(), 0); } TEST_F(StringUtilsTestSuite, TestSplitModelHandle) { auto input = "cortexso/tinyllama"; - std::string delimiter{'/'}; - auto result = SplitBy(input, delimiter); + auto result = SplitBy(input, "/"); EXPECT_EQ(result.size(), 2); EXPECT_EQ(result[0], "cortexso"); @@ -53,13 +50,20 @@ TEST_F(StringUtilsTestSuite, TestSplitModelHandle) { TEST_F(StringUtilsTestSuite, TestSplitModelHandleWithEmptyModelName) { auto input = "cortexso/"; - std::string delimiter{'/'}; - auto result = SplitBy(input, delimiter); + auto result = SplitBy(input, "/"); EXPECT_EQ(result.size(), 1); EXPECT_EQ(result[0], "cortexso"); } +TEST_F(StringUtilsTestSuite, TestSplitIfNotContainDelimiter) { + auto input = "https://cortex.so"; + auto result = SplitBy(input, ","); + + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0], "https://cortex.so"); +} + TEST_F(StringUtilsTestSuite, TestStartsWith) { auto input = "this is a test"; auto prefix = "this"; diff --git a/engine/utils/curl_utils.h b/engine/utils/curl_utils.h index 88b05828a..0c31d0830 100644 --- a/engine/utils/curl_utils.h +++ b/engine/utils/curl_utils.h @@ -71,6 +71,62 @@ inline cpp::result SimpleGet(const std::string& url) { return readBuffer; } +inline cpp::result SimplePatch( + const std::string& url, const std::string& body = "") { + auto curl = curl_easy_init(); + + if (!curl) { + return cpp::fail("Failed to init CURL"); + } + + auto headers = GetHeaders(url); + curl_slist* curl_headers = nullptr; + curl_headers = + curl_slist_append(curl_headers, "Content-Type: application/json"); + + if (headers.has_value()) { + for (const auto& [key, value] : headers.value()) { + auto header = key + ": " + value; + curl_headers = curl_slist_append(curl_headers, header.c_str()); + } + } + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers); + + std::string readBuffer; + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "PATCH"); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + + // Set content length if body is not empty + if (!body.empty()) { + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, body.length()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + } + + // Perform the request + auto res = curl_easy_perform(curl); + + curl_slist_free_all(curl_headers); + curl_easy_cleanup(curl); + if (res != CURLE_OK) { + CTL_ERR("CURL request failed: " + std::string(curl_easy_strerror(res))); + return cpp::fail("CURL request failed: " + + static_cast(curl_easy_strerror(res))); + } + auto http_code = 0; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + if (http_code >= 400) { + CTL_ERR("HTTP request failed with status code: " + + std::to_string(http_code)); + return cpp::fail(readBuffer); + } + + return readBuffer; +} + inline cpp::result SimplePost( const std::string& url, const std::string& body = "") { curl_global_init(CURL_GLOBAL_DEFAULT); @@ -223,6 +279,25 @@ inline cpp::result SimplePostJson( return root; } +inline cpp::result SimplePatchJson( + const std::string& url, const std::string& body = "") { + auto result = SimplePatch(url, body); + if (result.has_error()) { + CTL_ERR("Failed to get JSON from " + url + ": " + result.error()); + return cpp::fail(result.error()); + } + + CTL_INF("Response: " + result.value()); + Json::Value root; + Json::Reader reader; + if (!reader.parse(result.value(), root)) { + return cpp::fail("JSON from " + url + + " parsing error: " + reader.getFormattedErrorMessages()); + } + + return root; +} + inline std::optional> GetHeaders( const std::string& url) { auto url_obj = url_parser::FromUrlString(url); diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h index 9e40e423b..264d04025 100644 --- a/engine/utils/string_utils.h +++ b/engine/utils/string_utils.h @@ -99,7 +99,7 @@ inline bool EndsWith(const std::string& str, const std::string& suffix) { } inline std::vector SplitBy(const std::string& str, - const std::string& delimiter) { + const std::string&& delimiter) { std::vector tokens; size_t prev = 0, pos = 0; do {