Skip to content

Commit

Permalink
Init model.list utils (#1240)
Browse files Browse the repository at this point in the history
* Init model.list utils

* Add cmakelist compile

* Add cmakelist compile

* Fix CI build windows

* add unitest

* Add test

* Fix fail unitest
  • Loading branch information
nguyenhoangthuan99 authored Sep 18, 2024
1 parent 142adf0 commit b8078af
Show file tree
Hide file tree
Showing 6 changed files with 358 additions and 2 deletions.
1 change: 1 addition & 0 deletions engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ find_package(CURL REQUIRED)
add_executable(${TARGET_NAME} main.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/modellist_utils.cc
)

target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib)
Expand Down
5 changes: 3 additions & 2 deletions engine/test/components/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ project(test-components)

enable_testing()

add_executable(${PROJECT_NAME} ${SRCS})
add_executable(${PROJECT_NAME} ${SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/modellist_utils.cc)

find_package(Drogon CONFIG REQUIRED)
find_package(GTest CONFIG REQUIRED)
find_package(yaml-cpp CONFIG REQUIRED)

target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main
target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp
${CMAKE_THREAD_LIBS_INIT})
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)

Expand Down
91 changes: 91 additions & 0 deletions engine/test/components/test_modellist_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include <filesystem>
#include <iostream>
#include "gtest/gtest.h"
#include "utils/modellist_utils.h"
#include "utils/file_manager_utils.h"
class ModelListUtilsTestSuite : public ::testing::Test {
protected:
modellist_utils::ModelListUtils model_list_;

const modellist_utils::ModelEntry kTestModel{
"test_model_id", "test_author",
"main", "/path/to/model.yaml",
"test_alias", modellist_utils::ModelStatus::READY};
};
void SetUp() {
// Create a temporary directory for tests
file_manager_utils::CreateConfigFileIfNotExist();
}

void TearDown() {
// Clean up the temporary directory
}
TEST_F(ModelListUtilsTestSuite, TestAddModelEntry) {
EXPECT_TRUE(model_list_.AddModelEntry(kTestModel));

auto retrieved_model = model_list_.GetModelInfo("test_model_id");
EXPECT_EQ(retrieved_model.model_id, kTestModel.model_id);
EXPECT_EQ(retrieved_model.author_repo_id, kTestModel.author_repo_id);
}

TEST_F(ModelListUtilsTestSuite, TestGetModelInfo) {
model_list_.AddModelEntry(kTestModel);

auto model_by_id = model_list_.GetModelInfo("test_model_id");
EXPECT_EQ(model_by_id.model_id, kTestModel.model_id);

auto model_by_alias = model_list_.GetModelInfo("test_alias");
EXPECT_EQ(model_by_alias.model_id, kTestModel.model_id);

EXPECT_THROW(model_list_.GetModelInfo("non_existent_model"),
std::runtime_error);
}

TEST_F(ModelListUtilsTestSuite, TestUpdateModelEntry) {
model_list_.AddModelEntry(kTestModel);

modellist_utils::ModelEntry updated_model = kTestModel;
updated_model.status = modellist_utils::ModelStatus::RUNNING;

EXPECT_TRUE(model_list_.UpdateModelEntry("test_model_id", updated_model));

auto retrieved_model = model_list_.GetModelInfo("test_model_id");
EXPECT_EQ(retrieved_model.status, modellist_utils::ModelStatus::RUNNING);
updated_model.status = modellist_utils::ModelStatus::READY;
model_list_.UpdateModelEntry("test_model_id", updated_model);
}

TEST_F(ModelListUtilsTestSuite, TestDeleteModelEntry) {
model_list_.AddModelEntry(kTestModel);

EXPECT_TRUE(model_list_.DeleteModelEntry("test_model_id"));
EXPECT_THROW(model_list_.GetModelInfo("test_model_id"), std::runtime_error);
}

TEST_F(ModelListUtilsTestSuite, TestGenerateShortenedAlias) {
auto alias = model_list_.GenerateShortenedAlias(
"huggingface.co/bartowski/llama3.1-7b-gguf/Model_ID_Xxx.gguf", {});
EXPECT_EQ(alias, "model_id_xxx");

// Test with existing entries to force longer alias
modellist_utils::ModelEntry existing_model = kTestModel;
existing_model.model_alias = "model_id_xxx";
std::vector<modellist_utils::ModelEntry> existing_entries = {existing_model};

alias = model_list_.GenerateShortenedAlias(
"huggingface.co/bartowski/llama3.1-7b-gguf/Model_ID_Xxx.gguf",
existing_entries);
EXPECT_EQ(alias, "llama3.1-7b-gguf:model_id_xxx");
}

TEST_F(ModelListUtilsTestSuite, TestPersistence) {
model_list_.AddModelEntry(kTestModel);

// Create a new ModelListUtils instance to test if it loads from file
modellist_utils::ModelListUtils new_model_list;
auto retrieved_model = new_model_list.GetModelInfo("test_model_id");

EXPECT_EQ(retrieved_model.model_id, kTestModel.model_id);
EXPECT_EQ(retrieved_model.author_repo_id, kTestModel.author_repo_id);
model_list_.DeleteModelEntry("test_model_id");
}
1 change: 1 addition & 0 deletions engine/utils/file_manager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ inline void CreateDirectoryRecursively(const std::string& path) {
}

inline std::filesystem::path GetModelsContainerPath() {
CreateConfigFileIfNotExist();
auto cortex_path = GetCortexDataPath();
auto models_container_path = cortex_path / "models";

Expand Down
218 changes: 218 additions & 0 deletions engine/utils/modellist_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
#include "modellist_utils.h"
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <regex>
#include <sstream>
#include <stdexcept>
#include "file_manager_utils.h"
namespace modellist_utils {
const std::string ModelListUtils::kModelListPath =
(file_manager_utils::GetModelsContainerPath() /
std::filesystem::path("model.list"))
.string();

std::vector<ModelEntry> ModelListUtils::LoadModelList() const {
std::vector<ModelEntry> entries;
std::filesystem::path file_path(kModelListPath);

// Check if the file exists, if not, create it
if (!std::filesystem::exists(file_path)) {
std::ofstream create_file(kModelListPath);
if (!create_file) {
throw std::runtime_error("Unable to create model.list file: " +
kModelListPath);
}
create_file.close();
return entries; // Return empty vector for newly created file
}

std::ifstream file(kModelListPath);
if (!file.is_open()) {
throw std::runtime_error("Unable to open model.list file: " +
kModelListPath);
}

std::string line;
while (std::getline(file, line)) {
std::istringstream iss(line);
ModelEntry entry;
std::string status_str;
if (!(iss >> entry.model_id >> entry.author_repo_id >> entry.branch_name >>
entry.path_to_model_yaml >> entry.model_alias >> status_str)) {
LOG_WARN << "Invalid entry in model.list: " << line;
} else {
entry.status =
(status_str == "RUNNING") ? ModelStatus::RUNNING : ModelStatus::READY;
entries.push_back(entry);
}
}
return entries;
}

bool ModelListUtils::IsUnique(const std::vector<ModelEntry>& entries,
const std::string& model_id,
const std::string& model_alias) const {
return std::none_of(
entries.begin(), entries.end(), [&](const ModelEntry& entry) {
return entry.model_id == model_id || entry.model_alias == model_id ||
entry.model_id == model_alias ||
entry.model_alias == model_alias;
});
}

void ModelListUtils::SaveModelList(
const std::vector<ModelEntry>& entries) const {
std::ofstream file(kModelListPath);
if (!file.is_open()) {
throw std::runtime_error("Unable to open model.list file for writing: " +
kModelListPath);
}

for (const auto& entry : entries) {
file << entry.model_id << " " << entry.author_repo_id << " "
<< entry.branch_name << " " << entry.path_to_model_yaml << " "
<< entry.model_alias << " "
<< (entry.status == ModelStatus::RUNNING ? "RUNNING" : "READY")
<< std::endl;
}
}

std::string ModelListUtils::GenerateShortenedAlias(
const std::string& model_id, const std::vector<ModelEntry>& entries) const {
std::vector<std::string> parts;
std::istringstream iss(model_id);
std::string part;
while (std::getline(iss, part, '/')) {
parts.push_back(part);
}

if (parts.empty()) {
return model_id; // Return original if no parts
}

// Extract the filename without extension
std::string filename = parts.back();
size_t last_dot_pos = filename.find_last_of('.');
if (last_dot_pos != std::string::npos) {
filename = filename.substr(0, last_dot_pos);
}

// Convert to lowercase
std::transform(filename.begin(), filename.end(), filename.begin(),
[](unsigned char c) { return std::tolower(c); });

// Generate alias candidates
std::vector<std::string> candidates;
candidates.push_back(filename);

if (parts.size() >= 2) {
candidates.push_back(parts[parts.size() - 2] + ":" + filename);
}

if (parts.size() >= 3) {
candidates.push_back(parts[parts.size() - 3] + ":" +
parts[parts.size() - 2] + "/" + filename);
}

if (parts.size() >= 4) {
candidates.push_back(parts[0] + ":" + parts[1] + "/" +
parts[parts.size() - 2] + "/" + filename);
}

// Find the first unique candidate
for (const auto& candidate : candidates) {
if (IsUnique(entries, model_id, candidate)) {
return candidate;
}
}

// If all candidates are taken, append a number to the last candidate
std::string base_candidate = candidates.back();
int suffix = 1;
std::string unique_candidate = base_candidate;
while (!IsUnique(entries, model_id, unique_candidate)) {
unique_candidate = base_candidate + "-" + std::to_string(suffix++);
}

return unique_candidate;
}

ModelEntry ModelListUtils::GetModelInfo(const std::string& identifier) const {
std::lock_guard<std::mutex> lock(mutex_);
auto entries = LoadModelList();
auto it = std::find_if(
entries.begin(), entries.end(), [&identifier](const ModelEntry& entry) {
return entry.model_id == identifier || entry.model_alias == identifier;
});

if (it != entries.end()) {
return *it;
} else {
throw std::runtime_error("Model not found: " + identifier);
}
}

void ModelListUtils::PrintModelInfo(const ModelEntry& entry) const {
LOG_INFO << "Model ID: " << entry.model_id;
LOG_INFO << "Author/Repo ID: " << entry.author_repo_id;
LOG_INFO << "Branch Name: " << entry.branch_name;
LOG_INFO << "Path to model.yaml: " << entry.path_to_model_yaml;
LOG_INFO << "Model Alias: " << entry.model_alias;
LOG_INFO << "Status: "
<< (entry.status == ModelStatus::RUNNING ? "RUNNING" : "READY");
}

bool ModelListUtils::AddModelEntry(ModelEntry new_entry, bool use_short_alias) {
std::lock_guard<std::mutex> lock(mutex_);
auto entries = LoadModelList();

if (IsUnique(entries, new_entry.model_id, new_entry.model_alias)) {
if (use_short_alias) {
new_entry.model_alias =
GenerateShortenedAlias(new_entry.model_id, entries);
}
new_entry.status = ModelStatus::READY; // Set default status to READY
entries.push_back(std::move(new_entry));
SaveModelList(entries);
return true;
}
return false; // Entry not added due to non-uniqueness
}

bool ModelListUtils::UpdateModelEntry(const std::string& identifier,
const ModelEntry& updated_entry) {
std::lock_guard<std::mutex> lock(mutex_);
auto entries = LoadModelList();
auto it = std::find_if(
entries.begin(), entries.end(), [&identifier](const ModelEntry& entry) {
return entry.model_id == identifier || entry.model_alias == identifier;
});

if (it != entries.end()) {
*it = updated_entry;
SaveModelList(entries);
return true;
}
return false; // Entry not found
}

bool ModelListUtils::DeleteModelEntry(const std::string& identifier) {
std::lock_guard<std::mutex> lock(mutex_);
auto entries = LoadModelList();
auto it = std::find_if(entries.begin(), entries.end(),
[&identifier](const ModelEntry& entry) {
return (entry.model_id == identifier ||
entry.model_alias == identifier) &&
entry.status == ModelStatus::READY;
});

if (it != entries.end()) {
entries.erase(it);
SaveModelList(entries);
return true;
}
return false; // Entry not found or not in READY state
}
} // namespace modellist_utils
44 changes: 44 additions & 0 deletions engine/utils/modellist_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once
#include <trantor/utils/Logger.h>
#include <mutex>
#include <string>
#include <vector>
#include "logging_utils.h"
namespace modellist_utils {

enum class ModelStatus { READY, RUNNING };

struct ModelEntry {
std::string model_id;
std::string author_repo_id;
std::string branch_name;
std::string path_to_model_yaml;
std::string model_alias;
ModelStatus status;
};

class ModelListUtils {

private:
mutable std::mutex mutex_; // For thread safety

std::vector<ModelEntry> LoadModelList() const;
bool IsUnique(const std::vector<ModelEntry>& entries,
const std::string& model_id,
const std::string& model_alias) const;
void SaveModelList(const std::vector<ModelEntry>& entries) const;

public:
static const std::string kModelListPath;
ModelListUtils() = default;
std::string GenerateShortenedAlias(
const std::string& model_id,
const std::vector<ModelEntry>& entries) const;
ModelEntry GetModelInfo(const std::string& identifier) const;
void PrintModelInfo(const ModelEntry& entry) const;
bool AddModelEntry(ModelEntry new_entry, bool use_short_alias = false);
bool UpdateModelEntry(const std::string& identifier,
const ModelEntry& updated_entry);
bool DeleteModelEntry(const std::string& identifier);
};
} // namespace modellist_utils

0 comments on commit b8078af

Please sign in to comment.