Skip to content

Commit

Permalink
Opt class for positional argument handling
Browse files Browse the repository at this point in the history
Added support for positional arguments `MODEL` and `PROMPT`.

Signed-off-by: Eric Curtin <[email protected]>
  • Loading branch information
ericcurtin committed Nov 26, 2024
1 parent 0eb4e12 commit ff6a907
Showing 1 changed file with 89 additions and 93 deletions.
182 changes: 89 additions & 93 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,96 +17,102 @@

typedef std::unique_ptr<char[]> char_array_ptr;

struct Argument {
std::string flag;
std::string help_text;
};

struct Options {
std::string model_path, prompt_non_interactive;
int ngl = 99;
int n_ctx = 2048;
};
class Opt {
public:
int init_opt(int argc, const char ** argv) {
construct_help_str_();
// Parse arguments
if (parse(argc, argv)) {
fprintf(stderr, "Error: Failed to parse arguments.\n");
help();
return 1;
}

class ArgumentParser {
public:
ArgumentParser(const char * program_name) : program_name(program_name) {}
// If help is requested, show help and exit
if (help_) {
help();
return 2;
}

void add_argument(const std::string & flag, std::string & var, const std::string & help_text = "") {
string_args[flag] = &var;
arguments.push_back({flag, help_text});
return 0; // Success
}

void add_argument(const std::string & flag, int & var, const std::string & help_text = "") {
int_args[flag] = &var;
arguments.push_back({flag, help_text});
const char * model_ = nullptr;
std::string prompt_;
int context_size_ = 2048, ngl_ = 0;

private:
std::string help_str_;
bool help_ = false;

void construct_help_str_() {
help_str_ =
"Description:\n"
" Runs a llm\n"
"\n"
"Usage:\n"
" llama-run [options] MODEL [PROMPT]\n"
"\n"
"Options:\n"
" -c, --context-size <value>\n"
" Context size (default: " +
std::to_string(context_size_);
help_str_ +=
")\n"
" -n, --ngl <value>\n"
" Number of GPU layers (default: " +
std::to_string(ngl_);
help_str_ +=
")\n"
" -h, --help\n"
" Show help message\n"
"\n"
"Examples:\n"
" llama-run ~/.local/share/ramalama/models/ollama/smollm\\:135m\n"
" llama-run --ngl 99 ~/.local/share/ramalama/models/ollama/smollm\\:135m\n"
" llama-run --ngl 99 ~/.local/share/ramalama/models/ollama/smollm\\:135m Hello World\n";
}

int parse(int argc, const char ** argv) {
if (parse_arguments(argc, argv) || !model_) {
return 1;
}

return 0;
}

int parse_arguments(int argc, const char ** argv) {
int positional_args_i = 0;
for (int i = 1; i < argc; ++i) {
std::string arg = argv[i];
if (string_args.count(arg)) {
if (i + 1 < argc) {
*string_args[arg] = argv[++i];
} else {
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
print_usage();
if (std::strcmp(argv[i], "-c") == 0 || std::strcmp(argv[i], "--context-size") == 0) {
if (i + 1 >= argc) {
return 1;
}
} else if (int_args.count(arg)) {
if (i + 1 < argc) {
if (parse_int_arg(argv[++i], *int_args[arg]) != 0) {
fprintf(stderr, "error: invalid value for %s: %s\n", arg.c_str(), argv[i]);
print_usage();
return 1;
}
} else {
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
print_usage();

context_size_ = std::atoi(argv[++i]);
} else if (std::strcmp(argv[i], "-n") == 0 || std::strcmp(argv[i], "--ngl") == 0) {
if (i + 1 >= argc) {
return 1;
}

ngl_ = std::atoi(argv[++i]);
} else if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) {
help_ = true;
} else if (!positional_args_i) {
++positional_args_i;
model_ = argv[i];
} else if (positional_args_i == 1) {
++positional_args_i;
prompt_ = argv[i];
} else {
fprintf(stderr, "error: unrecognized argument %s\n", arg.c_str());
print_usage();
return 1;
prompt_ += " " + std::string(argv[i]);
}
}

if (string_args["-m"]->empty()) {
fprintf(stderr, "error: -m is required\n");
print_usage();
return 1;
}

return 0;
}

private:
const char * program_name;
std::unordered_map<std::string, std::string *> string_args;
std::unordered_map<std::string, int *> int_args;
std::vector<Argument> arguments;

int parse_int_arg(const char * arg, int & value) {
char * end;
const long val = std::strtol(arg, &end, 10);
if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) {
value = static_cast<int>(val);
return 0;
}
return 1;
}

void print_usage() const {
printf("\nUsage:\n");
printf(" %s [OPTIONS]\n\n", program_name);
printf("Options:\n");
for (const auto & arg : arguments) {
printf(" %-10s %s\n", arg.flag.c_str(), arg.help_text.c_str());
}

printf("\n");
}
void help() const { printf("%s", help_str_.c_str()); }
};

class LlamaData {
Expand All @@ -116,13 +122,13 @@ class LlamaData {
llama_context_ptr context;
std::vector<llama_chat_message> messages;

int init(const Options & opt) {
model = initialize_model(opt.model_path, opt.ngl);
int init(const Opt & opt) {
model = initialize_model(opt.model_, opt.ngl_);
if (!model) {
return 1;
}

context = initialize_context(model, opt.n_ctx);
context = initialize_context(model, opt.context_size_);
if (!context) {
return 1;
}
Expand Down Expand Up @@ -273,19 +279,6 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
return 0;
}

static int parse_arguments(const int argc, const char ** argv, Options & opt) {
ArgumentParser parser(argv[0]);
parser.add_argument("-m", opt.model_path, "model");
parser.add_argument("-p", opt.prompt_non_interactive, "prompt");
parser.add_argument("-c", opt.n_ctx, "context_size");
parser.add_argument("-ngl", opt.ngl, "n_gpu_layers");
if (parser.parse(argc, argv)) {
return 1;
}

return 0;
}

static int read_user_input(std::string & user) {
std::getline(std::cin, user);
return user.empty(); // Indicate an error or empty input
Expand Down Expand Up @@ -382,17 +375,20 @@ static std::string read_pipe_data() {
}

int main(int argc, const char ** argv) {
Options opt;
if (parse_arguments(argc, argv, opt)) {
Opt opt;
const int opt_ret = opt.init_opt(argc, argv);
if (opt_ret == 2) {
return 0;
} else if (opt_ret) {
return 1;
}

if (!is_stdin_a_terminal()) {
if (!opt.prompt_non_interactive.empty()) {
opt.prompt_non_interactive += "\n\n";
if (!opt.prompt_.empty()) {
opt.prompt_ += "\n\n";
}

opt.prompt_non_interactive += read_pipe_data();
opt.prompt_ += read_pipe_data();
}

llama_log_set(log_callback, nullptr);
Expand All @@ -401,7 +397,7 @@ int main(int argc, const char ** argv) {
return 1;
}

if (chat_loop(llama_data, opt.prompt_non_interactive)) {
if (chat_loop(llama_data, opt.prompt_)) {
return 1;
}

Expand Down

0 comments on commit ff6a907

Please sign in to comment.