Skip to content

Commit

Permalink
Merge branch 'dev' into fix/linux-arm
Browse files Browse the repository at this point in the history
  • Loading branch information
vansangpfiev authored Dec 27, 2024
2 parents 537caf5 + 3456c7b commit 26daec9
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 91 deletions.
53 changes: 15 additions & 38 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,65 +488,40 @@ void Models::StartModel(
if (!http_util::HasFieldInReq(req, callback, "model"))
return;
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
StartParameterOverride params_override;
if (auto& o = (*(req->getJsonObject()))["prompt_template"]; !o.isNull()) {
params_override.custom_prompt_template = o.asString();
}

if (auto& o = (*(req->getJsonObject()))["cache_enabled"]; !o.isNull()) {
params_override.cache_enabled = o.asBool();
}

if (auto& o = (*(req->getJsonObject()))["ngl"]; !o.isNull()) {
params_override.ngl = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["n_parallel"]; !o.isNull()) {
params_override.n_parallel = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["ctx_len"]; !o.isNull()) {
params_override.ctx_len = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["cache_type"]; !o.isNull()) {
params_override.cache_type = o.asString();
}

std::optional<std::string> mmproj;
if (auto& o = (*(req->getJsonObject()))["mmproj"]; !o.isNull()) {
params_override.mmproj = o.asString();
mmproj = o.asString();
}

auto bypass_llama_model_path = false;
// Support both llama_model_path and model_path for backward compatible
// model_path has higher priority
if (auto& o = (*(req->getJsonObject()))["llama_model_path"]; !o.isNull()) {
params_override.model_path = o.asString();
auto model_path = o.asString();
if (auto& mp = (*(req->getJsonObject()))["model_path"]; mp.isNull()) {
// Bypass if model does not exist in DB and llama_model_path exists
if (std::filesystem::exists(params_override.model_path.value()) &&
if (std::filesystem::exists(model_path) &&
!model_service_->HasModel(model_handle)) {
CTL_INF("llama_model_path exists, bypass check model id");
params_override.bypass_llama_model_path = true;
bypass_llama_model_path = true;
}
}
}

if (auto& o = (*(req->getJsonObject()))["model_path"]; !o.isNull()) {
params_override.model_path = o.asString();
}
auto bypass_model_check = (mmproj.has_value() || bypass_llama_model_path);

auto model_entry = model_service_->GetDownloadedModel(model_handle);
if (!model_entry.has_value() && !params_override.bypass_model_check()) {
if (!model_entry.has_value() && !bypass_model_check) {
Json::Value ret;
ret["message"] = "Cannot find model: " + model_handle;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
return;
}
std::string engine_name = params_override.bypass_model_check()
? kLlamaEngine
: model_entry.value().engine;
std::string engine_name =
bypass_model_check ? kLlamaEngine : model_entry.value().engine;
auto engine_validate = engine_service_->IsEngineReady(engine_name);
if (engine_validate.has_error()) {
Json::Value ret;
Expand All @@ -565,7 +540,9 @@ void Models::StartModel(
return;
}

auto result = model_service_->StartModel(model_handle, params_override);
auto result = model_service_->StartModel(
model_handle, *(req->getJsonObject()) /*params_override*/,
bypass_model_check);
if (result.has_error()) {
Json::Value ret;
ret["message"] = result.error();
Expand Down Expand Up @@ -668,7 +645,7 @@ void Models::AddRemoteModel(

auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
auto engine_name = (*(req->getJsonObject())).get("engine", "").asString();

auto engine_validate = engine_service_->IsEngineReady(engine_name);
if (engine_validate.has_error()) {
Json::Value ret;
Expand All @@ -687,7 +664,7 @@ void Models::AddRemoteModel(
callback(resp);
return;
}

config::RemoteModelConfig model_config;
model_config.LoadFromJson(*(req->getJsonObject()));
cortex::db::Models modellist_utils_obj;
Expand Down
7 changes: 4 additions & 3 deletions engine/services/engine_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ cpp::result<bool, std::string> EngineService::UnzipEngine(
CTL_INF("Found cuda variant, extract it");
found_cuda = true;
// extract binary
auto cuda_path =
file_manager_utils::GetCudaToolkitPath(NormalizeEngine(engine));
auto cuda_path = file_manager_utils::GetCudaToolkitPath(
NormalizeEngine(engine), true);
archive_utils::ExtractArchive(path + "/" + cf, cuda_path.string(),
true);
}
Expand Down Expand Up @@ -434,7 +434,8 @@ cpp::result<bool, std::string> EngineService::DownloadCuda(
}};

auto on_finished = [engine](const DownloadTask& finishedTask) {
auto engine_path = file_manager_utils::GetCudaToolkitPath(engine);
auto engine_path = file_manager_utils::GetCudaToolkitPath(engine, true);

archive_utils::ExtractArchive(finishedTask.items[0].localPath.string(),
engine_path.string());
try {
Expand Down
4 changes: 2 additions & 2 deletions engine/services/file_service.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "file_service.h"
#include <cstdint>
#include "utils/ulid/ulid.hh"
#include "utils/ulid_generator.h"

cpp::result<OpenAi::File, std::string> FileService::UploadFile(
const std::string& filename, const std::string& purpose,
Expand All @@ -11,7 +11,7 @@ cpp::result<OpenAi::File, std::string> FileService::UploadFile(
std::chrono::system_clock::now().time_since_epoch())
.count();

auto file_id{"file-" + ulid::Marshal(ulid::CreateNowRand())};
auto file_id{"file-" + ulid::GenerateUlid()};
OpenAi::File file;
file.id = file_id;
file.object = "file";
Expand Down
7 changes: 2 additions & 5 deletions engine/services/message_service.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "services/message_service.h"
#include "utils/logging_utils.h"
#include "utils/result.hpp"
#include "utils/ulid/ulid.hh"
#include "utils/ulid_generator.h"

cpp::result<OpenAi::Message, std::string> MessageService::CreateMessage(
const std::string& thread_id, const OpenAi::Role& role,
Expand All @@ -27,11 +27,8 @@ cpp::result<OpenAi::Message, std::string> MessageService::CreateMessage(
std::get<std::vector<std::unique_ptr<OpenAi::Content>>>(content));
}

auto ulid = ulid::CreateNowRand();
auto msg_id = ulid::Marshal(ulid);

OpenAi::Message msg;
msg.id = msg_id;
msg.id = ulid::GenerateUlid();
msg.object = "thread.message";
msg.created_at = seconds_since_epoch;
msg.thread_id = thread_id;
Expand Down
35 changes: 17 additions & 18 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -749,19 +749,28 @@ cpp::result<void, std::string> ModelService::DeleteModel(
}

cpp::result<StartModelResult, std::string> ModelService::StartModel(
const std::string& model_handle,
const StartParameterOverride& params_override) {
const std::string& model_handle, const Json::Value& params_override,
bool bypass_model_check) {
namespace fs = std::filesystem;
namespace fmu = file_manager_utils;
cortex::db::Models modellist_handler;
config::YamlHandler yaml_handler;
std::optional<std::string> custom_prompt_template;
std::optional<int> ctx_len;
if (auto& o = params_override["prompt_template"]; !o.isNull()) {
custom_prompt_template = o.asString();
}

if (auto& o = params_override["ctx_len"]; !o.isNull()) {
ctx_len = o.asInt();
}

try {
constexpr const int kDefautlContextLength = 8192;
int max_model_context_length = kDefautlContextLength;
Json::Value json_data;
// Currently we don't support download vision models, so we need to bypass check
if (!params_override.bypass_model_check()) {
if (!bypass_model_check) {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
if (model_entry.has_error()) {
CTL_WRN("Error: " + model_entry.error());
Expand Down Expand Up @@ -839,29 +848,19 @@ cpp::result<StartModelResult, std::string> ModelService::StartModel(
}

json_data["model"] = model_handle;
if (auto& cpt = params_override.custom_prompt_template;
!cpt.value_or("").empty()) {
if (auto& cpt = custom_prompt_template; !cpt.value_or("").empty()) {
auto parse_prompt_result = string_utils::ParsePrompt(cpt.value());
json_data["system_prompt"] = parse_prompt_result.system_prompt;
json_data["user_prompt"] = parse_prompt_result.user_prompt;
json_data["ai_prompt"] = parse_prompt_result.ai_prompt;
}

#define ASSIGN_IF_PRESENT(json_obj, param_override, param_name) \
if (param_override.param_name) { \
json_obj[#param_name] = param_override.param_name.value(); \
}
json_helper::MergeJson(json_data, params_override);

ASSIGN_IF_PRESENT(json_data, params_override, cache_enabled);
ASSIGN_IF_PRESENT(json_data, params_override, ngl);
ASSIGN_IF_PRESENT(json_data, params_override, n_parallel);
ASSIGN_IF_PRESENT(json_data, params_override, cache_type);
ASSIGN_IF_PRESENT(json_data, params_override, mmproj);
ASSIGN_IF_PRESENT(json_data, params_override, model_path);
#undef ASSIGN_IF_PRESENT
if (params_override.ctx_len) {
// Set the latest ctx_len
if (ctx_len) {
json_data["ctx_len"] =
std::min(params_override.ctx_len.value(), max_model_context_length);
std::min(ctx_len.value(), max_model_context_length);
}
CTL_INF(json_data.toStyledString());
auto may_fallback_res = MayFallbackToCpu(json_data["model_path"].asString(),
Expand Down
19 changes: 2 additions & 17 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@ struct ModelPullInfo {
std::string download_url;
};

struct StartParameterOverride {
std::optional<bool> cache_enabled;
std::optional<int> ngl;
std::optional<int> n_parallel;
std::optional<int> ctx_len;
std::optional<std::string> custom_prompt_template;
std::optional<std::string> cache_type;
std::optional<std::string> mmproj;
std::optional<std::string> model_path;
bool bypass_llama_model_path = false;
bool bypass_model_check() const {
return mmproj.has_value() || bypass_llama_model_path;
}
};

struct StartModelResult {
bool success;
std::optional<std::string> warning;
Expand Down Expand Up @@ -82,8 +67,8 @@ class ModelService {
cpp::result<void, std::string> DeleteModel(const std::string& model_handle);

cpp::result<StartModelResult, std::string> StartModel(
const std::string& model_handle,
const StartParameterOverride& params_override);
const std::string& model_handle, const Json::Value& params_override,
bool bypass_model_check);

cpp::result<bool, std::string> StopModel(const std::string& model_handle);

Expand Down
8 changes: 3 additions & 5 deletions engine/services/thread_service.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "thread_service.h"
#include <chrono>
#include "utils/logging_utils.h"
#include "utils/ulid/ulid.hh"
#include "utils/ulid_generator.h"

cpp::result<OpenAi::Thread, std::string> ThreadService::CreateThread(
std::unique_ptr<OpenAi::ThreadToolResources> tool_resources,
Expand All @@ -12,11 +13,8 @@ cpp::result<OpenAi::Thread, std::string> ThreadService::CreateThread(
std::chrono::system_clock::now().time_since_epoch())
.count();

auto ulid = ulid::CreateNowRand();
auto thread_id = ulid::Marshal(ulid);

OpenAi::Thread thread;
thread.id = thread_id;
thread.id = ulid::GenerateUlid();
thread.object = "thread";
thread.created_at = seconds_since_epoch;

Expand Down
58 changes: 58 additions & 0 deletions engine/test/components/test_json_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,61 @@ TEST(ParseJsonStringTest, EmptyString) {

EXPECT_TRUE(result.isNull());
}

TEST(MergeJsonTest, MergeSimpleObjects) {
Json::Value json1, json2;
json1["name"] = "John";
json1["age"] = 30;

json2["age"] = 31;
json2["email"] = "[email protected]";

json_helper::MergeJson(json1, json2);

Json::Value expected;
expected["name"] = "John";
expected["age"] = 31;
expected["email"] = "[email protected]";

EXPECT_EQ(json1, expected);
}

TEST(MergeJsonTest, MergeNestedObjects) {
Json::Value json1, json2;
json1["person"]["name"] = "John";
json1["person"]["age"] = 30;

json2["person"]["age"] = 31;
json2["person"]["email"] = "[email protected]";

json_helper::MergeJson(json1, json2);

Json::Value expected;
expected["person"]["name"] = "John";
expected["person"]["age"] = 31;
expected["person"]["email"] = "[email protected]";

EXPECT_EQ(json1, expected);
}

TEST(MergeJsonTest, MergeArrays) {
Json::Value json1, json2;
json1["hobbies"] = Json::Value(Json::arrayValue);
json1["hobbies"].append("reading");
json1["hobbies"].append("painting");

json2["hobbies"] = Json::Value(Json::arrayValue);
json2["hobbies"].append("hiking");
json2["hobbies"].append("painting");

json_helper::MergeJson(json1, json2);

Json::Value expected;
expected["hobbies"] = Json::Value(Json::arrayValue);
expected["hobbies"].append("reading");
expected["hobbies"].append("painting");
expected["hobbies"].append("hiking");
expected["hobbies"].append("painting");

EXPECT_EQ(json1, expected);
}
5 changes: 3 additions & 2 deletions engine/utils/file_manager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,14 @@ std::filesystem::path GetModelsContainerPath() {
return models_container_path;
}

std::filesystem::path GetCudaToolkitPath(const std::string& engine) {
std::filesystem::path GetCudaToolkitPath(const std::string& engine,
bool create_if_not_exist) {
auto engine_path = getenv("ENGINE_PATH")
? std::filesystem::path(getenv("ENGINE_PATH"))
: GetCortexDataPath();

auto cuda_path = engine_path / "engines" / engine / "deps";
if (!std::filesystem::exists(cuda_path)) {
if (create_if_not_exist && !std::filesystem::exists(cuda_path)) {
std::filesystem::create_directories(cuda_path);
}

Expand Down
3 changes: 2 additions & 1 deletion engine/utils/file_manager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void CreateDirectoryRecursively(const std::string& path);

std::filesystem::path GetModelsContainerPath();

std::filesystem::path GetCudaToolkitPath(const std::string& engine);
std::filesystem::path GetCudaToolkitPath(const std::string& engine,
bool create_if_not_exist = false);

std::filesystem::path GetEnginesContainerPath();

Expand Down
Loading

0 comments on commit 26daec9

Please sign in to comment.