Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/model import cmd #1248

Merged
merged 16 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions engine/commands/model_import_cmd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "model_import_cmd.h"
#include <filesystem>
#include <iostream>
#include <vector>
#include "config/gguf_parser.h"
#include "config/yaml_config.h"
#include "trantor/utils/Logger.h"
#include "utils/file_manager_utils.h"
#include "utils/logging_utils.h"
#include "utils/modellist_utils.h"

namespace commands {

ModelImportCmd::ModelImportCmd(std::string model_handle, std::string model_path)
: model_handle_(std::move(model_handle)),
model_path_(std::move(model_path)) {}

void ModelImportCmd::Exec() {
config::GGUFHandler gguf_handler;
config::YamlHandler yaml_handler;
modellist_utils::ModelListUtils modellist_utils_obj;

std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() /
std::filesystem::path("imported") /
std::filesystem::path(model_handle_ + ".yml"))
.string();
modellist_utils::ModelEntry model_entry{
model_handle_, "local", "imported",
model_yaml_path, model_handle_, modellist_utils::ModelStatus::READY};
try {
std::filesystem::create_directories(
std::filesystem::path(model_yaml_path).parent_path());
gguf_handler.Parse(model_path_);
auto model_config = gguf_handler.GetModelConfig();
model_config.files.push_back(model_path_);
model_config.model = model_handle_;
yaml_handler.UpdateModelConfig(model_config);

if (modellist_utils_obj.AddModelEntry(model_entry)) {
yaml_handler.WriteYamlFile(model_yaml_path);
CLI_LOG("Model is imported successfully!");
} else {
CLI_LOG("Fail to import model, model_id '" + model_handle_ +
"' already exists!");
}

} catch (const std::exception& e) {
std::remove(model_yaml_path.c_str());
CLI_LOG("Error importing model path '" + model_path_ + "' with model_id '" +
model_handle_ + "': " + e.what());
}
}
} // namespace commands
15 changes: 15 additions & 0 deletions engine/commands/model_import_cmd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <string>
namespace commands {

class ModelImportCmd {
public:
ModelImportCmd(std::string model_handle, std::string model_path);
void Exec();

private:
std::string model_handle_;
std::string model_path_;
};
} // namespace commands
14 changes: 14 additions & 0 deletions engine/controllers/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "commands/model_alias_cmd.h"
#include "commands/model_del_cmd.h"
#include "commands/model_get_cmd.h"
#include "commands/model_import_cmd.h"
#include "commands/model_list_cmd.h"
#include "commands/model_pull_cmd.h"
#include "commands/model_start_cmd.h"
Expand Down Expand Up @@ -166,6 +167,19 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) {
auto model_update_cmd =
models_cmd->add_subcommand("update", "Update configuration of a model");

std::string model_path;
auto model_import_cmd = models_cmd->add_subcommand(
"import", "Import a gguf model from local file");
model_import_cmd->add_option("--model_id", model_id, "");
model_import_cmd->add_option("--model_path", model_path,
"Absolute path to .gguf model, the path should "
"include the gguf file name");
model_import_cmd->require_option(2);
model_import_cmd->callback([&model_id,&model_path]() {
commands::ModelImportCmd command(model_id, model_path);
command.Exec();
});

// Default version is latest
std::string version{"latest"};
// engines group commands
Expand Down
72 changes: 72 additions & 0 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,78 @@ void Models::DeleteModel(const HttpRequestPtr& req,
}
}

void Models::ImportModel(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const {
if (!http_util::HasFieldInReq(req, callback, "modelId") ||
!http_util::HasFieldInReq(req, callback, "modelPath")) {
return;
}
auto modelHandle = (*(req->getJsonObject())).get("modelId", "").asString();
auto modelPath = (*(req->getJsonObject())).get("modelPath", "").asString();
config::GGUFHandler gguf_handler;
config::YamlHandler yaml_handler;
modellist_utils::ModelListUtils modellist_utils_obj;

std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() /
std::filesystem::path("imported") /
std::filesystem::path(modelHandle + ".yml"))
.string();
modellist_utils::ModelEntry model_entry{
modelHandle, "local", "imported",
model_yaml_path, modelHandle, modellist_utils::ModelStatus::READY};
try {
std::filesystem::create_directories(
std::filesystem::path(model_yaml_path).parent_path());
gguf_handler.Parse(modelPath);
config::ModelConfig model_config = gguf_handler.GetModelConfig();
model_config.files.push_back(modelPath);
model_config.name = modelHandle;
yaml_handler.UpdateModelConfig(model_config);

if (modellist_utils_obj.AddModelEntry(model_entry)) {
yaml_handler.WriteYamlFile(model_yaml_path);
std::string success_message = "Model is imported successfully!";
LOG_INFO << success_message;
Json::Value ret;
ret["result"] = "OK";
ret["modelHandle"] = modelHandle;
ret["message"] = success_message;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k200OK);
callback(resp);

} else {
std::string error_message = "Fail to import model, model_id '" +
modelHandle + "' already exists!";
LOG_ERROR << error_message;
Json::Value ret;
ret["result"] = "Import failed!";
ret["modelHandle"] = modelHandle;
ret["message"] = error_message;

auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
}

} catch (const std::exception& e) {
std::remove(model_yaml_path.c_str());
std::string error_message = "Error importing model path '" + modelPath +
"' with model_id '" + modelHandle +
"': " + e.what();
LOG_ERROR << error_message;
Json::Value ret;
ret["result"] = "Import failed!";
ret["modelHandle"] = modelHandle;
ret["message"] = error_message;

auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
}
}

void Models::SetModelAlias(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const {
Expand Down
3 changes: 3 additions & 0 deletions engine/controllers/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Models : public drogon::HttpController<Models> {
METHOD_ADD(Models::PullModel, "/pull", Post);
METHOD_ADD(Models::ListModel, "/list", Get);
METHOD_ADD(Models::GetModel, "/get", Post);
METHOD_ADD(Models::ImportModel, "/import", Post);
METHOD_ADD(Models::DeleteModel, "/{1}", Delete);
METHOD_ADD(Models::SetModelAlias, "/alias", Post);
METHOD_LIST_END
Expand All @@ -25,6 +26,8 @@ class Models : public drogon::HttpController<Models> {
std::function<void(const HttpResponsePtr&)>&& callback) const;
void GetModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const;
void ImportModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const;
void DeleteModel(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& model_id) const;
Expand Down
1 change: 1 addition & 0 deletions engine/e2e-test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from test_cli_server_start import TestCliServerStart
from test_cortex_update import TestCortexUpdate
from test_create_log_folder import TestCreateLogFolder
from test_cli_model_import import TestCliModelImport

if __name__ == "__main__":
pytest.main([__file__, "-v"])
14 changes: 14 additions & 0 deletions engine/e2e-test/test_cli_model_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
from test_runner import run

class TestCliModelImport:

@pytest.mark.skipif(True, reason="Expensive test. Only test when you have local gguf file.")
def test_model_import_should_be_success(self):

exit_code, output, error = run(
"Pull model", ["models", "import", "--model_id","test_model","--model_path","/path/to/local/gguf"],
timeout=None
)
assert exit_code == 0, f"Model import failed failed with error: {error}"
nguyenhoangthuan99 marked this conversation as resolved.
Show resolved Hide resolved
# TODO: skip this test. since download model is taking too long
Loading