From 0b75004bf15f457201246278817707ca14d9539e Mon Sep 17 00:00:00 2001 From: James Date: Wed, 6 Nov 2024 10:01:09 +0700 Subject: [PATCH 1/7] fix: cors --- engine/CMakeLists.txt | 5 -- engine/addon.cc | 83 --------------------- engine/cli/CMakeLists.txt | 5 -- engine/cli/main.cc | 4 - engine/controllers/engines.h | 7 +- engine/controllers/models.h | 13 ++-- engine/controllers/process_manager.h | 2 +- engine/main.cc | 6 ++ engine/test/components/test_cortex_utils.cc | 41 ---------- engine/test/components/test_models_db.cc | 4 +- engine/utils/cortex_utils.h | 53 +------------ 11 files changed, 22 insertions(+), 201 deletions(-) delete mode 100644 engine/addon.cc delete mode 100644 engine/test/components/test_cortex_utils.cc diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index d4e9ac5f6..dc4ce8807 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -30,11 +30,6 @@ if(MSVC) ) endif() -if(DEBUG) - message(STATUS "CORTEX-CPP DEBUG IS ON") - add_compile_definitions(ALLOW_ALL_CORS) -endif() - if(NOT DEFINED CORTEX_VARIANT) set(CORTEX_VARIANT "prod") endif() diff --git a/engine/addon.cc b/engine/addon.cc deleted file mode 100644 index 503358160..000000000 --- a/engine/addon.cc +++ /dev/null @@ -1,83 +0,0 @@ -#include - -#include -#include -#include -#include // for PATH_MAX -#include -#include "cortex-common/cortexpythoni.h" -#include "utils/cortex_utils.h" -#include "utils/dylib.h" - -#if defined(__APPLE__) && defined(__MACH__) -#include // for dirname() -#include -#elif defined(__linux__) -#include // for dirname() -#include // for readlink() -#elif defined(_WIN32) -#include -#undef max -#else -#error "Unsupported platform!" -#endif - -static Napi::Env* s_env = nullptr; - -void start(const int port = 3929) { - int thread_num = 1; - std::string host = "127.0.0.1"; - int logical_cores = std::thread::hardware_concurrency(); - int drogon_thread_num = std::max(thread_num, logical_cores); -#ifdef CORTEX_CPP_VERSION - LOG_INFO << "cortex-cpp version: " << CORTEX_CPP_VERSION; -#else - LOG_INFO << "cortex-cpp version: undefined"; -#endif - - LOG_INFO << "Server started, listening at: " << host << ":" << port; - LOG_INFO << "Please load your model"; - drogon::app().addListener(host, port); - drogon::app().setThreadNum(drogon_thread_num); - LOG_INFO << "Number of thread is:" << drogon::app().getThreadNum(); - - drogon::app().run(); -} - -void stop() { - drogon::app().quit(); -} - -void exitCallback() { - Napi::TypeError::New(*s_env, "Process Exited!").ThrowAsJavaScriptException(); -} - -Napi::Value Start(const Napi::CallbackInfo& info) { - Napi::Env env = info.Env(); - - s_env = &env; - - // Register exitCallback with atexit - std::atexit(exitCallback); - - - Napi::Number jsParam = info[0].As(); - int port = jsParam.Int32Value(); - - start(port); - return env.Undefined(); -} - -Napi::Value Stop(const Napi::CallbackInfo& info) { - Napi::Env env = info.Env(); - stop(); - return Napi::String::New(env, "Server stopped successfully"); -} - -Napi::Object Init(Napi::Env env, Napi::Object exports) { - exports.Set(Napi::String::New(env, "start"), Napi::Function::New(env, Start)); - exports.Set(Napi::String::New(env, "stop"), Napi::Function::New(env, Start)); - return exports; -} - -NODE_API_MODULE(cortex-cpp, Init) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index be0a7dcfe..0e25a4873 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -28,11 +28,6 @@ if(MSVC) ) endif() -if(DEBUG) - message(STATUS "CORTEX-CPP DEBUG IS ON") - add_compile_definitions(ALLOW_ALL_CORS) -endif() - if(NOT DEFINED CORTEX_VARIANT) set(CORTEX_VARIANT "prod") endif() diff --git a/engine/cli/main.cc b/engine/cli/main.cc index 62a88eb38..8fa771fa6 100644 --- a/engine/cli/main.cc +++ b/engine/cli/main.cc @@ -1,13 +1,9 @@ #include #include "command_line_parser.h" #include "commands/cortex_upd_cmd.h" -#include "cortex-common/cortexpythoni.h" #include "services/download_service.h" -#include "services/model_service.h" #include "utils/archive_utils.h" #include "utils/cortex_utils.h" -#include "utils/dylib.h" -#include "utils/event_processor.h" #include "utils/file_logger.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" diff --git a/engine/controllers/engines.h b/engine/controllers/engines.h index de1dbf6ea..85d09172f 100644 --- a/engine/controllers/engines.h +++ b/engine/controllers/engines.h @@ -14,13 +14,14 @@ class Engines : public drogon::HttpController { METHOD_ADD(Engines::GetInstalledEngineVariants, "/{1}", Get); METHOD_ADD(Engines::InstallEngine, "/{1}?version={2}&variant={3}", Post); - METHOD_ADD(Engines::UninstallEngine, "/{1}?version={2}&variant={3}", Delete); + METHOD_ADD(Engines::UninstallEngine, "/{1}?version={2}&variant={3}", Options, + Delete); METHOD_ADD(Engines::SetDefaultEngineVariant, "/{1}/default?version={2}&variant={3}", Post); METHOD_ADD(Engines::GetDefaultEngineVariant, "/{1}/default", Get); METHOD_ADD(Engines::LoadEngine, "/{1}/load", Post); - METHOD_ADD(Engines::UnloadEngine, "/{1}/load", Delete); + METHOD_ADD(Engines::UnloadEngine, "/{1}/load", Options, Delete); METHOD_ADD(Engines::UpdateEngine, "/{1}/update", Post); METHOD_ADD(Engines::ListEngine, "", Get); METHOD_ADD(Engines::GetEngineVersions, "/{1}/versions", Get); @@ -30,7 +31,7 @@ class Engines : public drogon::HttpController { ADD_METHOD_TO(Engines::InstallEngine, "/v1/engines/{1}?version={2}&variant={3}", Post); ADD_METHOD_TO(Engines::UninstallEngine, - "/v1/engines/{1}?version={2}&variant={3}", Delete); + "/v1/engines/{1}?version={2}&variant={3}", Options, Delete); ADD_METHOD_TO(Engines::SetDefaultEngineVariant, "/v1/engines/{1}/default?version={2}&variant={3}", Post); ADD_METHOD_TO(Engines::GetDefaultEngineVariant, "/v1/engines/{1}/default", diff --git a/engine/controllers/models.h b/engine/controllers/models.h index b48a0d1aa..14e19e102 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -12,23 +12,23 @@ class Models : public drogon::HttpController { METHOD_LIST_BEGIN METHOD_ADD(Models::PullModel, "/pull", Post); METHOD_ADD(Models::GetModelPullInfo, "/pull/info", Post); - METHOD_ADD(Models::AbortPullModel, "/pull", Delete); + METHOD_ADD(Models::AbortPullModel, "/pull", Options, Delete); METHOD_ADD(Models::ListModel, "", Get); METHOD_ADD(Models::GetModel, "/{1}", Get); METHOD_ADD(Models::UpdateModel, "/{1}", Patch); METHOD_ADD(Models::ImportModel, "/import", Post); - METHOD_ADD(Models::DeleteModel, "/{1}", Delete); + METHOD_ADD(Models::DeleteModel, "/{1}", Options, Delete); METHOD_ADD(Models::StartModel, "/start", Post); METHOD_ADD(Models::StopModel, "/stop", Post); METHOD_ADD(Models::GetModelStatus, "/status/{1}", Get); ADD_METHOD_TO(Models::PullModel, "/v1/models/pull", Post); - ADD_METHOD_TO(Models::AbortPullModel, "/v1/models/pull", Delete); + ADD_METHOD_TO(Models::AbortPullModel, "/v1/models/pull", Options, Delete); ADD_METHOD_TO(Models::ListModel, "/v1/models", Get); ADD_METHOD_TO(Models::GetModel, "/v1/models/{1}", Get); ADD_METHOD_TO(Models::UpdateModel, "/v1/models/{1}", Patch); ADD_METHOD_TO(Models::ImportModel, "/v1/models/import", Post); - ADD_METHOD_TO(Models::DeleteModel, "/v1/models/{1}", Delete); + ADD_METHOD_TO(Models::DeleteModel, "/v1/models/{1}", Options, Delete); ADD_METHOD_TO(Models::StartModel, "/v1/models/start", Post); ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Post); ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get); @@ -40,8 +40,9 @@ class Models : public drogon::HttpController { void PullModel(const HttpRequestPtr& req, std::function&& callback); - void GetModelPullInfo(const HttpRequestPtr& req, - std::function&& callback) const; + void GetModelPullInfo( + const HttpRequestPtr& req, + std::function&& callback) const; void AbortPullModel(const HttpRequestPtr& req, std::function&& callback); void ListModel(const HttpRequestPtr& req, diff --git a/engine/controllers/process_manager.h b/engine/controllers/process_manager.h index f62c51abb..bded7b103 100644 --- a/engine/controllers/process_manager.h +++ b/engine/controllers/process_manager.h @@ -8,7 +8,7 @@ using namespace drogon; class ProcessManager : public drogon::HttpController { public: METHOD_LIST_BEGIN - METHOD_ADD(ProcessManager::destroy, "/destroy", Delete); + METHOD_ADD(ProcessManager::destroy, "/destroy", Options, Delete); METHOD_LIST_END void destroy(const HttpRequestPtr& req, diff --git a/engine/main.cc b/engine/main.cc index f8c20410f..257b36b24 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -126,6 +126,12 @@ void RunServer(std::optional port) { drogon::app().setThreadNum(drogon_thread_num); LOG_INFO << "Number of thread is:" << drogon::app().getThreadNum(); drogon::app().disableSigtermHandling(); + drogon::app().registerPostHandlingAdvice( + [](const drogon::HttpRequestPtr& req, + const drogon::HttpResponsePtr& resp) { + resp->addHeader("Access-Control-Allow-Origin", "*"); + resp->addHeader("Access-Control-Allow-Methods", "*"); + }); drogon::app().run(); } diff --git a/engine/test/components/test_cortex_utils.cc b/engine/test/components/test_cortex_utils.cc deleted file mode 100644 index 2d85f6909..000000000 --- a/engine/test/components/test_cortex_utils.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "gtest/gtest.h" -#include "utils/cortex_utils.h" - -class NitroUtilTest : public ::testing::Test { -}; - -TEST_F(NitroUtilTest, left_trim) { - { - std::string empty; - cortex_utils::ltrim(empty); - EXPECT_EQ(empty, ""); - } - - { - std::string s = "abc"; - std::string expected = "abc"; - cortex_utils::ltrim(s); - EXPECT_EQ(s, expected); - } - - { - std::string s = " abc"; - std::string expected = "abc"; - cortex_utils::ltrim(s); - EXPECT_EQ(s, expected); - } - - { - std::string s = "1 abc 2 "; - std::string expected = "1 abc 2 "; - cortex_utils::ltrim(s); - EXPECT_EQ(s, expected); - } - - { - std::string s = " |abc"; - std::string expected = "|abc"; - cortex_utils::ltrim(s); - EXPECT_EQ(s, expected); - } -} diff --git a/engine/test/components/test_models_db.cc b/engine/test/components/test_models_db.cc index 20726bfbc..ef54fe7e0 100644 --- a/engine/test/components/test_models_db.cc +++ b/engine/test/components/test_models_db.cc @@ -1,5 +1,3 @@ -#include -#include #include "database/models.h" #include "gtest/gtest.h" #include "utils/file_manager_utils.h" @@ -164,4 +162,4 @@ TEST_F(ModelsTestSuite, TestHasModel) { // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); } -} // namespace cortex::db \ No newline at end of file +} // namespace cortex::db diff --git a/engine/utils/cortex_utils.h b/engine/utils/cortex_utils.h index f0c2a5c1b..5e62661ba 100644 --- a/engine/utils/cortex_utils.h +++ b/engine/utils/cortex_utils.h @@ -27,7 +27,6 @@ #endif namespace cortex_utils { -inline std::string models_folder = "./models"; inline std::string logs_folder = "./logs"; inline std::string logs_base_name = "./logs/cortex.log"; inline std::string logs_cli_base_name = "./logs/cortex-cli.log"; @@ -99,25 +98,6 @@ inline std::string imageToBase64(const std::string& imagePath) { return base64Encode(buffer); } -// Helper function to generate a unique filename -inline std::string generateUniqueFilename(const std::string& prefix, - const std::string& extension) { - // Get current time as a timestamp - auto now = std::chrono::system_clock::now(); - auto now_ms = std::chrono::time_point_cast(now); - auto epoch = now_ms.time_since_epoch(); - auto value = std::chrono::duration_cast(epoch); - - // Generate a random number - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(1000, 9999); - - std::stringstream ss; - ss << prefix << value.count() << "_" << dis(gen) << extension; - return ss.str(); -} - inline void processLocalImage( const std::string& localPath, std::function callback) { @@ -163,11 +143,6 @@ inline std::vector listFilesInDir(const std::string& path) { return files; } -inline std::string rtrim(const std::string& str) { - size_t end = str.find_last_not_of("\n\t "); - return (end == std::string::npos) ? "" : str.substr(0, end + 1); -} - inline std::string generate_random_string(std::size_t length) { const std::string characters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; @@ -259,43 +234,21 @@ inline void nitro_logo() { } inline drogon::HttpResponsePtr CreateCortexHttpResponse() { - auto resp = drogon::HttpResponse::newHttpResponse(); -#ifdef ALLOW_ALL_CORS - LOG_INFO << "Respond for all cors!"; - resp->addHeader("Access-Control-Allow-Origin", "*"); -#endif - return resp; + return drogon::HttpResponse::newHttpResponse(); } inline drogon::HttpResponsePtr CreateCortexHttpJsonResponse( const Json::Value& data) { - auto resp = drogon::HttpResponse::newHttpJsonResponse(data); -#ifdef ALLOW_ALL_CORS - LOG_INFO << "Respond for all cors!"; - resp->addHeader("Access-Control-Allow-Origin", "*"); -#endif - // Drogon already set the content-type header to "application/json" - return resp; + return drogon::HttpResponse::newHttpJsonResponse(data); }; inline drogon::HttpResponsePtr CreateCortexStreamResponse( const std::function& callback, const std::string& attachmentFileName = "") { - auto resp = drogon::HttpResponse::newStreamResponse( + return drogon::HttpResponse::newStreamResponse( callback, attachmentFileName, drogon::CT_NONE, "text/event-stream"); -#ifdef ALLOW_ALL_CORS - LOG_INFO << "Respond for all cors!"; - resp->addHeader("Access-Control-Allow-Origin", "*"); -#endif - return resp; } -inline void ltrim(std::string& s) { - s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { - return !std::isspace(ch); - })); -}; - #if defined(_WIN32) inline std::string GetCurrentPath() { wchar_t path[MAX_PATH]; From c047ef838ef840cf66e7a0bfb3c22c9c203dc39c Mon Sep 17 00:00:00 2001 From: James Date: Wed, 6 Nov 2024 16:02:02 +0700 Subject: [PATCH 2/7] feat: add api for configuration --- engine/common/api_server_configuration.h | 78 +++++++++++ engine/controllers/configs.cc | 53 +++++++ engine/controllers/configs.h | 34 +++++ engine/controllers/prelight.cc | 13 -- engine/controllers/prelight.h | 18 --- engine/main.cc | 32 ++++- engine/services/config_service.cc | 33 +++++ engine/services/config_service.h | 13 ++ .../test_api_server_configuration.cc | 132 ++++++++++++++++++ engine/utils/config_yaml_utils.h | 17 ++- engine/utils/file_manager_utils.h | 3 +- 11 files changed, 387 insertions(+), 39 deletions(-) create mode 100644 engine/common/api_server_configuration.h create mode 100644 engine/controllers/configs.cc create mode 100644 engine/controllers/configs.h delete mode 100644 engine/controllers/prelight.cc delete mode 100644 engine/controllers/prelight.h create mode 100644 engine/services/config_service.cc create mode 100644 engine/services/config_service.h create mode 100644 engine/test/components/test_api_server_configuration.cc diff --git a/engine/common/api_server_configuration.h b/engine/common/api_server_configuration.h new file mode 100644 index 000000000..72d0aeedf --- /dev/null +++ b/engine/common/api_server_configuration.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include + +class ApiServerConfiguration { + public: + ApiServerConfiguration(bool cors = true, + std::vector allowed_origins = {}) + : cors{cors}, allowed_origins{allowed_origins} {} + + bool cors{true}; + std::vector allowed_origins; + + Json::Value ToJson() const { + Json::Value root; + root["cors"] = cors; + root["allowed_origins"] = Json::Value(Json::arrayValue); + for (const auto& origin : allowed_origins) { + root["allowed_origins"].append(origin); + } + return root; + } + + void UpdateFromJson(const Json::Value& json, + std::vector* updated_fields = nullptr, + std::vector* invalid_fields = nullptr, + std::vector* unknown_fields = nullptr) { + const std::unordered_map> + field_updater{ + {"cors", + [this](const Json::Value& value) -> bool { + if (!value.isBool()) { + return false; + } + cors = value.asBool(); + return true; + }}, + + {"allowed_origins", [this](const Json::Value& value) -> bool { + if (!value.isArray()) { + return false; + } + for (const auto& origin : value) { + if (!origin.isString()) { + return false; + } + } + + this->allowed_origins.clear(); + for (const auto& origin : value) { + this->allowed_origins.push_back(origin.asString()); + } + return true; + }}}; + + for (const auto& key : json.getMemberNames()) { + auto updater = field_updater.find(key); + if (updater != field_updater.end()) { + if (updater->second(json[key])) { + if (updated_fields != nullptr) { + updated_fields->push_back(key); + } + } else { + if (invalid_fields != nullptr) { + invalid_fields->push_back(key); + } + } + } else { + if (unknown_fields != nullptr) { + unknown_fields->push_back(key); + } + } + } + }; +}; diff --git a/engine/controllers/configs.cc b/engine/controllers/configs.cc new file mode 100644 index 000000000..630e6e65e --- /dev/null +++ b/engine/controllers/configs.cc @@ -0,0 +1,53 @@ +#include "configs.h" + +void Configs::GetConfigurations( + const HttpRequestPtr& req, + std::function&& callback) const { + auto get_config_result = config_service_->GetApiServerConfiguration(); + if (get_config_result.has_error()) { + Json::Value error_json; + error_json["error"] = get_config_result.error(); + auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + + auto resp = drogon::HttpResponse::newHttpJsonResponse( + get_config_result.value().ToJson()); + resp->setStatusCode(drogon::k200OK); + callback(resp); + return; +} + +void Configs::UpdateConfigurations( + const HttpRequestPtr& req, + std::function&& callback) { + auto json_body = req->getJsonObject(); + if (!json_body) { + Json::Value error_json; + error_json["error"] = "Configuration must be provided via JSON body"; + auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + auto update_config_result = + config_service_->UpdateApiServerConfiguration(*json_body); + if (update_config_result.has_error()) { + Json::Value error_json; + error_json["error"] = update_config_result.error(); + auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + + Json::Value root; + root["message"] = "Configuration updated successfully"; + root["config"] = update_config_result.value().ToJson(); + auto resp = drogon::HttpResponse::newHttpJsonResponse(root); + resp->setStatusCode(drogon::k200OK); + callback(resp); + return; +} diff --git a/engine/controllers/configs.h b/engine/controllers/configs.h new file mode 100644 index 000000000..48f277f46 --- /dev/null +++ b/engine/controllers/configs.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include "services/config_service.h" + +using namespace drogon; + +class Configs : public drogon::HttpController { + public: + METHOD_LIST_BEGIN + + METHOD_ADD(Configs::GetConfigurations, "", Get); + METHOD_ADD(Configs::UpdateConfigurations, "", Patch); + + ADD_METHOD_TO(Configs::GetConfigurations, "/v1/configs", Get); + ADD_METHOD_TO(Configs::UpdateConfigurations, "/v1/configs", Patch); + + METHOD_LIST_END + + explicit Configs(std::shared_ptr config_service) + : config_service_{config_service} {} + + void GetConfigurations( + const HttpRequestPtr& req, + std::function&& callback) const; + + void UpdateConfigurations( + const HttpRequestPtr& req, + std::function&& callback); + + private: + std::shared_ptr config_service_; +}; diff --git a/engine/controllers/prelight.cc b/engine/controllers/prelight.cc deleted file mode 100644 index 9c4c63095..000000000 --- a/engine/controllers/prelight.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "prelight.h" - -void prelight::handlePrelight( - const HttpRequestPtr &req, - std::function &&callback) { - auto resp = drogon::HttpResponse::newHttpResponse(); - resp->setStatusCode(drogon::HttpStatusCode::k200OK); - resp->addHeader("Access-Control-Allow-Origin", "*"); - resp->addHeader("Access-Control-Allow-Methods", "POST, OPTIONS"); - resp->addHeader("Access-Control-Allow-Headers", "*"); - callback(resp); -} - diff --git a/engine/controllers/prelight.h b/engine/controllers/prelight.h deleted file mode 100644 index 387f5a51b..000000000 --- a/engine/controllers/prelight.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include - -using namespace drogon; - -class prelight : public drogon::HttpController { -public: - METHOD_LIST_BEGIN - ADD_METHOD_TO(prelight::handlePrelight, "/v1/chat/completions", Options); - ADD_METHOD_TO(prelight::handlePrelight, "/v1/embeddings", Options); - ADD_METHOD_TO(prelight::handlePrelight, "/v1/audio/transcriptions", Options); - ADD_METHOD_TO(prelight::handlePrelight, "/v1/audio/translations", Options); - METHOD_LIST_END - - void handlePrelight(const HttpRequestPtr &req, - std::function &&callback); -}; diff --git a/engine/main.cc b/engine/main.cc index 257b36b24..6053fa91c 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -1,12 +1,14 @@ #include #include #include +#include "controllers/configs.h" #include "controllers/engines.h" #include "controllers/events.h" #include "controllers/models.h" #include "controllers/process_manager.h" #include "controllers/server.h" #include "cortex-common/cortexpythoni.h" +#include "services/config_service.h" #include "services/model_service.h" #include "utils/archive_utils.h" #include "utils/cortex_utils.h" @@ -100,6 +102,7 @@ void RunServer(std::optional port) { std::make_shared(engine_service); auto model_service = std::make_shared(download_service, inference_svc); + auto config_service = std::make_shared(); // initialize custom controllers auto engine_ctl = std::make_shared(engine_service); @@ -108,12 +111,14 @@ void RunServer(std::optional port) { auto pm_ctl = std::make_shared(); auto server_ctl = std::make_shared(inference_svc, engine_service); + auto config_ctl = std::make_shared(config_service); drogon::app().registerController(engine_ctl); drogon::app().registerController(model_ctl); drogon::app().registerController(event_ctl); drogon::app().registerController(pm_ctl); drogon::app().registerController(server_ctl); + drogon::app().registerController(config_ctl); auto upload_path = std::filesystem::temp_directory_path() / "cortex-uploads"; drogon::app().setUploadPath(upload_path.string()); @@ -126,11 +131,30 @@ void RunServer(std::optional port) { drogon::app().setThreadNum(drogon_thread_num); LOG_INFO << "Number of thread is:" << drogon::app().getThreadNum(); drogon::app().disableSigtermHandling(); + + // CORS drogon::app().registerPostHandlingAdvice( - [](const drogon::HttpRequestPtr& req, - const drogon::HttpResponsePtr& resp) { - resp->addHeader("Access-Control-Allow-Origin", "*"); - resp->addHeader("Access-Control-Allow-Methods", "*"); + [config_service](const drogon::HttpRequestPtr& req, + const drogon::HttpResponsePtr& resp) { + if (!config_service->GetApiServerConfiguration()->cors) { + CTL_INF("CORS is disabled!"); + return; + } + + const std::string& origin = req->getHeader("Origin"); + CTL_INF("Origin: " << origin); + + auto allowed_origins = + config_service->GetApiServerConfiguration()->allowed_origins; + + // Check if the origin is in our allowed list + auto it = + std::find(allowed_origins.begin(), allowed_origins.end(), origin); + if (it != allowed_origins.end()) { + resp->addHeader("Access-Control-Allow-Origin", origin); + } else if (allowed_origins.empty()) { + resp->addHeader("Access-Control-Allow-Origin", "*"); + } }); drogon::app().run(); diff --git a/engine/services/config_service.cc b/engine/services/config_service.cc new file mode 100644 index 000000000..f0a36a430 --- /dev/null +++ b/engine/services/config_service.cc @@ -0,0 +1,33 @@ +#include "config_service.h" +#include "common/api_server_configuration.h" +#include "utils/file_manager_utils.h" + +cpp::result +ConfigService::UpdateApiServerConfiguration(const Json::Value& json) { + auto config = file_manager_utils::GetCortexConfig(); + ApiServerConfiguration api_server_config{config.enableCors, + config.allowedOrigins}; + std::cout << json.toStyledString() << std::endl; + std::vector updated_fields; + std::vector invalid_fields; + std::vector unknown_fields; + + api_server_config.UpdateFromJson(json, &updated_fields, &invalid_fields, + &unknown_fields); + + if (updated_fields.empty()) { + return cpp::fail("No configuration updated"); + } + + config.enableCors = api_server_config.cors; + config.allowedOrigins = api_server_config.allowed_origins; + + auto result = file_manager_utils::UpdateCortexConfig(config); + return api_server_config; +} + +cpp::result +ConfigService::GetApiServerConfiguration() const { + auto config = file_manager_utils::GetCortexConfig(); + return ApiServerConfiguration{config.enableCors, config.allowedOrigins}; +} diff --git a/engine/services/config_service.h b/engine/services/config_service.h new file mode 100644 index 000000000..3e9b9ec6e --- /dev/null +++ b/engine/services/config_service.h @@ -0,0 +1,13 @@ +#pragma once + +#include "common/api_server_configuration.h" +#include "utils/result.hpp" + +class ConfigService { + public: + cpp::result UpdateApiServerConfiguration( + const Json::Value& json); + + cpp::result GetApiServerConfiguration() + const; +}; diff --git a/engine/test/components/test_api_server_configuration.cc b/engine/test/components/test_api_server_configuration.cc new file mode 100644 index 000000000..97bc3a253 --- /dev/null +++ b/engine/test/components/test_api_server_configuration.cc @@ -0,0 +1,132 @@ +#include +#include +#include "common/api_server_configuration.h" + +class ApiServerConfigurationTest : public ::testing::Test { + protected: + ApiServerConfiguration config; + + // Helper to create JSON from string + Json::Value parseJson(const std::string& jsonStr) { + Json::Value root; + Json::Reader reader; + reader.parse(jsonStr, root); + return root; + } +}; + +// Test default values +TEST_F(ApiServerConfigurationTest, DefaultValues) { + EXPECT_TRUE(config.cors); + EXPECT_TRUE(config.allowed_origins.empty()); +} + +// Test CORS update +TEST_F(ApiServerConfigurationTest, UpdateCors) { + auto json = parseJson(R"({"cors": false})"); + std::vector updated_fields; + config.UpdateFromJson(json, &updated_fields); + + EXPECT_FALSE(config.cors); + ASSERT_EQ(updated_fields.size(), 1); + EXPECT_EQ(updated_fields[0], "cors"); +} + +// Test allowed origins update +TEST_F(ApiServerConfigurationTest, UpdateAllowedOrigins) { + auto json = parseJson(R"({ + "allowed_origins": ["https://example.com", "https://test.com"] + })"); + std::vector updated_fields; + config.UpdateFromJson(json, &updated_fields); + + ASSERT_EQ(config.allowed_origins.size(), 2); + EXPECT_EQ(config.allowed_origins[0], "https://example.com"); + EXPECT_EQ(config.allowed_origins[1], "https://test.com"); + ASSERT_EQ(updated_fields.size(), 1); + EXPECT_EQ(updated_fields[0], "allowed_origins"); +} + +// Test multiple field updates +TEST_F(ApiServerConfigurationTest, UpdateMultipleFields) { + auto json = parseJson(R"({ + "cors": false, + "allowed_origins": ["https://example.com"] + })"); + std::vector updated_fields; + config.UpdateFromJson(json, &updated_fields); + + EXPECT_FALSE(config.cors); + ASSERT_EQ(config.allowed_origins.size(), 1); + EXPECT_EQ(config.allowed_origins[0], "https://example.com"); + ASSERT_EQ(updated_fields.size(), 2); +} + +// Test unknown fields +TEST_F(ApiServerConfigurationTest, UnknownFields) { + auto json = parseJson(R"({ + "cors": false, + "unknown_field": "value" + })"); + std::vector updated_fields; + std::vector unknown_fields; + config.UpdateFromJson(json, &updated_fields, nullptr, &unknown_fields); + + EXPECT_FALSE(config.cors); + ASSERT_EQ(updated_fields.size(), 1); + ASSERT_EQ(unknown_fields.size(), 1); + EXPECT_EQ(unknown_fields[0], "unknown_field"); +} + +// Test invalid field types +TEST_F(ApiServerConfigurationTest, InvalidFieldTypes) { + auto json = parseJson(R"({ + "cors": "invalid_bool", + "allowed_origins": "invalid_array" + })"); + std::vector updated_fields; + std::vector invalid_fields; + config.UpdateFromJson(json, &updated_fields, &invalid_fields); + + for (const auto& field : updated_fields) { + std::cout << field << std::endl; + } + + EXPECT_TRUE(config.cors); // Should retain default value + EXPECT_TRUE(config.allowed_origins.empty()); // Should retain default value + EXPECT_TRUE(updated_fields.empty()); +} + +// Test empty update +TEST_F(ApiServerConfigurationTest, EmptyUpdate) { + auto json = parseJson("{}"); + std::vector updated_fields; + config.UpdateFromJson(json, &updated_fields); + + EXPECT_TRUE(config.cors); // Should retain default value + EXPECT_TRUE(config.allowed_origins.empty()); + EXPECT_TRUE(updated_fields.empty()); +} + +// Test allowed_origins with invalid array elements +TEST_F(ApiServerConfigurationTest, InvalidArrayElements) { + auto json = parseJson(R"({ + "allowed_origins": ["valid", 123, true, "also_valid"] + })"); + std::vector updated_fields; + config.UpdateFromJson(json, &updated_fields); + + ASSERT_EQ(updated_fields.size(), 0); + ASSERT_EQ(config.allowed_origins.size(), 0); +} + +// Test nullopt parameters +TEST_F(ApiServerConfigurationTest, NulloptParameters) { + auto json = parseJson(R"({ + "cors": false, + "unknown_field": "value" + })"); + config.UpdateFromJson(json); // Should not crash + + EXPECT_FALSE(config.cors); +} diff --git a/engine/utils/config_yaml_utils.h b/engine/utils/config_yaml_utils.h index 87a114d25..ae66ce171 100644 --- a/engine/utils/config_yaml_utils.h +++ b/engine/utils/config_yaml_utils.h @@ -29,6 +29,9 @@ struct CortexConfig { std::string gitHubToken; std::string llamacppVariant; std::string llamacppVersion; + + bool enableCors; + std::vector allowedOrigins; }; const std::string kDefaultHost{"127.0.0.1"}; @@ -36,6 +39,8 @@ const std::string kDefaultPort{"39281"}; const int kDefaultMaxLines{100000}; constexpr const uint64_t kDefaultCheckedForUpdateAt = 0u; constexpr const auto kDefaultLatestRelease = "default_version"; +constexpr const auto kDefaultCorsEnabled = true; +const std::vector kDefaultEnabledOrigins{}; inline cpp::result DumpYamlConfig(const CortexConfig& config, const std::string& path) { @@ -62,6 +67,8 @@ inline cpp::result DumpYamlConfig(const CortexConfig& config, node["gitHubToken"] = config.gitHubToken; node["llamacppVariant"] = config.llamacppVariant; node["llamacppVersion"] = config.llamacppVersion; + node["enableCors"] = config.enableCors; + node["allowedOrigins"] = config.allowedOrigins; out_file << node; out_file.close(); @@ -89,7 +96,8 @@ inline CortexConfig FromYaml(const std::string& path, !node["logOnnxPath"] || !node["logTensorrtLLMPath"] || !node["huggingFaceToken"] || !node["gitHubUserAgent"] || !node["gitHubToken"] || !node["llamacppVariant"] || - !node["llamacppVersion"]); + !node["llamacppVersion"] || !node["enableCors"] || + !node["allowedOrigins"]); CortexConfig config = { .logFolderPath = node["logFolderPath"] @@ -135,7 +143,11 @@ inline CortexConfig FromYaml(const std::string& path, .llamacppVersion = node["llamacppVersion"] ? node["llamacppVersion"].as() : "", - }; + .enableCors = node["enableCors"] ? node["enableCors"].as() : true, + .allowedOrigins = + node["allowedOrigins"] + ? node["allowedOrigins"].as>() + : std::vector{}}; if (should_update_config) { auto result = DumpYamlConfig(config, path); if (result.has_error()) { @@ -148,5 +160,4 @@ inline CortexConfig FromYaml(const std::string& path, throw; } } - } // namespace config_yaml_utils diff --git a/engine/utils/file_manager_utils.h b/engine/utils/file_manager_utils.h index b6d1f1c5a..e060eed8a 100644 --- a/engine/utils/file_manager_utils.h +++ b/engine/utils/file_manager_utils.h @@ -196,7 +196,8 @@ inline config_yaml_utils::CortexConfig GetCortexConfig() { .apiServerPort = config_yaml_utils::kDefaultPort, .checkedForUpdateAt = config_yaml_utils::kDefaultCheckedForUpdateAt, .latestRelease = config_yaml_utils::kDefaultLatestRelease, - }; + .enableCors = config_yaml_utils::kDefaultCorsEnabled, + .allowedOrigins = config_yaml_utils::kDefaultEnabledOrigins}; return config_yaml_utils::FromYaml(config_path.string(), default_cfg); } From aec5351c91f65aad6f925fb5be7d9ebfb7ea16ba Mon Sep 17 00:00:00 2001 From: James Date: Wed, 6 Nov 2024 16:08:19 +0700 Subject: [PATCH 3/7] remove log --- engine/services/config_service.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/engine/services/config_service.cc b/engine/services/config_service.cc index f0a36a430..9f1589887 100644 --- a/engine/services/config_service.cc +++ b/engine/services/config_service.cc @@ -7,7 +7,6 @@ ConfigService::UpdateApiServerConfiguration(const Json::Value& json) { auto config = file_manager_utils::GetCortexConfig(); ApiServerConfiguration api_server_config{config.enableCors, config.allowedOrigins}; - std::cout << json.toStyledString() << std::endl; std::vector updated_fields; std::vector invalid_fields; std::vector unknown_fields; From cdfda3a43effec3aab4920cee1d1e283f43db139 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 6 Nov 2024 16:35:13 +0700 Subject: [PATCH 4/7] chore: update API references --- docs/static/openapi/cortex.json | 136 +++++++++++++++++++++++++++++--- 1 file changed, 127 insertions(+), 9 deletions(-) diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 84bb7efed..763337b5c 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -1142,7 +1142,8 @@ "required": true, "schema": { "type": "string", - "enum": ["onnxruntime", "llama-cpp", "tensorrt-llm"] + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "default": "llama-cpp" }, "description": "The type of engine" } @@ -1200,7 +1201,8 @@ "required": true, "schema": { "type": "string", - "enum": ["onnxruntime", "llama-cpp", "tensorrt-llm"] + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "default": "llama-cpp" }, "description": "The type of engine" }, @@ -1245,7 +1247,8 @@ "required": true, "schema": { "type": "string", - "enum": ["onnxruntime", "llama-cpp", "tensorrt-llm"] + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "default": "llama-cpp" }, "description": "The type of engine" }, @@ -1335,7 +1338,8 @@ "required": true, "schema": { "type": "string", - "enum": ["onnxruntime", "llama-cpp", "tensorrt-llm"] + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "default": "llama-cpp" }, "description": "The type of engine" } @@ -1378,7 +1382,8 @@ "required": true, "schema": { "type": "string", - "enum": ["onnxruntime", "llama-cpp", "tensorrt-llm"] + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "default": "llama-cpp" }, "description": "The type of engine" }, @@ -1433,7 +1438,8 @@ "required": true, "schema": { "type": "string", - "enum": ["onnxruntime", "llama-cpp", "tensorrt-llm"] + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "default": "llama-cpp" }, "description": "The name of the engine to update" } @@ -1468,7 +1474,8 @@ "required": true, "schema": { "type": "string", - "enum": ["onnxruntime", "llama-cpp", "tensorrt-llm"] + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "default": "llama-cpp" }, "description": "The name of the engine to update" } @@ -1505,7 +1512,8 @@ "required": true, "schema": { "type": "string", - "enum": ["onnxruntime", "llama-cpp", "tensorrt-llm"] + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "default": "llama-cpp" }, "description": "The name of the engine to update" } @@ -1530,6 +1538,111 @@ }, "tags": ["Engines"] } + }, + "/v1/configs": { + "get": { + "summary": "Get Configurations", + "description": "Retrieves the current configuration settings of the Cortex server.", + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "allowed_origins": { + "type": "array", + "items": { + "type": "string" + }, + "example": ["http://localhost:39281", "https://cortex.so"] + }, + "cors": { + "type": "boolean", + "example": false + } + } + }, + "example": { + "allowed_origins": [ + "http://localhost:39281", + "https://cortex.so" + ], + "cors": false + } + } + } + } + }, + "tags": ["Configurations"] + }, + "patch": { + "tags": ["Configurations"], + "summary": "Update configuration settings", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "cors": { + "type": "boolean", + "description": "Indicates whether CORS is enabled.", + "example": false + }, + "allowed_origins": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of allowed origins.", + "example": ["http://localhost:39281", "https://cortex.so"] + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Configuration updated successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": { + "allowed_origins": { + "type": "array", + "items": { + "type": "string" + }, + "example": [ + "http://localhost:39281", + "https://cortex.so" + ] + }, + "cors": { + "type": "boolean", + "example": false + } + } + }, + "message": { + "type": "string", + "example": "Configuration updated successfully" + } + } + } + } + } + } + } + } } }, "info": { @@ -1559,6 +1672,10 @@ "name": "Server", "description": "These endpoints manage the lifecycle of Server, including heath check and shutdown." }, + { + "name": "Configuration", + "description": "These endpoints manage the configuration of the Cortex server." + }, { "name": "Messages", "description": "These endpoints manage the retrieval and storage of conversation content, including responses from LLMs and other metadata related to chat interactions." @@ -1587,7 +1704,8 @@ "Running Models", "Processes", "Status", - "Server" + "Server", + "Configurations" ] } ], From 01a5d76e4c1f48a1db023b5bf6767e7ae5b6fbf7 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 6 Nov 2024 23:16:09 +0700 Subject: [PATCH 5/7] add: cli API --- engine/cli/command_line_parser.cc | 64 +++++++++++++++++- engine/cli/command_line_parser.h | 4 ++ engine/cli/commands/config_get_cmd.cc | 46 +++++++++++++ engine/cli/commands/config_get_cmd.h | 10 +++ engine/cli/commands/config_upd_cmd.cc | 73 ++++++++++++++++++++ engine/cli/commands/config_upd_cmd.h | 12 ++++ engine/controllers/configs.cc | 8 +-- engine/test/components/test_string_utils.cc | 20 +++--- engine/utils/curl_utils.h | 75 +++++++++++++++++++++ engine/utils/string_utils.h | 2 +- 10 files changed, 300 insertions(+), 14 deletions(-) create mode 100644 engine/cli/commands/config_get_cmd.cc create mode 100644 engine/cli/commands/config_get_cmd.h create mode 100644 engine/cli/commands/config_upd_cmd.cc create mode 100644 engine/cli/commands/config_upd_cmd.h diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 625750248..5ce54c532 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -2,6 +2,9 @@ #include #include #include +#include +#include "commands/config_get_cmd.h" +#include "commands/config_upd_cmd.h" #include "commands/cortex_upd_cmd.h" #include "commands/engine_get_cmd.h" #include "commands/engine_install_cmd.h" @@ -31,6 +34,7 @@ constexpr const auto kInferenceGroup = "Inference"; constexpr const auto kModelsGroup = "Models"; constexpr const auto kEngineGroup = "Engines"; constexpr const auto kSystemGroup = "Server"; +constexpr const auto kConfigGroup = "Configurations"; constexpr const auto kSubcommands = "Subcommands"; } // namespace @@ -57,6 +61,8 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) { SetupSystemCommands(); + SetupConfigsCommands(); + app_.add_flag("--verbose", log_verbose, "Get verbose logs"); // Logic is handled in main.cc, just for cli helper command @@ -301,6 +307,62 @@ void CommandLineParser::SetupModelCommands() { }); } +void CommandLineParser::SetupConfigsCommands() { + auto config_cmd = + app_.add_subcommand("config", "Subcommands for managing configurations"); + config_cmd->usage( + "Usage:\n" + commands::GetCortexBinary() + + " config status for listing all API server configuration.\n" + + commands::GetCortexBinary() + + " config --cors [on/off] to toggle CORS.\n" + + commands::GetCortexBinary() + + " config --allowed_origins [comma separated origin] to set a list of " + "allowed origin"); + config_cmd->group(kConfigGroup); + auto config_status_cmd = + config_cmd->add_subcommand("status", "Print all configurations"); + config_status_cmd->callback([this] { + if (std::exchange(executed_, true)) + return; + commands::ConfigGetCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort)); + }); + + // TODO: this can be improved + std::vector avai_opts{"cors", "allowed_origins"}; + std::unordered_map description{ + {"cors", "[on/off] Toggling CORS."}, + {"allowed_origins", + "Allowed origins for CORS. Comma separated. E.g. " + "http://localhost,https://cortex.so"}}; + for (const auto& opt : avai_opts) { + std::string option = "--" + opt; + config_cmd->add_option(option, config_update_opts_[opt], description[opt]) + ->expected(0, 1) + ->default_str("*"); + } + + config_cmd->callback([this, config_cmd] { + if (std::exchange(executed_, true)) + return; + + auto is_empty = true; + for (const auto& [key, value] : config_update_opts_) { + if (!value.empty()) { + is_empty = false; + break; + } + } + if (is_empty) { + CLI_LOG(config_cmd->help()); + return; + } + commands::ConfigUpdCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), + config_update_opts_); + }); +} + void CommandLineParser::SetupEngineCommands() { auto engines_cmd = app_.add_subcommand("engines", "Subcommands for managing engines"); @@ -339,7 +401,7 @@ void CommandLineParser::SetupEngineCommands() { CLI_LOG(install_cmd->help()); } }); - for (auto& engine : engine_service_.kSupportEngines) { + for (const auto& engine : engine_service_.kSupportEngines) { std::string engine_name{engine}; EngineInstall(install_cmd, engine_name, cml_data_.engine_version, cml_data_.engine_src); diff --git a/engine/cli/command_line_parser.h b/engine/cli/command_line_parser.h index 9f3cdda12..de51ba212 100644 --- a/engine/cli/command_line_parser.h +++ b/engine/cli/command_line_parser.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include "CLI/CLI.hpp" #include "services/engine_service.h" #include "services/model_service.h" @@ -22,6 +23,8 @@ class CommandLineParser { void SetupSystemCommands(); + void SetupConfigsCommands(); + void EngineInstall(CLI::App* parent, const std::string& engine_name, std::string& version, std::string& src); @@ -62,5 +65,6 @@ class CommandLineParser { std::unordered_map model_update_options; }; CmlData cml_data_; + std::unordered_map config_update_opts_; bool executed_ = false; }; diff --git a/engine/cli/commands/config_get_cmd.cc b/engine/cli/commands/config_get_cmd.cc new file mode 100644 index 000000000..aa2e059d7 --- /dev/null +++ b/engine/cli/commands/config_get_cmd.cc @@ -0,0 +1,46 @@ +#include "config_get_cmd.h" +#include +#include "commands/server_start_cmd.h" +#include "utils/curl_utils.h" +#include "utils/logging_utils.h" +#include "utils/url_parser.h" + +void commands::ConfigGetCmd::Exec(const std::string& host, int port) { + // Start server if server is not started yet + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return; + } + } + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "configs"}, + }; + + auto get_config_result = curl_utils::SimpleGetJson(url.ToFullPath()); + if (get_config_result.has_error()) { + CLI_LOG_ERROR( + "Failed to get configurations: " << get_config_result.error()); + return; + } + + auto json_value = get_config_result.value(); + tabulate::Table table; + table.add_row({"Config name", "Value"}); + + for (const auto& key : json_value.getMemberNames()) { + if (json_value[key].isArray()) { + for (const auto& value : json_value[key]) { + table.add_row({key, value.asString()}); + } + } else { + table.add_row({key, json_value[key].asString()}); + } + } + + std::cout << table << std::endl; + return; +} diff --git a/engine/cli/commands/config_get_cmd.h b/engine/cli/commands/config_get_cmd.h new file mode 100644 index 000000000..9431dc100 --- /dev/null +++ b/engine/cli/commands/config_get_cmd.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace commands { +class ConfigGetCmd { + public: + void Exec(const std::string& host, int port); +}; +} // namespace commands diff --git a/engine/cli/commands/config_upd_cmd.cc b/engine/cli/commands/config_upd_cmd.cc new file mode 100644 index 000000000..ebf86fcea --- /dev/null +++ b/engine/cli/commands/config_upd_cmd.cc @@ -0,0 +1,73 @@ +#include "config_upd_cmd.h" +#include "commands/server_start_cmd.h" +#include "utils/curl_utils.h" +#include "utils/logging_utils.h" +#include "utils/string_utils.h" +#include "utils/url_parser.h" + +namespace { +const std::vector config_keys{"cors", "allowed_origins"}; + +inline Json::Value NormalizeJson( + const std::unordered_map options) { + Json::Value root; + for (const auto& [key, value] : options) { + if (std::find(config_keys.begin(), config_keys.end(), key) == + config_keys.end()) { + continue; + } + + if (key == "cors") { + if (string_utils::EqualsIgnoreCase("on", value)) { + root["cors"] = true; + } else if (string_utils::EqualsIgnoreCase("off", value)) { + root["cors"] = false; + } + } else if (key == "allowed_origins") { + auto origins = string_utils::SplitBy(value, ","); + Json::Value origin_array(Json::arrayValue); + for (const auto& origin : origins) { + origin_array.append(origin); + } + root[key] = origin_array; + } + } + + CTL_DBG("Normalized config update request: " << root.toStyledString()); + + return root; +} +}; // namespace + +void commands::ConfigUpdCmd::Exec( + const std::string& host, int port, + const std::unordered_map& options) { + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return; + } + } + + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "configs"}, + }; + + auto json = NormalizeJson(options); + if (json.empty()) { + CLI_LOG_ERROR("Invalid configuration options provided"); + return; + } + + auto update_cnf_result = + curl_utils::SimplePatch(url.ToFullPath(), json.toStyledString()); + if (update_cnf_result.has_error()) { + CLI_LOG_ERROR(update_cnf_result.error()); + return; + } + + CLI_LOG("Configuration updated successfully!"); +} diff --git a/engine/cli/commands/config_upd_cmd.h b/engine/cli/commands/config_upd_cmd.h new file mode 100644 index 000000000..55375b3b7 --- /dev/null +++ b/engine/cli/commands/config_upd_cmd.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace commands { +class ConfigUpdCmd { + public: + void Exec(const std::string& host, int port, + const std::unordered_map& options); +}; +} // namespace commands diff --git a/engine/controllers/configs.cc b/engine/controllers/configs.cc index 630e6e65e..41b08cf45 100644 --- a/engine/controllers/configs.cc +++ b/engine/controllers/configs.cc @@ -6,7 +6,7 @@ void Configs::GetConfigurations( auto get_config_result = config_service_->GetApiServerConfiguration(); if (get_config_result.has_error()) { Json::Value error_json; - error_json["error"] = get_config_result.error(); + error_json["message"] = get_config_result.error(); auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json); resp->setStatusCode(drogon::k400BadRequest); callback(resp); @@ -24,9 +24,9 @@ void Configs::UpdateConfigurations( const HttpRequestPtr& req, std::function&& callback) { auto json_body = req->getJsonObject(); - if (!json_body) { + if (json_body == nullptr) { Json::Value error_json; - error_json["error"] = "Configuration must be provided via JSON body"; + error_json["message"] = "Configuration must be provided via JSON body"; auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json); resp->setStatusCode(drogon::k400BadRequest); callback(resp); @@ -36,7 +36,7 @@ void Configs::UpdateConfigurations( config_service_->UpdateApiServerConfiguration(*json_body); if (update_config_result.has_error()) { Json::Value error_json; - error_json["error"] = update_config_result.error(); + error_json["message"] = update_config_result.error(); auto resp = drogon::HttpResponse::newHttpJsonResponse(error_json); resp->setStatusCode(drogon::k400BadRequest); callback(resp); diff --git a/engine/test/components/test_string_utils.cc b/engine/test/components/test_string_utils.cc index 71ab78a78..c412c5ec4 100644 --- a/engine/test/components/test_string_utils.cc +++ b/engine/test/components/test_string_utils.cc @@ -23,8 +23,7 @@ TEST_F(StringUtilsTestSuite, ParsePrompt) { TEST_F(StringUtilsTestSuite, TestSplitBy) { auto input = "this is a test"; - std::string delimiter{' '}; - auto result = SplitBy(input, delimiter); + auto result = SplitBy(input, " "); EXPECT_EQ(result.size(), 4); EXPECT_EQ(result[0], "this"); @@ -35,16 +34,14 @@ TEST_F(StringUtilsTestSuite, TestSplitBy) { TEST_F(StringUtilsTestSuite, TestSplitByWithEmptyString) { auto input = ""; - std::string delimiter{' '}; - auto result = SplitBy(input, delimiter); + auto result = SplitBy(input, " "); EXPECT_EQ(result.size(), 0); } TEST_F(StringUtilsTestSuite, TestSplitModelHandle) { auto input = "cortexso/tinyllama"; - std::string delimiter{'/'}; - auto result = SplitBy(input, delimiter); + auto result = SplitBy(input, "/"); EXPECT_EQ(result.size(), 2); EXPECT_EQ(result[0], "cortexso"); @@ -53,13 +50,20 @@ TEST_F(StringUtilsTestSuite, TestSplitModelHandle) { TEST_F(StringUtilsTestSuite, TestSplitModelHandleWithEmptyModelName) { auto input = "cortexso/"; - std::string delimiter{'/'}; - auto result = SplitBy(input, delimiter); + auto result = SplitBy(input, "/"); EXPECT_EQ(result.size(), 1); EXPECT_EQ(result[0], "cortexso"); } +TEST_F(StringUtilsTestSuite, TestSplitIfNotContainDelimiter) { + auto input = "https://cortex.so"; + auto result = SplitBy(input, ","); + + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0], "https://cortex.so"); +} + TEST_F(StringUtilsTestSuite, TestStartsWith) { auto input = "this is a test"; auto prefix = "this"; diff --git a/engine/utils/curl_utils.h b/engine/utils/curl_utils.h index 88b05828a..0c31d0830 100644 --- a/engine/utils/curl_utils.h +++ b/engine/utils/curl_utils.h @@ -71,6 +71,62 @@ inline cpp::result SimpleGet(const std::string& url) { return readBuffer; } +inline cpp::result SimplePatch( + const std::string& url, const std::string& body = "") { + auto curl = curl_easy_init(); + + if (!curl) { + return cpp::fail("Failed to init CURL"); + } + + auto headers = GetHeaders(url); + curl_slist* curl_headers = nullptr; + curl_headers = + curl_slist_append(curl_headers, "Content-Type: application/json"); + + if (headers.has_value()) { + for (const auto& [key, value] : headers.value()) { + auto header = key + ": " + value; + curl_headers = curl_slist_append(curl_headers, header.c_str()); + } + } + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers); + + std::string readBuffer; + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "PATCH"); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + + // Set content length if body is not empty + if (!body.empty()) { + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, body.length()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + } + + // Perform the request + auto res = curl_easy_perform(curl); + + curl_slist_free_all(curl_headers); + curl_easy_cleanup(curl); + if (res != CURLE_OK) { + CTL_ERR("CURL request failed: " + std::string(curl_easy_strerror(res))); + return cpp::fail("CURL request failed: " + + static_cast(curl_easy_strerror(res))); + } + auto http_code = 0; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + if (http_code >= 400) { + CTL_ERR("HTTP request failed with status code: " + + std::to_string(http_code)); + return cpp::fail(readBuffer); + } + + return readBuffer; +} + inline cpp::result SimplePost( const std::string& url, const std::string& body = "") { curl_global_init(CURL_GLOBAL_DEFAULT); @@ -223,6 +279,25 @@ inline cpp::result SimplePostJson( return root; } +inline cpp::result SimplePatchJson( + const std::string& url, const std::string& body = "") { + auto result = SimplePatch(url, body); + if (result.has_error()) { + CTL_ERR("Failed to get JSON from " + url + ": " + result.error()); + return cpp::fail(result.error()); + } + + CTL_INF("Response: " + result.value()); + Json::Value root; + Json::Reader reader; + if (!reader.parse(result.value(), root)) { + return cpp::fail("JSON from " + url + + " parsing error: " + reader.getFormattedErrorMessages()); + } + + return root; +} + inline std::optional> GetHeaders( const std::string& url) { auto url_obj = url_parser::FromUrlString(url); diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h index 9e40e423b..264d04025 100644 --- a/engine/utils/string_utils.h +++ b/engine/utils/string_utils.h @@ -99,7 +99,7 @@ inline bool EndsWith(const std::string& str, const std::string& suffix) { } inline std::vector SplitBy(const std::string& str, - const std::string& delimiter) { + const std::string&& delimiter) { std::vector tokens; size_t prev = 0, pos = 0; do { From 38233d8ce916a9c44ca33b50bd5db13599ffa2ed Mon Sep 17 00:00:00 2001 From: James Date: Wed, 6 Nov 2024 23:32:17 +0700 Subject: [PATCH 6/7] fix build --- engine/common/api_server_configuration.h | 1 + 1 file changed, 1 insertion(+) diff --git a/engine/common/api_server_configuration.h b/engine/common/api_server_configuration.h index 72d0aeedf..9d261231f 100644 --- a/engine/common/api_server_configuration.h +++ b/engine/common/api_server_configuration.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include From 229be62e88bc9bf1e2ddec93d72860bd9690055d Mon Sep 17 00:00:00 2001 From: James Date: Thu, 7 Nov 2024 08:38:26 +0700 Subject: [PATCH 7/7] fix build windows --- engine/cli/commands/config_get_cmd.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/engine/cli/commands/config_get_cmd.cc b/engine/cli/commands/config_get_cmd.cc index aa2e059d7..62d9638a5 100644 --- a/engine/cli/commands/config_get_cmd.cc +++ b/engine/cli/commands/config_get_cmd.cc @@ -1,9 +1,11 @@ #include "config_get_cmd.h" -#include #include "commands/server_start_cmd.h" #include "utils/curl_utils.h" #include "utils/logging_utils.h" #include "utils/url_parser.h" +// clang-format off +#include +// clang-format on void commands::ConfigGetCmd::Exec(const std::string& host, int port) { // Start server if server is not started yet