-
Notifications
You must be signed in to change notification settings - Fork 145
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Init model.list utils * Add cmakelist compile * Add cmakelist compile * Fix CI build windows * add unitest * Add test * Fix fail unitest
- Loading branch information
1 parent
142adf0
commit b8078af
Showing
6 changed files
with
358 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#include <filesystem> | ||
#include <iostream> | ||
#include "gtest/gtest.h" | ||
#include "utils/modellist_utils.h" | ||
#include "utils/file_manager_utils.h" | ||
class ModelListUtilsTestSuite : public ::testing::Test { | ||
protected: | ||
modellist_utils::ModelListUtils model_list_; | ||
|
||
const modellist_utils::ModelEntry kTestModel{ | ||
"test_model_id", "test_author", | ||
"main", "/path/to/model.yaml", | ||
"test_alias", modellist_utils::ModelStatus::READY}; | ||
}; | ||
void SetUp() { | ||
// Create a temporary directory for tests | ||
file_manager_utils::CreateConfigFileIfNotExist(); | ||
} | ||
|
||
void TearDown() { | ||
// Clean up the temporary directory | ||
} | ||
TEST_F(ModelListUtilsTestSuite, TestAddModelEntry) { | ||
EXPECT_TRUE(model_list_.AddModelEntry(kTestModel)); | ||
|
||
auto retrieved_model = model_list_.GetModelInfo("test_model_id"); | ||
EXPECT_EQ(retrieved_model.model_id, kTestModel.model_id); | ||
EXPECT_EQ(retrieved_model.author_repo_id, kTestModel.author_repo_id); | ||
} | ||
|
||
TEST_F(ModelListUtilsTestSuite, TestGetModelInfo) { | ||
model_list_.AddModelEntry(kTestModel); | ||
|
||
auto model_by_id = model_list_.GetModelInfo("test_model_id"); | ||
EXPECT_EQ(model_by_id.model_id, kTestModel.model_id); | ||
|
||
auto model_by_alias = model_list_.GetModelInfo("test_alias"); | ||
EXPECT_EQ(model_by_alias.model_id, kTestModel.model_id); | ||
|
||
EXPECT_THROW(model_list_.GetModelInfo("non_existent_model"), | ||
std::runtime_error); | ||
} | ||
|
||
TEST_F(ModelListUtilsTestSuite, TestUpdateModelEntry) { | ||
model_list_.AddModelEntry(kTestModel); | ||
|
||
modellist_utils::ModelEntry updated_model = kTestModel; | ||
updated_model.status = modellist_utils::ModelStatus::RUNNING; | ||
|
||
EXPECT_TRUE(model_list_.UpdateModelEntry("test_model_id", updated_model)); | ||
|
||
auto retrieved_model = model_list_.GetModelInfo("test_model_id"); | ||
EXPECT_EQ(retrieved_model.status, modellist_utils::ModelStatus::RUNNING); | ||
updated_model.status = modellist_utils::ModelStatus::READY; | ||
model_list_.UpdateModelEntry("test_model_id", updated_model); | ||
} | ||
|
||
TEST_F(ModelListUtilsTestSuite, TestDeleteModelEntry) { | ||
model_list_.AddModelEntry(kTestModel); | ||
|
||
EXPECT_TRUE(model_list_.DeleteModelEntry("test_model_id")); | ||
EXPECT_THROW(model_list_.GetModelInfo("test_model_id"), std::runtime_error); | ||
} | ||
|
||
TEST_F(ModelListUtilsTestSuite, TestGenerateShortenedAlias) { | ||
auto alias = model_list_.GenerateShortenedAlias( | ||
"huggingface.co/bartowski/llama3.1-7b-gguf/Model_ID_Xxx.gguf", {}); | ||
EXPECT_EQ(alias, "model_id_xxx"); | ||
|
||
// Test with existing entries to force longer alias | ||
modellist_utils::ModelEntry existing_model = kTestModel; | ||
existing_model.model_alias = "model_id_xxx"; | ||
std::vector<modellist_utils::ModelEntry> existing_entries = {existing_model}; | ||
|
||
alias = model_list_.GenerateShortenedAlias( | ||
"huggingface.co/bartowski/llama3.1-7b-gguf/Model_ID_Xxx.gguf", | ||
existing_entries); | ||
EXPECT_EQ(alias, "llama3.1-7b-gguf:model_id_xxx"); | ||
} | ||
|
||
TEST_F(ModelListUtilsTestSuite, TestPersistence) { | ||
model_list_.AddModelEntry(kTestModel); | ||
|
||
// Create a new ModelListUtils instance to test if it loads from file | ||
modellist_utils::ModelListUtils new_model_list; | ||
auto retrieved_model = new_model_list.GetModelInfo("test_model_id"); | ||
|
||
EXPECT_EQ(retrieved_model.model_id, kTestModel.model_id); | ||
EXPECT_EQ(retrieved_model.author_repo_id, kTestModel.author_repo_id); | ||
model_list_.DeleteModelEntry("test_model_id"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
#include "modellist_utils.h" | ||
#include <algorithm> | ||
#include <filesystem> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <regex> | ||
#include <sstream> | ||
#include <stdexcept> | ||
#include "file_manager_utils.h" | ||
namespace modellist_utils { | ||
const std::string ModelListUtils::kModelListPath = | ||
(file_manager_utils::GetModelsContainerPath() / | ||
std::filesystem::path("model.list")) | ||
.string(); | ||
|
||
std::vector<ModelEntry> ModelListUtils::LoadModelList() const { | ||
std::vector<ModelEntry> entries; | ||
std::filesystem::path file_path(kModelListPath); | ||
|
||
// Check if the file exists, if not, create it | ||
if (!std::filesystem::exists(file_path)) { | ||
std::ofstream create_file(kModelListPath); | ||
if (!create_file) { | ||
throw std::runtime_error("Unable to create model.list file: " + | ||
kModelListPath); | ||
} | ||
create_file.close(); | ||
return entries; // Return empty vector for newly created file | ||
} | ||
|
||
std::ifstream file(kModelListPath); | ||
if (!file.is_open()) { | ||
throw std::runtime_error("Unable to open model.list file: " + | ||
kModelListPath); | ||
} | ||
|
||
std::string line; | ||
while (std::getline(file, line)) { | ||
std::istringstream iss(line); | ||
ModelEntry entry; | ||
std::string status_str; | ||
if (!(iss >> entry.model_id >> entry.author_repo_id >> entry.branch_name >> | ||
entry.path_to_model_yaml >> entry.model_alias >> status_str)) { | ||
LOG_WARN << "Invalid entry in model.list: " << line; | ||
} else { | ||
entry.status = | ||
(status_str == "RUNNING") ? ModelStatus::RUNNING : ModelStatus::READY; | ||
entries.push_back(entry); | ||
} | ||
} | ||
return entries; | ||
} | ||
|
||
bool ModelListUtils::IsUnique(const std::vector<ModelEntry>& entries, | ||
const std::string& model_id, | ||
const std::string& model_alias) const { | ||
return std::none_of( | ||
entries.begin(), entries.end(), [&](const ModelEntry& entry) { | ||
return entry.model_id == model_id || entry.model_alias == model_id || | ||
entry.model_id == model_alias || | ||
entry.model_alias == model_alias; | ||
}); | ||
} | ||
|
||
void ModelListUtils::SaveModelList( | ||
const std::vector<ModelEntry>& entries) const { | ||
std::ofstream file(kModelListPath); | ||
if (!file.is_open()) { | ||
throw std::runtime_error("Unable to open model.list file for writing: " + | ||
kModelListPath); | ||
} | ||
|
||
for (const auto& entry : entries) { | ||
file << entry.model_id << " " << entry.author_repo_id << " " | ||
<< entry.branch_name << " " << entry.path_to_model_yaml << " " | ||
<< entry.model_alias << " " | ||
<< (entry.status == ModelStatus::RUNNING ? "RUNNING" : "READY") | ||
<< std::endl; | ||
} | ||
} | ||
|
||
std::string ModelListUtils::GenerateShortenedAlias( | ||
const std::string& model_id, const std::vector<ModelEntry>& entries) const { | ||
std::vector<std::string> parts; | ||
std::istringstream iss(model_id); | ||
std::string part; | ||
while (std::getline(iss, part, '/')) { | ||
parts.push_back(part); | ||
} | ||
|
||
if (parts.empty()) { | ||
return model_id; // Return original if no parts | ||
} | ||
|
||
// Extract the filename without extension | ||
std::string filename = parts.back(); | ||
size_t last_dot_pos = filename.find_last_of('.'); | ||
if (last_dot_pos != std::string::npos) { | ||
filename = filename.substr(0, last_dot_pos); | ||
} | ||
|
||
// Convert to lowercase | ||
std::transform(filename.begin(), filename.end(), filename.begin(), | ||
[](unsigned char c) { return std::tolower(c); }); | ||
|
||
// Generate alias candidates | ||
std::vector<std::string> candidates; | ||
candidates.push_back(filename); | ||
|
||
if (parts.size() >= 2) { | ||
candidates.push_back(parts[parts.size() - 2] + ":" + filename); | ||
} | ||
|
||
if (parts.size() >= 3) { | ||
candidates.push_back(parts[parts.size() - 3] + ":" + | ||
parts[parts.size() - 2] + "/" + filename); | ||
} | ||
|
||
if (parts.size() >= 4) { | ||
candidates.push_back(parts[0] + ":" + parts[1] + "/" + | ||
parts[parts.size() - 2] + "/" + filename); | ||
} | ||
|
||
// Find the first unique candidate | ||
for (const auto& candidate : candidates) { | ||
if (IsUnique(entries, model_id, candidate)) { | ||
return candidate; | ||
} | ||
} | ||
|
||
// If all candidates are taken, append a number to the last candidate | ||
std::string base_candidate = candidates.back(); | ||
int suffix = 1; | ||
std::string unique_candidate = base_candidate; | ||
while (!IsUnique(entries, model_id, unique_candidate)) { | ||
unique_candidate = base_candidate + "-" + std::to_string(suffix++); | ||
} | ||
|
||
return unique_candidate; | ||
} | ||
|
||
ModelEntry ModelListUtils::GetModelInfo(const std::string& identifier) const { | ||
std::lock_guard<std::mutex> lock(mutex_); | ||
auto entries = LoadModelList(); | ||
auto it = std::find_if( | ||
entries.begin(), entries.end(), [&identifier](const ModelEntry& entry) { | ||
return entry.model_id == identifier || entry.model_alias == identifier; | ||
}); | ||
|
||
if (it != entries.end()) { | ||
return *it; | ||
} else { | ||
throw std::runtime_error("Model not found: " + identifier); | ||
} | ||
} | ||
|
||
void ModelListUtils::PrintModelInfo(const ModelEntry& entry) const { | ||
LOG_INFO << "Model ID: " << entry.model_id; | ||
LOG_INFO << "Author/Repo ID: " << entry.author_repo_id; | ||
LOG_INFO << "Branch Name: " << entry.branch_name; | ||
LOG_INFO << "Path to model.yaml: " << entry.path_to_model_yaml; | ||
LOG_INFO << "Model Alias: " << entry.model_alias; | ||
LOG_INFO << "Status: " | ||
<< (entry.status == ModelStatus::RUNNING ? "RUNNING" : "READY"); | ||
} | ||
|
||
bool ModelListUtils::AddModelEntry(ModelEntry new_entry, bool use_short_alias) { | ||
std::lock_guard<std::mutex> lock(mutex_); | ||
auto entries = LoadModelList(); | ||
|
||
if (IsUnique(entries, new_entry.model_id, new_entry.model_alias)) { | ||
if (use_short_alias) { | ||
new_entry.model_alias = | ||
GenerateShortenedAlias(new_entry.model_id, entries); | ||
} | ||
new_entry.status = ModelStatus::READY; // Set default status to READY | ||
entries.push_back(std::move(new_entry)); | ||
SaveModelList(entries); | ||
return true; | ||
} | ||
return false; // Entry not added due to non-uniqueness | ||
} | ||
|
||
bool ModelListUtils::UpdateModelEntry(const std::string& identifier, | ||
const ModelEntry& updated_entry) { | ||
std::lock_guard<std::mutex> lock(mutex_); | ||
auto entries = LoadModelList(); | ||
auto it = std::find_if( | ||
entries.begin(), entries.end(), [&identifier](const ModelEntry& entry) { | ||
return entry.model_id == identifier || entry.model_alias == identifier; | ||
}); | ||
|
||
if (it != entries.end()) { | ||
*it = updated_entry; | ||
SaveModelList(entries); | ||
return true; | ||
} | ||
return false; // Entry not found | ||
} | ||
|
||
bool ModelListUtils::DeleteModelEntry(const std::string& identifier) { | ||
std::lock_guard<std::mutex> lock(mutex_); | ||
auto entries = LoadModelList(); | ||
auto it = std::find_if(entries.begin(), entries.end(), | ||
[&identifier](const ModelEntry& entry) { | ||
return (entry.model_id == identifier || | ||
entry.model_alias == identifier) && | ||
entry.status == ModelStatus::READY; | ||
}); | ||
|
||
if (it != entries.end()) { | ||
entries.erase(it); | ||
SaveModelList(entries); | ||
return true; | ||
} | ||
return false; // Entry not found or not in READY state | ||
} | ||
} // namespace modellist_utils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#pragma once | ||
#include <trantor/utils/Logger.h> | ||
#include <mutex> | ||
#include <string> | ||
#include <vector> | ||
#include "logging_utils.h" | ||
namespace modellist_utils { | ||
|
||
enum class ModelStatus { READY, RUNNING }; | ||
|
||
struct ModelEntry { | ||
std::string model_id; | ||
std::string author_repo_id; | ||
std::string branch_name; | ||
std::string path_to_model_yaml; | ||
std::string model_alias; | ||
ModelStatus status; | ||
}; | ||
|
||
class ModelListUtils { | ||
|
||
private: | ||
mutable std::mutex mutex_; // For thread safety | ||
|
||
std::vector<ModelEntry> LoadModelList() const; | ||
bool IsUnique(const std::vector<ModelEntry>& entries, | ||
const std::string& model_id, | ||
const std::string& model_alias) const; | ||
void SaveModelList(const std::vector<ModelEntry>& entries) const; | ||
|
||
public: | ||
static const std::string kModelListPath; | ||
ModelListUtils() = default; | ||
std::string GenerateShortenedAlias( | ||
const std::string& model_id, | ||
const std::vector<ModelEntry>& entries) const; | ||
ModelEntry GetModelInfo(const std::string& identifier) const; | ||
void PrintModelInfo(const ModelEntry& entry) const; | ||
bool AddModelEntry(ModelEntry new_entry, bool use_short_alias = false); | ||
bool UpdateModelEntry(const std::string& identifier, | ||
const ModelEntry& updated_entry); | ||
bool DeleteModelEntry(const std::string& identifier); | ||
}; | ||
} // namespace modellist_utils |