diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 59793b2a6..1c33ab1dc 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -488,55 +488,31 @@ void Models::StartModel( if (!http_util::HasFieldInReq(req, callback, "model")) return; auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); - StartParameterOverride params_override; - if (auto& o = (*(req->getJsonObject()))["prompt_template"]; !o.isNull()) { - params_override.custom_prompt_template = o.asString(); - } - - if (auto& o = (*(req->getJsonObject()))["cache_enabled"]; !o.isNull()) { - params_override.cache_enabled = o.asBool(); - } - - if (auto& o = (*(req->getJsonObject()))["ngl"]; !o.isNull()) { - params_override.ngl = o.asInt(); - } - - if (auto& o = (*(req->getJsonObject()))["n_parallel"]; !o.isNull()) { - params_override.n_parallel = o.asInt(); - } - - if (auto& o = (*(req->getJsonObject()))["ctx_len"]; !o.isNull()) { - params_override.ctx_len = o.asInt(); - } - - if (auto& o = (*(req->getJsonObject()))["cache_type"]; !o.isNull()) { - params_override.cache_type = o.asString(); - } + std::optional mmproj; if (auto& o = (*(req->getJsonObject()))["mmproj"]; !o.isNull()) { - params_override.mmproj = o.asString(); + mmproj = o.asString(); } + auto bypass_llama_model_path = false; // Support both llama_model_path and model_path for backward compatible // model_path has higher priority if (auto& o = (*(req->getJsonObject()))["llama_model_path"]; !o.isNull()) { - params_override.model_path = o.asString(); + auto model_path = o.asString(); if (auto& mp = (*(req->getJsonObject()))["model_path"]; mp.isNull()) { // Bypass if model does not exist in DB and llama_model_path exists - if (std::filesystem::exists(params_override.model_path.value()) && + if (std::filesystem::exists(model_path) && !model_service_->HasModel(model_handle)) { CTL_INF("llama_model_path exists, bypass check model id"); - params_override.bypass_llama_model_path = true; + bypass_llama_model_path = true; } } } - if (auto& o = (*(req->getJsonObject()))["model_path"]; !o.isNull()) { - params_override.model_path = o.asString(); - } + auto bypass_model_check = (mmproj.has_value() || bypass_llama_model_path); auto model_entry = model_service_->GetDownloadedModel(model_handle); - if (!model_entry.has_value() && !params_override.bypass_model_check()) { + if (!model_entry.has_value() && !bypass_model_check) { Json::Value ret; ret["message"] = "Cannot find model: " + model_handle; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); @@ -544,9 +520,8 @@ void Models::StartModel( callback(resp); return; } - std::string engine_name = params_override.bypass_model_check() - ? kLlamaEngine - : model_entry.value().engine; + std::string engine_name = + bypass_model_check ? kLlamaEngine : model_entry.value().engine; auto engine_validate = engine_service_->IsEngineReady(engine_name); if (engine_validate.has_error()) { Json::Value ret; @@ -565,7 +540,9 @@ void Models::StartModel( return; } - auto result = model_service_->StartModel(model_handle, params_override); + auto result = model_service_->StartModel( + model_handle, *(req->getJsonObject()) /*params_override*/, + bypass_model_check); if (result.has_error()) { Json::Value ret; ret["message"] = result.error(); @@ -668,7 +645,7 @@ void Models::AddRemoteModel( auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); auto engine_name = (*(req->getJsonObject())).get("engine", "").asString(); - + auto engine_validate = engine_service_->IsEngineReady(engine_name); if (engine_validate.has_error()) { Json::Value ret; @@ -687,7 +664,7 @@ void Models::AddRemoteModel( callback(resp); return; } - + config::RemoteModelConfig model_config; model_config.LoadFromJson(*(req->getJsonObject())); cortex::db::Models modellist_utils_obj; diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 0d909b61f..be0eb12a7 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -749,19 +749,28 @@ cpp::result ModelService::DeleteModel( } cpp::result ModelService::StartModel( - const std::string& model_handle, - const StartParameterOverride& params_override) { + const std::string& model_handle, const Json::Value& params_override, + bool bypass_model_check) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; + std::optional custom_prompt_template; + std::optional ctx_len; + if (auto& o = params_override["prompt_template"]; !o.isNull()) { + custom_prompt_template = o.asString(); + } + + if (auto& o = params_override["ctx_len"]; !o.isNull()) { + ctx_len = o.asInt(); + } try { constexpr const int kDefautlContextLength = 8192; int max_model_context_length = kDefautlContextLength; Json::Value json_data; // Currently we don't support download vision models, so we need to bypass check - if (!params_override.bypass_model_check()) { + if (!bypass_model_check) { auto model_entry = modellist_handler.GetModelInfo(model_handle); if (model_entry.has_error()) { CTL_WRN("Error: " + model_entry.error()); @@ -839,29 +848,19 @@ cpp::result ModelService::StartModel( } json_data["model"] = model_handle; - if (auto& cpt = params_override.custom_prompt_template; - !cpt.value_or("").empty()) { + if (auto& cpt = custom_prompt_template; !cpt.value_or("").empty()) { auto parse_prompt_result = string_utils::ParsePrompt(cpt.value()); json_data["system_prompt"] = parse_prompt_result.system_prompt; json_data["user_prompt"] = parse_prompt_result.user_prompt; json_data["ai_prompt"] = parse_prompt_result.ai_prompt; } -#define ASSIGN_IF_PRESENT(json_obj, param_override, param_name) \ - if (param_override.param_name) { \ - json_obj[#param_name] = param_override.param_name.value(); \ - } + json_helper::MergeJson(json_data, params_override); - ASSIGN_IF_PRESENT(json_data, params_override, cache_enabled); - ASSIGN_IF_PRESENT(json_data, params_override, ngl); - ASSIGN_IF_PRESENT(json_data, params_override, n_parallel); - ASSIGN_IF_PRESENT(json_data, params_override, cache_type); - ASSIGN_IF_PRESENT(json_data, params_override, mmproj); - ASSIGN_IF_PRESENT(json_data, params_override, model_path); -#undef ASSIGN_IF_PRESENT - if (params_override.ctx_len) { + // Set the latest ctx_len + if (ctx_len) { json_data["ctx_len"] = - std::min(params_override.ctx_len.value(), max_model_context_length); + std::min(ctx_len.value(), max_model_context_length); } CTL_INF(json_data.toStyledString()); auto may_fallback_res = MayFallbackToCpu(json_data["model_path"].asString(), diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 8b24b3421..ab3596812 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -22,21 +22,6 @@ struct ModelPullInfo { std::string download_url; }; -struct StartParameterOverride { - std::optional cache_enabled; - std::optional ngl; - std::optional n_parallel; - std::optional ctx_len; - std::optional custom_prompt_template; - std::optional cache_type; - std::optional mmproj; - std::optional model_path; - bool bypass_llama_model_path = false; - bool bypass_model_check() const { - return mmproj.has_value() || bypass_llama_model_path; - } -}; - struct StartModelResult { bool success; std::optional warning; @@ -82,8 +67,8 @@ class ModelService { cpp::result DeleteModel(const std::string& model_handle); cpp::result StartModel( - const std::string& model_handle, - const StartParameterOverride& params_override); + const std::string& model_handle, const Json::Value& params_override, + bool bypass_model_check); cpp::result StopModel(const std::string& model_handle); diff --git a/engine/test/components/test_json_helper.cc b/engine/test/components/test_json_helper.cc index cb3f4683a..ba5e27165 100644 --- a/engine/test/components/test_json_helper.cc +++ b/engine/test/components/test_json_helper.cc @@ -33,3 +33,61 @@ TEST(ParseJsonStringTest, EmptyString) { EXPECT_TRUE(result.isNull()); } + +TEST(MergeJsonTest, MergeSimpleObjects) { + Json::Value json1, json2; + json1["name"] = "John"; + json1["age"] = 30; + + json2["age"] = 31; + json2["email"] = "john@example.com"; + + json_helper::MergeJson(json1, json2); + + Json::Value expected; + expected["name"] = "John"; + expected["age"] = 31; + expected["email"] = "john@example.com"; + + EXPECT_EQ(json1, expected); +} + +TEST(MergeJsonTest, MergeNestedObjects) { + Json::Value json1, json2; + json1["person"]["name"] = "John"; + json1["person"]["age"] = 30; + + json2["person"]["age"] = 31; + json2["person"]["email"] = "john@example.com"; + + json_helper::MergeJson(json1, json2); + + Json::Value expected; + expected["person"]["name"] = "John"; + expected["person"]["age"] = 31; + expected["person"]["email"] = "john@example.com"; + + EXPECT_EQ(json1, expected); +} + +TEST(MergeJsonTest, MergeArrays) { + Json::Value json1, json2; + json1["hobbies"] = Json::Value(Json::arrayValue); + json1["hobbies"].append("reading"); + json1["hobbies"].append("painting"); + + json2["hobbies"] = Json::Value(Json::arrayValue); + json2["hobbies"].append("hiking"); + json2["hobbies"].append("painting"); + + json_helper::MergeJson(json1, json2); + + Json::Value expected; + expected["hobbies"] = Json::Value(Json::arrayValue); + expected["hobbies"].append("reading"); + expected["hobbies"].append("painting"); + expected["hobbies"].append("hiking"); + expected["hobbies"].append("painting"); + + EXPECT_EQ(json1, expected); +} diff --git a/engine/utils/json_helper.h b/engine/utils/json_helper.h index 82f994751..3b08651c4 100644 --- a/engine/utils/json_helper.h +++ b/engine/utils/json_helper.h @@ -16,4 +16,28 @@ inline std::string DumpJsonString(const Json::Value& json) { builder["indentation"] = ""; return Json::writeString(builder, json); } + +inline void MergeJson(Json::Value& target, const Json::Value& source) { + for (const auto& member : source.getMemberNames()) { + if (target.isMember(member)) { + // If the member exists in both objects, recursively merge the values + if (target[member].type() == Json::objectValue && + source[member].type() == Json::objectValue) { + MergeJson(target[member], source[member]); + } else if (target[member].type() == Json::arrayValue && + source[member].type() == Json::arrayValue) { + // If the member is an array in both objects, merge the arrays + for (const auto& value : source[member]) { + target[member].append(value); + } + } else { + // Otherwise, overwrite the value in the target with the value from the source + target[member] = source[member]; + } + } else { + // If the member doesn't exist in the target, add it + target[member] = source[member]; + } + } +} } // namespace json_helper