Skip to content

Commit

Permalink
fix: add cpu_threads to model.yaml (#1845)
Browse files Browse the repository at this point in the history
Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored Jan 13, 2025
1 parent da7576d commit c508e68
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 74 deletions.
1 change: 1 addition & 0 deletions engine/cli/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,7 @@ void CommandLineParser::ModelUpdate(CLI::App* parent) {
"ngl",
"ctx_len",
"n_parallel",
"cpu_threads",
"engine",
"prompt_template",
"system_template",
Expand Down
6 changes: 6 additions & 0 deletions engine/cli/commands/model_upd_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ void ModelUpdCmd::UpdateConfig(Json::Value& data, const std::string& key,
data["n_parallel"] = static_cast<int>(f);
});
}},
{"cpu_threads",
[this](Json::Value &data, const std::string& k, const std::string& v) {
UpdateNumericField(k, v, [&data](float f) {
data["cpu_threads"] = static_cast<int>(f);
});
}},
{"tp",
[this](Json::Value &data, const std::string& k, const std::string& v) {
UpdateNumericField(k, v, [&data](float f) {
Expand Down
8 changes: 8 additions & 0 deletions engine/config/model_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ struct ModelConfig {
int ngl = std::numeric_limits<int>::quiet_NaN();
int ctx_len = std::numeric_limits<int>::quiet_NaN();
int n_parallel = 1;
int cpu_threads = -1;
std::string engine;
std::string prompt_template;
std::string system_template;
Expand Down Expand Up @@ -272,6 +273,8 @@ struct ModelConfig {
ctx_len = json["ctx_len"].asInt();
if (json.isMember("n_parallel"))
n_parallel = json["n_parallel"].asInt();
if (json.isMember("cpu_threads"))
cpu_threads = json["cpu_threads"].asInt();
if (json.isMember("engine"))
engine = json["engine"].asString();
if (json.isMember("prompt_template"))
Expand Down Expand Up @@ -362,6 +365,9 @@ struct ModelConfig {
obj["ngl"] = ngl;
obj["ctx_len"] = ctx_len;
obj["n_parallel"] = n_parallel;
if (cpu_threads > 0) {
obj["cpu_threads"] = cpu_threads;
}
obj["engine"] = engine;
obj["prompt_template"] = prompt_template;
obj["system_template"] = system_template;
Expand Down Expand Up @@ -474,6 +480,8 @@ struct ModelConfig {
format_utils::MAGENTA);
oss << format_utils::print_kv("n_parallel", std::to_string(n_parallel),
format_utils::MAGENTA);
oss << format_utils::print_kv("cpu_threads", std::to_string(cpu_threads),
format_utils::MAGENTA);
if (ngl != std::numeric_limits<int>::quiet_NaN())
oss << format_utils::print_kv("ngl", std::to_string(ngl),
format_utils::MAGENTA);
Expand Down
124 changes: 65 additions & 59 deletions engine/config/yaml_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ void YamlHandler::ModelConfigFromYaml() {
tmp.ctx_len = yaml_node_["ctx_len"].as<int>();
if (yaml_node_["n_parallel"])
tmp.n_parallel = yaml_node_["n_parallel"].as<int>();
if (yaml_node_["cpu_threads"])
tmp.cpu_threads = yaml_node_["cpu_threads"].as<int>();
if (yaml_node_["tp"])
tmp.tp = yaml_node_["tp"].as<int>();
if (yaml_node_["stream"])
Expand Down Expand Up @@ -224,6 +226,8 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
yaml_node_["ctx_len"] = model_config_.ctx_len;
if (!std::isnan(static_cast<double>(model_config_.n_parallel)))
yaml_node_["n_parallel"] = model_config_.n_parallel;
if (!std::isnan(static_cast<double>(model_config_.cpu_threads)))
yaml_node_["cpu_threads"] = model_config_.cpu_threads;
if (!std::isnan(static_cast<double>(model_config_.tp)))
yaml_node_["tp"] = model_config_.tp;
if (!std::isnan(static_cast<double>(model_config_.stream)))
Expand Down Expand Up @@ -283,110 +287,112 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
// Method to write all attributes to a YAML file
void YamlHandler::WriteYamlFile(const std::string& file_path) const {
try {
std::ofstream outFile(file_path);
if (!outFile) {
std::ofstream out_file(file_path);
if (!out_file) {
throw std::runtime_error("Failed to open output file.");
}
// Write GENERAL GGUF METADATA
outFile << "# BEGIN GENERAL GGUF METADATA\n";
outFile << format_utils::writeKeyValue(
out_file << "# BEGIN GENERAL GGUF METADATA\n";
out_file << format_utils::WriteKeyValue(
"id", yaml_node_["id"],
"Model ID unique between models (author / quantization)");
outFile << format_utils::writeKeyValue(
out_file << format_utils::WriteKeyValue(
"model", yaml_node_["model"],
"Model ID which is used for request construct - should be "
"unique between models (author / quantization)");
outFile << format_utils::writeKeyValue("name", yaml_node_["name"],
out_file << format_utils::WriteKeyValue("name", yaml_node_["name"],
"metadata.general.name");
if (yaml_node_["version"]) {
outFile << "version: " << yaml_node_["version"].as<std::string>() << "\n";
out_file << "version: " << yaml_node_["version"].as<std::string>() << "\n";
}
if (yaml_node_["files"] && yaml_node_["files"].size()) {
outFile << "files: # Can be relative OR absolute local file "
out_file << "files: # Can be relative OR absolute local file "
"path\n";
for (const auto& source : yaml_node_["files"]) {
outFile << " - " << source << "\n";
out_file << " - " << source << "\n";
}
}

outFile << "# END GENERAL GGUF METADATA\n";
outFile << "\n";
out_file << "# END GENERAL GGUF METADATA\n";
out_file << "\n";
// Write INFERENCE PARAMETERS
outFile << "# BEGIN INFERENCE PARAMETERS\n";
outFile << "# BEGIN REQUIRED\n";
out_file << "# BEGIN INFERENCE PARAMETERS\n";
out_file << "# BEGIN REQUIRED\n";
if (yaml_node_["stop"] && yaml_node_["stop"].size()) {
outFile << "stop: # tokenizer.ggml.eos_token_id\n";
out_file << "stop: # tokenizer.ggml.eos_token_id\n";
for (const auto& stop : yaml_node_["stop"]) {
outFile << " - " << stop << "\n";
out_file << " - " << stop << "\n";
}
}

outFile << "# END REQUIRED\n";
outFile << "\n";
outFile << "# BEGIN OPTIONAL\n";
outFile << format_utils::writeKeyValue("size", yaml_node_["size"]);
outFile << format_utils::writeKeyValue("stream", yaml_node_["stream"],
out_file << "# END REQUIRED\n";
out_file << "\n";
out_file << "# BEGIN OPTIONAL\n";
out_file << format_utils::WriteKeyValue("size", yaml_node_["size"]);
out_file << format_utils::WriteKeyValue("stream", yaml_node_["stream"],
"Default true?");
outFile << format_utils::writeKeyValue("top_p", yaml_node_["top_p"],
out_file << format_utils::WriteKeyValue("top_p", yaml_node_["top_p"],
"Ranges: 0 to 1");
outFile << format_utils::writeKeyValue(
out_file << format_utils::WriteKeyValue(
"temperature", yaml_node_["temperature"], "Ranges: 0 to 1");
outFile << format_utils::writeKeyValue(
out_file << format_utils::WriteKeyValue(
"frequency_penalty", yaml_node_["frequency_penalty"], "Ranges: 0 to 1");
outFile << format_utils::writeKeyValue(
out_file << format_utils::WriteKeyValue(
"presence_penalty", yaml_node_["presence_penalty"], "Ranges: 0 to 1");
outFile << format_utils::writeKeyValue(
out_file << format_utils::WriteKeyValue(
"max_tokens", yaml_node_["max_tokens"],
"Should be default to context length");
outFile << format_utils::writeKeyValue("seed", yaml_node_["seed"]);
outFile << format_utils::writeKeyValue("dynatemp_range",
out_file << format_utils::WriteKeyValue("seed", yaml_node_["seed"]);
out_file << format_utils::WriteKeyValue("dynatemp_range",
yaml_node_["dynatemp_range"]);
outFile << format_utils::writeKeyValue("dynatemp_exponent",
out_file << format_utils::WriteKeyValue("dynatemp_exponent",
yaml_node_["dynatemp_exponent"]);
outFile << format_utils::writeKeyValue("top_k", yaml_node_["top_k"]);
outFile << format_utils::writeKeyValue("min_p", yaml_node_["min_p"]);
outFile << format_utils::writeKeyValue("tfs_z", yaml_node_["tfs_z"]);
outFile << format_utils::writeKeyValue("typ_p", yaml_node_["typ_p"]);
outFile << format_utils::writeKeyValue("repeat_last_n",
out_file << format_utils::WriteKeyValue("top_k", yaml_node_["top_k"]);
out_file << format_utils::WriteKeyValue("min_p", yaml_node_["min_p"]);
out_file << format_utils::WriteKeyValue("tfs_z", yaml_node_["tfs_z"]);
out_file << format_utils::WriteKeyValue("typ_p", yaml_node_["typ_p"]);
out_file << format_utils::WriteKeyValue("repeat_last_n",
yaml_node_["repeat_last_n"]);
outFile << format_utils::writeKeyValue("repeat_penalty",
out_file << format_utils::WriteKeyValue("repeat_penalty",
yaml_node_["repeat_penalty"]);
outFile << format_utils::writeKeyValue("mirostat", yaml_node_["mirostat"]);
outFile << format_utils::writeKeyValue("mirostat_tau",
out_file << format_utils::WriteKeyValue("mirostat", yaml_node_["mirostat"]);
out_file << format_utils::WriteKeyValue("mirostat_tau",
yaml_node_["mirostat_tau"]);
outFile << format_utils::writeKeyValue("mirostat_eta",
out_file << format_utils::WriteKeyValue("mirostat_eta",
yaml_node_["mirostat_eta"]);
outFile << format_utils::writeKeyValue("penalize_nl",
out_file << format_utils::WriteKeyValue("penalize_nl",
yaml_node_["penalize_nl"]);
outFile << format_utils::writeKeyValue("ignore_eos",
out_file << format_utils::WriteKeyValue("ignore_eos",
yaml_node_["ignore_eos"]);
outFile << format_utils::writeKeyValue("n_probs", yaml_node_["n_probs"]);
outFile << format_utils::writeKeyValue("min_keep", yaml_node_["min_keep"]);
outFile << format_utils::writeKeyValue("grammar", yaml_node_["grammar"]);
outFile << "# END OPTIONAL\n";
outFile << "# END INFERENCE PARAMETERS\n";
outFile << "\n";
out_file << format_utils::WriteKeyValue("n_probs", yaml_node_["n_probs"]);
out_file << format_utils::WriteKeyValue("min_keep", yaml_node_["min_keep"]);
out_file << format_utils::WriteKeyValue("grammar", yaml_node_["grammar"]);
out_file << "# END OPTIONAL\n";
out_file << "# END INFERENCE PARAMETERS\n";
out_file << "\n";
// Write MODEL LOAD PARAMETERS
outFile << "# BEGIN MODEL LOAD PARAMETERS\n";
outFile << "# BEGIN REQUIRED\n";
outFile << format_utils::writeKeyValue("engine", yaml_node_["engine"],
out_file << "# BEGIN MODEL LOAD PARAMETERS\n";
out_file << "# BEGIN REQUIRED\n";
out_file << format_utils::WriteKeyValue("engine", yaml_node_["engine"],
"engine to run model");
outFile << "prompt_template:";
outFile << " " << yaml_node_["prompt_template"] << "\n";
outFile << "# END REQUIRED\n";
outFile << "\n";
outFile << "# BEGIN OPTIONAL\n";
outFile << format_utils::writeKeyValue(
out_file << "prompt_template:";
out_file << " " << yaml_node_["prompt_template"] << "\n";
out_file << "# END REQUIRED\n";
out_file << "\n";
out_file << "# BEGIN OPTIONAL\n";
out_file << format_utils::WriteKeyValue(
"ctx_len", yaml_node_["ctx_len"],
"llama.context_length | 0 or undefined = loaded from model");
outFile << format_utils::writeKeyValue("n_parallel",
out_file << format_utils::WriteKeyValue("n_parallel",
yaml_node_["n_parallel"]);
outFile << format_utils::writeKeyValue("ngl", yaml_node_["ngl"],
out_file << format_utils::WriteKeyValue("cpu_threads",
yaml_node_["cpu_threads"]);
out_file << format_utils::WriteKeyValue("ngl", yaml_node_["ngl"],
"Undefined = loaded from model");
outFile << "# END OPTIONAL\n";
outFile << "# END MODEL LOAD PARAMETERS\n";
out_file << "# END OPTIONAL\n";
out_file << "# END MODEL LOAD PARAMETERS\n";

outFile.close();
out_file.close();
} catch (const std::exception& e) {
std::cerr << "Error writing to file: " << e.what() << std::endl;
throw;
Expand Down
12 changes: 6 additions & 6 deletions engine/test/components/test_format_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,37 @@ TEST_F(FormatUtilsTest, WriteKeyValue) {
{
YAML::Node node;
std::string result =
format_utils::writeKeyValue("key", node["does_not_exist"]);
format_utils::WriteKeyValue("key", node["does_not_exist"]);
EXPECT_EQ(result, "");
}

{
YAML::Node node = YAML::Load("value");
std::string result = format_utils::writeKeyValue("key", node);
std::string result = format_utils::WriteKeyValue("key", node);
EXPECT_EQ(result, "key: value\n");
}

{
YAML::Node node = YAML::Load("3.14159");
std::string result = format_utils::writeKeyValue("key", node);
std::string result = format_utils::WriteKeyValue("key", node);
EXPECT_EQ(result, "key: 3.14159\n");
}

{
YAML::Node node = YAML::Load("3.000000");
std::string result = format_utils::writeKeyValue("key", node);
std::string result = format_utils::WriteKeyValue("key", node);
EXPECT_EQ(result, "key: 3\n");
}

{
YAML::Node node = YAML::Load("3.140000");
std::string result = format_utils::writeKeyValue("key", node);
std::string result = format_utils::WriteKeyValue("key", node);
EXPECT_EQ(result, "key: 3.14\n");
}

{
YAML::Node node = YAML::Load("value");
std::string result = format_utils::writeKeyValue("key", node, "comment");
std::string result = format_utils::WriteKeyValue("key", node, "comment");
EXPECT_EQ(result, "key: value # comment\n");
}
}
Expand Down
6 changes: 6 additions & 0 deletions engine/test/components/test_yaml_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ temperature: 0.7
max_tokens: 100
stream: true
n_parallel: 2
cpu_threads: 3
stop:
- "END"
files:
Expand All @@ -84,6 +85,7 @@ n_parallel: 2
EXPECT_EQ(config.max_tokens, 100);
EXPECT_TRUE(config.stream);
EXPECT_EQ(config.n_parallel, 2);
EXPECT_EQ(config.cpu_threads, 3);
EXPECT_EQ(config.stop.size(), 1);
EXPECT_EQ(config.stop[0], "END");
EXPECT_EQ(config.files.size(), 1);
Expand All @@ -104,6 +106,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) {
new_config.max_tokens = 200;
new_config.stream = false;
new_config.n_parallel = 2;
new_config.cpu_threads = 3;
new_config.stop = {"STOP", "END"};
new_config.files = {"updated_file1.gguf", "updated_file2.gguf"};

Expand All @@ -120,6 +123,7 @@ TEST_F(YamlHandlerTest, UpdateModelConfig) {
EXPECT_EQ(config.max_tokens, 200);
EXPECT_FALSE(config.stream);
EXPECT_EQ(config.n_parallel, 2);
EXPECT_EQ(config.cpu_threads, 3);
EXPECT_EQ(config.stop.size(), 2);
EXPECT_EQ(config.stop[0], "STOP");
EXPECT_EQ(config.stop[1], "END");
Expand All @@ -140,6 +144,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) {
new_config.max_tokens = 150;
new_config.stream = true;
new_config.n_parallel = 2;
new_config.cpu_threads = 3;
new_config.stop = {"HALT"};
new_config.files = {"write_test_file.gguf"};

Expand All @@ -164,6 +169,7 @@ TEST_F(YamlHandlerTest, WriteYamlFile) {
EXPECT_EQ(read_config.max_tokens, 150);
EXPECT_TRUE(read_config.stream);
EXPECT_EQ(read_config.n_parallel, 2);
EXPECT_EQ(read_config.cpu_threads, 3);
EXPECT_EQ(read_config.stop.size(), 1);
EXPECT_EQ(read_config.stop[0], "HALT");
EXPECT_EQ(read_config.files.size(), 1);
Expand Down
18 changes: 9 additions & 9 deletions engine/utils/format_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ inline std::string print_float(const std::string& key, float value) {
} else
return "";
};
inline std::string writeKeyValue(const std::string& key,
inline std::string WriteKeyValue(const std::string& key,
const YAML::Node& value,
const std::string& comment = "") {
std::ostringstream outFile;
std::ostringstream out_file;
if (!value)
return "";
outFile << key << ": ";
out_file << key << ": ";

// Check if the value is a float and round it to 6 decimal places
if (value.IsScalar()) {
Expand All @@ -66,19 +66,19 @@ inline std::string writeKeyValue(const std::string& key,
if (strValue.back() == '.') {
strValue.pop_back();
}
outFile << strValue;
out_file << strValue;
} catch (const std::exception& e) {
outFile << value; // If not a float, write as is
out_file << value; // If not a float, write as is
}
} else {
outFile << value;
out_file << value;
}

if (!comment.empty()) {
outFile << " # " << comment;
out_file << " # " << comment;
}
outFile << "\n";
return outFile.str();
out_file << "\n";
return out_file.str();
};

inline std::string BytesToHumanReadable(uint64_t bytes) {
Expand Down

0 comments on commit c508e68

Please sign in to comment.