From c508e6846b914030c809572f3aeef411dda002b0 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 13 Jan 2025 08:59:46 +0700 Subject: [PATCH] fix: add cpu_threads to model.yaml (#1845) Co-authored-by: vansangpfiev --- engine/cli/command_line_parser.cc | 1 + engine/cli/commands/model_upd_cmd.cc | 6 + engine/config/model_config.h | 8 ++ engine/config/yaml_config.cc | 124 ++++++++++---------- engine/test/components/test_format_utils.cc | 12 +- engine/test/components/test_yaml_handler.cc | 6 + engine/utils/format_utils.h | 18 +-- 7 files changed, 101 insertions(+), 74 deletions(-) diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 6f8f227e6..b423a6896 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -908,6 +908,7 @@ void CommandLineParser::ModelUpdate(CLI::App* parent) { "ngl", "ctx_len", "n_parallel", + "cpu_threads", "engine", "prompt_template", "system_template", diff --git a/engine/cli/commands/model_upd_cmd.cc b/engine/cli/commands/model_upd_cmd.cc index 6534d1fbd..1572581ec 100644 --- a/engine/cli/commands/model_upd_cmd.cc +++ b/engine/cli/commands/model_upd_cmd.cc @@ -228,6 +228,12 @@ void ModelUpdCmd::UpdateConfig(Json::Value& data, const std::string& key, data["n_parallel"] = static_cast(f); }); }}, + {"cpu_threads", + [this](Json::Value &data, const std::string& k, const std::string& v) { + UpdateNumericField(k, v, [&data](float f) { + data["cpu_threads"] = static_cast(f); + }); + }}, {"tp", [this](Json::Value &data, const std::string& k, const std::string& v) { UpdateNumericField(k, v, [&data](float f) { diff --git a/engine/config/model_config.h b/engine/config/model_config.h index d8ede92f7..ea671354e 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -164,6 +164,7 @@ struct ModelConfig { int ngl = std::numeric_limits::quiet_NaN(); int ctx_len = std::numeric_limits::quiet_NaN(); int n_parallel = 1; + int cpu_threads = -1; std::string engine; std::string prompt_template; std::string system_template; @@ -272,6 +273,8 @@ struct ModelConfig { ctx_len = json["ctx_len"].asInt(); if (json.isMember("n_parallel")) n_parallel = json["n_parallel"].asInt(); + if (json.isMember("cpu_threads")) + cpu_threads = json["cpu_threads"].asInt(); if (json.isMember("engine")) engine = json["engine"].asString(); if (json.isMember("prompt_template")) @@ -362,6 +365,9 @@ struct ModelConfig { obj["ngl"] = ngl; obj["ctx_len"] = ctx_len; obj["n_parallel"] = n_parallel; + if (cpu_threads > 0) { + obj["cpu_threads"] = cpu_threads; + } obj["engine"] = engine; obj["prompt_template"] = prompt_template; obj["system_template"] = system_template; @@ -474,6 +480,8 @@ struct ModelConfig { format_utils::MAGENTA); oss << format_utils::print_kv("n_parallel", std::to_string(n_parallel), format_utils::MAGENTA); + oss << format_utils::print_kv("cpu_threads", std::to_string(cpu_threads), + format_utils::MAGENTA); if (ngl != std::numeric_limits::quiet_NaN()) oss << format_utils::print_kv("ngl", std::to_string(ngl), format_utils::MAGENTA); diff --git a/engine/config/yaml_config.cc b/engine/config/yaml_config.cc index bbe7f430c..57b2b3ecb 100644 --- a/engine/config/yaml_config.cc +++ b/engine/config/yaml_config.cc @@ -119,6 +119,8 @@ void YamlHandler::ModelConfigFromYaml() { tmp.ctx_len = yaml_node_["ctx_len"].as(); if (yaml_node_["n_parallel"]) tmp.n_parallel = yaml_node_["n_parallel"].as(); + if (yaml_node_["cpu_threads"]) + tmp.cpu_threads = yaml_node_["cpu_threads"].as(); if (yaml_node_["tp"]) tmp.tp = yaml_node_["tp"].as(); if (yaml_node_["stream"]) @@ -224,6 +226,8 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) { yaml_node_["ctx_len"] = model_config_.ctx_len; if (!std::isnan(static_cast(model_config_.n_parallel))) yaml_node_["n_parallel"] = model_config_.n_parallel; + if (!std::isnan(static_cast(model_config_.cpu_threads))) + yaml_node_["cpu_threads"] = model_config_.cpu_threads; if (!std::isnan(static_cast(model_config_.tp))) yaml_node_["tp"] = model_config_.tp; if (!std::isnan(static_cast(model_config_.stream))) @@ -283,110 +287,112 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) { // Method to write all attributes to a YAML file void YamlHandler::WriteYamlFile(const std::string& file_path) const { try { - std::ofstream outFile(file_path); - if (!outFile) { + std::ofstream out_file(file_path); + if (!out_file) { throw std::runtime_error("Failed to open output file."); } // Write GENERAL GGUF METADATA - outFile << "# BEGIN GENERAL GGUF METADATA\n"; - outFile << format_utils::writeKeyValue( + out_file << "# BEGIN GENERAL GGUF METADATA\n"; + out_file << format_utils::WriteKeyValue( "id", yaml_node_["id"], "Model ID unique between models (author / quantization)"); - outFile << format_utils::writeKeyValue( + out_file << format_utils::WriteKeyValue( "model", yaml_node_["model"], "Model ID which is used for request construct - should be " "unique between models (author / quantization)"); - outFile << format_utils::writeKeyValue("name", yaml_node_["name"], + out_file << format_utils::WriteKeyValue("name", yaml_node_["name"], "metadata.general.name"); if (yaml_node_["version"]) { - outFile << "version: " << yaml_node_["version"].as() << "\n"; + out_file << "version: " << yaml_node_["version"].as() << "\n"; } if (yaml_node_["files"] && yaml_node_["files"].size()) { - outFile << "files: # Can be relative OR absolute local file " + out_file << "files: # Can be relative OR absolute local file " "path\n"; for (const auto& source : yaml_node_["files"]) { - outFile << " - " << source << "\n"; + out_file << " - " << source << "\n"; } } - outFile << "# END GENERAL GGUF METADATA\n"; - outFile << "\n"; + out_file << "# END GENERAL GGUF METADATA\n"; + out_file << "\n"; // Write INFERENCE PARAMETERS - outFile << "# BEGIN INFERENCE PARAMETERS\n"; - outFile << "# BEGIN REQUIRED\n"; + out_file << "# BEGIN INFERENCE PARAMETERS\n"; + out_file << "# BEGIN REQUIRED\n"; if (yaml_node_["stop"] && yaml_node_["stop"].size()) { - outFile << "stop: # tokenizer.ggml.eos_token_id\n"; + out_file << "stop: # tokenizer.ggml.eos_token_id\n"; for (const auto& stop : yaml_node_["stop"]) { - outFile << " - " << stop << "\n"; + out_file << " - " << stop << "\n"; } } - outFile << "# END REQUIRED\n"; - outFile << "\n"; - outFile << "# BEGIN OPTIONAL\n"; - outFile << format_utils::writeKeyValue("size", yaml_node_["size"]); - outFile << format_utils::writeKeyValue("stream", yaml_node_["stream"], + out_file << "# END REQUIRED\n"; + out_file << "\n"; + out_file << "# BEGIN OPTIONAL\n"; + out_file << format_utils::WriteKeyValue("size", yaml_node_["size"]); + out_file << format_utils::WriteKeyValue("stream", yaml_node_["stream"], "Default true?"); - outFile << format_utils::writeKeyValue("top_p", yaml_node_["top_p"], + out_file << format_utils::WriteKeyValue("top_p", yaml_node_["top_p"], "Ranges: 0 to 1"); - outFile << format_utils::writeKeyValue( + out_file << format_utils::WriteKeyValue( "temperature", yaml_node_["temperature"], "Ranges: 0 to 1"); - outFile << format_utils::writeKeyValue( + out_file << format_utils::WriteKeyValue( "frequency_penalty", yaml_node_["frequency_penalty"], "Ranges: 0 to 1"); - outFile << format_utils::writeKeyValue( + out_file << format_utils::WriteKeyValue( "presence_penalty", yaml_node_["presence_penalty"], "Ranges: 0 to 1"); - outFile << format_utils::writeKeyValue( + out_file << format_utils::WriteKeyValue( "max_tokens", yaml_node_["max_tokens"], "Should be default to context length"); - outFile << format_utils::writeKeyValue("seed", yaml_node_["seed"]); - outFile << format_utils::writeKeyValue("dynatemp_range", + out_file << format_utils::WriteKeyValue("seed", yaml_node_["seed"]); + out_file << format_utils::WriteKeyValue("dynatemp_range", yaml_node_["dynatemp_range"]); - outFile << format_utils::writeKeyValue("dynatemp_exponent", + out_file << format_utils::WriteKeyValue("dynatemp_exponent", yaml_node_["dynatemp_exponent"]); - outFile << format_utils::writeKeyValue("top_k", yaml_node_["top_k"]); - outFile << format_utils::writeKeyValue("min_p", yaml_node_["min_p"]); - outFile << format_utils::writeKeyValue("tfs_z", yaml_node_["tfs_z"]); - outFile << format_utils::writeKeyValue("typ_p", yaml_node_["typ_p"]); - outFile << format_utils::writeKeyValue("repeat_last_n", + out_file << format_utils::WriteKeyValue("top_k", yaml_node_["top_k"]); + out_file << format_utils::WriteKeyValue("min_p", yaml_node_["min_p"]); + out_file << format_utils::WriteKeyValue("tfs_z", yaml_node_["tfs_z"]); + out_file << format_utils::WriteKeyValue("typ_p", yaml_node_["typ_p"]); + out_file << format_utils::WriteKeyValue("repeat_last_n", yaml_node_["repeat_last_n"]); - outFile << format_utils::writeKeyValue("repeat_penalty", + out_file << format_utils::WriteKeyValue("repeat_penalty", yaml_node_["repeat_penalty"]); - outFile << format_utils::writeKeyValue("mirostat", yaml_node_["mirostat"]); - outFile << format_utils::writeKeyValue("mirostat_tau", + out_file << format_utils::WriteKeyValue("mirostat", yaml_node_["mirostat"]); + out_file << format_utils::WriteKeyValue("mirostat_tau", yaml_node_["mirostat_tau"]); - outFile << format_utils::writeKeyValue("mirostat_eta", + out_file << format_utils::WriteKeyValue("mirostat_eta", yaml_node_["mirostat_eta"]); - outFile << format_utils::writeKeyValue("penalize_nl", + out_file << format_utils::WriteKeyValue("penalize_nl", yaml_node_["penalize_nl"]); - outFile << format_utils::writeKeyValue("ignore_eos", + out_file << format_utils::WriteKeyValue("ignore_eos", yaml_node_["ignore_eos"]); - outFile << format_utils::writeKeyValue("n_probs", yaml_node_["n_probs"]); - outFile << format_utils::writeKeyValue("min_keep", yaml_node_["min_keep"]); - outFile << format_utils::writeKeyValue("grammar", yaml_node_["grammar"]); - outFile << "# END OPTIONAL\n"; - outFile << "# END INFERENCE PARAMETERS\n"; - outFile << "\n"; + out_file << format_utils::WriteKeyValue("n_probs", yaml_node_["n_probs"]); + out_file << format_utils::WriteKeyValue("min_keep", yaml_node_["min_keep"]); + out_file << format_utils::WriteKeyValue("grammar", yaml_node_["grammar"]); + out_file << "# END OPTIONAL\n"; + out_file << "# END INFERENCE PARAMETERS\n"; + out_file << "\n"; // Write MODEL LOAD PARAMETERS - outFile << "# BEGIN MODEL LOAD PARAMETERS\n"; - outFile << "# BEGIN REQUIRED\n"; - outFile << format_utils::writeKeyValue("engine", yaml_node_["engine"], + out_file << "# BEGIN MODEL LOAD PARAMETERS\n"; + out_file << "# BEGIN REQUIRED\n"; + out_file << format_utils::WriteKeyValue("engine", yaml_node_["engine"], "engine to run model"); - outFile << "prompt_template:"; - outFile << " " << yaml_node_["prompt_template"] << "\n"; - outFile << "# END REQUIRED\n"; - outFile << "\n"; - outFile << "# BEGIN OPTIONAL\n"; - outFile << format_utils::writeKeyValue( + out_file << "prompt_template:"; + out_file << " " << yaml_node_["prompt_template"] << "\n"; + out_file << "# END REQUIRED\n"; + out_file << "\n"; + out_file << "# BEGIN OPTIONAL\n"; + out_file << format_utils::WriteKeyValue( "ctx_len", yaml_node_["ctx_len"], "llama.context_length | 0 or undefined = loaded from model"); - outFile << format_utils::writeKeyValue("n_parallel", + out_file << format_utils::WriteKeyValue("n_parallel", yaml_node_["n_parallel"]); - outFile << format_utils::writeKeyValue("ngl", yaml_node_["ngl"], + out_file << format_utils::WriteKeyValue("cpu_threads", + yaml_node_["cpu_threads"]); + out_file << format_utils::WriteKeyValue("ngl", yaml_node_["ngl"], "Undefined = loaded from model"); - outFile << "# END OPTIONAL\n"; - outFile << "# END MODEL LOAD PARAMETERS\n"; + out_file << "# END OPTIONAL\n"; + out_file << "# END MODEL LOAD PARAMETERS\n"; - outFile.close(); + out_file.close(); } catch (const std::exception& e) { std::cerr << "Error writing to file: " << e.what() << std::endl; throw; diff --git a/engine/test/components/test_format_utils.cc b/engine/test/components/test_format_utils.cc index cd777d5fa..d279b5940 100644 --- a/engine/test/components/test_format_utils.cc +++ b/engine/test/components/test_format_utils.cc @@ -9,37 +9,37 @@ TEST_F(FormatUtilsTest, WriteKeyValue) { { YAML::Node node; std::string result = - format_utils::writeKeyValue("key", node["does_not_exist"]); + format_utils::WriteKeyValue("key", node["does_not_exist"]); EXPECT_EQ(result, ""); } { YAML::Node node = YAML::Load("value"); - std::string result = format_utils::writeKeyValue("key", node); + std::string result = format_utils::WriteKeyValue("key", node); EXPECT_EQ(result, "key: value\n"); } { YAML::Node node = YAML::Load("3.14159"); - std::string result = format_utils::writeKeyValue("key", node); + std::string result = format_utils::WriteKeyValue("key", node); EXPECT_EQ(result, "key: 3.14159\n"); } { YAML::Node node = YAML::Load("3.000000"); - std::string result = format_utils::writeKeyValue("key", node); + std::string result = format_utils::WriteKeyValue("key", node); EXPECT_EQ(result, "key: 3\n"); } { YAML::Node node = YAML::Load("3.140000"); - std::string result = format_utils::writeKeyValue("key", node); + std::string result = format_utils::WriteKeyValue("key", node); EXPECT_EQ(result, "key: 3.14\n"); } { YAML::Node node = YAML::Load("value"); - std::string result = format_utils::writeKeyValue("key", node, "comment"); + std::string result = format_utils::WriteKeyValue("key", node, "comment"); EXPECT_EQ(result, "key: value # comment\n"); } } diff --git a/engine/test/components/test_yaml_handler.cc b/engine/test/components/test_yaml_handler.cc index f699e0c6a..c7e4b6a21 100644 --- a/engine/test/components/test_yaml_handler.cc +++ b/engine/test/components/test_yaml_handler.cc @@ -63,6 +63,7 @@ temperature: 0.7 max_tokens: 100 stream: true n_parallel: 2 +cpu_threads: 3 stop: - "END" files: @@ -84,6 +85,7 @@ n_parallel: 2 EXPECT_EQ(config.max_tokens, 100); EXPECT_TRUE(config.stream); EXPECT_EQ(config.n_parallel, 2); + EXPECT_EQ(config.cpu_threads, 3); EXPECT_EQ(config.stop.size(), 1); EXPECT_EQ(config.stop[0], "END"); EXPECT_EQ(config.files.size(), 1); @@ -104,6 +106,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) { new_config.max_tokens = 200; new_config.stream = false; new_config.n_parallel = 2; + new_config.cpu_threads = 3; new_config.stop = {"STOP", "END"}; new_config.files = {"updated_file1.gguf", "updated_file2.gguf"}; @@ -120,6 +123,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) { EXPECT_EQ(config.max_tokens, 200); EXPECT_FALSE(config.stream); EXPECT_EQ(config.n_parallel, 2); + EXPECT_EQ(config.cpu_threads, 3); EXPECT_EQ(config.stop.size(), 2); EXPECT_EQ(config.stop[0], "STOP"); EXPECT_EQ(config.stop[1], "END"); @@ -140,6 +144,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) { new_config.max_tokens = 150; new_config.stream = true; new_config.n_parallel = 2; + new_config.cpu_threads = 3; new_config.stop = {"HALT"}; new_config.files = {"write_test_file.gguf"}; @@ -164,6 +169,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) { EXPECT_EQ(read_config.max_tokens, 150); EXPECT_TRUE(read_config.stream); EXPECT_EQ(read_config.n_parallel, 2); + EXPECT_EQ(read_config.cpu_threads, 3); EXPECT_EQ(read_config.stop.size(), 1); EXPECT_EQ(read_config.stop[0], "HALT"); EXPECT_EQ(read_config.files.size(), 1); diff --git a/engine/utils/format_utils.h b/engine/utils/format_utils.h index 141866378..5dccee359 100644 --- a/engine/utils/format_utils.h +++ b/engine/utils/format_utils.h @@ -46,13 +46,13 @@ inline std::string print_float(const std::string& key, float value) { } else return ""; }; -inline std::string writeKeyValue(const std::string& key, +inline std::string WriteKeyValue(const std::string& key, const YAML::Node& value, const std::string& comment = "") { - std::ostringstream outFile; + std::ostringstream out_file; if (!value) return ""; - outFile << key << ": "; + out_file << key << ": "; // Check if the value is a float and round it to 6 decimal places if (value.IsScalar()) { @@ -66,19 +66,19 @@ inline std::string writeKeyValue(const std::string& key, if (strValue.back() == '.') { strValue.pop_back(); } - outFile << strValue; + out_file << strValue; } catch (const std::exception& e) { - outFile << value; // If not a float, write as is + out_file << value; // If not a float, write as is } } else { - outFile << value; + out_file << value; } if (!comment.empty()) { - outFile << " # " << comment; + out_file << " # " << comment; } - outFile << "\n"; - return outFile.str(); + out_file << "\n"; + return out_file.str(); }; inline std::string BytesToHumanReadable(uint64_t bytes) {