From 9132124e08cbbaac4b6c69dec64902cc7563bf83 Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Sat, 18 Mar 2023 12:20:20 -0300 Subject: [PATCH] Remove direct access to std streams from llama_main The goal is to allow running llama_main while connected to other streams, such as TCP sockets. Signed-off-by: Thiago Padilha --- llama.cpp | 69 +++++++++++++++++++++++++++++-------------------------- llama.h | 5 +++- main.cpp | 2 +- 3 files changed, 41 insertions(+), 35 deletions(-) diff --git a/llama.cpp b/llama.cpp index 35ec3e1401dce1..05e37a0d6e0449 100644 --- a/llama.cpp +++ b/llama.cpp @@ -718,13 +718,16 @@ int llama_main( gpt_vocab vocab, llama_model model, int64_t t_load_us, - int64_t t_main_start_us) { + int64_t t_main_start_us, + FILE *instream, + FILE *outstream, + FILE *errstream) { if (params.seed < 0) { params.seed = time(NULL); } - fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); + fprintf(errstream, "%s: seed = %d\n", __func__, params.seed); std::mt19937 rng(params.seed); if (params.prompt.empty()) { @@ -751,13 +754,13 @@ int llama_main( // tokenize the reverse prompt std::vector antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false); - fprintf(stderr, "\n"); - fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + fprintf(errstream, "\n"); + fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); for (int i = 0; i < (int) embd_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str()); + fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str()); } - fprintf(stderr, "\n"); + fprintf(errstream, "\n"); if (params.interactive) { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; @@ -769,19 +772,19 @@ int llama_main( signal(SIGINT, sigint_handler); #endif - fprintf(stderr, "%s: interactive mode on.\n", __func__); + fprintf(errstream, "%s: interactive mode on.\n", __func__); if(antiprompt_inp.size()) { - fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str()); - fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); + fprintf(errstream, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str()); + fprintf(errstream, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); for (int i = 0; i < (int) antiprompt_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); + fprintf(errstream, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); } - fprintf(stderr, "\n"); + fprintf(errstream, "\n"); } } - fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); - fprintf(stderr, "\n\n"); + fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + fprintf(errstream, "\n\n"); std::vector embd; @@ -795,7 +798,7 @@ int llama_main( if (params.interactive) { - fprintf(stderr, "== Running in interactive mode. ==\n" + fprintf(errstream, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif @@ -814,7 +817,7 @@ int llama_main( // set the color for the prompt which will be output initially if (params.use_color) { - printf(ANSI_COLOR_YELLOW); + fprintf(outstream, ANSI_COLOR_YELLOW); } while (remaining_tokens > 0) { @@ -823,7 +826,7 @@ int llama_main( const int64_t t_start_us = ggml_time_us(); if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { - fprintf(stderr, "Failed to predict\n"); + fprintf(errstream, "Failed to predict\n"); return 1; } @@ -877,16 +880,16 @@ int llama_main( // reset color to default if we there is no pending user input if (!input_noecho && params.use_color && embd_inp.size() == input_consumed) { - printf(ANSI_COLOR_RESET); + fprintf(outstream, ANSI_COLOR_RESET); } } // display text if (!input_noecho) { for (auto id : embd) { - printf("%s", vocab.id_to_token[id].c_str()); + fprintf(outstream, "%s", vocab.id_to_token[id].c_str()); } - fflush(stdout); + fflush(outstream); } // in interactive mode, and not currently processing queued inputs; @@ -901,16 +904,16 @@ int llama_main( // currently being interactive bool another_line=true; while (another_line) { - fflush(stdout); + fflush(outstream); char buf[256] = {0}; int n_read; - if(params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); - if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) { + if(params.use_color) fprintf(outstream, ANSI_BOLD ANSI_COLOR_GREEN); + if (fscanf(instream, "%255[^\n]%n%*c", buf, &n_read) <= 0) { // presumable empty line, consume the newline - std::ignore = scanf("%*c"); + std::ignore = fscanf(instream, "%*c"); n_read=0; } - if(params.use_color) printf(ANSI_COLOR_RESET); + if(params.use_color) fprintf(outstream, ANSI_COLOR_RESET); if (n_read > 0 && buf[n_read-1]=='\\') { another_line = true; @@ -936,7 +939,7 @@ int llama_main( // end of text token if (embd.back() == 2) { - fprintf(stderr, " [end of text]\n"); + fprintf(errstream, " [end of text]\n"); break; } } @@ -949,18 +952,18 @@ int llama_main( { const int64_t t_main_end_us = ggml_time_us(); - fprintf(stderr, "\n\n"); - fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, mem_per_token); - fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); - fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); - fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); - fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + fprintf(errstream, "\n\n"); + fprintf(errstream, "%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + fprintf(errstream, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); + fprintf(errstream, "%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); + fprintf(errstream, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); + fprintf(errstream, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); } ggml_free(model.ctx); if (params.use_color) { - printf(ANSI_COLOR_RESET); + fprintf(outstream, ANSI_COLOR_RESET); } return 0; diff --git a/llama.h b/llama.h index 9cacb613c71f84..7c8409d1a158e0 100644 --- a/llama.h +++ b/llama.h @@ -64,5 +64,8 @@ int llama_main( gpt_vocab vocab, llama_model model, int64_t t_load_us, - int64_t t_main_start_us); + int64_t t_main_start_us, + FILE *instream, + FILE *outstream, + FILE *errstream); bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx); diff --git a/main.cpp b/main.cpp index 7106a8e1978e11..e3fc73e750a212 100644 --- a/main.cpp +++ b/main.cpp @@ -56,5 +56,5 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } - return llama_main(params, vocab, model, t_main_start_us, t_load_us); + return llama_main(params, vocab, model, t_main_start_us, t_load_us, stdin, stdout, stderr); }