diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 479e300ce..d006f0f2d 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -5,77 +5,470 @@ "post": { "operationId": "AssistantsController_create", "summary": "Create assistant", - "description": "Creates a new assistant.", - "parameters": [], + "description": "Creates a new assistant with the specified configuration.", "requestBody": { "required": true, "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateAssistantDto" + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The model identifier to use for the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant. Maximum of 128 tools.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs for the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "required": ["model"] } } } }, "responses": { - "201": { - "description": "The assistant has been successfully created." + "200": { + "description": "Ok", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs that can be attached to the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] + } + } + } } }, "tags": ["Assistants"] }, - "get": { - "operationId": "AssistantsController_findAll", - "summary": "List assistants", - "description": "Returns a list of assistants.", + "patch": { + "operationId": "AssistantsController_update", + "summary": "Update assistant", + "description": "Updates an assistant. Requires at least one modifiable field.", "parameters": [ { - "name": "limit", - "required": false, - "in": "query", - "description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.", - "schema": { - "type": "number" - } - }, - { - "name": "order", - "required": false, - "in": "query", - "description": "Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.", - "schema": { - "type": "string" - } - }, - { - "name": "after", - "required": false, - "in": "query", - "description": "A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.", + "name": "id", + "required": true, + "in": "path", + "description": "The unique identifier of the assistant.", "schema": { "type": "string" } }, { - "name": "before", - "required": false, - "in": "query", - "description": "A cursor for use in pagination. before is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.", + "name": "OpenAI-Beta", + "required": true, + "in": "header", + "description": "Beta feature header.", "schema": { - "type": "string" + "type": "string", + "enum": ["assistants=v2"] } } ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The model identifier to use for the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant. Maximum of 128 tools.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs for the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "minProperties": 1 + } + } + } + }, "responses": { "200": { "description": "Ok", "content": { "application/json": { "schema": { - "type": "array", - "items": { - "$ref": "#/components/schemas/AssistantEntity" - } + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "name": { + "type": "string", + "description": "The name of the assistant." + }, + "description": { + "type": "string", + "description": "The description of the assistant." + }, + "instructions": { + "type": "string", + "description": "Instructions for the assistant's behavior." + }, + "tools": { + "type": "array", + "description": "A list of tools enabled on the assistant.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": [ + "code_interpreter", + "file_search", + "function" + ] + } + } + } + }, + "tool_resources": { + "type": "object", + "description": "Resources used by the assistant's tools.", + "properties": { + "code_interpreter": { + "type": "object" + }, + "file_search": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs that can be attached to the assistant.", + "additionalProperties": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "Temperature parameter for response generation." + }, + "top_p": { + "type": "number", + "format": "float", + "description": "Top p parameter for response generation." + }, + "response_format": { + "oneOf": [ + { + "type": "string", + "enum": ["auto"] + }, + { + "type": "object" + } + ] + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] + } + } + } + } + }, + "tags": ["Assistants"] + }, + "get": { + "operationId": "AssistantsController_list", + "summary": "List assistants", + "description": "Returns a list of assistants.", + "responses": { + "200": { + "description": "Ok", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "object": { + "type": "string", + "enum": ["list"], + "description": "The object type, which is always 'list' for a list response." + }, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs that can be attached to the assistant.", + "additionalProperties": true + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] + } + } + }, + "required": ["object", "data"] } } } @@ -88,7 +481,7 @@ "get": { "operationId": "AssistantsController_findOne", "summary": "Get assistant", - "description": "Retrieves a specific assistant defined by an assistant's `id`.", + "description": "Retrieves a specific assistant by ID.", "parameters": [ { "name": "id", @@ -98,6 +491,16 @@ "schema": { "type": "string" } + }, + { + "name": "OpenAI-Beta", + "required": true, + "in": "header", + "description": "Beta feature header.", + "schema": { + "type": "string", + "enum": ["assistants=v2"] + } } ], "responses": { @@ -106,7 +509,38 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/AssistantEntity" + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the assistant." + }, + "object": { + "type": "string", + "enum": ["assistant"], + "description": "The object type, which is always 'assistant'." + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp (in seconds) of when the assistant was created." + }, + "model": { + "type": "string", + "description": "The model identifier used by the assistant." + }, + "metadata": { + "type": "object", + "description": "Set of key-value pairs attached to the assistant.", + "additionalProperties": true + } + }, + "required": [ + "id", + "object", + "created_at", + "model", + "metadata" + ] } } } @@ -117,7 +551,7 @@ "delete": { "operationId": "AssistantsController_remove", "summary": "Delete assistant", - "description": "Deletes a specific assistant defined by an assistant's `id`.", + "description": "Deletes a specific assistant by ID.", "parameters": [ { "name": "id", @@ -131,11 +565,28 @@ ], "responses": { "200": { - "description": "The assistant has been successfully deleted.", + "description": "Ok", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/DeleteAssistantResponseDto" + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier of the deleted assistant." + }, + "object": { + "type": "string", + "enum": ["assistant.deleted"], + "description": "The object type for a deleted assistant." + }, + "deleted": { + "type": "boolean", + "enum": [true], + "description": "Indicates the assistant was successfully deleted." + } + }, + "required": ["id", "object", "deleted"] } } } @@ -3456,6 +3907,7 @@ "Files", "Hardware", "Events", + "Assistants", "Threads", "Messages", "Pulling Models", diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 420434eb9..024f015a8 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -162,7 +162,7 @@ target_link_libraries(${TARGET_NAME} PRIVATE JsonCpp::JsonCpp Drogon::Drogon Ope target_link_libraries(${TARGET_NAME} PRIVATE SQLiteCpp) target_link_libraries(${TARGET_NAME} PRIVATE eventpp::eventpp) target_link_libraries(${TARGET_NAME} PRIVATE lfreist-hwinfo::hwinfo) - + # ############################################################################## if(CMAKE_CXX_STANDARD LESS 17) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index efff03d10..4ca734d6a 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -83,6 +83,7 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../services/database_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/python-engine/python_engine.cc diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 825780895..6f8f227e6 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -49,8 +49,9 @@ CommandLineParser::CommandLineParser() : app_("\nCortex.cpp CLI\n"), download_service_{std::make_shared()}, dylib_path_manager_{std::make_shared()}, - engine_service_{std::make_shared(download_service_, - dylib_path_manager_)} { + db_service_{std::make_shared()}, + engine_service_{std::make_shared( + download_service_, dylib_path_manager_, db_service_)} { supported_engines_ = engine_service_->GetSupportedEngineNames().value(); } @@ -177,7 +178,7 @@ void CommandLineParser::SetupCommonCommands() { return; commands::RunCmd rc(cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort), - cml_data_.model_id, engine_service_); + cml_data_.model_id, db_service_, engine_service_); rc.Exec(cml_data_.run_detach, run_settings_); }); } @@ -216,9 +217,10 @@ void CommandLineParser::SetupModelCommands() { CLI_LOG(model_start_cmd->help()); return; }; - commands::ModelStartCmd().Exec(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), - cml_data_.model_id, run_settings_); + commands::ModelStartCmd(db_service_) + .Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id, + run_settings_); }); auto stop_model_cmd = diff --git a/engine/cli/command_line_parser.h b/engine/cli/command_line_parser.h index 14e10e420..5b64f7f4d 100644 --- a/engine/cli/command_line_parser.h +++ b/engine/cli/command_line_parser.h @@ -45,6 +45,7 @@ class CommandLineParser { CLI::App app_; std::shared_ptr download_service_; std::shared_ptr dylib_path_manager_; + std::shared_ptr db_service_; std::shared_ptr engine_service_; std::vector supported_engines_; diff --git a/engine/cli/commands/chat_completion_cmd.cc b/engine/cli/commands/chat_completion_cmd.cc index 77d222176..77ee4fca3 100644 --- a/engine/cli/commands/chat_completion_cmd.cc +++ b/engine/cli/commands/chat_completion_cmd.cc @@ -56,10 +56,9 @@ void ChatCompletionCmd::Exec(const std::string& host, int port, const std::string& model_handle, std::string msg) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { - auto model_entry = modellist_handler.GetModelInfo(model_handle); + auto model_entry = db_service_->GetModelInfo(model_handle); if (model_entry.has_error()) { CLI_LOG("Error: " + model_entry.error()); return; diff --git a/engine/cli/commands/chat_completion_cmd.h b/engine/cli/commands/chat_completion_cmd.h index a784b4604..44de5d256 100644 --- a/engine/cli/commands/chat_completion_cmd.h +++ b/engine/cli/commands/chat_completion_cmd.h @@ -3,16 +3,20 @@ #include #include #include "config/model_config.h" +#include "services/database_service.h" namespace commands { class ChatCompletionCmd { public: + explicit ChatCompletionCmd(std::shared_ptr db_service) + : db_service_(db_service) {} void Exec(const std::string& host, int port, const std::string& model_handle, std::string msg); void Exec(const std::string& host, int port, const std::string& model_handle, const config::ModelConfig& mc, std::string msg); private: + std::shared_ptr db_service_; std::vector histories_; }; } // namespace commands diff --git a/engine/cli/commands/model_start_cmd.cc b/engine/cli/commands/model_start_cmd.cc index 12aec944d..ef5d5c1f2 100644 --- a/engine/cli/commands/model_start_cmd.cc +++ b/engine/cli/commands/model_start_cmd.cc @@ -13,7 +13,7 @@ bool ModelStartCmd::Exec( const std::unordered_map& options, bool print_success_log) { std::optional model_id = - SelectLocalModel(host, port, model_handle); + SelectLocalModel(host, port, model_handle, *db_service_); if (!model_id.has_value()) { return false; diff --git a/engine/cli/commands/model_start_cmd.h b/engine/cli/commands/model_start_cmd.h index 124ef463d..c69bfc32a 100644 --- a/engine/cli/commands/model_start_cmd.h +++ b/engine/cli/commands/model_start_cmd.h @@ -3,16 +3,23 @@ #include #include #include "json/json.h" +#include "services/database_service.h" namespace commands { class ModelStartCmd { public: + explicit ModelStartCmd(std::shared_ptr db_service) + : db_service_(db_service) {} bool Exec(const std::string& host, int port, const std::string& model_handle, const std::unordered_map& options, bool print_success_log = true); - private: + + private: bool UpdateConfig(Json::Value& data, const std::string& key, const std::string& value); + + private: + std::shared_ptr db_service_; }; } // namespace commands diff --git a/engine/cli/commands/run_cmd.cc b/engine/cli/commands/run_cmd.cc index 91a813d64..c01d3d806 100644 --- a/engine/cli/commands/run_cmd.cc +++ b/engine/cli/commands/run_cmd.cc @@ -14,12 +14,11 @@ namespace commands { std::optional SelectLocalModel(std::string host, int port, - const std::string& model_handle) { + const std::string& model_handle, + DatabaseService& db_service) { std::optional model_id = model_handle; - cortex::db::Models modellist_handler; - if (model_handle.empty()) { - auto all_local_models = modellist_handler.LoadModelList(); + auto all_local_models = db_service.LoadModelList(); if (all_local_models.has_error() || all_local_models.value().empty()) { CLI_LOG("No local models available!"); return std::nullopt; @@ -42,7 +41,7 @@ std::optional SelectLocalModel(std::string host, int port, CLI_LOG("Selected: " << selection.value()); } } else { - auto related_models_ids = modellist_handler.FindRelatedModel(model_handle); + auto related_models_ids = db_service.FindRelatedModel(model_handle); if (related_models_ids.has_error() || related_models_ids.value().empty()) { auto result = ModelPullCmd().Exec(host, port, model_handle); if (!result) { @@ -69,19 +68,18 @@ std::optional SelectLocalModel(std::string host, int port, void RunCmd::Exec(bool run_detach, const std::unordered_map& options) { std::optional model_id = - SelectLocalModel(host_, port_, model_handle_); + SelectLocalModel(host_, port_, model_handle_, *db_service_); if (!model_id.has_value()) { return; } - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; auto address = host_ + ":" + std::to_string(port_); try { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - auto model_entry = modellist_handler.GetModelInfo(*model_id); + auto model_entry = db_service_->GetModelInfo(*model_id); if (model_entry.has_error()) { CLI_LOG("Error: " + model_entry.error()); return; @@ -128,7 +126,7 @@ void RunCmd::Exec(bool run_detach, mc.engine.find(kLlamaEngine) == std::string::npos) || !commands::ModelStatusCmd().IsLoaded(host_, port_, *model_id)) { - auto res = commands::ModelStartCmd() + auto res = commands::ModelStartCmd(db_service_) .Exec(host_, port_, *model_id, options, false /*print_success_log*/); if (!res) { @@ -144,7 +142,7 @@ void RunCmd::Exec(bool run_detach, << commands::GetCortexBinary() << " run " << *model_id << "` for interactive chat shell"); } else { - ChatCompletionCmd().Exec(host_, port_, *model_id, mc, ""); + ChatCompletionCmd(db_service_).Exec(host_, port_, *model_id, mc, ""); } } } catch (const std::exception& e) { diff --git a/engine/cli/commands/run_cmd.h b/engine/cli/commands/run_cmd.h index b22b064f9..ec5c61fd3 100644 --- a/engine/cli/commands/run_cmd.h +++ b/engine/cli/commands/run_cmd.h @@ -2,20 +2,24 @@ #include #include +#include "services/database_service.h" #include "services/engine_service.h" namespace commands { std::optional SelectLocalModel(std::string host, int port, - const std::string& model_handle); + const std::string& model_handle, + DatabaseService& db_service); class RunCmd { public: explicit RunCmd(std::string host, int port, std::string model_handle, + std::shared_ptr db_service, std::shared_ptr engine_service) : host_{std::move(host)}, port_{port}, model_handle_{std::move(model_handle)}, + db_service_(db_service), engine_service_{engine_service} {}; void Exec(bool chat_flag, @@ -25,6 +29,7 @@ class RunCmd { std::string host_; int port_; std::string model_handle_; + std::shared_ptr db_service_; std::shared_ptr engine_service_; }; } // namespace commands diff --git a/engine/cli/commands/server_start_cmd.cc b/engine/cli/commands/server_start_cmd.cc index 3d6045cd5..4268f6362 100644 --- a/engine/cli/commands/server_start_cmd.cc +++ b/engine/cli/commands/server_start_cmd.cc @@ -114,7 +114,8 @@ bool ServerStartCmd::Exec(const std::string& host, int port, // Some engines requires to add lib search path before process being created auto download_srv = std::make_shared(); auto dylib_path_mng = std::make_shared(); - EngineService(download_srv, dylib_path_mng).RegisterEngineLibPath(); + auto db_srv = std::make_shared(); + EngineService(download_srv, dylib_path_mng, db_srv).RegisterEngineLibPath(); std::string p = cortex_utils::GetCurrentPath() + "/" + exe; execl(p.c_str(), exe.c_str(), "--start-server", "--config_file_path", diff --git a/engine/common/assistant.h b/engine/common/assistant.h index e49147e9e..6210a0c2c 100644 --- a/engine/common/assistant.h +++ b/engine/common/assistant.h @@ -1,9 +1,13 @@ #pragma once #include +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" #include "common/assistant_tool.h" -#include "common/thread_tool_resources.h" +#include "common/tool_resources.h" #include "common/variant_map.h" +#include "utils/logging_utils.h" #include "utils/result.hpp" namespace OpenAi { @@ -75,7 +79,49 @@ struct JanAssistant : JsonSerializable { } }; -struct Assistant { +struct Assistant : JsonSerializable { + Assistant() = default; + + ~Assistant() = default; + + Assistant(const Assistant&) = delete; + + Assistant& operator=(const Assistant&) = delete; + + Assistant(Assistant&& other) noexcept + : id{std::move(other.id)}, + object{std::move(other.object)}, + created_at{other.created_at}, + name{std::move(other.name)}, + description{std::move(other.description)}, + model(std::move(other.model)), + instructions(std::move(other.instructions)), + tools(std::move(other.tools)), + tool_resources(std::move(other.tool_resources)), + metadata(std::move(other.metadata)), + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} + + Assistant& operator=(Assistant&& other) noexcept { + if (this != &other) { + id = std::move(other.id); + object = std::move(other.object); + created_at = other.created_at; + name = std::move(other.name); + description = std::move(other.description); + model = std::move(other.model); + instructions = std::move(other.instructions); + tools = std::move(other.tools); + tool_resources = std::move(other.tool_resources); + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + response_format = std::move(other.response_format); + } + return *this; + } + /** * The identifier, which can be referenced in API endpoints. */ @@ -126,8 +172,7 @@ struct Assistant { * requires a list of file IDs, while the file_search tool requires a list * of vector store IDs. */ - std::optional> - tool_resources; + std::unique_ptr tool_resources; /** * Set of 16 key-value pairs that can be attached to an object. This can be @@ -153,5 +198,223 @@ struct Assistant { * We generally recommend altering this or temperature but not both. */ std::optional top_p; + + std::variant response_format; + + cpp::result ToJson() override { + try { + Json::Value root; + + root["id"] = std::move(id); + root["object"] = "assistant"; + root["created_at"] = created_at; + if (name.has_value()) { + root["name"] = name.value(); + } + if (description.has_value()) { + root["description"] = description.value(); + } + root["model"] = model; + if (instructions.has_value()) { + root["instructions"] = instructions.value(); + } + + Json::Value tools_jarr{Json::arrayValue}; + for (auto& tool_ptr : tools) { + if (auto it = tool_ptr->ToJson(); it.has_value()) { + tools_jarr.append(it.value()); + } else { + CTL_WRN("Failed to convert content to json: " + it.error()); + } + } + root["tools"] = tools_jarr; + if (tool_resources) { + Json::Value tool_resources_json{Json::objectValue}; + + if (auto* code_interpreter = + dynamic_cast(tool_resources.get())) { + auto result = code_interpreter->ToJson(); + if (result.has_value()) { + tool_resources_json["code_interpreter"] = result.value(); + } else { + CTL_WRN("Failed to convert code_interpreter to json: " + + result.error()); + } + } else if (auto* file_search = dynamic_cast( + tool_resources.get())) { + auto result = file_search->ToJson(); + if (result.has_value()) { + tool_resources_json["file_search"] = result.value(); + } else { + CTL_WRN("Failed to convert file_search to json: " + result.error()); + } + } + + // Only add tool_resources to root if we successfully serialized some resources + if (!tool_resources_json.empty()) { + root["tool_resources"] = tool_resources_json; + } + } + Json::Value metadata_json{Json::objectValue}; + for (const auto& [key, value] : metadata) { + if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else { + metadata_json[key] = std::get(value); + } + } + root["metadata"] = metadata_json; + + if (temperature.has_value()) { + root["temperature"] = temperature.value(); + } + if (top_p.has_value()) { + root["top_p"] = top_p.value(); + } + return root; + } catch (const std::exception& e) { + return cpp::fail("ToJson failed: " + std::string(e.what())); + } + } + + static cpp::result FromJson(Json::Value&& json) { + try { + Assistant assistant; + + // Parse required fields + if (!json.isMember("id") || !json["id"].isString()) { + return cpp::fail("Missing or invalid 'id' field"); + } + assistant.id = json["id"].asString(); + + if (!json.isMember("object") || !json["object"].isString() || + json["object"].asString() != "assistant") { + return cpp::fail("Missing or invalid 'object' field"); + } + + if (!json.isMember("created_at") || !json["created_at"].isUInt64()) { + return cpp::fail("Missing or invalid 'created_at' field"); + } + assistant.created_at = json["created_at"].asUInt64(); + + if (!json.isMember("model") || !json["model"].isString()) { + return cpp::fail("Missing or invalid 'model' field"); + } + assistant.model = json["model"].asString(); + + // Parse optional fields + if (json.isMember("name") && json["name"].isString()) { + assistant.name = json["name"].asString(); + } + + if (json.isMember("description") && json["description"].isString()) { + assistant.description = json["description"].asString(); + } + + if (json.isMember("instructions") && json["instructions"].isString()) { + assistant.instructions = json["instructions"].asString(); + } + + // Parse tools array + if (json.isMember("tools") && json["tools"].isArray()) { + auto tools_array = json["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = AssistantCodeInterpreterTool::FromJson(); + if (result.has_value()) { + assistant.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + + result.error()); + } + } else if (tool_type == "function") { + auto result = AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + assistant.tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + } + + if (json.isMember("tool_resources") && + json["tool_resources"].isObject()) { + const auto& tool_resources_json = json["tool_resources"]; + + // Parse code interpreter resources + if (tool_resources_json.isMember("code_interpreter")) { + auto result = OpenAi::CodeInterpreter::FromJson( + tool_resources_json["code_interpreter"]); + if (result.has_value()) { + assistant.tool_resources = + std::make_unique( + std::move(result.value())); + } else { + CTL_WRN("Failed to parse code_interpreter resources: " + + result.error()); + } + } + + // Parse file search resources + if (tool_resources_json.isMember("file_search")) { + auto result = + OpenAi::FileSearch::FromJson(tool_resources_json["file_search"]); + if (result.has_value()) { + assistant.tool_resources = + std::make_unique(std::move(result.value())); + } else { + CTL_WRN("Failed to parse file_search resources: " + result.error()); + } + } + } + + // Parse metadata + if (json.isMember("metadata") && json["metadata"].isObject()) { + auto res = Cortex::ConvertJsonValueToMap(json["metadata"]); + if (res.has_value()) { + assistant.metadata = res.value(); + } else { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } + } + + if (json.isMember("temperature") && json["temperature"].isDouble()) { + assistant.temperature = json["temperature"].asFloat(); + } + + if (json.isMember("top_p") && json["top_p"].isDouble()) { + assistant.top_p = json["top_p"].asFloat(); + } + + return assistant; + } catch (const std::exception& e) { + return cpp::fail("FromJson failed: " + std::string(e.what())); + } + } }; } // namespace OpenAi diff --git a/engine/common/assistant_code_interpreter_tool.h b/engine/common/assistant_code_interpreter_tool.h new file mode 100644 index 000000000..43bfac47c --- /dev/null +++ b/engine/common/assistant_code_interpreter_tool.h @@ -0,0 +1,32 @@ +#pragma once + +#include "common/assistant_tool.h" + +namespace OpenAi { +struct AssistantCodeInterpreterTool : public AssistantTool { + AssistantCodeInterpreterTool() : AssistantTool("code_interpreter") {} + + AssistantCodeInterpreterTool(const AssistantCodeInterpreterTool&) = delete; + + AssistantCodeInterpreterTool& operator=(const AssistantCodeInterpreterTool&) = + delete; + + AssistantCodeInterpreterTool(AssistantCodeInterpreterTool&&) = default; + + AssistantCodeInterpreterTool& operator=(AssistantCodeInterpreterTool&&) = + default; + + ~AssistantCodeInterpreterTool() = default; + + static cpp::result FromJson() { + AssistantCodeInterpreterTool tool; + return std::move(tool); + } + + cpp::result ToJson() override { + Json::Value json; + json["type"] = type; + return json; + } +}; +} // namespace OpenAi diff --git a/engine/common/assistant_file_search_tool.h b/engine/common/assistant_file_search_tool.h new file mode 100644 index 000000000..2abaa7f6e --- /dev/null +++ b/engine/common/assistant_file_search_tool.h @@ -0,0 +1,151 @@ +#pragma once + +#include "common/assistant_tool.h" +#include "common/json_serializable.h" + +namespace OpenAi { +struct FileSearchRankingOption : public JsonSerializable { + /** + * The ranker to use for the file search. If not specified will use the auto ranker. + */ + std::string ranker; + + /** + * The score threshold for the file search. All values must be a + * floating point number between 0 and 1. + */ + float score_threshold; + + FileSearchRankingOption(float score_threshold, + const std::string& ranker = "auto") + : ranker{ranker}, score_threshold{score_threshold} {} + + FileSearchRankingOption(const FileSearchRankingOption&) = delete; + + FileSearchRankingOption& operator=(const FileSearchRankingOption&) = delete; + + FileSearchRankingOption(FileSearchRankingOption&&) = default; + + FileSearchRankingOption& operator=(FileSearchRankingOption&&) = default; + + ~FileSearchRankingOption() = default; + + static cpp::result FromJson( + const Json::Value& json) { + if (!json.isMember("score_threshold")) { + return cpp::fail("score_threshold must be provided"); + } + + FileSearchRankingOption option{ + json["score_threshold"].asFloat(), + std::move(json.get("ranker", "auto").asString())}; + return option; + } + + cpp::result ToJson() override { + Json::Value json; + json["ranker"] = ranker; + json["score_threshold"] = score_threshold; + return json; + } +}; + +/** + * Overrides for the file search tool. + */ +struct AssistantFileSearch : public JsonSerializable { + /** + * The maximum number of results the file search tool should output. + * The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. + * This number should be between 1 and 50 inclusive. + * + * Note that the file search tool may output fewer than max_num_results results. + * See the file search tool documentation for more information. + */ + int max_num_results; + + /** + * The ranking options for the file search. If not specified, + * the file search tool will use the auto ranker and a score_threshold of 0. + * + * See the file search tool documentation for more information. + */ + FileSearchRankingOption ranking_options; + + AssistantFileSearch(int max_num_results, + FileSearchRankingOption&& ranking_options) + : max_num_results{max_num_results}, + ranking_options{std::move(ranking_options)} {} + + AssistantFileSearch(const AssistantFileSearch&) = delete; + + AssistantFileSearch& operator=(const AssistantFileSearch&) = delete; + + AssistantFileSearch(AssistantFileSearch&&) = default; + + AssistantFileSearch& operator=(AssistantFileSearch&&) = default; + + ~AssistantFileSearch() = default; + + static cpp::result FromJson( + const Json::Value& json) { + try { + AssistantFileSearch search{ + json["max_num_results"].asInt(), + FileSearchRankingOption::FromJson(json["ranking_options"]).value()}; + return search; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + Json::Value root; + root["max_num_results"] = max_num_results; + root["ranking_options"] = ranking_options.ToJson().value(); + return root; + } +}; + +struct AssistantFileSearchTool : public AssistantTool { + AssistantFileSearch file_search; + + AssistantFileSearchTool(AssistantFileSearch& file_search) + : AssistantTool("file_search"), file_search{std::move(file_search)} {} + + AssistantFileSearchTool(const AssistantFileSearchTool&) = delete; + + AssistantFileSearchTool& operator=(const AssistantFileSearchTool&) = delete; + + AssistantFileSearchTool(AssistantFileSearchTool&&) = default; + + AssistantFileSearchTool& operator=(AssistantFileSearchTool&&) = default; + + ~AssistantFileSearchTool() = default; + + static cpp::result FromJson( + const Json::Value& json) { + try { + AssistantFileSearch search{json["file_search"]["max_num_results"].asInt(), + FileSearchRankingOption::FromJson( + json["file_search"]["ranking_options"]) + .value()}; + AssistantFileSearchTool tool{search}; + return tool; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value root; + root["type"] = type; + root["file_search"] = file_search.ToJson().value(); + return root; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +}; // namespace OpenAi diff --git a/engine/common/assistant_function_tool.h b/engine/common/assistant_function_tool.h new file mode 100644 index 000000000..7998cb8ff --- /dev/null +++ b/engine/common/assistant_function_tool.h @@ -0,0 +1,130 @@ +#pragma once + +#include +#include "common/assistant_tool.h" +#include "common/json_serializable.h" + +namespace OpenAi { +struct AssistantFunction : public JsonSerializable { + AssistantFunction(const std::string& description, const std::string& name, + const Json::Value& parameters, + const std::optional& strict) + : description{std::move(description)}, + name{std::move(name)}, + parameters{std::move(parameters)}, + strict{strict} {} + + AssistantFunction(const AssistantFunction&) = delete; + + AssistantFunction& operator=(const AssistantFunction&) = delete; + + AssistantFunction(AssistantFunction&&) = default; + + AssistantFunction& operator=(AssistantFunction&&) = default; + + ~AssistantFunction() = default; + + /** + * A description of what the function does, used by the model to choose + * when and how to call the function. + */ + std::string description; + + /** + * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain + * underscores and dashes, with a maximum length of 64. + */ + std::string name; + + /** + * The parameters the functions accepts, described as a JSON Schema object. + * See the guide for examples, and the JSON Schema reference for documentation + * about the format. + * + * Omitting parameters defines a function with an empty parameter list. + */ + Json::Value parameters; + + /** + * Whether to enable strict schema adherence when generating the function call. + * If set to true, the model will follow the exact schema defined in the parameters + * field. Only a subset of JSON Schema is supported when strict is true. + * + * Learn more about Structured Outputs in the function calling guide. + */ + std::optional strict; + + static cpp::result FromJson( + const Json::Value& json) { + if (json.empty()) { + return cpp::fail("Function json can't be empty"); + } + + if (!json.isMember("name") || json.get("name", "").asString().empty()) { + return cpp::fail("Function name can't be empty"); + } + + if (!json.isMember("description")) { + return cpp::fail("Function description is mandatory"); + } + + if (!json.isMember("parameters")) { + return cpp::fail("Function parameters are mandatory"); + } + + std::optional is_strict = std::nullopt; + if (json.isMember("strict")) { + is_strict = json["strict"].asBool(); + } + AssistantFunction function{json["description"].asString(), + json["name"].asString(), json["parameters"], + is_strict}; + function.parameters = json["parameters"]; + return function; + } + + cpp::result ToJson() override { + Json::Value json; + json["description"] = description; + json["name"] = name; + if (strict.has_value()) { + json["strict"] = *strict; + } + json["parameters"] = parameters; + return json; + } +}; + +struct AssistantFunctionTool : public AssistantTool { + AssistantFunctionTool(AssistantFunction& function) + : AssistantTool("function"), function{std::move(function)} {} + + AssistantFunctionTool(const AssistantFunctionTool&) = delete; + + AssistantFunctionTool& operator=(const AssistantFunctionTool&) = delete; + + AssistantFunctionTool(AssistantFunctionTool&&) = default; + + AssistantFunctionTool& operator=(AssistantFunctionTool&&) = default; + + ~AssistantFunctionTool() = default; + + AssistantFunction function; + + static cpp::result FromJson( + const Json::Value& json) { + auto function_res = AssistantFunction::FromJson(json["function"]); + if (function_res.has_error()) { + return cpp::fail("Failed to parse function: " + function_res.error()); + } + return AssistantFunctionTool{function_res.value()}; + } + + cpp::result ToJson() override { + Json::Value root; + root["type"] = type; + root["function"] = function.ToJson().value(); + return root; + } +}; +}; // namespace OpenAi diff --git a/engine/common/assistant_tool.h b/engine/common/assistant_tool.h index 622721708..d02392392 100644 --- a/engine/common/assistant_tool.h +++ b/engine/common/assistant_tool.h @@ -1,91 +1,27 @@ #pragma once -#include #include +#include "common/json_serializable.h" namespace OpenAi { -struct AssistantTool { +struct AssistantTool : public JsonSerializable { std::string type; AssistantTool(const std::string& type) : type{type} {} - virtual ~AssistantTool() = default; -}; - -struct AssistantCodeInterpreterTool : public AssistantTool { - AssistantCodeInterpreterTool() : AssistantTool{"code_interpreter"} {} - - ~AssistantCodeInterpreterTool() = default; -}; - -struct AssistantFileSearchTool : public AssistantTool { - AssistantFileSearchTool() : AssistantTool("file_search") {} - - ~AssistantFileSearchTool() = default; + AssistantTool(const AssistantTool&) = delete; - /** - * The ranking options for the file search. If not specified, - * the file search tool will use the auto ranker and a score_threshold of 0. - * - * See the file search tool documentation for more information. - */ - struct RankingOption { - /** - * The ranker to use for the file search. If not specified will use the auto ranker. - */ - std::string ranker; + AssistantTool& operator=(const AssistantTool&) = delete; - /** - * The score threshold for the file search. All values must be a - * floating point number between 0 and 1. - */ - float score_threshold; - }; + AssistantTool(AssistantTool&& other) noexcept : type{std::move(other.type)} {} - /** - * Overrides for the file search tool. - */ - struct FileSearch { - /** - * The maximum number of results the file search tool should output. - * The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. - * This number should be between 1 and 50 inclusive. - * - * Note that the file search tool may output fewer than max_num_results results. - * See the file search tool documentation for more information. - */ - int max_num_result; - }; -}; - -struct AssistantFunctionTool : public AssistantTool { - AssistantFunctionTool() : AssistantTool("function") {} - - ~AssistantFunctionTool() = default; - - struct Function { - /** - * A description of what the function does, used by the model to choose - * when and how to call the function. - */ - std::string description; + AssistantTool& operator=(AssistantTool&& other) noexcept { + if (this != &other) { + type = std::move(other.type); + } + return *this; + } - /** - * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain - * underscores and dashes, with a maximum length of 64. - */ - std::string name; - - // TODO: namh handle parameters - - /** - * Whether to enable strict schema adherence when generating the function call. - * If set to true, the model will follow the exact schema defined in the parameters - * field. Only a subset of JSON Schema is supported when strict is true. - * - * Learn more about Structured Outputs in the function calling guide. - */ - std::optional strict; - }; + virtual ~AssistantTool() = default; }; } // namespace OpenAi diff --git a/engine/common/dto/assistant_create_dto.h b/engine/common/dto/assistant_create_dto.h new file mode 100644 index 000000000..19d79b833 --- /dev/null +++ b/engine/common/dto/assistant_create_dto.h @@ -0,0 +1,211 @@ +#pragma once + +#include +#include +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" +#include "common/assistant_tool.h" +#include "common/dto/base_dto.h" +#include "common/tool_resources.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" + +namespace dto { +struct CreateAssistantDto : public BaseDto { + CreateAssistantDto() = default; + + ~CreateAssistantDto() = default; + + CreateAssistantDto(const CreateAssistantDto&) = delete; + + CreateAssistantDto& operator=(const CreateAssistantDto&) = delete; + + CreateAssistantDto(CreateAssistantDto&& other) noexcept + : model{std::move(other.model)}, + name{std::move(other.name)}, + description{std::move(other.description)}, + instructions{std::move(other.instructions)}, + tools{std::move(other.tools)}, + tool_resources{std::move(other.tool_resources)}, + metadata{std::move(other.metadata)}, + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} + + CreateAssistantDto& operator=(CreateAssistantDto&& other) noexcept { + if (this != &other) { + model = std::move(other.model); + name = std::move(other.name); + description = std::move(other.description); + instructions = std::move(other.instructions); + tools = std::move(other.tools); + tool_resources = std::move(other.tool_resources), + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + response_format = std::move(other.response_format); + } + return *this; + } + + std::string model; + + std::optional name; + + std::optional description; + + std::optional instructions; + + /** + * A list of tool enabled on the assistant. There can be a maximum of + * 128 tools per assistant. Tools can be of types code_interpreter, + * file_search, or function. + */ + std::vector> tools; + + /** + * A set of resources that are used by the assistant's tools. The resources + * are specific to the type of tool. For example, the code_interpreter tool + * requires a list of file IDs, while the file_search tool requires a list + * of vector store IDs. + */ + std::unique_ptr tool_resources; + + std::optional metadata; + + std::optional temperature; + + std::optional top_p; + + std::optional> response_format; + + cpp::result Validate() const override { + if (model.empty()) { + return cpp::fail("Model is mandatory"); + } + + if (response_format.has_value()) { + const auto& variant_value = response_format.value(); + if (std::holds_alternative(variant_value)) { + if (std::get(variant_value) != "auto") { + return cpp::fail("Invalid response_format"); + } + } + } + + return {}; + } + + static CreateAssistantDto FromJson(Json::Value&& root) { + if (root.empty()) { + throw std::runtime_error("Json passed in FromJson can't be empty"); + } + CreateAssistantDto dto; + dto.model = std::move(root["model"].asString()); + if (root.isMember("name")) { + dto.name = std::move(root["name"].asString()); + } + if (root.isMember("description")) { + dto.description = std::move(root["description"].asString()); + } + if (root.isMember("instructions")) { + dto.instructions = std::move(root["instructions"].asString()); + } + if (root["metadata"].isObject() && !root["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + dto.metadata = std::move(res.value()); + } + } + if (root.isMember("temperature")) { + dto.temperature = root["temperature"].asFloat(); + } + if (root.isMember("top_p")) { + dto.top_p = root["top_p"].asFloat(); + } + if (root.isMember("tools") && root["tools"].isArray()) { + auto tools_array = root["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = OpenAi::AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = OpenAi::AssistantCodeInterpreterTool::FromJson(); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + result.error()); + } + } else if (tool_type == "function") { + auto result = OpenAi::AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + } + if (root.isMember("tool_resources") && root["tool_resources"].isObject()) { + const auto& tool_resources_json = root["tool_resources"]; + + // Parse code interpreter resources + if (tool_resources_json.isMember("code_interpreter")) { + auto result = OpenAi::CodeInterpreter::FromJson( + tool_resources_json["code_interpreter"]); + if (result.has_value()) { + dto.tool_resources = std::make_unique( + std::move(result.value())); + } else { + CTL_WRN("Failed to parse code_interpreter resources: " + + result.error()); + } + } + + // Parse file search resources + if (tool_resources_json.isMember("file_search")) { + auto result = + OpenAi::FileSearch::FromJson(tool_resources_json["file_search"]); + if (result.has_value()) { + dto.tool_resources = + std::make_unique(std::move(result.value())); + } else { + CTL_WRN("Failed to parse file_search resources: " + result.error()); + } + } + } + if (root.isMember("response_format")) { + const auto& response_format = root["response_format"]; + if (response_format.isString()) { + dto.response_format = response_format.asString(); + } else if (response_format.isObject()) { + dto.response_format = response_format; + } else { + throw std::runtime_error( + "response_format must be either a string or an object"); + } + } + return dto; + } +}; +} // namespace dto diff --git a/engine/common/dto/assistant_update_dto.h b/engine/common/dto/assistant_update_dto.h new file mode 100644 index 000000000..01e5844d7 --- /dev/null +++ b/engine/common/dto/assistant_update_dto.h @@ -0,0 +1,201 @@ +#pragma once + +#include "common/assistant_code_interpreter_tool.h" +#include "common/assistant_file_search_tool.h" +#include "common/assistant_function_tool.h" +#include "common/dto/base_dto.h" +#include "common/tool_resources.h" +#include "common/variant_map.h" +#include "utils/logging_utils.h" + +namespace dto { +struct UpdateAssistantDto : public BaseDto { + UpdateAssistantDto() = default; + + ~UpdateAssistantDto() = default; + + UpdateAssistantDto(const UpdateAssistantDto&) = delete; + + UpdateAssistantDto& operator=(const UpdateAssistantDto&) = delete; + + UpdateAssistantDto(UpdateAssistantDto&& other) noexcept + : model{std::move(other.model)}, + name{std::move(other.name)}, + description{std::move(other.description)}, + instructions{std::move(other.instructions)}, + tools{std::move(other.tools)}, + tool_resources{std::move(other.tool_resources)}, + metadata{std::move(other.metadata)}, + temperature{std::move(other.temperature)}, + top_p{std::move(other.top_p)}, + response_format{std::move(other.response_format)} {} + + UpdateAssistantDto& operator=(UpdateAssistantDto&& other) noexcept { + if (this != &other) { + model = std::move(other.model); + name = std::move(other.name); + description = std::move(other.description); + instructions = std::move(other.instructions); + tools = std::move(other.tools); + tool_resources = std::move(other.tool_resources), + metadata = std::move(other.metadata); + temperature = std::move(other.temperature); + top_p = std::move(other.top_p); + response_format = std::move(other.response_format); + } + return *this; + } + std::optional model; + + std::optional name; + + std::optional description; + + std::optional instructions; + + /** + * A list of tool enabled on the assistant. There can be a maximum of + * 128 tools per assistant. Tools can be of types code_interpreter, + * file_search, or function. + */ + std::vector> tools; + + /** + * A set of resources that are used by the assistant's tools. The resources + * are specific to the type of tool. For example, the code_interpreter tool + * requires a list of file IDs, while the file_search tool requires a list + * of vector store IDs. + */ + std::unique_ptr tool_resources; + + std::optional metadata; + + std::optional temperature; + + std::optional top_p; + + std::optional> response_format; + + cpp::result Validate() const override { + if (!model.has_value() && !name.has_value() && !description.has_value() && + !instructions.has_value() && !metadata.has_value() && + !temperature.has_value() && !top_p.has_value() && + !response_format.has_value()) { + return cpp::fail("At least one field must be provided"); + } + + return {}; + } + + static UpdateAssistantDto FromJson(Json::Value&& root) { + if (root.empty()) { + throw std::runtime_error("Json passed in FromJson can't be empty"); + } + UpdateAssistantDto dto; + dto.model = std::move(root["model"].asString()); + if (root.isMember("name")) { + dto.name = std::move(root["name"].asString()); + } + if (root.isMember("description")) { + dto.description = std::move(root["description"].asString()); + } + if (root.isMember("instruction")) { + dto.instructions = std::move(root["instruction"].asString()); + } + if (root["metadata"].isObject() && !root["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + dto.metadata = std::move(res.value()); + } + } + if (root.isMember("temperature")) { + dto.temperature = root["temperature"].asFloat(); + } + if (root.isMember("top_p")) { + dto.top_p = root["top_p"].asFloat(); + } + if (root.isMember("tools") && root["tools"].isArray()) { + auto tools_array = root["tools"]; + for (const auto& tool : tools_array) { + if (!tool.isMember("type") || !tool["type"].isString()) { + CTL_WRN("Tool missing type field or invalid type"); + continue; + } + + std::string tool_type = tool["type"].asString(); + if (tool_type == "file_search") { + auto result = OpenAi::AssistantFileSearchTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse file_search tool: " + result.error()); + } + } else if (tool_type == "code_interpreter") { + auto result = OpenAi::AssistantCodeInterpreterTool::FromJson(); + if (result.has_value()) { + dto.tools.push_back( + std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse code_interpreter tool: " + result.error()); + } + } else if (tool_type == "function") { + auto result = OpenAi::AssistantFunctionTool::FromJson(tool); + if (result.has_value()) { + dto.tools.push_back(std::make_unique( + std::move(result.value()))); + } else { + CTL_WRN("Failed to parse function tool: " + result.error()); + } + } else { + CTL_WRN("Unknown tool type: " + tool_type); + } + } + } + if (root.isMember("tool_resources") && root["tool_resources"].isObject()) { + const auto& tool_resources_json = root["tool_resources"]; + + // Parse code interpreter resources + if (tool_resources_json.isMember("code_interpreter")) { + auto result = OpenAi::CodeInterpreter::FromJson( + tool_resources_json["code_interpreter"]); + if (result.has_value()) { + dto.tool_resources = std::make_unique( + std::move(result.value())); + } else { + CTL_WRN("Failed to parse code_interpreter resources: " + + result.error()); + } + } + + // Parse file search resources + if (tool_resources_json.isMember("file_search")) { + auto result = + OpenAi::FileSearch::FromJson(tool_resources_json["file_search"]); + if (result.has_value()) { + dto.tool_resources = + std::make_unique(std::move(result.value())); + } else { + CTL_WRN("Failed to parse file_search resources: " + result.error()); + } + } + } + if (root.isMember("response_format")) { + const auto& response_format = root["response_format"]; + if (response_format.isString()) { + dto.response_format = response_format.asString(); + } else if (response_format.isObject()) { + dto.response_format = response_format; + } else { + throw std::runtime_error( + "response_format must be either a string or an object"); + } + } + return dto; + }; +}; +} // namespace dto diff --git a/engine/common/dto/base_dto.h b/engine/common/dto/base_dto.h new file mode 100644 index 000000000..ed7460aa3 --- /dev/null +++ b/engine/common/dto/base_dto.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include "utils/result.hpp" + +namespace dto { +template +struct BaseDto { + virtual ~BaseDto() = default; + + /** + * Validate itself. + */ + virtual cpp::result Validate() const = 0; +}; +} // namespace dto diff --git a/engine/common/message_attachment.h b/engine/common/message_attachment.h index 767ec9bea..6a0fb02e9 100644 --- a/engine/common/message_attachment.h +++ b/engine/common/message_attachment.h @@ -4,22 +4,27 @@ #include "common/json_serializable.h" namespace OpenAi { - // The tools to add this file to. struct Tool { std::string type; Tool(const std::string& type) : type{type} {} + + virtual ~Tool() = default; }; // The type of tool being defined: code_interpreter -struct CodeInterpreter : Tool { - CodeInterpreter() : Tool{"code_interpreter"} {} +struct MessageCodeInterpreter : Tool { + MessageCodeInterpreter() : Tool{"code_interpreter"} {} + + ~MessageCodeInterpreter() = default; }; // The type of tool being defined: file_search -struct FileSearch : Tool { - FileSearch() : Tool{"file_search"} {} +struct MessageFileSearch : Tool { + MessageFileSearch() : Tool{"file_search"} {} + + ~MessageFileSearch() = default; }; // A list of files attached to the message, and the tools they were added to. diff --git a/engine/common/repository/assistant_repository.h b/engine/common/repository/assistant_repository.h new file mode 100644 index 000000000..d0ff1908d --- /dev/null +++ b/engine/common/repository/assistant_repository.h @@ -0,0 +1,25 @@ +#pragma once + +#include "common/assistant.h" +#include "utils/result.hpp" + +class AssistantRepository { + public: + virtual cpp::result, std::string> + ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, const std::string& before) const = 0; + + virtual cpp::result CreateAssistant( + OpenAi::Assistant& assistant) = 0; + + virtual cpp::result RetrieveAssistant( + const std::string assistant_id) const = 0; + + virtual cpp::result ModifyAssistant( + OpenAi::Assistant& assistant) = 0; + + virtual cpp::result DeleteAssistant( + const std::string& assitant_id) = 0; + + virtual ~AssistantRepository() = default; +}; diff --git a/engine/common/thread.h b/engine/common/thread.h index 2bd5d866b..dc57ba32d 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -4,7 +4,7 @@ #include #include #include "common/assistant.h" -#include "common/thread_tool_resources.h" +#include "common/tool_resources.h" #include "common/variant_map.h" #include "json_serializable.h" #include "utils/logging_utils.h" @@ -36,7 +36,7 @@ struct Thread : JsonSerializable { * of tool. For example, the code_interpreter tool requires a list of * file IDs, while the file_search tool requires a list of vector store IDs. */ - std::unique_ptr tool_resources; + std::unique_ptr tool_resources; /** * Set of 16 key-value pairs that can be attached to an object. @@ -65,7 +65,7 @@ struct Thread : JsonSerializable { const auto& tool_json = json["tool_resources"]; if (tool_json.isMember("code_interpreter")) { - auto code_interpreter = std::make_unique(); + auto code_interpreter = std::make_unique(); const auto& file_ids = tool_json["code_interpreter"]["file_ids"]; if (file_ids.isArray()) { for (const auto& file_id : file_ids) { @@ -74,7 +74,7 @@ struct Thread : JsonSerializable { } thread.tool_resources = std::move(code_interpreter); } else if (tool_json.isMember("file_search")) { - auto file_search = std::make_unique(); + auto file_search = std::make_unique(); const auto& store_ids = tool_json["file_search"]["vector_store_ids"]; if (store_ids.isArray()) { for (const auto& store_id : store_ids) { @@ -148,10 +148,10 @@ struct Thread : JsonSerializable { Json::Value tool_json; if (auto code_interpreter = - dynamic_cast(tool_resources.get())) { + dynamic_cast(tool_resources.get())) { tool_json["code_interpreter"] = tool_result.value(); } else if (auto file_search = - dynamic_cast(tool_resources.get())) { + dynamic_cast(tool_resources.get())) { tool_json["file_search"] = tool_result.value(); } json["tool_resources"] = tool_json; diff --git a/engine/common/thread_tool_resources.h b/engine/common/thread_tool_resources.h deleted file mode 100644 index 3c22a4480..000000000 --- a/engine/common/thread_tool_resources.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include -#include -#include "common/json_serializable.h" - -namespace OpenAi { - -struct ThreadToolResources : JsonSerializable { - ~ThreadToolResources() = default; - - virtual cpp::result ToJson() override = 0; -}; - -struct ThreadCodeInterpreter : ThreadToolResources { - std::vector file_ids; - - cpp::result ToJson() override { - try { - Json::Value json; - Json::Value file_ids_json{Json::arrayValue}; - for (auto& file_id : file_ids) { - file_ids_json.append(file_id); - } - json["file_ids"] = file_ids_json; - return json; - } catch (const std::exception& e) { - return cpp::fail(std::string("ToJson failed: ") + e.what()); - } - } -}; - -struct ThreadFileSearch : ThreadToolResources { - std::vector vector_store_ids; - - cpp::result ToJson() override { - try { - Json::Value json; - Json::Value vector_store_ids_json{Json::arrayValue}; - for (auto& vector_store_id : vector_store_ids) { - vector_store_ids_json.append(vector_store_id); - } - json["vector_store_ids"] = vector_store_ids_json; - return json; - } catch (const std::exception& e) { - return cpp::fail(std::string("ToJson failed: ") + e.what()); - } - } -}; -} // namespace OpenAi diff --git a/engine/common/tool_resources.h b/engine/common/tool_resources.h new file mode 100644 index 000000000..5aadb3f8b --- /dev/null +++ b/engine/common/tool_resources.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include +#include "common/json_serializable.h" + +namespace OpenAi { + +struct ToolResources : JsonSerializable { + ToolResources() = default; + + ToolResources(const ToolResources&) = delete; + + ToolResources& operator=(const ToolResources&) = delete; + + ToolResources(ToolResources&&) noexcept = default; + + ToolResources& operator=(ToolResources&&) noexcept = default; + + virtual ~ToolResources() = default; + + virtual cpp::result ToJson() override = 0; +}; + +struct CodeInterpreter : ToolResources { + CodeInterpreter() = default; + + ~CodeInterpreter() override = default; + + CodeInterpreter(const CodeInterpreter&) = delete; + + CodeInterpreter& operator=(const CodeInterpreter&) = delete; + + CodeInterpreter(CodeInterpreter&& other) noexcept + : ToolResources(std::move(other)), file_ids(std::move(other.file_ids)) {} + + CodeInterpreter& operator=(CodeInterpreter&& other) noexcept { + if (this != &other) { + ToolResources::operator=(std::move(other)); + file_ids = std::move(other.file_ids); + } + return *this; + } + + std::vector file_ids; + + static cpp::result FromJson( + const Json::Value& json) { + CodeInterpreter code_interpreter; + if (json.isMember("file_ids")) { + for (const auto& file_id : json["file_ids"]) { + code_interpreter.file_ids.push_back(file_id.asString()); + } + } + return code_interpreter; + } + + cpp::result ToJson() override { + Json::Value json; + Json::Value file_ids_json{Json::arrayValue}; + for (auto& file_id : file_ids) { + file_ids_json.append(file_id); + } + json["file_ids"] = file_ids_json; + return json; + } +}; + +struct FileSearch : ToolResources { + FileSearch() = default; + + ~FileSearch() override = default; + + FileSearch(const FileSearch&) = delete; + + FileSearch& operator=(const FileSearch&) = delete; + + FileSearch(FileSearch&& other) noexcept + : ToolResources(std::move(other)), + vector_store_ids{std::move(other.vector_store_ids)} {} + + FileSearch& operator=(FileSearch&& other) noexcept { + if (this != &other) { + ToolResources::operator=(std::move(other)); + + vector_store_ids = std::move(other.vector_store_ids); + } + return *this; + } + + std::vector vector_store_ids; + + static cpp::result FromJson( + const Json::Value& json) { + FileSearch file_search; + if (json.isMember("vector_store_ids")) { + for (const auto& vector_store_id : json["vector_store_ids"]) { + file_search.vector_store_ids.push_back(vector_store_id.asString()); + } + } + return file_search; + } + + cpp::result ToJson() override { + Json::Value json; + Json::Value vector_store_ids_json{Json::arrayValue}; + for (auto& vector_store_id : vector_store_ids) { + vector_store_ids_json.append(vector_store_id); + } + json["vector_store_ids"] = vector_store_ids_json; + return json; + } +}; +} // namespace OpenAi diff --git a/engine/controllers/assistants.cc b/engine/controllers/assistants.cc index 405d7ed3c..530e180a5 100644 --- a/engine/controllers/assistants.cc +++ b/engine/controllers/assistants.cc @@ -1,4 +1,6 @@ #include "assistants.h" +#include "common/api-dto/delete_success_response.h" +#include "common/dto/assistant_create_dto.h" #include "utils/cortex_utils.h" #include "utils/logging_utils.h" @@ -6,7 +8,12 @@ void Assistants::RetrieveAssistant( const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) const { - CTL_INF("RetrieveAssistant: " + assistant_id); + const auto& headers = req->headers(); + auto it = headers.find(kOpenAiAssistantKeyV2); + if (it != headers.end() && it->second == kOpenAiAssistantValueV2) { + return RetrieveAssistantV2(req, std::move(callback), assistant_id); + } + auto res = assistant_service_->RetrieveAssistant(assistant_id); if (res.has_error()) { Json::Value ret; @@ -33,6 +40,78 @@ void Assistants::RetrieveAssistant( } } +void Assistants::RetrieveAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const { + auto res = assistant_service_->RetrieveAssistantV2(assistant_id); + + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto to_json_res = res->ToJson(); + if (to_json_res.has_error()) { + CTL_ERR("Failed to convert assistant to json: " + to_json_res.error()); + Json::Value ret; + ret["message"] = to_json_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + // TODO: namh need to use the text response because it contains model config + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Assistants::CreateAssistantV2( + const HttpRequestPtr& req, + std::function&& callback) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto dto = dto::CreateAssistantDto::FromJson(std::move(*json_body)); + CTL_INF("CreateAssistantV2: " << dto.model); + auto validate_res = dto.Validate(); + if (validate_res.has_error()) { + Json::Value ret; + ret["message"] = validate_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = assistant_service_->CreateAssistantV2(dto); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto to_json_res = res->ToJson(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(to_json_res.value()); + resp->setStatusCode(k200OK); + callback(resp); +} + void Assistants::CreateAssistant( const HttpRequestPtr& req, std::function&& callback, @@ -88,10 +167,55 @@ void Assistants::CreateAssistant( callback(resp); } +void Assistants::ModifyAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto dto = dto::UpdateAssistantDto::FromJson(std::move(*json_body)); + auto validate_res = dto.Validate(); + if (validate_res.has_error()) { + Json::Value ret; + ret["message"] = validate_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = assistant_service_->ModifyAssistantV2(assistant_id, dto); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} + void Assistants::ModifyAssistant( const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) { + const auto& headers = req->headers(); + auto it = headers.find(kOpenAiAssistantKeyV2); + if (it != headers.end() && it->second == kOpenAiAssistantValueV2) { + return ModifyAssistantV2(req, std::move(callback), assistant_id); + } auto json_body = req->getJsonObject(); if (json_body == nullptr) { Json::Value ret; @@ -142,3 +266,62 @@ void Assistants::ModifyAssistant( resp->setStatusCode(k200OK); callback(resp); } + +void Assistants::ListAssistants( + const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, std::optional order, + std::optional after, std::optional before) const { + + auto res = assistant_service_->ListAssistants( + std::stoi(limit.value_or("20")), order.value_or("desc"), + after.value_or(""), before.value_or("")); + if (res.has_error()) { + Json::Value root; + root["message"] = res.error(); + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + + Json::Value assistant_list(Json::arrayValue); + for (auto& msg : res.value()) { + if (auto it = msg.ToJson(); it.has_value()) { + assistant_list.append(it.value()); + } else { + CTL_WRN("Failed to convert message to json: " + it.error()); + } + } + + Json::Value root; + root["object"] = "list"; + root["data"] = assistant_list; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k200OK); + callback(response); +} + +void Assistants::DeleteAssistant( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto res = assistant_service_->DeleteAssistantV2(assistant_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + api_response::DeleteSuccessResponse response; + response.id = assistant_id; + response.object = "assistant.deleted"; + response.deleted = true; + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(response.ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} diff --git a/engine/controllers/assistants.h b/engine/controllers/assistants.h index 94ddd14b1..30111bb01 100644 --- a/engine/controllers/assistants.h +++ b/engine/controllers/assistants.h @@ -7,33 +7,72 @@ using namespace drogon; class Assistants : public drogon::HttpController { + constexpr static auto kOpenAiAssistantKeyV2 = "openai-beta"; + constexpr static auto kOpenAiAssistantValueV2 = "assistants=v2"; + public: METHOD_LIST_BEGIN + ADD_METHOD_TO( + Assistants::ListAssistants, + "/v1/" + "assistants?limit={limit}&order={order}&after={after}&before={before}", + Get); + + ADD_METHOD_TO(Assistants::DeleteAssistant, "/v1/assistants/{assistant_id}", + Options, Delete); + ADD_METHOD_TO(Assistants::RetrieveAssistant, "/v1/assistants/{assistant_id}", Get); ADD_METHOD_TO(Assistants::CreateAssistant, "/v1/assistants/{assistant_id}", Options, Post); + ADD_METHOD_TO(Assistants::CreateAssistantV2, "/v1/assistants", Options, Post); + ADD_METHOD_TO(Assistants::ModifyAssistant, "/v1/assistants/{assistant_id}", Options, Patch); + METHOD_LIST_END explicit Assistants(std::shared_ptr assistant_srv) : assistant_service_{assistant_srv} {}; + void ListAssistants(const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, + std::optional order, + std::optional after, + std::optional before) const; + void RetrieveAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id) const; + void RetrieveAssistantV2( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const; + + void DeleteAssistant(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + void CreateAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id); + void CreateAssistantV2( + const HttpRequestPtr& req, + std::function&& callback); + void ModifyAssistant(const HttpRequestPtr& req, std::function&& callback, const std::string& assistant_id); + void ModifyAssistantV2(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + private: std::shared_ptr assistant_service_; }; diff --git a/engine/controllers/hardware.h b/engine/controllers/hardware.h index 6cca4fd2a..8b2b551ce 100644 --- a/engine/controllers/hardware.h +++ b/engine/controllers/hardware.h @@ -9,7 +9,7 @@ using namespace drogon; class Hardware : public drogon::HttpController { public: explicit Hardware(std::shared_ptr engine_svc, - std::shared_ptr hw_svc) + std::shared_ptr hw_svc) : engine_svc_(engine_svc), hw_svc_(hw_svc) {} METHOD_LIST_BEGIN METHOD_ADD(Hardware::GetHardwareInfo, "/hardware", Get); @@ -27,5 +27,5 @@ class Hardware : public drogon::HttpController { private: std::shared_ptr engine_svc_ = nullptr; - std::shared_ptr hw_svc_= nullptr; + std::shared_ptr hw_svc_= nullptr; }; \ No newline at end of file diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 3ad94467c..d6b985ffb 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -165,10 +165,9 @@ void Models::ListModel( model_service_->ForceIndexingModelList(); // Iterate through directory - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; - auto list_entry = modellist_handler.LoadModelList(); + auto list_entry = db_service_->LoadModelList(); if (list_entry) { for (const auto& model_entry : list_entry.value()) { try { @@ -266,9 +265,8 @@ void Models::GetModel(const HttpRequestPtr& req, Json::Value ret; try { - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; - auto model_entry = modellist_handler.GetModelInfo(model_id); + auto model_entry = db_service_->GetModelInfo(model_id); if (model_entry.has_error()) { ret["id"] = model_id; ret["object"] = "model"; @@ -361,8 +359,7 @@ void Models::UpdateModel(const HttpRequestPtr& req, namespace fmu = file_manager_utils; auto json_body = *(req->getJsonObject()); try { - cortex::db::Models model_list_utils; - auto model_entry = model_list_utils.GetModelInfo(model_id); + auto model_entry = db_service_->GetModelInfo(model_id); config::YamlHandler yaml_handler; auto yaml_fp = fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.value().path_to_model_yaml)); @@ -432,7 +429,6 @@ void Models::ImportModel( auto option = (*(req->getJsonObject())).get("option", "symlink").asString(); config::GGUFHandler gguf_handler; config::YamlHandler yaml_handler; - cortex::db::Models modellist_utils_obj; std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() / std::filesystem::path("imported") / std::filesystem::path(modelHandle + ".yml")) @@ -471,7 +467,7 @@ void Models::ImportModel( model_config.name = modelName.empty() ? model_config.name : modelName; yaml_handler.UpdateModelConfig(model_config); - if (modellist_utils_obj.AddModelEntry(model_entry).value()) { + if (db_service_->AddModelEntry(model_entry).value()) { yaml_handler.WriteYamlFile(model_yaml_path); std::string success_message = "Model is imported successfully!"; LOG_INFO << success_message; @@ -698,7 +694,6 @@ void Models::AddRemoteModel( config::RemoteModelConfig model_config; model_config.LoadFromJson(*(req->getJsonObject())); - cortex::db::Models modellist_utils_obj; std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() / std::filesystem::path("remote") / std::filesystem::path(model_handle + ".yml")) @@ -714,7 +709,7 @@ void Models::AddRemoteModel( "openai"}; std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); - if (modellist_utils_obj.AddModelEntry(model_entry).value()) { + if (db_service_->AddModelEntry(model_entry).value()) { model_config.SaveToYamlFile(model_yaml_path); std::string success_message = "Model is imported successfully!"; LOG_INFO << success_message; diff --git a/engine/controllers/models.h b/engine/controllers/models.h index d3200f33a..60053acdb 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -45,10 +45,12 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::GetModelSources, "/v1/models/sources", Get); METHOD_LIST_END - explicit Models(std::shared_ptr model_service, + explicit Models(std::shared_ptr db_service, + std::shared_ptr model_service, std::shared_ptr engine_service, - std::shared_ptr mss) - : model_service_{model_service}, + std::shared_ptr mss) + : db_service_(db_service), + model_service_{model_service}, engine_service_{engine_service}, model_src_svc_(mss) {} @@ -105,7 +107,8 @@ class Models : public drogon::HttpController { std::function&& callback); private: + std::shared_ptr db_service_; std::shared_ptr model_service_; std::shared_ptr engine_service_; - std::shared_ptr model_src_svc_; + std::shared_ptr model_src_svc_; }; diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 1c455e262..cc5cee54a 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -8,7 +8,7 @@ using namespace inferences; namespace inferences { -server::server(std::shared_ptr inference_service, +server::server(std::shared_ptr inference_service, std::shared_ptr engine_service) : inference_svc_(inference_service), engine_service_(engine_service) { #if defined(_WIN32) @@ -45,7 +45,7 @@ void server::ChatCompletion( }(); LOG_DEBUG << "request body: " << json_body->toStyledString(); - auto q = std::make_shared(); + auto q = std::make_shared(); auto ir = inference_svc_->HandleChatCompletion(q, json_body); if (ir.has_error()) { auto err = ir.error(); @@ -67,7 +67,7 @@ void server::ChatCompletion( void server::Embedding(const HttpRequestPtr& req, std::function&& callback) { LOG_TRACE << "Start embedding"; - auto q = std::make_shared(); + auto q = std::make_shared(); auto ir = inference_svc_->HandleEmbedding(q, req->getJsonObject()); if (ir.has_error()) { auto err = ir.error(); @@ -188,7 +188,7 @@ void server::LoadModel(const HttpRequestPtr& req, } void server::ProcessStreamRes(std::function cb, - std::shared_ptr q, + std::shared_ptr q, const std::string& engine_type, const std::string& model_id) { auto err_or_done = std::make_shared(false); @@ -228,7 +228,7 @@ void server::ProcessStreamRes(std::function cb, } void server::ProcessNonStreamRes(std::function cb, - services::SyncQueue& q) { + SyncQueue& q) { auto [status, res] = q.wait_and_pop(); function_calling_utils::PostProcessResponse(res); LOG_DEBUG << "response: " << res.toStyledString(); diff --git a/engine/controllers/server.h b/engine/controllers/server.h index 5f2a14677..42214a641 100644 --- a/engine/controllers/server.h +++ b/engine/controllers/server.h @@ -27,7 +27,7 @@ class server : public drogon::HttpController, public BaseChatCompletion, public BaseEmbedding { public: - server(std::shared_ptr inference_service, + server(std::shared_ptr inference_service, std::shared_ptr engine_service); ~server(); METHOD_LIST_BEGIN @@ -79,14 +79,14 @@ class server : public drogon::HttpController, private: void ProcessStreamRes(std::function cb, - std::shared_ptr q, + std::shared_ptr q, const std::string& engine_type, const std::string& model_id); void ProcessNonStreamRes(std::function cb, - services::SyncQueue& q); + SyncQueue& q); private: - std::shared_ptr inference_svc_; + std::shared_ptr inference_svc_; std::shared_ptr engine_service_; }; }; // namespace inferences diff --git a/engine/main.cc b/engine/main.cc index ddf1eefd8..77f51c7fa 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -15,11 +15,13 @@ #include "controllers/threads.h" #include "database/database.h" #include "migrations/migration_manager.h" +#include "repositories/assistant_fs_repository.h" #include "repositories/file_fs_repository.h" #include "repositories/message_fs_repository.h" #include "repositories/thread_fs_repository.h" #include "services/assistant_service.h" #include "services/config_service.h" +#include "services/database_service.h" #include "services/file_watcher_service.h" #include "services/message_service.h" #include "services/model_service.h" @@ -119,7 +121,8 @@ void RunServer(std::optional host, std::optional port, LOG_INFO << "cortex.cpp version: undefined"; #endif - auto hw_service = std::make_shared(); + auto db_service = std::make_shared(); + auto hw_service = std::make_shared(db_service); hw_service->UpdateHardwareInfos(); if (hw_service->ShouldRestart()) { CTL_INF("Restart to update hardware configuration"); @@ -139,12 +142,16 @@ void RunServer(std::optional host, std::optional port, // utils auto dylib_path_manager = std::make_shared(); - auto file_repo = std::make_shared(data_folder_path); + auto file_repo = + std::make_shared(data_folder_path, db_service); auto msg_repo = std::make_shared(data_folder_path); auto thread_repo = std::make_shared(data_folder_path); + auto assistant_repo = + std::make_shared(data_folder_path); auto file_srv = std::make_shared(file_repo); - auto assistant_srv = std::make_shared(thread_repo); + auto assistant_srv = + std::make_shared(thread_repo, assistant_repo); auto thread_srv = std::make_shared(thread_repo); auto message_srv = std::make_shared(msg_repo); @@ -152,13 +159,12 @@ void RunServer(std::optional host, std::optional port, auto config_service = std::make_shared(); auto download_service = std::make_shared(event_queue_ptr, config_service); - auto engine_service = - std::make_shared(download_service, dylib_path_manager); - auto inference_svc = - std::make_shared(engine_service); - auto model_src_svc = std::make_shared(); + auto engine_service = std::make_shared( + download_service, dylib_path_manager, db_service); + auto inference_svc = std::make_shared(engine_service); + auto model_src_svc = std::make_shared(db_service); auto model_service = std::make_shared( - download_service, inference_svc, engine_service); + db_service, hw_service, download_service, inference_svc, engine_service); inference_svc->SetModelService(model_service); auto file_watcher_srv = std::make_shared( @@ -173,8 +179,8 @@ void RunServer(std::optional host, std::optional port, auto thread_ctl = std::make_shared(thread_srv, message_srv); auto message_ctl = std::make_shared(message_srv); auto engine_ctl = std::make_shared(engine_service); - auto model_ctl = - std::make_shared(model_service, engine_service, model_src_svc); + auto model_ctl = std::make_shared(db_service, model_service, + engine_service, model_src_svc); auto event_ctl = std::make_shared(event_queue_ptr); auto pm_ctl = std::make_shared(); auto hw_ctl = std::make_shared(engine_service, hw_service); diff --git a/engine/repositories/assistant_fs_repository.cc b/engine/repositories/assistant_fs_repository.cc new file mode 100644 index 000000000..87b4174fd --- /dev/null +++ b/engine/repositories/assistant_fs_repository.cc @@ -0,0 +1,214 @@ +#include "assistant_fs_repository.h" +#include +#include +#include +#include +#include "utils/result.hpp" + +cpp::result, std::string> +AssistantFsRepository::ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + std::vector assistants; + try { + auto assistant_container_path = + data_folder_path_ / kAssistantContainerFolderName; + std::vector all_assistants; + + for (const auto& entry : + std::filesystem::directory_iterator(assistant_container_path)) { + if (!entry.is_directory()) { + continue; + } + + auto assistant_file = entry.path() / kAssistantFileName; + if (!std::filesystem::exists(assistant_file)) { + continue; + } + + auto current_assistant_id = entry.path().filename().string(); + + if (!after.empty() && current_assistant_id <= after) { + continue; + } + + if (!before.empty() && current_assistant_id >= before) { + continue; + } + + std::shared_lock assistant_lock(GrabAssistantMutex(current_assistant_id)); + auto assistant_res = LoadAssistant(current_assistant_id); + if (assistant_res.has_value()) { + all_assistants.push_back(std::move(assistant_res.value())); + } + assistant_lock.unlock(); + } + + // sorting + if (order == "desc") { + std::sort(all_assistants.begin(), all_assistants.end(), + [](const OpenAi::Assistant& assistant1, + const OpenAi::Assistant& assistant2) { + return assistant1.created_at > assistant2.created_at; + }); + } else { + std::sort(all_assistants.begin(), all_assistants.end(), + [](const OpenAi::Assistant& assistant1, + const OpenAi::Assistant& assistant2) { + return assistant1.created_at < assistant2.created_at; + }); + } + + size_t assistant_count = + std::min(static_cast(limit), all_assistants.size()); + for (size_t i = 0; i < assistant_count; i++) { + assistants.push_back(std::move(all_assistants[i])); + } + + return assistants; + } catch (const std::exception& e) { + return cpp::fail("Failed to list assistants: " + std::string(e.what())); + } +} + +cpp::result +AssistantFsRepository::RetrieveAssistant(const std::string assistant_id) const { + std::shared_lock lock(GrabAssistantMutex(assistant_id)); + return LoadAssistant(assistant_id); +} + +cpp::result AssistantFsRepository::ModifyAssistant( + OpenAi::Assistant& assistant) { + { + std::unique_lock lock(GrabAssistantMutex(assistant.id)); + auto path = GetAssistantPath(assistant.id); + + if (!std::filesystem::exists(path)) { + lock.unlock(); + return cpp::fail("Assistant doesn't exist: " + assistant.id); + } + } + + return SaveAssistant(assistant); +} + +cpp::result AssistantFsRepository::DeleteAssistant( + const std::string& assitant_id) { + { + std::unique_lock assistant_lock(GrabAssistantMutex(assitant_id)); + auto path = GetAssistantPath(assitant_id); + if (!std::filesystem::exists(path)) { + return cpp::fail("Assistant doesn't exist: " + assitant_id); + } + try { + std::filesystem::remove_all(path); + } catch (const std::exception& e) { + return cpp::fail(""); + } + } + + std::unique_lock map_lock(map_mutex_); + assistant_mutexes_.erase(assitant_id); + return {}; +} + +cpp::result +AssistantFsRepository::CreateAssistant(OpenAi::Assistant& assistant) { + CTL_INF("CreateAssistant: " + assistant.id); + { + std::unique_lock lock(GrabAssistantMutex(assistant.id)); + auto path = GetAssistantPath(assistant.id); + + if (std::filesystem::exists(path)) { + return cpp::fail("Assistant already exists: " + assistant.id); + } + + std::filesystem::create_directories(path); + auto assistant_file_path = path / kAssistantFileName; + std::ofstream assistant_file(assistant_file_path); + assistant_file.close(); + + CTL_INF("CreateAssistant created new file: " + assistant.id); + auto save_result = SaveAssistant(assistant); + if (save_result.has_error()) { + lock.unlock(); + return cpp::fail("Failed to save assistant: " + save_result.error()); + } + } + return RetrieveAssistant(assistant.id); +} + +cpp::result AssistantFsRepository::SaveAssistant( + OpenAi::Assistant& assistant) { + auto path = GetAssistantPath(assistant.id) / kAssistantFileName; + if (!std::filesystem::exists(path)) { + std::filesystem::create_directories(path); + } + + std::ofstream file(path); + if (!file) { + return cpp::fail("Failed to open file: " + path.string()); + } + try { + file << assistant.ToJson()->toStyledString(); + file.flush(); + file.close(); + return {}; + } catch (const std::exception& e) { + file.close(); + return cpp::fail("Failed to save assistant: " + std::string(e.what())); + } +} + +std::filesystem::path AssistantFsRepository::GetAssistantPath( + const std::string& assistant_id) const { + auto container_folder_path = + data_folder_path_ / kAssistantContainerFolderName; + if (!std::filesystem::exists(container_folder_path)) { + std::filesystem::create_directories(container_folder_path); + } + + return data_folder_path_ / kAssistantContainerFolderName / assistant_id; +} + +std::shared_mutex& AssistantFsRepository::GrabAssistantMutex( + const std::string& assistant_id) const { + std::shared_lock map_lock(map_mutex_); + auto it = assistant_mutexes_.find(assistant_id); + if (it != assistant_mutexes_.end()) { + return *it->second; + } + + map_lock.unlock(); + std::unique_lock map_write_lock(map_mutex_); + return *assistant_mutexes_ + .try_emplace(assistant_id, std::make_unique()) + .first->second; +} + +cpp::result +AssistantFsRepository::LoadAssistant(const std::string& assistant_id) const { + auto path = GetAssistantPath(assistant_id) / kAssistantFileName; + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + + try { + std::ifstream file(path); + if (!file.is_open()) { + return cpp::fail("Failed to open file: " + path.string()); + } + + Json::Value root; + Json::CharReaderBuilder builder; + JSONCPP_STRING errs; + + if (!parseFromStream(builder, file, &root, &errs)) { + return cpp::fail("Failed to parse JSON: " + errs); + } + + return OpenAi::Assistant::FromJson(std::move(root)); + } catch (const std::exception& e) { + return cpp::fail("Failed to load assistant: " + std::string(e.what())); + } +} diff --git a/engine/repositories/assistant_fs_repository.h b/engine/repositories/assistant_fs_repository.h new file mode 100644 index 000000000..f310bd54e --- /dev/null +++ b/engine/repositories/assistant_fs_repository.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include + +#include "common/repository/assistant_repository.h" + +class AssistantFsRepository : public AssistantRepository { + public: + constexpr static auto kAssistantContainerFolderName = "assistants"; + constexpr static auto kAssistantFileName = "assistant.json"; + + cpp::result, std::string> ListAssistants( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const override; + + cpp::result CreateAssistant( + OpenAi::Assistant& assistant) override; + + cpp::result RetrieveAssistant( + const std::string assistant_id) const override; + + cpp::result ModifyAssistant( + OpenAi::Assistant& assistant) override; + + cpp::result DeleteAssistant( + const std::string& assitant_id) override; + + explicit AssistantFsRepository(const std::filesystem::path& data_folder_path) + : data_folder_path_{data_folder_path} { + CTL_INF("Constructing AssistantFsRepository.."); + auto path = data_folder_path_ / kAssistantContainerFolderName; + + if (!std::filesystem::exists(path)) { + std::filesystem::create_directories(path); + } + } + + ~AssistantFsRepository() = default; + + private: + std::filesystem::path GetAssistantPath(const std::string& assistant_id) const; + + std::shared_mutex& GrabAssistantMutex(const std::string& assistant_id) const; + + cpp::result SaveAssistant(OpenAi::Assistant& assistant); + + cpp::result LoadAssistant( + const std::string& assistant_id) const; + + /** + * The path to the data folder. + */ + std::filesystem::path data_folder_path_; + + mutable std::shared_mutex map_mutex_; + mutable std::unordered_map> + assistant_mutexes_; +}; diff --git a/engine/repositories/file_fs_repository.cc b/engine/repositories/file_fs_repository.cc index a209d33c3..e6c28b38e 100644 --- a/engine/repositories/file_fs_repository.cc +++ b/engine/repositories/file_fs_repository.cc @@ -17,7 +17,6 @@ cpp::result FileFsRepository::StoreFile( std::filesystem::create_directories(file_container_path); } - cortex::db::File db; auto original_filename = file_metadata.filename; auto file_full_path = file_container_path / original_filename; @@ -53,7 +52,7 @@ cpp::result FileFsRepository::StoreFile( file.flush(); file.close(); - auto result = db.AddFileEntry(file_metadata); + auto result = db_service_->AddFileEntry(file_metadata); if (result.has_error()) { std::filesystem::remove(file_full_path); return cpp::fail(result.error()); @@ -70,8 +69,7 @@ cpp::result FileFsRepository::StoreFile( cpp::result, std::string> FileFsRepository::ListFiles( const std::string& purpose, uint8_t limit, const std::string& order, const std::string& after) const { - cortex::db::File db; - auto res = db.GetFileList(); + auto res = db_service_->GetFileList(); if (res.has_error()) { return cpp::fail(res.error()); } @@ -101,8 +99,7 @@ cpp::result FileFsRepository::RetrieveFile( CTL_INF("Retrieving file: " + file_id); auto file_container_path = GetFilePath(); - cortex::db::File db; - auto res = db.GetFileById(file_id); + auto res = db_service_->GetFileById(file_id); if (res.has_error()) { return cpp::fail(res.error()); } @@ -158,15 +155,14 @@ cpp::result FileFsRepository::DeleteFileLocal( const std::string& file_id) { CTL_INF("Deleting file: " + file_id); auto file_container_path = GetFilePath(); - cortex::db::File db; - auto file_metadata = db.GetFileById(file_id); + auto file_metadata = db_service_->GetFileById(file_id); if (file_metadata.has_error()) { return cpp::fail(file_metadata.error()); } auto file_path = file_container_path / file_metadata->filename; - auto res = db.DeleteFileEntry(file_id); + auto res = db_service_->DeleteFileEntry(file_id); if (res.has_error()) { CTL_ERR("Failed to delete file entry: " << res.error()); return cpp::fail(res.error()); diff --git a/engine/repositories/file_fs_repository.h b/engine/repositories/file_fs_repository.h index 974e81fa4..e2ad424a7 100644 --- a/engine/repositories/file_fs_repository.h +++ b/engine/repositories/file_fs_repository.h @@ -2,6 +2,7 @@ #include #include "common/repository/file_repository.h" +#include "services/database_service.h" #include "utils/logging_utils.h" class FileFsRepository : public FileRepository { @@ -28,8 +29,9 @@ class FileFsRepository : public FileRepository { cpp::result DeleteFileLocal( const std::string& file_id) override; - explicit FileFsRepository(std::filesystem::path data_folder_path) - : data_folder_path_{data_folder_path} { + explicit FileFsRepository(const std::filesystem::path& data_folder_path, + std::shared_ptr db_service) + : data_folder_path_{data_folder_path}, db_service_(db_service) { CTL_INF("Constructing FileFsRepository.."); auto file_container_path = data_folder_path_ / kFileContainerFolderName; @@ -47,4 +49,5 @@ class FileFsRepository : public FileRepository { * The path to the data folder. */ std::filesystem::path data_folder_path_; + std::shared_ptr db_service_ = nullptr; }; diff --git a/engine/repositories/message_fs_repository.h b/engine/repositories/message_fs_repository.h index 2146778bf..0ca6e89b3 100644 --- a/engine/repositories/message_fs_repository.h +++ b/engine/repositories/message_fs_repository.h @@ -32,7 +32,7 @@ class MessageFsRepository : public MessageRepository { const std::string& thread_id, std::optional> messages) override; - explicit MessageFsRepository(std::filesystem::path data_folder_path) + explicit MessageFsRepository(const std::filesystem::path& data_folder_path) : data_folder_path_{data_folder_path} { CTL_INF("Constructing MessageFsRepository.."); auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; diff --git a/engine/services/assistant_service.cc b/engine/services/assistant_service.cc index e769bf23f..08a5a743f 100644 --- a/engine/services/assistant_service.cc +++ b/engine/services/assistant_service.cc @@ -1,5 +1,7 @@ #include "assistant_service.h" +#include #include "utils/logging_utils.h" +#include "utils/ulid_generator.h" cpp::result AssistantService::CreateAssistant(const std::string& thread_id, @@ -26,3 +28,181 @@ AssistantService::ModifyAssistant(const std::string& thread_id, CTL_INF("RetrieveAssistant: " + thread_id); return thread_repository_->ModifyAssistant(thread_id, assistant); } + +cpp::result, std::string> +AssistantService::ListAssistants(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + CTL_INF("List assistants invoked"); + return assistant_repository_->ListAssistants(limit, order, after, before); +} + +cpp::result AssistantService::CreateAssistantV2( + const dto::CreateAssistantDto& create_dto) { + + OpenAi::Assistant assistant; + assistant.id = "asst_" + ulid::GenerateUlid(); + assistant.model = create_dto.model; + if (create_dto.name) { + assistant.name = *create_dto.name; + } + if (create_dto.description) { + assistant.description = *create_dto.description; + } + if (create_dto.instructions) { + assistant.instructions = *create_dto.instructions; + } + if (create_dto.metadata) { + assistant.metadata = *create_dto.metadata; + } + if (create_dto.temperature) { + assistant.temperature = *create_dto.temperature; + } + if (create_dto.top_p) { + assistant.top_p = *create_dto.top_p; + } + for (auto& tool_ptr : create_dto.tools) { + // Create a new unique_ptr in assistant.tools that takes ownership + if (auto* function_tool = + dynamic_cast(tool_ptr.get())) { + assistant.tools.push_back(std::make_unique( + std::move(*function_tool))); + } else if (auto* code_tool = + dynamic_cast( + tool_ptr.get())) { + assistant.tools.push_back( + std::make_unique( + std::move(*code_tool))); + } else if (auto* search_tool = + dynamic_cast( + tool_ptr.get())) { + assistant.tools.push_back( + std::make_unique( + std::move(*search_tool))); + } + } + if (create_dto.tool_resources) { + if (auto* code_interpreter = dynamic_cast( + create_dto.tool_resources.get())) { + assistant.tool_resources = std::make_unique( + std::move(*code_interpreter)); + } else if (auto* file_search = dynamic_cast( + create_dto.tool_resources.get())) { + assistant.tool_resources = + std::make_unique(std::move(*file_search)); + } + } + if (create_dto.response_format) { + assistant.response_format = *create_dto.response_format; + } + auto seconds_since_epoch = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + assistant.created_at = seconds_since_epoch; + return assistant_repository_->CreateAssistant(assistant); +} +cpp::result +AssistantService::RetrieveAssistantV2(const std::string& assistant_id) const { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + return assistant_repository_->RetrieveAssistant(assistant_id); +} + +cpp::result AssistantService::ModifyAssistantV2( + const std::string& assistant_id, + const dto::UpdateAssistantDto& update_dto) { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + if (!update_dto.Validate()) { + return cpp::fail("Invalid update assistant dto"); + } + + // First retrieve the existing assistant + auto existing_assistant = + assistant_repository_->RetrieveAssistant(assistant_id); + if (existing_assistant.has_error()) { + return cpp::fail(existing_assistant.error()); + } + + OpenAi::Assistant updated_assistant; + updated_assistant.id = assistant_id; + + // Update fields if they are present in the DTO + if (update_dto.model) { + updated_assistant.model = *update_dto.model; + } + if (update_dto.name) { + updated_assistant.name = *update_dto.name; + } + if (update_dto.description) { + updated_assistant.description = *update_dto.description; + } + if (update_dto.instructions) { + updated_assistant.instructions = *update_dto.instructions; + } + if (update_dto.metadata) { + updated_assistant.metadata = *update_dto.metadata; + } + if (update_dto.temperature) { + updated_assistant.temperature = *update_dto.temperature; + } + if (update_dto.top_p) { + updated_assistant.top_p = *update_dto.top_p; + } + for (auto& tool_ptr : update_dto.tools) { + if (auto* function_tool = + dynamic_cast(tool_ptr.get())) { + updated_assistant.tools.push_back( + std::make_unique( + std::move(*function_tool))); + } else if (auto* code_tool = + dynamic_cast( + tool_ptr.get())) { + updated_assistant.tools.push_back( + std::make_unique( + std::move(*code_tool))); + } else if (auto* search_tool = + dynamic_cast( + tool_ptr.get())) { + updated_assistant.tools.push_back( + std::make_unique( + std::move(*search_tool))); + } + } + if (update_dto.tool_resources) { + if (auto* code_interpreter = dynamic_cast( + update_dto.tool_resources.get())) { + updated_assistant.tool_resources = + std::make_unique( + std::move(*code_interpreter)); + } else if (auto* file_search = dynamic_cast( + update_dto.tool_resources.get())) { + updated_assistant.tool_resources = + std::make_unique(std::move(*file_search)); + } + } + if (update_dto.response_format) { + updated_assistant.response_format = *update_dto.response_format; + } + + auto res = assistant_repository_->ModifyAssistant(updated_assistant); + if (res.has_error()) { + return cpp::fail(res.error()); + } + + return updated_assistant; +} + +cpp::result AssistantService::DeleteAssistantV2( + const std::string& assistant_id) { + if (assistant_id.empty()) { + return cpp::fail("Assistant ID cannot be empty"); + } + + return assistant_repository_->DeleteAssistant(assistant_id); +} diff --git a/engine/services/assistant_service.h b/engine/services/assistant_service.h index e7f7414d1..ad31104ff 100644 --- a/engine/services/assistant_service.h +++ b/engine/services/assistant_service.h @@ -1,15 +1,14 @@ #pragma once #include "common/assistant.h" +#include "common/dto/assistant_create_dto.h" +#include "common/dto/assistant_update_dto.h" +#include "common/repository/assistant_repository.h" #include "repositories/thread_fs_repository.h" #include "utils/result.hpp" class AssistantService { public: - explicit AssistantService( - std::shared_ptr thread_repository) - : thread_repository_{thread_repository} {} - cpp::result CreateAssistant( const std::string& thread_id, const OpenAi::JanAssistant& assistant); @@ -19,6 +18,31 @@ class AssistantService { cpp::result ModifyAssistant( const std::string& thread_id, const OpenAi::JanAssistant& assistant); + // V2 + cpp::result CreateAssistantV2( + const dto::CreateAssistantDto& create_dto); + + cpp::result, std::string> ListAssistants( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const; + + cpp::result RetrieveAssistantV2( + const std::string& assistant_id) const; + + cpp::result ModifyAssistantV2( + const std::string& assistant_id, + const dto::UpdateAssistantDto& update_dto); + + cpp::result DeleteAssistantV2( + const std::string& assistant_id); + + explicit AssistantService( + std::shared_ptr thread_repository, + std::shared_ptr assistant_repository) + : thread_repository_{thread_repository}, + assistant_repository_{assistant_repository} {} + private: std::shared_ptr thread_repository_; + std::shared_ptr assistant_repository_; }; diff --git a/engine/services/database_service.cc b/engine/services/database_service.cc new file mode 100644 index 000000000..d4cd977a9 --- /dev/null +++ b/engine/services/database_service.cc @@ -0,0 +1,130 @@ +#include "database_service.h" + +// begin engines +std::optional DatabaseService::UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata) { + return cortex::db::Engines().UpsertEngine(engine_name, type, api_key, url, + version, variant, status, metadata); +} + +std::optional> DatabaseService::GetEngines() const { + return cortex::db::Engines().GetEngines(); +} + +std::optional DatabaseService::GetEngineById(int id) const { + return cortex::db::Engines().GetEngineById(id); +} + +std::optional DatabaseService::GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant) const { + return cortex::db::Engines().GetEngineByNameAndVariant(engine_name, variant); +} + +std::optional DatabaseService::DeleteEngineById(int id) { + return cortex::db::Engines().DeleteEngineById(id); +} +// end engines + +// begin file +cpp::result, std::string> +DatabaseService::GetFileList() const { + return cortex::db::File().GetFileList(); +} + +cpp::result DatabaseService::GetFileById( + const std::string& file_id) const { + return cortex::db::File().GetFileById(file_id); +} + +cpp::result DatabaseService::AddFileEntry( + OpenAi::File& file) { + return cortex::db::File().AddFileEntry(file); +} + +cpp::result DatabaseService::DeleteFileEntry( + const std::string& file_id) { + return cortex::db::File().DeleteFileEntry(file_id); +} +// end file + +// begin hardware +cpp::result, std::string> +DatabaseService::LoadHardwareList() const { + return cortex::db::Hardware().LoadHardwareList(); +} + +cpp::result DatabaseService::AddHardwareEntry( + const HardwareEntry& new_entry) { + return cortex::db::Hardware().AddHardwareEntry(new_entry); +} + +cpp::result DatabaseService::UpdateHardwareEntry( + const std::string& id, const HardwareEntry& updated_entry) { + return cortex::db::Hardware().UpdateHardwareEntry(id, updated_entry); +} + +cpp::result DatabaseService::DeleteHardwareEntry( + const std::string& id) { + return cortex::db::Hardware().DeleteHardwareEntry(id); +} +// end hardware + +// begin models +cpp::result, std::string> +DatabaseService::LoadModelList() const { + return cortex::db::Models().LoadModelList(); +} + +cpp::result DatabaseService::GetModelInfo( + const std::string& identifier) const { + return cortex::db::Models().GetModelInfo(identifier); +} + +cpp::result DatabaseService::AddModelEntry( + ModelEntry new_entry) { + return cortex::db::Models().AddModelEntry(new_entry); +} + +cpp::result DatabaseService::UpdateModelEntry( + const std::string& identifier, const ModelEntry& updated_entry) { + return cortex::db::Models().UpdateModelEntry(identifier, updated_entry); +} + +cpp::result DatabaseService::DeleteModelEntry( + const std::string& identifier) { + return cortex::db::Models().DeleteModelEntry(identifier); +} + +cpp::result DatabaseService::DeleteModelEntryWithOrg( + const std::string& src) { + return cortex::db::Models().DeleteModelEntryWithOrg(src); +} + +cpp::result DatabaseService::DeleteModelEntryWithRepo( + const std::string& src) { + return cortex::db::Models().DeleteModelEntryWithRepo(src); +} + +cpp::result, std::string> +DatabaseService::FindRelatedModel(const std::string& identifier) const { + return cortex::db::Models().FindRelatedModel(identifier); +} + +bool DatabaseService::HasModel(const std::string& identifier) const { + return cortex::db::Models().HasModel(identifier); +} + +cpp::result, std::string> +DatabaseService::GetModelSources() const { + return cortex::db::Models().GetModelSources(); +} + +cpp::result, std::string> DatabaseService::GetModels( + const std::string& model_src) const { + return cortex::db::Models().GetModels(model_src); +} +// end models \ No newline at end of file diff --git a/engine/services/database_service.h b/engine/services/database_service.h new file mode 100644 index 000000000..4fb4f7be0 --- /dev/null +++ b/engine/services/database_service.h @@ -0,0 +1,68 @@ +#pragma once +#include "database/engines.h" +#include "database/file.h" +#include "database/hardware.h" +#include "database/models.h" + +using EngineEntry = cortex::db::EngineEntry; +using HardwareEntry = cortex::db::HardwareEntry; +using ModelEntry = cortex::db::ModelEntry; + +class DatabaseService { + public: + // engines + std::optional UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata); + + std::optional> GetEngines() const; + std::optional GetEngineById(int id) const; + std::optional GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant = std::nullopt) const; + + std::optional DeleteEngineById(int id); + + // file + cpp::result, std::string> GetFileList() const; + + cpp::result GetFileById( + const std::string& file_id) const; + + cpp::result AddFileEntry(OpenAi::File& file); + + cpp::result DeleteFileEntry(const std::string& file_id); + + // hardware + cpp::result, std::string> LoadHardwareList() const; + cpp::result AddHardwareEntry( + const HardwareEntry& new_entry); + cpp::result UpdateHardwareEntry( + const std::string& id, const HardwareEntry& updated_entry); + cpp::result DeleteHardwareEntry(const std::string& id); + + // models + cpp::result, std::string> LoadModelList() const; + cpp::result GetModelInfo( + const std::string& identifier) const; + void PrintModelInfo(const ModelEntry& entry) const; + cpp::result AddModelEntry(ModelEntry new_entry); + cpp::result UpdateModelEntry( + const std::string& identifier, const ModelEntry& updated_entry); + cpp::result DeleteModelEntry( + const std::string& identifier); + cpp::result DeleteModelEntryWithOrg( + const std::string& src); + cpp::result DeleteModelEntryWithRepo( + const std::string& src); + cpp::result, std::string> FindRelatedModel( + const std::string& identifier) const; + bool HasModel(const std::string& identifier) const; + cpp::result, std::string> GetModelSources() const; + cpp::result, std::string> GetModels( + const std::string& model_src) const; + + private: +}; \ No newline at end of file diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index b6e097c0f..39e6e7961 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -13,7 +13,6 @@ #include "extensions/remote-engine/remote_engine.h" #include "utils/archive_utils.h" -#include "utils/cpuid/cpu_info.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" #include "utils/file_manager_utils.h" @@ -137,8 +136,8 @@ cpp::result EngineService::UnzipEngine( CTL_INF("Found cuda variant, extract it"); found_cuda = true; // extract binary - auto cuda_path = - file_manager_utils::GetCudaToolkitPath(NormalizeEngine(engine)); + auto cuda_path = file_manager_utils::GetCudaToolkitPath( + NormalizeEngine(engine), true); archive_utils::ExtractArchive(path + "/" + cf, cuda_path.string(), true); } @@ -370,10 +369,10 @@ cpp::result EngineService::DownloadEngine( }; auto downloadTask = - DownloadTask{.id = engine, + DownloadTask{.id = selected_variant->name, .type = DownloadType::Engine, .items = {DownloadItem{ - .id = engine, + .id = selected_variant->name, .downloadUrl = selected_variant->browser_download_url, .localPath = variant_path, }}}; @@ -440,7 +439,8 @@ cpp::result EngineService::DownloadCuda( }}; auto on_finished = [engine](const DownloadTask& finishedTask) { - auto engine_path = file_manager_utils::GetCudaToolkitPath(engine); + auto engine_path = file_manager_utils::GetCudaToolkitPath(engine, true); + archive_utils::ExtractArchive(finishedTask.items[0].localPath.string(), engine_path.string()); try { @@ -1050,8 +1050,8 @@ cpp::result EngineService::UpdateEngine( cpp::result, std::string> EngineService::GetEngines() { - cortex::db::Engines engines; - auto get_res = engines.GetEngines(); + assert(db_service_); + auto get_res = db_service_->GetEngines(); if (!get_res.has_value()) { return cpp::fail("Failed to get engine entries"); @@ -1062,8 +1062,8 @@ EngineService::GetEngines() { cpp::result EngineService::GetEngineById( int id) { - cortex::db::Engines engines; - auto get_res = engines.GetEngineById(id); + assert(db_service_); + auto get_res = db_service_->GetEngineById(id); if (!get_res.has_value()) { return cpp::fail("Engine with ID " + std::to_string(id) + " not found"); @@ -1076,8 +1076,8 @@ cpp::result EngineService::GetEngineByNameAndVariant( const std::string& engine_name, const std::optional variant) { - cortex::db::Engines engines; - auto get_res = engines.GetEngineByNameAndVariant(engine_name, variant); + assert(db_service_); + auto get_res = db_service_->GetEngineByNameAndVariant(engine_name, variant); if (!get_res.has_value()) { if (variant.has_value()) { @@ -1096,9 +1096,9 @@ cpp::result EngineService::UpsertEngine( const std::string& api_key, const std::string& url, const std::string& version, const std::string& variant, const std::string& status, const std::string& metadata) { - cortex::db::Engines engines; - auto upsert_res = engines.UpsertEngine(engine_name, type, api_key, url, - version, variant, status, metadata); + assert(db_service_); + auto upsert_res = db_service_->UpsertEngine( + engine_name, type, api_key, url, version, variant, status, metadata); if (upsert_res.has_value()) { return upsert_res.value(); } else { @@ -1107,8 +1107,8 @@ cpp::result EngineService::UpsertEngine( } std::string EngineService::DeleteEngine(int id) { - cortex::db::Engines engines; - auto delete_res = engines.DeleteEngineById(id); + assert(db_service_); + auto delete_res = db_service_->DeleteEngineById(id); if (delete_res.has_value()) { return delete_res.value(); } else { diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 0ef1a3284..a460582c6 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -12,6 +12,7 @@ #include "cortex-common/cortexpythoni.h" #include "cortex-common/remote_enginei.h" #include "database/engines.h" +#include "services/database_service.h" #include "services/download_service.h" #include "utils/cpuid/cpu_info.h" #include "utils/dylib.h" @@ -59,16 +60,19 @@ class EngineService : public EngineServiceI { std::string cuda_driver_version; }; HardwareInfo hw_inf_; + std::shared_ptr db_service_ = nullptr; public: explicit EngineService( std::shared_ptr download_service, - std::shared_ptr dylib_path_manager) + std::shared_ptr dylib_path_manager, + std::shared_ptr db_service) : download_service_{download_service}, dylib_path_manager_{dylib_path_manager}, hw_inf_{.sys_inf = system_info_utils::GetSystemInfo(), .cuda_driver_version = - system_info_utils::GetDriverAndCudaVersion().second} {} + system_info_utils::GetDriverAndCudaVersion().second}, + db_service_(db_service) {} std::vector GetEngineInfoList() const; diff --git a/engine/services/hardware_service.cc b/engine/services/hardware_service.cc index ca2bd8ed9..5552aca56 100644 --- a/engine/services/hardware_service.cc +++ b/engine/services/hardware_service.cc @@ -11,8 +11,6 @@ #include "database/hardware.h" #include "utils/cortex_utils.h" -namespace services { - namespace { bool TryConnectToServer(const std::string& host, int port) { constexpr const auto kMaxRetry = 4u; @@ -34,9 +32,8 @@ bool TryConnectToServer(const std::string& host, int port) { HardwareInfo HardwareService::GetHardwareInfo() { // append active state - cortex::db::Hardware hw_db; auto gpus = cortex::hw::GetGPUInfo(); - auto res = hw_db.LoadHardwareList(); + auto res = db_service_->LoadHardwareList(); if (res.has_value()) { // Only a few elements, brute-force is enough for (auto& entry : res.value()) { @@ -210,7 +207,6 @@ bool HardwareService::SetActivateHardwareConfig( const cortex::hw::ActivateHardwareConfig& ahc) { // Note: need to map software_id and hardware_id // Update to db - cortex::db::Hardware hw_db; // copy all gpu information to new vector auto ahc_gpus = ahc.gpus; auto activate = [&ahc](int software_id) { @@ -225,7 +221,7 @@ bool HardwareService::SetActivateHardwareConfig( return INT_MAX; }; - auto res = hw_db.LoadHardwareList(); + auto res = db_service_->LoadHardwareList(); if (res.has_value()) { bool need_update = false; std::vector> activated_ids; @@ -258,7 +254,7 @@ bool HardwareService::SetActivateHardwareConfig( for (auto& e : res.value()) { e.activated = activate(e.software_id); e.priority = priority(e.software_id); - auto res = hw_db.UpdateHardwareEntry(e.uuid, e); + auto res = db_service_->UpdateHardwareEntry(e.uuid, e); if (res.has_error()) { CTL_WRN(res.error()); } @@ -271,8 +267,7 @@ bool HardwareService::SetActivateHardwareConfig( void HardwareService::UpdateHardwareInfos() { using HwEntry = cortex::db::HardwareEntry; auto gpus = cortex::hw::GetGPUInfo(); - cortex::db::Hardware hw_db; - auto b = hw_db.LoadHardwareList(); + auto b = db_service_->LoadHardwareList(); std::vector> activated_gpu_bf; std::string debug_b; for (auto const& he : b.value()) { @@ -285,7 +280,8 @@ void HardwareService::UpdateHardwareInfos() { for (auto const& gpu : gpus) { // ignore error // Note: only support NVIDIA for now, so hardware_id = software_id - auto res = hw_db.AddHardwareEntry(HwEntry{.uuid = gpu.uuid, + auto res = + db_service_->AddHardwareEntry(HwEntry{.uuid = gpu.uuid, .type = "gpu", .hardware_id = std::stoi(gpu.id), .software_id = std::stoi(gpu.id), @@ -296,7 +292,7 @@ void HardwareService::UpdateHardwareInfos() { } } - auto a = hw_db.LoadHardwareList(); + auto a = db_service_->LoadHardwareList(); std::vector a_gpu; std::vector> activated_gpu_af; std::string debug_a; @@ -350,11 +346,10 @@ bool HardwareService::IsValidConfig( const cortex::hw::ActivateHardwareConfig& ahc) { if (ahc.gpus.empty()) return true; - cortex::db::Hardware hw_db; auto is_valid = [&ahc](int software_id) { return std::count(ahc.gpus.begin(), ahc.gpus.end(), software_id) > 0; }; - auto res = hw_db.LoadHardwareList(); + auto res = db_service_->LoadHardwareList(); if (res.has_value()) { for (auto const& e : res.value()) { if (is_valid(e.software_id)) { @@ -364,4 +359,3 @@ bool HardwareService::IsValidConfig( } return false; } -} // namespace services diff --git a/engine/services/hardware_service.h b/engine/services/hardware_service.h index 48ab7a4b1..ad9d70233 100644 --- a/engine/services/hardware_service.h +++ b/engine/services/hardware_service.h @@ -4,6 +4,7 @@ #include #include "common/hardware_config.h" +#include "database_service.h" #include "utils/hardware/cpu_info.h" #include "utils/hardware/gpu_info.h" #include "utils/hardware/os_info.h" @@ -11,8 +12,6 @@ #include "utils/hardware/ram_info.h" #include "utils/hardware/storage_info.h" -namespace services { - struct HardwareInfo { cortex::hw::CPU cpu; cortex::hw::OS os; @@ -24,6 +23,8 @@ struct HardwareInfo { class HardwareService { public: + explicit HardwareService(std::shared_ptr db_service) + : db_service_(db_service) {} HardwareInfo GetHardwareInfo(); bool Restart(const std::string& host, int port); bool SetActivateHardwareConfig(const cortex::hw::ActivateHardwareConfig& ahc); @@ -32,6 +33,6 @@ class HardwareService { bool IsValidConfig(const cortex::hw::ActivateHardwareConfig& ahc); private: + std::shared_ptr db_service_ = nullptr; std::optional ahc_; -}; -} // namespace services +}; \ No newline at end of file diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index b0d8cf550..3668fb6fe 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -4,7 +4,6 @@ #include "utils/function_calling/common.h" #include "utils/jinja_utils.h" -namespace services { cpp::result InferenceService::HandleChatCompletion( std::shared_ptr q, std::shared_ptr json_body) { std::string engine_type; @@ -395,4 +394,3 @@ bool InferenceService::HasFieldInReq(std::shared_ptr json_body, } return true; } -} // namespace services diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index ec5f556f5..f23be3f23 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -8,8 +8,6 @@ #include "services/model_service.h" #include "utils/result.hpp" -namespace services { - // Status and result using InferResult = std::pair; @@ -75,4 +73,3 @@ class InferenceService { std::shared_ptr engine_service_; std::weak_ptr model_service_; }; -} // namespace services diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 1e5b739a9..34ca60b3b 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -26,7 +26,8 @@ #include "utils/widechar_conv.h" namespace { -void ParseGguf(const DownloadItem& ggufDownloadItem, +void ParseGguf(DatabaseService& db_service, + const DownloadItem& ggufDownloadItem, std::optional author, std::optional name, std::optional size) { @@ -69,8 +70,7 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, CTL_INF("path_to_model_yaml: " << rel.string()); auto author_id = author.has_value() ? author.value() : "cortexso"; - cortex::db::Models modellist_utils_obj; - if (!modellist_utils_obj.HasModel(ggufDownloadItem.id)) { + if (!db_service.HasModel(ggufDownloadItem.id)) { cortex::db::ModelEntry model_entry{ .model = ggufDownloadItem.id, .author_repo_id = author_id, @@ -78,18 +78,17 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, .path_to_model_yaml = rel.string(), .model_alias = ggufDownloadItem.id, .status = cortex::db::ModelStatus::Downloaded}; - auto result = modellist_utils_obj.AddModelEntry(model_entry); + auto result = db_service.AddModelEntry(model_entry); if (result.has_error()) { CTL_ERR("Error adding model to modellist: " + result.error()); } } else { - if (auto m = modellist_utils_obj.GetModelInfo(ggufDownloadItem.id); + if (auto m = db_service.GetModelInfo(ggufDownloadItem.id); m.has_value()) { auto upd_m = m.value(); upd_m.status = cortex::db::ModelStatus::Downloaded; - if (auto r = - modellist_utils_obj.UpdateModelEntry(ggufDownloadItem.id, upd_m); + if (auto r = db_service.UpdateModelEntry(ggufDownloadItem.id, upd_m); r.has_error()) { CTL_ERR(r.error()); } @@ -143,10 +142,9 @@ cpp::result GetDownloadTask( void ModelService::ForceIndexingModelList() { CTL_INF("Force indexing model list"); - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; - auto list_entry = modellist_handler.LoadModelList(); + auto list_entry = db_service_->LoadModelList(); if (list_entry.has_error()) { CTL_ERR("Failed to load model list: " << list_entry.error()); return; @@ -170,8 +168,7 @@ void ModelService::ForceIndexingModelList() { yaml_handler.Reset(); } catch (const std::exception& e) { // remove in db - auto remove_result = - modellist_handler.DeleteModelEntry(model_entry.model); + auto remove_result = db_service_->DeleteModelEntry(model_entry.model); // silently ignore result } } @@ -224,10 +221,8 @@ cpp::result ModelService::HandleCortexsoModel( auto default_model_branch = huggingface_utils::GetDefaultBranch(modelName); - cortex::db::Models modellist_handler; - auto downloaded_model_ids = - modellist_handler.FindRelatedModel(modelName).value_or( - std::vector{}); + auto downloaded_model_ids = db_service_->FindRelatedModel(modelName).value_or( + std::vector{}); std::vector avai_download_opts{}; for (const auto& branch : branches.value()) { @@ -267,9 +262,8 @@ cpp::result ModelService::HandleCortexsoModel( std::optional ModelService::GetDownloadedModel( const std::string& modelId) const { - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; - auto model_entry = modellist_handler.GetModelInfo(modelId); + auto model_entry = db_service_->GetModelInfo(modelId); if (!model_entry.has_value()) { return std::nullopt; } @@ -316,7 +310,6 @@ cpp::result ModelService::HandleDownloadUrlAsync( } std::string huggingFaceHost{kHuggingFaceHost}; - cortex::db::Models modellist_handler; std::string unique_model_id = ""; if (temp_model_id.has_value()) { unique_model_id = temp_model_id.value(); @@ -324,7 +317,7 @@ cpp::result ModelService::HandleDownloadUrlAsync( unique_model_id = author + ":" + model_id + ":" + file_name; } - auto model_entry = modellist_handler.GetModelInfo(unique_model_id); + auto model_entry = db_service_->GetModelInfo(unique_model_id); if (model_entry.has_value() && model_entry->status == cortex::db::ModelStatus::Downloaded) { CLI_LOG("Model already downloaded: " << unique_model_id); @@ -352,14 +345,15 @@ cpp::result ModelService::HandleDownloadUrlAsync( .localPath = local_path, }}}}; - auto on_finished = [author, temp_name](const DownloadTask& finishedTask) { + auto on_finished = [this, author, + temp_name](const DownloadTask& finishedTask) { // Sum downloadedBytes from all items uint64_t model_size = 0; for (const auto& item : finishedTask.items) { model_size = model_size + item.bytes.value_or(0); } auto gguf_download_item = finishedTask.items[0]; - ParseGguf(gguf_download_item, author, temp_name, model_size); + ParseGguf(*db_service_, gguf_download_item, author, temp_name, model_size); }; downloadTask.id = unique_model_id; @@ -372,11 +366,10 @@ ModelService::GetEstimation(const std::string& model_handle, int n_ubatch) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { - auto model_entry = modellist_handler.GetModelInfo(model_handle); + auto model_entry = db_service_->GetModelInfo(model_handle); if (model_entry.has_error()) { CTL_WRN("Error: " + model_entry.error()); return cpp::fail(model_entry.error()); @@ -390,8 +383,8 @@ ModelService::GetEstimation(const std::string& model_handle, fs::path(model_entry.value().path_to_model_yaml)) .string()); auto mc = yaml_handler.GetModelConfig(); - services::HardwareService hw_svc; - auto hw_info = hw_svc.GetHardwareInfo(); + assert(hw_service_); + auto hw_info = hw_service_->GetHardwareInfo(); auto free_vram_MiB = 0u; for (const auto& gpu : hw_info.gpus) { free_vram_MiB += gpu.free_vram; @@ -444,8 +437,7 @@ cpp::result ModelService::HandleUrl( std::string huggingFaceHost{kHuggingFaceHost}; std::string unique_model_id{author + ":" + model_id + ":" + file_name}; - cortex::db::Models modellist_handler; - auto model_entry = modellist_handler.GetModelInfo(unique_model_id); + auto model_entry = db_service_->GetModelInfo(unique_model_id); if (model_entry.has_value()) { CLI_LOG("Model already downloaded: " << unique_model_id); @@ -473,14 +465,14 @@ cpp::result ModelService::HandleUrl( .localPath = local_path, }}}}; - auto on_finished = [author](const DownloadTask& finishedTask) { + auto on_finished = [this, author](const DownloadTask& finishedTask) { // Sum downloadedBytes from all items uint64_t model_size = 0; for (const auto& item : finishedTask.items) { model_size = model_size + item.bytes.value_or(0); } auto gguf_download_item = finishedTask.items[0]; - ParseGguf(gguf_download_item, author, std::nullopt, model_size); + ParseGguf(*db_service_, gguf_download_item, author, std::nullopt, model_size); }; auto result = download_service_->AddDownloadTask(downloadTask, on_finished); @@ -494,7 +486,7 @@ cpp::result ModelService::HandleUrl( } bool ModelService::HasModel(const std::string& id) const { - return cortex::db::Models().HasModel(id); + return db_service_->HasModel(id); } cpp::result @@ -507,7 +499,6 @@ ModelService::DownloadModelFromCortexsoAsync( return cpp::fail(download_task.error()); } - cortex::db::Models modellist_handler; std::string unique_model_id = ""; if (temp_model_id.has_value()) { unique_model_id = temp_model_id.value(); @@ -515,13 +506,13 @@ ModelService::DownloadModelFromCortexsoAsync( unique_model_id = name + ":" + branch; } - auto model_entry = modellist_handler.GetModelInfo(unique_model_id); + auto model_entry = db_service_->GetModelInfo(unique_model_id); if (model_entry.has_value() && model_entry->status == cortex::db::ModelStatus::Downloaded) { return cpp::fail("Please delete the model before downloading again"); } - auto on_finished = [unique_model_id, + auto on_finished = [this, unique_model_id, branch](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; @@ -621,8 +612,7 @@ ModelService::DownloadModelFromCortexsoAsync( file_manager_utils::ToRelativeCortexDataPath(model_yml_item->localPath); CTL_INF("path_to_model_yaml: " << rel.string()); - cortex::db::Models modellist_utils_obj; - if (!modellist_utils_obj.HasModel(unique_model_id)) { + if (!db_service_->HasModel(unique_model_id)) { cortex::db::ModelEntry model_entry{ .model = unique_model_id, .author_repo_id = "cortexso", @@ -630,18 +620,16 @@ ModelService::DownloadModelFromCortexsoAsync( .path_to_model_yaml = rel.string(), .model_alias = unique_model_id, .status = cortex::db::ModelStatus::Downloaded}; - auto result = modellist_utils_obj.AddModelEntry(model_entry); + auto result = db_service_->AddModelEntry(model_entry); if (result.has_error()) { CTL_ERR("Error adding model to modellist: " + result.error()); } } else { - if (auto m = modellist_utils_obj.GetModelInfo(unique_model_id); - m.has_value()) { + if (auto m = db_service_->GetModelInfo(unique_model_id); m.has_value()) { auto upd_m = m.value(); upd_m.status = cortex::db::ModelStatus::Downloaded; - if (auto r = - modellist_utils_obj.UpdateModelEntry(unique_model_id, upd_m); + if (auto r = db_service_->UpdateModelEntry(unique_model_id, upd_m); r.has_error()) { CTL_ERR(r.error()); } @@ -665,7 +653,7 @@ cpp::result ModelService::DownloadModelFromCortexso( } std::string model_id{name + ":" + branch}; - auto on_finished = [branch, model_id](const DownloadTask& finishedTask) { + auto on_finished = [this, branch, model_id](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; @@ -692,8 +680,7 @@ cpp::result ModelService::DownloadModelFromCortexso( file_manager_utils::ToRelativeCortexDataPath(model_yml_item->localPath); CTL_INF("path_to_model_yaml: " << rel.string()); - cortex::db::Models modellist_utils_obj; - if (!modellist_utils_obj.HasModel(model_id)) { + if (!db_service_->HasModel(model_id)) { cortex::db::ModelEntry model_entry{ .model = model_id, .author_repo_id = "cortexso", @@ -701,16 +688,16 @@ cpp::result ModelService::DownloadModelFromCortexso( .path_to_model_yaml = rel.string(), .model_alias = model_id, .status = cortex::db::ModelStatus::Downloaded}; - auto result = modellist_utils_obj.AddModelEntry(model_entry); + auto result = db_service_->AddModelEntry(model_entry); if (result.has_error()) { CTL_ERR("Error adding model to modellist: " + result.error()); } } else { - if (auto m = modellist_utils_obj.GetModelInfo(model_id); m.has_value()) { + if (auto m = db_service_->GetModelInfo(model_id); m.has_value()) { auto upd_m = m.value(); upd_m.status = cortex::db::ModelStatus::Downloaded; - if (auto r = modellist_utils_obj.UpdateModelEntry(model_id, upd_m); + if (auto r = db_service_->UpdateModelEntry(model_id, upd_m); r.has_error()) { CTL_ERR(r.error()); } @@ -764,7 +751,6 @@ cpp::result ModelService::DeleteModel( const std::string& model_handle) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; auto result = StopModel(model_handle); @@ -776,7 +762,7 @@ cpp::result ModelService::DeleteModel( } try { - auto model_entry = modellist_handler.GetModelInfo(model_handle); + auto model_entry = db_service_->GetModelInfo(model_handle); if (model_entry.has_error()) { CTL_WRN("Error: " + model_entry.error()); return cpp::fail(model_entry.error()); @@ -807,7 +793,7 @@ cpp::result ModelService::DeleteModel( } // update model.list - if (modellist_handler.DeleteModelEntry(model_handle)) { + if (db_service_->DeleteModelEntry(model_handle)) { return {}; } else { return cpp::fail("Could not delete model: " + model_handle); @@ -823,7 +809,6 @@ cpp::result ModelService::StartModel( bool bypass_model_check) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; std::optional custom_prompt_template; std::optional ctx_len; @@ -857,6 +842,7 @@ cpp::result ModelService::StartModel( config::PythonModelConfig python_model_config; python_model_config.ReadFromYaml( + fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.value().path_to_model_yaml)) .string()); @@ -1040,7 +1026,6 @@ cpp::result ModelService::StopModel( const std::string& model_handle) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { @@ -1048,7 +1033,7 @@ cpp::result ModelService::StopModel( bypass_stop_check_set_.end()); std::string engine_name = ""; if (!bypass_check) { - auto model_entry = modellist_handler.GetModelInfo(model_handle); + auto model_entry = db_service_->GetModelInfo(model_handle); if (model_entry.has_error()) { CTL_WRN("Error: " + model_entry.error()); return cpp::fail(model_entry.error()); @@ -1105,11 +1090,10 @@ cpp::result ModelService::GetModelStatus( const std::string& model_handle) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; - cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; try { - auto model_entry = modellist_handler.GetModelInfo(model_handle); + auto model_entry = db_service_->GetModelInfo(model_handle); if (model_entry.has_error()) { CTL_WRN("Error: " + model_entry.error()); return cpp::fail(model_entry.error()); @@ -1230,8 +1214,7 @@ cpp::result ModelService::GetModelPullInfo( auto default_model_branch = huggingface_utils::GetDefaultBranch(model_name); - cortex::db::Models modellist_handler; - auto downloaded_model_ids = modellist_handler.FindRelatedModel(model_name) + auto downloaded_model_ids = db_service_->FindRelatedModel(model_name) .value_or(std::vector{}); std::vector avai_download_opts{}; @@ -1275,8 +1258,8 @@ cpp::result, std::string> ModelService::MayFallbackToCpu(const std::string& model_path, int ngl, int ctx_len, int n_batch, int n_ubatch, const std::string& kv_cache_type) { - services::HardwareService hw_svc; - auto hw_info = hw_svc.GetHardwareInfo(); + assert(hw_service_); + auto hw_info = hw_service_->GetHardwareInfo(); assert(!!engine_svc_); auto default_engine = engine_svc_->GetDefaultEngineVariant(kLlamaEngine); bool is_cuda = false; diff --git a/engine/services/model_service.h b/engine/services/model_service.h index ab3596812..cc659fea5 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -6,12 +6,12 @@ #include "common/engine_servicei.h" #include "common/model_metadata.h" #include "config/model_config.h" +#include "services/database_service.h" #include "services/download_service.h" +#include "services/hardware_service.h" #include "utils/hardware/gguf/gguf_file_estimate.h" -namespace services { class InferenceService; -} struct ModelPullInfo { std::string id; @@ -31,14 +31,14 @@ class ModelService { public: void ForceIndexingModelList(); - explicit ModelService(std::shared_ptr download_service) - : download_service_{download_service} {}; - - explicit ModelService( - std::shared_ptr download_service, - std::shared_ptr inference_service, - std::shared_ptr engine_svc) - : download_service_{download_service}, + explicit ModelService(std::shared_ptr db_service, + std::shared_ptr hw_service, + std::shared_ptr download_service, + std::shared_ptr inference_service, + std::shared_ptr engine_svc) + : db_service_(db_service), + hw_service_(hw_service), + download_service_{download_service}, inference_svc_(inference_service), engine_svc_(engine_svc) {}; @@ -115,8 +115,10 @@ class ModelService { const std::string& model_path, int ngl, int ctx_len, int n_batch = 2048, int n_ubatch = 2048, const std::string& kv_cache_type = "f16"); + std::shared_ptr db_service_; + std::shared_ptr hw_service_; std::shared_ptr download_service_; - std::shared_ptr inference_svc_; + std::shared_ptr inference_svc_; std::unordered_set bypass_stop_check_set_; std::shared_ptr engine_svc_ = nullptr; diff --git a/engine/services/model_source_service.cc b/engine/services/model_source_service.cc index a7d9d5e6e..7fc0ef5b2 100644 --- a/engine/services/model_source_service.cc +++ b/engine/services/model_source_service.cc @@ -9,7 +9,6 @@ #include "utils/string_utils.h" #include "utils/url_parser.h" -namespace services { namespace hu = huggingface_utils; namespace { @@ -61,10 +60,13 @@ std::vector ParseJsonString(const std::string& json_str) { } // namespace -ModelSourceService::ModelSourceService() { +ModelSourceService::ModelSourceService( + std::shared_ptr db_service) + : db_service_(db_service) { sync_db_thread_ = std::thread(&ModelSourceService::SyncModelSource, this); running_ = true; } + ModelSourceService::~ModelSourceService() { running_ = false; if (sync_db_thread_.joinable()) { @@ -106,8 +108,7 @@ cpp::result ModelSourceService::AddModelSource( cpp::result ModelSourceService::RemoveModelSource( const std::string& model_source) { - cortex::db::Models model_db; - auto srcs = model_db.GetModelSources(); + auto srcs = db_service_->GetModelSources(); if (srcs.has_error()) { return cpp::fail(srcs.error()); } else { @@ -127,13 +128,13 @@ cpp::result ModelSourceService::RemoveModelSource( } if (r.pathParams.size() == 1) { - if (auto del_res = model_db.DeleteModelEntryWithOrg(model_source); + if (auto del_res = db_service_->DeleteModelEntryWithOrg(model_source); del_res.has_error()) { CTL_INF(del_res.error()); return cpp::fail(del_res.error()); } } else { - if (auto del_res = model_db.DeleteModelEntryWithRepo(model_source); + if (auto del_res = db_service_->DeleteModelEntryWithRepo(model_source); del_res.has_error()) { CTL_INF(del_res.error()); return cpp::fail(del_res.error()); @@ -145,8 +146,7 @@ cpp::result ModelSourceService::RemoveModelSource( cpp::result, std::string> ModelSourceService::GetModelSources() { - cortex::db::Models model_db; - return model_db.GetModelSources(); + return db_service_->GetModelSources(); } cpp::result ModelSourceService::AddHfOrg( @@ -156,10 +156,9 @@ cpp::result ModelSourceService::AddHfOrg( if (res.has_value()) { auto models = ParseJsonString(res.value()); // Get models from db - cortex::db::Models model_db; - auto model_list_before = - model_db.GetModels(model_source).value_or(std::vector{}); + auto model_list_before = db_service_->GetModels(model_source) + .value_or(std::vector{}); std::unordered_set updated_model_list; // Add new models for (auto const& m : models) { @@ -179,7 +178,7 @@ cpp::result ModelSourceService::AddHfOrg( // Clean up for (auto const& mid : model_list_before) { if (updated_model_list.find(mid) == updated_model_list.end()) { - if (auto del_res = model_db.DeleteModelEntry(mid); + if (auto del_res = db_service_->DeleteModelEntry(mid); del_res.has_error()) { CTL_INF(del_res.error()); } @@ -195,10 +194,9 @@ cpp::result ModelSourceService::AddHfRepo( const std::string& model_source, const std::string& author, const std::string& model_name) { // Get models from db - cortex::db::Models model_db; auto model_list_before = - model_db.GetModels(model_source).value_or(std::vector{}); + db_service_->GetModels(model_source).value_or(std::vector{}); std::unordered_set updated_model_list; auto add_res = AddRepoSiblings(model_source, author, model_name); if (add_res.has_error()) { @@ -208,7 +206,8 @@ cpp::result ModelSourceService::AddHfRepo( } for (auto const& mid : model_list_before) { if (updated_model_list.find(mid) == updated_model_list.end()) { - if (auto del_res = model_db.DeleteModelEntry(mid); del_res.has_error()) { + if (auto del_res = db_service_->DeleteModelEntry(mid); + del_res.has_error()) { CTL_INF(del_res.error()); } } @@ -234,7 +233,6 @@ ModelSourceService::AddRepoSiblings(const std::string& model_source, for (const auto& sibling : repo_info->siblings) { if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { - cortex::db::Models model_db; std::string model_id = author + ":" + model_name + ":" + sibling.rfilename; cortex::db::ModelEntry e = { @@ -248,15 +246,15 @@ ModelSourceService::AddRepoSiblings(const std::string& model_source, .status = cortex::db::ModelStatus::Downloadable, .engine = "llama-cpp", .metadata = repo_info->metadata}; - if (!model_db.HasModel(model_id)) { - if (auto add_res = model_db.AddModelEntry(e); add_res.has_error()) { + if (!db_service_->HasModel(model_id)) { + if (auto add_res = db_service_->AddModelEntry(e); add_res.has_error()) { CTL_INF(add_res.error()); } } else { - if (auto m = model_db.GetModelInfo(model_id); + if (auto m = db_service_->GetModelInfo(model_id); m.has_value() && m->status == cortex::db::ModelStatus::Downloadable) { - if (auto upd_res = model_db.UpdateModelEntry(model_id, e); + if (auto upd_res = db_service_->UpdateModelEntry(model_id, e); upd_res.has_error()) { CTL_INF(upd_res.error()); } @@ -276,10 +274,9 @@ cpp::result ModelSourceService::AddCortexsoOrg( if (res.has_value()) { auto models = ParseJsonString(res.value()); // Get models from db - cortex::db::Models model_db; - auto model_list_before = - model_db.GetModels(model_source).value_or(std::vector{}); + auto model_list_before = db_service_->GetModels(model_source) + .value_or(std::vector{}); std::unordered_set updated_model_list; for (auto const& m : models) { CTL_INF(m.id); @@ -313,7 +310,7 @@ cpp::result ModelSourceService::AddCortexsoOrg( // Clean up for (auto const& mid : model_list_before) { if (updated_model_list.find(mid) == updated_model_list.end()) { - if (auto del_res = model_db.DeleteModelEntry(mid); + if (auto del_res = db_service_->DeleteModelEntry(mid); del_res.has_error()) { CTL_INF(del_res.error()); } @@ -340,10 +337,9 @@ cpp::result ModelSourceService::AddCortexsoRepo( return cpp::fail(repo_info.error()); } // Get models from db - cortex::db::Models model_db; auto model_list_before = - model_db.GetModels(model_source).value_or(std::vector{}); + db_service_->GetModels(model_source).value_or(std::vector{}); std::unordered_set updated_model_list; for (auto const& [branch, _] : branches.value()) { @@ -359,7 +355,8 @@ cpp::result ModelSourceService::AddCortexsoRepo( // Clean up for (auto const& mid : model_list_before) { if (updated_model_list.find(mid) == updated_model_list.end()) { - if (auto del_res = model_db.DeleteModelEntry(mid); del_res.has_error()) { + if (auto del_res = db_service_->DeleteModelEntry(mid); + del_res.has_error()) { CTL_INF(del_res.error()); } } @@ -397,7 +394,6 @@ ModelSourceService::AddCortexsoRepoBranch(const std::string& model_source, CTL_INF("Only support gguf file format! - branch: " << branch); return {}; } else { - cortex::db::Models model_db; std::string model_id = model_name + ":" + branch; cortex::db::ModelEntry e = {.model = model_id, .author_repo_id = author, @@ -409,16 +405,16 @@ ModelSourceService::AddCortexsoRepoBranch(const std::string& model_source, .status = cortex::db::ModelStatus::Downloadable, .engine = "llama-cpp", .metadata = metadata}; - if (!model_db.HasModel(model_id)) { + if (!db_service_->HasModel(model_id)) { CTL_INF("Adding model to db: " << model_name << ":" << branch); - if (auto res = model_db.AddModelEntry(e); + if (auto res = db_service_->AddModelEntry(e); res.has_error() || !res.value()) { CTL_DBG("Cannot add model to db: " << model_id); } } else { - if (auto m = model_db.GetModelInfo(model_id); + if (auto m = db_service_->GetModelInfo(model_id); m.has_value() && m->status == cortex::db::ModelStatus::Downloadable) { - if (auto upd_res = model_db.UpdateModelEntry(model_id, e); + if (auto upd_res = db_service_->UpdateModelEntry(model_id, e); upd_res.has_error()) { CTL_INF(upd_res.error()); } @@ -444,8 +440,7 @@ void ModelSourceService::SyncModelSource() { CTL_DBG("Start to sync cortex.db"); start_time = current_time; - cortex::db::Models model_db; - auto res = model_db.GetModelSources(); + auto res = db_service_->GetModelSources(); if (res.has_error()) { CTL_INF(res.error()); } else { @@ -489,5 +484,3 @@ void ModelSourceService::SyncModelSource() { } } } - -} // namespace services \ No newline at end of file diff --git a/engine/services/model_source_service.h b/engine/services/model_source_service.h index aa0b37259..7227267d3 100644 --- a/engine/services/model_source_service.h +++ b/engine/services/model_source_service.h @@ -2,14 +2,14 @@ #include #include #include +#include "services/database_service.h" #include "utils/result.hpp" -namespace services { class ModelSourceService { public: - explicit ModelSourceService(); + explicit ModelSourceService(std::shared_ptr db_service); ~ModelSourceService(); - + cpp::result AddModelSource( const std::string& model_source); @@ -22,9 +22,9 @@ class ModelSourceService { cpp::result AddHfOrg(const std::string& model_source, const std::string& author); - cpp::result AddHfRepo( - const std::string& model_source, const std::string& author, - const std::string& model_name); + cpp::result AddHfRepo(const std::string& model_source, + const std::string& author, + const std::string& model_name); cpp::result, std::string> AddRepoSiblings( const std::string& model_source, const std::string& author, @@ -41,13 +41,12 @@ class ModelSourceService { AddCortexsoRepoBranch(const std::string& model_source, const std::string& author, const std::string& model_name, - const std::string& branch, - const std::string& metadata); + const std::string& branch, const std::string& metadata); void SyncModelSource(); private: + std::shared_ptr db_service_ = nullptr; std::thread sync_db_thread_; std::atomic running_; -}; -} // namespace services \ No newline at end of file +}; \ No newline at end of file diff --git a/engine/services/thread_service.cc b/engine/services/thread_service.cc index 0ec0ac89d..9c5e7e857 100644 --- a/engine/services/thread_service.cc +++ b/engine/services/thread_service.cc @@ -4,7 +4,7 @@ #include "utils/ulid_generator.h" cpp::result ThreadService::CreateThread( - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "CreateThread"; @@ -46,7 +46,7 @@ cpp::result ThreadService::RetrieveThread( cpp::result ThreadService::ModifyThread( const std::string& thread_id, - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata) { LOG_TRACE << "ModifyThread " << thread_id; auto retrieve_res = RetrieveThread(thread_id); diff --git a/engine/services/thread_service.h b/engine/services/thread_service.h index 966b0ab01..7011f46f3 100644 --- a/engine/services/thread_service.h +++ b/engine/services/thread_service.h @@ -2,7 +2,6 @@ #include #include "common/repository/thread_repository.h" -#include "common/thread_tool_resources.h" #include "common/variant_map.h" #include "utils/result.hpp" @@ -12,7 +11,7 @@ class ThreadService { : thread_repository_{thread_repository} {} cpp::result CreateThread( - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result, std::string> ListThreads( @@ -24,7 +23,7 @@ class ThreadService { cpp::result ModifyThread( const std::string& thread_id, - std::unique_ptr tool_resources, + std::unique_ptr tool_resources, std::optional metadata); cpp::result DeleteThread( diff --git a/engine/test/components/test_assistant.cc b/engine/test/components/test_assistant.cc new file mode 100644 index 000000000..20ba08f34 --- /dev/null +++ b/engine/test/components/test_assistant.cc @@ -0,0 +1,194 @@ +#include +#include "common/assistant.h" + +namespace OpenAi { +namespace { + +class AssistantTest : public ::testing::Test { + protected: + void SetUp() override { + // Set up base assistant with minimal required fields + base_assistant.id = "asst_123"; + base_assistant.object = "assistant"; + base_assistant.created_at = 1702000000; + base_assistant.model = "gpt-4"; + } + + Assistant base_assistant; +}; + +TEST_F(AssistantTest, MinimalAssistantToJson) { + auto result = base_assistant.ToJson(); + ASSERT_TRUE(result.has_value()); + + Json::Value json = result.value(); + EXPECT_EQ(json["id"].asString(), "asst_123"); + EXPECT_EQ(json["object"].asString(), "assistant"); + EXPECT_EQ(json["created_at"].asUInt64(), 1702000000); + EXPECT_EQ(json["model"].asString(), "gpt-4"); +} + +TEST_F(AssistantTest, FullAssistantToJson) { + base_assistant.name = "Test Assistant"; + base_assistant.description = "Test Description"; + base_assistant.instructions = "Test Instructions"; + base_assistant.temperature = 0.7f; + base_assistant.top_p = 0.9f; + + // Add a code interpreter tool + auto code_tool = std::make_unique(); + base_assistant.tools.push_back(std::move(code_tool)); + + // Add metadata + base_assistant.metadata["key1"] = std::string("value1"); + base_assistant.metadata["key2"] = true; + base_assistant.metadata["key3"] = static_cast(42ULL); + + auto result = base_assistant.ToJson(); + ASSERT_TRUE(result.has_value()); + + Json::Value json = result.value(); + EXPECT_EQ(json["name"].asString(), "Test Assistant"); + EXPECT_EQ(json["description"].asString(), "Test Description"); + EXPECT_EQ(json["instructions"].asString(), "Test Instructions"); + EXPECT_FLOAT_EQ(json["temperature"].asFloat(), 0.7f); + EXPECT_FLOAT_EQ(json["top_p"].asFloat(), 0.9f); + + EXPECT_TRUE(json["tools"].isArray()); + EXPECT_EQ(json["tools"].size(), 1); + EXPECT_EQ(json["tools"][0]["type"].asString(), "code_interpreter"); + + EXPECT_TRUE(json["metadata"].isObject()); + EXPECT_EQ(json["metadata"]["key1"].asString(), "value1"); + EXPECT_EQ(json["metadata"]["key2"].asBool(), true); + EXPECT_EQ(json["metadata"]["key3"].asUInt64(), 42ULL); +} + +TEST_F(AssistantTest, FromJsonMinimal) { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "assistant"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + ASSERT_TRUE(result.has_value()); + + const auto& assistant = result.value(); + EXPECT_EQ(assistant.id, "asst_123"); + EXPECT_EQ(assistant.object, "assistant"); + EXPECT_EQ(assistant.created_at, 1702000000); + EXPECT_EQ(assistant.model, "gpt-4"); +} + +TEST_F(AssistantTest, FromJsonComplete) { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "assistant"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + input["name"] = "Test Assistant"; + input["description"] = "Test Description"; + input["instructions"] = "Test Instructions"; + input["temperature"] = 0.7; + input["top_p"] = 0.9; + + // Add tools + Json::Value tools(Json::arrayValue); + Json::Value code_tool; + code_tool["type"] = "code_interpreter"; + tools.append(code_tool); + + Json::Value function_tool; + function_tool["type"] = "function"; + function_tool["function"] = Json::Value(Json::objectValue); + function_tool["function"]["name"] = "test_function"; + function_tool["function"]["description"] = "Test function"; + function_tool["function"]["parameters"] = Json::Value(Json::objectValue); + tools.append(function_tool); + input["tools"] = tools; + + // Add metadata + Json::Value metadata(Json::objectValue); + metadata["key1"] = "value1"; + metadata["key2"] = true; + metadata["key3"] = 42; + input["metadata"] = metadata; + + auto result = Assistant::FromJson(std::move(input)); + ASSERT_TRUE(result.has_value()); + + const auto& assistant = result.value(); + EXPECT_EQ(assistant.name.value(), "Test Assistant"); + EXPECT_EQ(assistant.description.value(), "Test Description"); + EXPECT_EQ(assistant.instructions.value(), "Test Instructions"); + EXPECT_FLOAT_EQ(assistant.temperature.value(), 0.7f); + EXPECT_FLOAT_EQ(assistant.top_p.value(), 0.9f); + + EXPECT_EQ(assistant.tools.size(), 2); + EXPECT_TRUE(dynamic_cast(assistant.tools[0].get()) != nullptr); + EXPECT_TRUE(dynamic_cast(assistant.tools[1].get()) != nullptr); + + EXPECT_EQ(assistant.metadata.size(), 3); + EXPECT_EQ(std::get(assistant.metadata.at("key1")), "value1"); + EXPECT_EQ(std::get(assistant.metadata.at("key2")), true); + EXPECT_EQ(std::get(assistant.metadata.at("key3")), 42ULL); +} + +TEST_F(AssistantTest, FromJsonInvalidInput) { + // Missing required field 'id' + { + Json::Value input; + input["object"] = "assistant"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + EXPECT_FALSE(result.has_value()); + } + + // Invalid object type + { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "invalid"; + input["created_at"] = 1702000000; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + EXPECT_FALSE(result.has_value()); + } + + // Invalid created_at type + { + Json::Value input; + input["id"] = "asst_123"; + input["object"] = "assistant"; + input["created_at"] = "invalid"; + input["model"] = "gpt-4"; + + auto result = Assistant::FromJson(std::move(input)); + EXPECT_FALSE(result.has_value()); + } +} + +TEST_F(AssistantTest, MoveConstructorAndAssignment) { + base_assistant.name = "Test Assistant"; + base_assistant.tools.push_back(std::make_unique()); + + // Test move constructor + Assistant moved_assistant(std::move(base_assistant)); + EXPECT_EQ(moved_assistant.id, "asst_123"); + EXPECT_EQ(moved_assistant.name.value(), "Test Assistant"); + EXPECT_EQ(moved_assistant.tools.size(), 1); + + // Test move assignment + Assistant another_assistant; + another_assistant = std::move(moved_assistant); + EXPECT_EQ(another_assistant.id, "asst_123"); + EXPECT_EQ(another_assistant.name.value(), "Test Assistant"); + EXPECT_EQ(another_assistant.tools.size(), 1); +} + +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_assistant_tool_code_interpreter.cc b/engine/test/components/test_assistant_tool_code_interpreter.cc new file mode 100644 index 000000000..f32526504 --- /dev/null +++ b/engine/test/components/test_assistant_tool_code_interpreter.cc @@ -0,0 +1,49 @@ +#include +#include +#include "common/assistant_code_interpreter_tool.h" + +namespace OpenAi { +namespace { + +class AssistantCodeInterpreterToolTest : public ::testing::Test {}; + +TEST_F(AssistantCodeInterpreterToolTest, BasicConstruction) { + AssistantCodeInterpreterTool tool; + EXPECT_EQ(tool.type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, MoveConstructor) { + AssistantCodeInterpreterTool original; + AssistantCodeInterpreterTool moved(std::move(original)); + EXPECT_EQ(moved.type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, MoveAssignment) { + AssistantCodeInterpreterTool original; + AssistantCodeInterpreterTool target; + target = std::move(original); + EXPECT_EQ(target.type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, FromJson) { + Json::Value json; // Empty JSON is fine for this tool + auto result = AssistantCodeInterpreterTool::FromJson(); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value().type, "code_interpreter"); +} + +TEST_F(AssistantCodeInterpreterToolTest, ToJson) { + AssistantCodeInterpreterTool tool; + auto result = tool.ToJson(); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value()["type"].asString(), "code_interpreter"); + + // Verify no extra fields + Json::Value::Members members = result.value().getMemberNames(); + EXPECT_EQ(members.size(), 1); // Only "type" field should be present + EXPECT_EQ(members[0], "type"); +} +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_assistant_tool_file_search.cc b/engine/test/components/test_assistant_tool_file_search.cc new file mode 100644 index 000000000..25a2ffc05 --- /dev/null +++ b/engine/test/components/test_assistant_tool_file_search.cc @@ -0,0 +1,207 @@ +#include +#include +#include "common/assistant_file_search_tool.h" + +namespace OpenAi { +namespace { + +class AssistantFileSearchToolTest : public ::testing::Test {}; + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionBasicConstruction) { + const float threshold = 0.75f; + const std::string ranker = "test_ranker"; + FileSearchRankingOption option{threshold, ranker}; + + EXPECT_EQ(option.score_threshold, threshold); + EXPECT_EQ(option.ranker, ranker); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionDefaultRanker) { + const float threshold = 0.5f; + FileSearchRankingOption option{threshold}; + + EXPECT_EQ(option.score_threshold, threshold); + EXPECT_EQ(option.ranker, "auto"); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionFromValidJson) { + Json::Value json; + json["score_threshold"] = 0.8f; + json["ranker"] = "custom_ranker"; + + auto result = FileSearchRankingOption::FromJson(json); + ASSERT_TRUE(result.has_value()); + + EXPECT_EQ(result.value().score_threshold, 0.8f); + EXPECT_EQ(result.value().ranker, "custom_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionFromInvalidJson) { + Json::Value json; + auto result = FileSearchRankingOption::FromJson(json); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(AssistantFileSearchToolTest, FileSearchRankingOptionToJson) { + FileSearchRankingOption option{0.9f, "special_ranker"}; + auto json_result = option.ToJson(); + + ASSERT_TRUE(json_result.has_value()); + Json::Value json = json_result.value(); + + EXPECT_EQ(json["score_threshold"].asFloat(), 0.9f); + EXPECT_EQ(json["ranker"].asString(), "special_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchBasicConstruction) { + FileSearchRankingOption ranking_option{0.7f, "test_ranker"}; + AssistantFileSearch search{10, std::move(ranking_option)}; + + EXPECT_EQ(search.max_num_results, 10); + EXPECT_EQ(search.ranking_options.score_threshold, 0.7f); + EXPECT_EQ(search.ranking_options.ranker, "test_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchFromValidJson) { + Json::Value json; + json["max_num_results"] = 15; + + Json::Value ranking_json; + ranking_json["score_threshold"] = 0.85f; + ranking_json["ranker"] = "custom_ranker"; + json["ranking_options"] = ranking_json; + + auto result = AssistantFileSearch::FromJson(json); + ASSERT_TRUE(result.has_value()); + + EXPECT_EQ(result.value().max_num_results, 15); + EXPECT_EQ(result.value().ranking_options.score_threshold, 0.85f); + EXPECT_EQ(result.value().ranking_options.ranker, "custom_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchFromInvalidJson) { + Json::Value json; + // Missing required fields + auto result = AssistantFileSearch::FromJson(json); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToJson) { + FileSearchRankingOption ranking_option{0.95f, "advanced_ranker"}; + AssistantFileSearch search{20, std::move(ranking_option)}; + + auto json_result = search.ToJson(); + ASSERT_TRUE(json_result.has_value()); + + Json::Value json = json_result.value(); + EXPECT_EQ(json["max_num_results"].asInt(), 20); + EXPECT_EQ(json["ranking_options"]["score_threshold"].asFloat(), 0.95f); + EXPECT_EQ(json["ranking_options"]["ranker"].asString(), "advanced_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolConstruction) { + FileSearchRankingOption ranking_option{0.8f, "tool_ranker"}; + AssistantFileSearch search{25, std::move(ranking_option)}; + AssistantFileSearchTool tool{search}; + + EXPECT_EQ(tool.type, "file_search"); + EXPECT_EQ(tool.file_search.max_num_results, 25); + EXPECT_EQ(tool.file_search.ranking_options.score_threshold, 0.8f); + EXPECT_EQ(tool.file_search.ranking_options.ranker, "tool_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolFromValidJson) { + Json::Value json; + json["type"] = "file_search"; + + Json::Value file_search; + file_search["max_num_results"] = 30; + + Json::Value ranking_options; + ranking_options["score_threshold"] = 0.75f; + ranking_options["ranker"] = "json_ranker"; + file_search["ranking_options"] = ranking_options; + + json["file_search"] = file_search; + + auto result = AssistantFileSearchTool::FromJson(json); + ASSERT_TRUE(result.has_value()); + + EXPECT_EQ(result.value().type, "file_search"); + EXPECT_EQ(result.value().file_search.max_num_results, 30); + EXPECT_EQ(result.value().file_search.ranking_options.score_threshold, 0.75f); + EXPECT_EQ(result.value().file_search.ranking_options.ranker, "json_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolFromInvalidJson) { + Json::Value json; + // Missing required fields + auto result = AssistantFileSearchTool::FromJson(json); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(AssistantFileSearchToolTest, AssistantFileSearchToolToJson) { + FileSearchRankingOption ranking_option{0.65f, "final_ranker"}; + AssistantFileSearch search{35, std::move(ranking_option)}; + AssistantFileSearchTool tool{search}; + + auto json_result = tool.ToJson(); + ASSERT_TRUE(json_result.has_value()); + + Json::Value json = json_result.value(); + EXPECT_EQ(json["type"].asString(), "file_search"); + EXPECT_EQ(json["file_search"]["max_num_results"].asInt(), 35); + EXPECT_EQ(json["file_search"]["ranking_options"]["score_threshold"].asFloat(), + 0.65f); + EXPECT_EQ(json["file_search"]["ranking_options"]["ranker"].asString(), + "final_ranker"); +} + +TEST_F(AssistantFileSearchToolTest, MoveConstructorsAndAssignments) { + // Test FileSearchRankingOption move operations + FileSearchRankingOption original_option{0.8f, "original_ranker"}; + FileSearchRankingOption moved_option{std::move(original_option)}; + EXPECT_EQ(moved_option.score_threshold, 0.8f); + EXPECT_EQ(moved_option.ranker, "original_ranker"); + + FileSearchRankingOption assign_target{0.5f}; + assign_target = std::move(moved_option); + EXPECT_EQ(assign_target.score_threshold, 0.8f); + EXPECT_EQ(assign_target.ranker, "original_ranker"); + + // Test AssistantFileSearch move operations + FileSearchRankingOption search_option{0.9f, "search_ranker"}; + AssistantFileSearch original_search{40, std::move(search_option)}; + AssistantFileSearch moved_search{std::move(original_search)}; + EXPECT_EQ(moved_search.max_num_results, 40); + EXPECT_EQ(moved_search.ranking_options.score_threshold, 0.9f); + + // Test AssistantFileSearchTool move operations + FileSearchRankingOption tool_option{0.7f, "tool_ranker"}; + AssistantFileSearch tool_search{45, std::move(tool_option)}; + AssistantFileSearchTool original_tool{tool_search}; + AssistantFileSearchTool moved_tool{std::move(original_tool)}; + EXPECT_EQ(moved_tool.type, "file_search"); + EXPECT_EQ(moved_tool.file_search.max_num_results, 45); +} + +TEST_F(AssistantFileSearchToolTest, EdgeCases) { + // Test boundary values for score_threshold + FileSearchRankingOption min_threshold{0.0f}; + EXPECT_EQ(min_threshold.score_threshold, 0.0f); + + FileSearchRankingOption max_threshold{1.0f}; + EXPECT_EQ(max_threshold.score_threshold, 1.0f); + + // Test boundary values for max_num_results + FileSearchRankingOption ranking_option{0.5f}; + AssistantFileSearch min_results{1, std::move(ranking_option)}; + EXPECT_EQ(min_results.max_num_results, 1); + + FileSearchRankingOption ranking_option2{0.5f}; + AssistantFileSearch max_results{50, std::move(ranking_option2)}; + EXPECT_EQ(max_results.max_num_results, 50); +} + +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_assistant_tool_function.cc b/engine/test/components/test_assistant_tool_function.cc new file mode 100644 index 000000000..6f59df693 --- /dev/null +++ b/engine/test/components/test_assistant_tool_function.cc @@ -0,0 +1,240 @@ +#include +#include "common/assistant_function_tool.h" +#include + +namespace OpenAi { +namespace { + +class AssistantFunctionTest : public ::testing::Test { +protected: + void SetUp() override { + // Common test setup + basic_description = "Test function description"; + basic_name = "test_function"; + basic_params = Json::Value(Json::objectValue); + basic_params["type"] = "object"; + basic_params["properties"] = Json::Value(Json::objectValue); + } + + std::string basic_description; + std::string basic_name; + Json::Value basic_params; +}; + +TEST_F(AssistantFunctionTest, BasicConstructionWithoutStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, std::nullopt); + + EXPECT_EQ(function.description, basic_description); + EXPECT_EQ(function.name, basic_name); + EXPECT_EQ(function.parameters, basic_params); + EXPECT_FALSE(function.strict.has_value()); +} + +TEST_F(AssistantFunctionTest, BasicConstructionWithStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, true); + + EXPECT_EQ(function.description, basic_description); + EXPECT_EQ(function.name, basic_name); + EXPECT_EQ(function.parameters, basic_params); + ASSERT_TRUE(function.strict.has_value()); + EXPECT_TRUE(*function.strict); +} + +TEST_F(AssistantFunctionTest, MoveConstructor) { + AssistantFunction original(basic_description, basic_name, basic_params, true); + + AssistantFunction moved(std::move(original)); + + EXPECT_EQ(moved.description, basic_description); + EXPECT_EQ(moved.name, basic_name); + EXPECT_EQ(moved.parameters, basic_params); + ASSERT_TRUE(moved.strict.has_value()); + EXPECT_TRUE(*moved.strict); +} + +TEST_F(AssistantFunctionTest, MoveAssignment) { + AssistantFunction original(basic_description, basic_name, basic_params, true); + + AssistantFunction target("other", "other_name", Json::Value(Json::objectValue), false); + target = std::move(original); + + EXPECT_EQ(target.description, basic_description); + EXPECT_EQ(target.name, basic_name); + EXPECT_EQ(target.parameters, basic_params); + ASSERT_TRUE(target.strict.has_value()); + EXPECT_TRUE(*target.strict); +} + +TEST_F(AssistantFunctionTest, FromValidJson) { + Json::Value json; + json["description"] = basic_description; + json["name"] = basic_name; + json["strict"] = true; + json["parameters"] = basic_params; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_value()); + + const auto& function = result.value(); + EXPECT_EQ(function.description, basic_description); + EXPECT_EQ(function.name, basic_name); + EXPECT_EQ(function.parameters, basic_params); + ASSERT_TRUE(function.strict.has_value()); + EXPECT_TRUE(*function.strict); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationEmptyJson) { + Json::Value json; + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function json can't be empty"); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationEmptyName) { + Json::Value json; + json["description"] = basic_description; + json["parameters"] = basic_params; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function name can't be empty"); + + // Test with empty name value + json["name"] = ""; + result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function name can't be empty"); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationMissingDescription) { + Json::Value json; + json["name"] = basic_name; + json["parameters"] = basic_params; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function description is mandatory"); +} + +TEST_F(AssistantFunctionTest, FromJsonValidationMissingParameters) { + Json::Value json; + json["name"] = basic_name; + json["description"] = basic_description; + + auto result = AssistantFunction::FromJson(json); + ASSERT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Function parameters are mandatory"); +} + +TEST_F(AssistantFunctionTest, ToJsonWithStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, true); + + auto result = function.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + EXPECT_EQ(json["description"].asString(), basic_description); + EXPECT_EQ(json["name"].asString(), basic_name); + EXPECT_EQ(json["parameters"], basic_params); + EXPECT_TRUE(json["strict"].asBool()); +} + +TEST_F(AssistantFunctionTest, ToJsonWithoutStrict) { + AssistantFunction function(basic_description, basic_name, basic_params, std::nullopt); + + auto result = function.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + EXPECT_EQ(json["description"].asString(), basic_description); + EXPECT_EQ(json["name"].asString(), basic_name); + EXPECT_EQ(json["parameters"], basic_params); + EXPECT_FALSE(json.isMember("strict")); +} + +// AssistantFunctionTool Tests +class AssistantFunctionToolTest : public ::testing::Test { +protected: + void SetUp() override { + description = "Test tool description"; + name = "test_tool"; + params = Json::Value(Json::objectValue); + params["type"] = "object"; + } + + std::string description; + std::string name; + Json::Value params; +}; + +TEST_F(AssistantFunctionToolTest, BasicConstruction) { + AssistantFunction function(description, name, params, true); + AssistantFunctionTool tool(function); + + EXPECT_EQ(tool.type, "function"); + EXPECT_EQ(tool.function.description, description); + EXPECT_EQ(tool.function.name, name); + EXPECT_EQ(tool.function.parameters, params); + ASSERT_TRUE(tool.function.strict.has_value()); + EXPECT_TRUE(*tool.function.strict); +} + +TEST_F(AssistantFunctionToolTest, MoveConstructor) { + AssistantFunction function(description, name, params, true); + AssistantFunctionTool original(function); + + AssistantFunctionTool moved(std::move(original)); + + EXPECT_EQ(moved.type, "function"); + EXPECT_EQ(moved.function.description, description); + EXPECT_EQ(moved.function.name, name); + EXPECT_EQ(moved.function.parameters, params); +} + +TEST_F(AssistantFunctionToolTest, FromValidJson) { + Json::Value function_json; + function_json["description"] = description; + function_json["name"] = name; + function_json["strict"] = true; + function_json["parameters"] = params; + + Json::Value json; + json["type"] = "function"; + json["function"] = function_json; + + auto result = AssistantFunctionTool::FromJson(json); + ASSERT_TRUE(result.has_value()); + + const auto& tool = result.value(); + EXPECT_EQ(tool.type, "function"); + EXPECT_EQ(tool.function.description, description); + EXPECT_EQ(tool.function.name, name); + EXPECT_EQ(tool.function.parameters, params); + ASSERT_TRUE(tool.function.strict.has_value()); + EXPECT_TRUE(*tool.function.strict); +} + +TEST_F(AssistantFunctionToolTest, FromInvalidJson) { + Json::Value json; + auto result = AssistantFunctionTool::FromJson(json); + EXPECT_TRUE(result.has_error()); + EXPECT_EQ(result.error(), "Failed to parse function: Function json can't be empty"); +} + +TEST_F(AssistantFunctionToolTest, ToJson) { + AssistantFunction function(description, name, params, true); + AssistantFunctionTool tool(function); + + auto result = tool.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + EXPECT_EQ(json["type"].asString(), "function"); + EXPECT_EQ(json["function"]["description"].asString(), description); + EXPECT_EQ(json["function"]["name"].asString(), name); + EXPECT_EQ(json["function"]["parameters"], params); + EXPECT_TRUE(json["function"]["strict"].asBool()); +} + +} // namespace +} // namespace OpenAi diff --git a/engine/test/components/test_tool_resources.cc b/engine/test/components/test_tool_resources.cc new file mode 100644 index 000000000..2b78e6494 --- /dev/null +++ b/engine/test/components/test_tool_resources.cc @@ -0,0 +1,212 @@ +#include +#include +#include "common/tool_resources.h" + +namespace OpenAi { +namespace { + +// Mock class for testing abstract ToolResources +class MockToolResources : public ToolResources { + public: + cpp::result ToJson() override { + Json::Value json; + json["mock"] = "value"; + return json; + } +}; + +class ToolResourcesTest : public ::testing::Test {}; + +TEST_F(ToolResourcesTest, MoveConstructor) { + MockToolResources original; + MockToolResources moved(std::move(original)); + + auto json_result = moved.ToJson(); + ASSERT_TRUE(json_result.has_value()); + EXPECT_EQ(json_result.value()["mock"].asString(), "value"); +} + +TEST_F(ToolResourcesTest, MoveAssignment) { + MockToolResources original; + MockToolResources target; + target = std::move(original); + + auto json_result = target.ToJson(); + ASSERT_TRUE(json_result.has_value()); + EXPECT_EQ(json_result.value()["mock"].asString(), "value"); +} + +class CodeInterpreterTest : public ::testing::Test { + protected: + void SetUp() override { sample_file_ids = {"file1", "file2", "file3"}; } + + std::vector sample_file_ids; +}; + +TEST_F(CodeInterpreterTest, DefaultConstruction) { + CodeInterpreter interpreter; + EXPECT_TRUE(interpreter.file_ids.empty()); +} + +TEST_F(CodeInterpreterTest, MoveConstructor) { + CodeInterpreter original; + original.file_ids = sample_file_ids; + + CodeInterpreter moved(std::move(original)); + EXPECT_EQ(moved.file_ids, sample_file_ids); + EXPECT_TRUE(original.file_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(CodeInterpreterTest, MoveAssignment) { + CodeInterpreter original; + original.file_ids = sample_file_ids; + + CodeInterpreter target; + target = std::move(original); + EXPECT_EQ(target.file_ids, sample_file_ids); + EXPECT_TRUE(original.file_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(CodeInterpreterTest, FromJsonWithFileIds) { + Json::Value json; + Json::Value file_ids(Json::arrayValue); + for (const auto& id : sample_file_ids) { + file_ids.append(id); + } + json["file_ids"] = file_ids; + + auto result = CodeInterpreter::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value().file_ids, sample_file_ids); +} + +TEST_F(CodeInterpreterTest, FromJsonWithoutFileIds) { + Json::Value json; // Empty JSON + auto result = CodeInterpreter::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_TRUE(result.value().file_ids.empty()); +} + +TEST_F(CodeInterpreterTest, ToJson) { + CodeInterpreter interpreter; + interpreter.file_ids = sample_file_ids; + + auto result = interpreter.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("file_ids")); + ASSERT_TRUE(json["file_ids"].isArray()); + ASSERT_EQ(json["file_ids"].size(), sample_file_ids.size()); + + for (Json::ArrayIndex i = 0; i < json["file_ids"].size(); ++i) { + EXPECT_EQ(json["file_ids"][i].asString(), sample_file_ids[i]); + } +} + +TEST_F(CodeInterpreterTest, ToJsonEmptyFileIds) { + CodeInterpreter interpreter; + + auto result = interpreter.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("file_ids")); + ASSERT_TRUE(json["file_ids"].isArray()); + EXPECT_EQ(json["file_ids"].size(), 0); +} + +class FileSearchTest : public ::testing::Test { + protected: + void SetUp() override { + sample_vector_store_ids = {"store1", "store2", "store3"}; + } + + std::vector sample_vector_store_ids; +}; + +TEST_F(FileSearchTest, DefaultConstruction) { + FileSearch search; + EXPECT_TRUE(search.vector_store_ids.empty()); +} + +TEST_F(FileSearchTest, MoveConstructor) { + FileSearch original; + original.vector_store_ids = sample_vector_store_ids; + + FileSearch moved(std::move(original)); + EXPECT_EQ(moved.vector_store_ids, sample_vector_store_ids); + EXPECT_TRUE( + original.vector_store_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(FileSearchTest, MoveAssignment) { + FileSearch original; + original.vector_store_ids = sample_vector_store_ids; + + FileSearch target; + target = std::move(original); + EXPECT_EQ(target.vector_store_ids, sample_vector_store_ids); + EXPECT_TRUE( + original.vector_store_ids.empty()); // NOLINT: Checking moved-from state +} + +TEST_F(FileSearchTest, FromJsonWithVectorStoreIds) { + Json::Value json; + Json::Value vector_store_ids(Json::arrayValue); + for (const auto& id : sample_vector_store_ids) { + vector_store_ids.append(id); + } + json["vector_store_ids"] = vector_store_ids; + + auto result = FileSearch::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value().vector_store_ids, sample_vector_store_ids); +} + +TEST_F(FileSearchTest, FromJsonWithoutVectorStoreIds) { + Json::Value json; // Empty JSON + auto result = FileSearch::FromJson(json); + ASSERT_TRUE(result.has_value()); + EXPECT_TRUE(result.value().vector_store_ids.empty()); +} + +TEST_F(FileSearchTest, ToJson) { + FileSearch search; + search.vector_store_ids = sample_vector_store_ids; + + auto result = search.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("vector_store_ids")); + ASSERT_TRUE(json["vector_store_ids"].isArray()); + ASSERT_EQ(json["vector_store_ids"].size(), sample_vector_store_ids.size()); + + for (Json::ArrayIndex i = 0; i < json["vector_store_ids"].size(); ++i) { + EXPECT_EQ(json["vector_store_ids"][i].asString(), + sample_vector_store_ids[i]); + } +} + +TEST_F(FileSearchTest, ToJsonEmptyVectorStoreIds) { + FileSearch search; + + auto result = search.ToJson(); + ASSERT_TRUE(result.has_value()); + + const auto& json = result.value(); + ASSERT_TRUE(json.isMember("vector_store_ids")); + ASSERT_TRUE(json["vector_store_ids"].isArray()); + EXPECT_EQ(json["vector_store_ids"].size(), 0); +} + +TEST_F(FileSearchTest, SelfAssignment) { + FileSearch search; + search.vector_store_ids = sample_vector_store_ids; + + search = std::move(search); // Self-assignment with move + EXPECT_EQ(search.vector_store_ids, sample_vector_store_ids); +} +} // namespace +} // namespace OpenAi diff --git a/engine/utils/file_manager_utils.cc b/engine/utils/file_manager_utils.cc index e6859c018..a83c93efa 100644 --- a/engine/utils/file_manager_utils.cc +++ b/engine/utils/file_manager_utils.cc @@ -290,13 +290,14 @@ std::filesystem::path GetModelsContainerPath() { return models_container_path; } -std::filesystem::path GetCudaToolkitPath(const std::string& engine) { +std::filesystem::path GetCudaToolkitPath(const std::string& engine, + bool create_if_not_exist) { auto engine_path = getenv("ENGINE_PATH") ? std::filesystem::path(getenv("ENGINE_PATH")) : GetCortexDataPath(); auto cuda_path = engine_path / "engines" / engine / "deps"; - if (!std::filesystem::exists(cuda_path)) { + if (create_if_not_exist && !std::filesystem::exists(cuda_path)) { std::filesystem::create_directories(cuda_path); } diff --git a/engine/utils/file_manager_utils.h b/engine/utils/file_manager_utils.h index 91102d002..059fe6ae3 100644 --- a/engine/utils/file_manager_utils.h +++ b/engine/utils/file_manager_utils.h @@ -45,7 +45,8 @@ void CreateDirectoryRecursively(const std::string& path); std::filesystem::path GetModelsContainerPath(); -std::filesystem::path GetCudaToolkitPath(const std::string& engine); +std::filesystem::path GetCudaToolkitPath(const std::string& engine, + bool create_if_not_exist = false); std::filesystem::path GetEnginesContainerPath();