Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opt class for positional argument handling #10508

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ int main(int argc, char ** argv) {

// load the model and apply lora adapter, if any
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
std::cout << params << "\n";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed.

common_init_result llama_init = common_init_from_params(params);

model = llama_init.model;
Expand Down
180 changes: 84 additions & 96 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,109 +4,106 @@
#include <unistd.h>
#endif

#include <climits>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>

#include "llama-cpp.h"

typedef std::unique_ptr<char[]> char_array_ptr;

struct Argument {
std::string flag;
std::string help_text;
};

struct Options {
std::string model_path, prompt_non_interactive;
int ngl = 99;
int n_ctx = 2048;
};
class Opt {
public:
int init_opt(int argc, const char ** argv) {
construct_help_str_();
// Parse arguments
if (parse(argc, argv)) {
fprintf(stderr, "Error: Failed to parse arguments.\n");
help();
return 1;
}

class ArgumentParser {
public:
ArgumentParser(const char * program_name) : program_name(program_name) {}
// If help is requested, show help and exit
if (help_) {
help();
return 2;
}

void add_argument(const std::string & flag, std::string & var, const std::string & help_text = "") {
string_args[flag] = &var;
arguments.push_back({flag, help_text});
return 0; // Success
}

void add_argument(const std::string & flag, int & var, const std::string & help_text = "") {
int_args[flag] = &var;
arguments.push_back({flag, help_text});
const char * model_ = nullptr;
std::string prompt_;
int context_size_ = 2048, ngl_ = 0;

private:
std::string help_str_;
bool help_ = false;

void construct_help_str_() {
help_str_ =
"Description:\n"
" Runs a llm\n"
"\n"
"Usage:\n"
" llama-run [options] MODEL [PROMPT]\n"
"\n"
"Options:\n"
" -c, --context-size <value>\n"
" Context size (default: " +
std::to_string(context_size_);
help_str_ +=
")\n"
" -n, --ngl <value>\n"
" Number of GPU layers (default: " +
std::to_string(ngl_);
help_str_ +=
")\n"
" -h, --help\n"
" Show help message\n"
"\n"
"Examples:\n"
" llama-run your_model.gguf\n"
" llama-run --ngl 99 your_model.gguf\n"
" llama-run --ngl 99 your_model.gguf Hello World\n";
}

int parse(int argc, const char ** argv) {
int positional_args_i = 0;
for (int i = 1; i < argc; ++i) {
std::string arg = argv[i];
if (string_args.count(arg)) {
if (i + 1 < argc) {
*string_args[arg] = argv[++i];
} else {
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
print_usage();
if (std::strcmp(argv[i], "-c") == 0 || std::strcmp(argv[i], "--context-size") == 0) {
if (i + 1 >= argc) {
return 1;
}
} else if (int_args.count(arg)) {
if (i + 1 < argc) {
if (parse_int_arg(argv[++i], *int_args[arg]) != 0) {
fprintf(stderr, "error: invalid value for %s: %s\n", arg.c_str(), argv[i]);
print_usage();
return 1;
}
} else {
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
print_usage();

context_size_ = std::atoi(argv[++i]);
} else if (std::strcmp(argv[i], "-n") == 0 || std::strcmp(argv[i], "--ngl") == 0) {
if (i + 1 >= argc) {
return 1;
}

ngl_ = std::atoi(argv[++i]);
} else if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) {
help_ = true;
return 0;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be return 2

} else if (!positional_args_i) {
++positional_args_i;
model_ = argv[i];
} else if (positional_args_i == 1) {
++positional_args_i;
prompt_ = argv[i];
} else {
fprintf(stderr, "error: unrecognized argument %s\n", arg.c_str());
print_usage();
return 1;
prompt_ += " " + std::string(argv[i]);
}
}

if (string_args["-m"]->empty()) {
fprintf(stderr, "error: -m is required\n");
print_usage();
return 1;
}

return 0;
}

private:
const char * program_name;
std::unordered_map<std::string, std::string *> string_args;
std::unordered_map<std::string, int *> int_args;
std::vector<Argument> arguments;

int parse_int_arg(const char * arg, int & value) {
char * end;
const long val = std::strtol(arg, &end, 10);
if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) {
value = static_cast<int>(val);
return 0;
}
return 1;
return !model_; // model_ is the only required value
}

void print_usage() const {
printf("\nUsage:\n");
printf(" %s [OPTIONS]\n\n", program_name);
printf("Options:\n");
for (const auto & arg : arguments) {
printf(" %-10s %s\n", arg.flag.c_str(), arg.help_text.c_str());
}

printf("\n");
}
void help() const { printf("%s", help_str_.c_str()); }
};

class LlamaData {
Expand All @@ -116,13 +113,13 @@ class LlamaData {
llama_context_ptr context;
std::vector<llama_chat_message> messages;

int init(const Options & opt) {
model = initialize_model(opt.model_path, opt.ngl);
int init(const Opt & opt) {
model = initialize_model(opt.model_, opt.ngl_);
if (!model) {
return 1;
}

context = initialize_context(model, opt.n_ctx);
context = initialize_context(model, opt.context_size_);
if (!context) {
return 1;
}
Expand All @@ -134,6 +131,7 @@ class LlamaData {
private:
// Initializes the model and returns a unique pointer to it
llama_model_ptr initialize_model(const std::string & model_path, const int ngl) {
ggml_backend_load_all();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caught this here @slaren

llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = ngl;

Expand Down Expand Up @@ -273,19 +271,6 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
return 0;
}

static int parse_arguments(const int argc, const char ** argv, Options & opt) {
ArgumentParser parser(argv[0]);
parser.add_argument("-m", opt.model_path, "model");
parser.add_argument("-p", opt.prompt_non_interactive, "prompt");
parser.add_argument("-c", opt.n_ctx, "context_size");
parser.add_argument("-ngl", opt.ngl, "n_gpu_layers");
if (parser.parse(argc, argv)) {
return 1;
}

return 0;
}

static int read_user_input(std::string & user) {
std::getline(std::cin, user);
return user.empty(); // Indicate an error or empty input
Expand Down Expand Up @@ -382,17 +367,20 @@ static std::string read_pipe_data() {
}

int main(int argc, const char ** argv) {
Options opt;
if (parse_arguments(argc, argv, opt)) {
Opt opt;
const int opt_ret = opt.init_opt(argc, argv);
if (opt_ret == 2) {
return 0;
} else if (opt_ret) {
return 1;
}

if (!is_stdin_a_terminal()) {
if (!opt.prompt_non_interactive.empty()) {
opt.prompt_non_interactive += "\n\n";
if (!opt.prompt_.empty()) {
opt.prompt_ += "\n\n";
}

opt.prompt_non_interactive += read_pipe_data();
opt.prompt_ += read_pipe_data();
}

llama_log_set(log_callback, nullptr);
Expand All @@ -401,7 +389,7 @@ int main(int argc, const char ** argv) {
return 1;
}

if (chat_loop(llama_data, opt.prompt_non_interactive)) {
if (chat_loop(llama_data, opt.prompt_)) {
return 1;
}

Expand Down
Loading