From d09d5ed640583b3504f69926c72334e97aa45b86 Mon Sep 17 00:00:00 2001 From: Bach Le Date: Fri, 7 Jul 2023 21:35:46 +0800 Subject: [PATCH 01/10] Initial implementation --- examples/common.cpp | 32 +++++-- examples/common.h | 8 +- examples/embd-input/embd-input-lib.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- examples/main/main.cpp | 114 +++++++++++++++++++++++-- examples/perplexity/perplexity.cpp | 2 +- examples/server/server.cpp | 2 +- examples/simple/simple.cpp | 2 +- 8 files changed, 148 insertions(+), 16 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 3278a064346b4..35be2b5aa465a 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -236,6 +236,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.mirostat_tau = std::stof(argv[i]); + } else if (arg == "--cfg-negative-prompt") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.cfg_negative_prompt = argv[i]; + } else if (arg == "--cfg-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.cfg_scale = std::stof(argv[i]); + } else if (arg == "--cfg-smooth-factor") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.cfg_smooth_factor = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) { invalid_param = true; @@ -468,6 +486,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n"); fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); + fprintf(stderr, " --cfg-negative-prompt PROMPT \n"); + fprintf(stderr, " negative prompt to use for guidance. (default: empty)\n"); + fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); + fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); @@ -534,7 +556,7 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s return res; } -std::tuple llama_init_from_gpt_params(const gpt_params & params) { +std::tuple llama_init_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; @@ -553,14 +575,14 @@ std::tuple llama_init_from_gpt_par llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return std::make_tuple(nullptr, nullptr); + return std::make_tuple(nullptr, nullptr, lparams); } llama_context * lctx = llama_new_context_with_model(model, lparams); if (lctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); llama_free_model(model); - return std::make_tuple(nullptr, nullptr); + return std::make_tuple(nullptr, nullptr, lparams); } if (!params.lora_adapter.empty()) { @@ -572,11 +594,11 @@ std::tuple llama_init_from_gpt_par fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); llama_free(lctx); llama_free_model(model); - return std::make_tuple(nullptr, nullptr); + return std::make_tuple(nullptr, nullptr, lparams); } } - return std::make_tuple(model, lctx); + return std::make_tuple(model, lctx, lparams); } void console_init(console_state & con_st) { diff --git a/examples/common.h b/examples/common.h index 96f2228f8677b..bed576438858b 100644 --- a/examples/common.h +++ b/examples/common.h @@ -48,6 +48,12 @@ struct gpt_params { float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate + // Classifier-Free Guidance + // https://arxiv.org/abs/2306.17806 + std::string cfg_negative_prompt; // string to help guidance + float cfg_scale = 1.f; // How strong is guidance + float cfg_smooth_factor = 1.f; // Smooth factor between old and new logits + std::string model = "models/7B/ggml-model.bin"; // model path std::string model_alias = "unknown"; // model alias std::string prompt = ""; @@ -98,7 +104,7 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s // Model utils // -std::tuple llama_init_from_gpt_params(const gpt_params & params); +std::tuple llama_init_from_gpt_params(const gpt_params & params); // // Console utils diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 5fa4942be7aaf..576ac0af02a61 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -42,7 +42,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) { g_ctx = &ctx; // load the model and apply lora adapter, if any - std::tie(model, ctx) = llama_init_from_gpt_params(params); + std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return nullptr; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 03e801c2a6d4b..7b1135e6af28f 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -41,7 +41,7 @@ int main(int argc, char ** argv) { llama_context * ctx; // load the model - std::tie(model, ctx) = llama_init_from_gpt_params(params); + std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 0f6391acba45d..65ead0a008015 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -54,6 +54,20 @@ void sigint_handler(int signo) { } #endif +void inplace_log_softmax(float* logits, int n_vocab) { + float sum = 0.f; + for (int i = 0; i < n_vocab; ++i) { + float p = expf(logits[i]); + logits[i] = p; + sum += p; + } + + for (int i = 0; i < n_vocab; ++i) { + float p = logits[i]; + logits[i] = logf(p/ sum); + } +} + int main(int argc, char ** argv) { gpt_params params; @@ -109,10 +123,16 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; + llama_context * guidance_ctx = NULL; + struct llama_context_params lparams; g_ctx = &ctx; // load the model and apply lora adapter, if any - std::tie(model, ctx) = llama_init_from_gpt_params(params); + std::tie(model, ctx, lparams) = llama_init_from_gpt_params(params); + if (params.cfg_scale > 1.f) { + guidance_ctx = llama_new_context_with_model(model, lparams); + } + if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; @@ -183,15 +203,28 @@ int main(int argc, char ** argv) { // tokenize the prompt std::vector embd_inp; - if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { - // Add a space in front of the first character to match OG llama tokenizer behavior - params.prompt.insert(0, 1, ' '); + // Add a space in front of the first character to match OG llama tokenizer behavior + params.prompt.insert(0, 1, ' '); + if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { embd_inp = ::llama_tokenize(ctx, params.prompt, true); } else { embd_inp = session_tokens; } + // Tokenize negative prompt + std::vector guidance_inp; + int guidance_offset = 0; + int original_prompt_len = 0; + if (guidance_ctx) { + params.cfg_negative_prompt.insert(0, 1, ' '); + guidance_inp = ::llama_tokenize(guidance_ctx, params.cfg_negative_prompt, true); + + std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true); + original_prompt_len = original_inp.size(); + guidance_offset = (int)guidance_inp.size() - original_prompt_len; + } + const int n_ctx = llama_n_ctx(ctx); if ((int) embd_inp.size() > n_ctx - 4) { @@ -258,6 +291,16 @@ int main(int argc, char ** argv) { for (int i = 0; i < (int) embd_inp.size(); i++) { fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); } + + if (guidance_ctx) { + fprintf(stderr, "\n"); + fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); + fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); + for (int i = 0; i < (int) guidance_inp.size(); i++) { + fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i])); + } + } + if (params.n_keep > 0) { fprintf(stderr, "%s: static prompt based on n_keep: '", __func__); for (int i = 0; i < params.n_keep; i++) { @@ -334,11 +377,13 @@ int main(int argc, char ** argv) { int n_remain = params.n_predict; int n_consumed = 0; int n_session_consumed = 0; + int guidance_n_past = 0; // the first thing we will do is to output the prompt, so set color accordingly console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; + std::vector guidance_embd; // do one empty run to warm up the model { @@ -367,11 +412,12 @@ int main(int argc, char ** argv) { // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches - if (n_past + (int) embd.size() > n_ctx) { + if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { const int n_left = n_past - params.n_keep; // always keep the first token - BOS n_past = std::max(1, params.n_keep); + guidance_n_past = std::max(1, params.n_keep + guidance_offset); // insert n_left/2 tokens at the start of embd from last_n_tokens embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); @@ -412,6 +458,48 @@ int main(int argc, char ** argv) { // evaluate tokens in batches // embd is typically prepared beforehand to fit within a batch, but not always + + if (guidance_ctx) { + int input_size = 0; + llama_token* input_buf = NULL; + + if (guidance_n_past < (int) guidance_inp.size()) { + // Guidance context should have the same data with these modifications: + // + // * Replace the initial prompt + // * Shift everything by guidance_offset + guidance_embd = guidance_inp; + if (embd.begin() + original_prompt_len < embd.end()) { + guidance_embd.insert( + guidance_embd.end(), + embd.begin() + original_prompt_len, + embd.end() + ); + } + + input_buf = guidance_embd.data(); + input_size = guidance_embd.size(); + fprintf(stderr, "\n---------------------\n"); + for (int i = 0; i < (int) guidance_embd.size(); i++) { + fprintf(stderr, "%s", llama_token_to_str(ctx, guidance_embd[i])); + } + fprintf(stderr, "\n---------------------\n"); + } else { + input_buf = embd.data(); + input_size = embd.size(); + } + + for (int i = 0; i < input_size; i += params.n_batch) { + int n_eval = std::min(input_size - i, params.n_batch); + if (llama_eval(guidance_ctx, input_buf + i, n_eval, guidance_n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + + guidance_n_past += n_eval; + } + } + for (int i = 0; i < (int) embd.size(); i += params.n_batch) { int n_eval = (int) embd.size() - i; if (n_eval > params.n_batch) { @@ -431,6 +519,7 @@ int main(int argc, char ** argv) { } embd.clear(); + guidance_embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { // out of user input, sample next token @@ -465,6 +554,21 @@ int main(int argc, char ** argv) { logits[it->first] += it->second; } + if (guidance_ctx) { + inplace_log_softmax(logits, n_vocab); + auto* guidance_logits = llama_get_logits(guidance_ctx); + inplace_log_softmax(guidance_logits, n_vocab); + + for (int i = 0; i < n_vocab; ++i) { + guidance_logits[i] = params.cfg_scale * (logits[i] - guidance_logits[i]) + guidance_logits[i]; + } + inplace_log_softmax(guidance_logits, n_vocab); + + for (int i = 0; i < n_vocab; ++i) { + logits[i] = guidance_logits[i] * params.cfg_smooth_factor + logits[i] * (1 - params.cfg_smooth_factor); + } + } + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index fd4b03cb261f6..768c2b400cfd8 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -153,7 +153,7 @@ int main(int argc, char ** argv) { llama_context * ctx; // load the model and apply lora adapter, if any - std::tie(model, ctx) = llama_init_from_gpt_params(params); + std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2cbfc0018de3a..55cf1c94d8f11 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -245,7 +245,7 @@ struct llama_server_context bool loadModel(const gpt_params ¶ms_) { params = params_; - std::tie(model, ctx) = llama_init_from_gpt_params(params); + std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); if (model == nullptr) { LOG_ERROR("unable to load model", {{"model", params_.model}}); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 2d913cebb813a..f597888656cce 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -71,7 +71,7 @@ int main(int argc, char ** argv) llama_model * model; llama_context * ctx; - std::tie(model, ctx) = llama_init_from_gpt_params( params ); + std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params( params ); if ( model == NULL ) { From 478630019bf785b16831444e54cc87ebeefb1c03 Mon Sep 17 00:00:00 2001 From: Bach Le Date: Fri, 7 Jul 2023 22:20:04 +0800 Subject: [PATCH 02/10] Remove debug print --- examples/main/main.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 65ead0a008015..88262c9204308 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -479,11 +479,11 @@ int main(int argc, char ** argv) { input_buf = guidance_embd.data(); input_size = guidance_embd.size(); - fprintf(stderr, "\n---------------------\n"); - for (int i = 0; i < (int) guidance_embd.size(); i++) { - fprintf(stderr, "%s", llama_token_to_str(ctx, guidance_embd[i])); - } - fprintf(stderr, "\n---------------------\n"); + //fprintf(stderr, "\n---------------------\n"); + //for (int i = 0; i < (int) guidance_embd.size(); i++) { + //fprintf(stderr, "%s", llama_token_to_str(ctx, guidance_embd[i])); + //} + //fprintf(stderr, "\n---------------------\n"); } else { input_buf = embd.data(); input_size = embd.size(); From 8ba5b137c8dde9bdb348839170212ddd8e167b39 Mon Sep 17 00:00:00 2001 From: Bach Le Date: Fri, 7 Jul 2023 22:25:00 +0800 Subject: [PATCH 03/10] Restore signature of llama_init_from_gpt_params --- examples/common.cpp | 16 +++++++++++----- examples/common.h | 3 ++- examples/embd-input/embd-input-lib.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- examples/main/main.cpp | 4 ++-- examples/perplexity/perplexity.cpp | 2 +- examples/server/server.cpp | 2 +- examples/simple/simple.cpp | 2 +- 8 files changed, 20 insertions(+), 13 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 35be2b5aa465a..4ee1ea79a019d 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -556,7 +556,7 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s return res; } -std::tuple llama_init_from_gpt_params(const gpt_params & params) { +struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; @@ -572,17 +572,23 @@ std::tuple llama_init_from_gpt_params(const gpt_params & params) { + auto lparams = llama_get_context_params_from_gpt_params(params); + llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return std::make_tuple(nullptr, nullptr, lparams); + return std::make_tuple(nullptr, nullptr); } llama_context * lctx = llama_new_context_with_model(model, lparams); if (lctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); llama_free_model(model); - return std::make_tuple(nullptr, nullptr, lparams); + return std::make_tuple(nullptr, nullptr); } if (!params.lora_adapter.empty()) { @@ -594,11 +600,11 @@ std::tuple llama_tokenize(struct llama_context * ctx, const std::s // Model utils // -std::tuple llama_init_from_gpt_params(const gpt_params & params); +std::tuple llama_init_from_gpt_params(const gpt_params & params); +struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params); // // Console utils diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 576ac0af02a61..5fa4942be7aaf 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -42,7 +42,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) { g_ctx = &ctx; // load the model and apply lora adapter, if any - std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); + std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return nullptr; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 7b1135e6af28f..03e801c2a6d4b 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -41,7 +41,7 @@ int main(int argc, char ** argv) { llama_context * ctx; // load the model - std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); + std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 88262c9204308..dbc384f52f5c4 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -124,12 +124,12 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; llama_context * guidance_ctx = NULL; - struct llama_context_params lparams; g_ctx = &ctx; // load the model and apply lora adapter, if any - std::tie(model, ctx, lparams) = llama_init_from_gpt_params(params); + std::tie(model, ctx) = llama_init_from_gpt_params(params); if (params.cfg_scale > 1.f) { + struct llama_context_params lparams = llama_get_context_params_from_gpt_params(params); guidance_ctx = llama_new_context_with_model(model, lparams); } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 768c2b400cfd8..fd4b03cb261f6 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -153,7 +153,7 @@ int main(int argc, char ** argv) { llama_context * ctx; // load the model and apply lora adapter, if any - std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); + std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 55cf1c94d8f11..2cbfc0018de3a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -245,7 +245,7 @@ struct llama_server_context bool loadModel(const gpt_params ¶ms_) { params = params_; - std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); + std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == nullptr) { LOG_ERROR("unable to load model", {{"model", params_.model}}); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index f597888656cce..2d913cebb813a 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -71,7 +71,7 @@ int main(int argc, char ** argv) llama_model * model; llama_context * ctx; - std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params( params ); + std::tie(model, ctx) = llama_init_from_gpt_params( params ); if ( model == NULL ) { From 8f91b52fdf3cc932c648f5d72f394d2aa7599d63 Mon Sep 17 00:00:00 2001 From: Bach Le Date: Fri, 7 Jul 2023 22:50:42 +0800 Subject: [PATCH 04/10] Free guidance context --- examples/main/main.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index dbc384f52f5c4..7e3193839f052 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -772,6 +772,7 @@ int main(int argc, char ** argv) { } llama_print_timings(ctx); + llama_free(guidance_ctx); llama_free(ctx); llama_free_model(model); From 114d4c5389f17bd5ad542d9742ed73c0bd775390 Mon Sep 17 00:00:00 2001 From: Bach Le Date: Fri, 7 Jul 2023 23:10:47 +0800 Subject: [PATCH 05/10] Make freeing of guidance_ctx conditional --- examples/main/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7e3193839f052..8d3a3a3b6c543 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -772,7 +772,7 @@ int main(int argc, char ** argv) { } llama_print_timings(ctx); - llama_free(guidance_ctx); + if (guidance_ctx) { llama_free(guidance_ctx); } llama_free(ctx); llama_free_model(model); From 422a7ffdaf643817b4e3bbb15257c53a258d3c43 Mon Sep 17 00:00:00 2001 From: Bach Le Date: Fri, 7 Jul 2023 23:45:37 +0800 Subject: [PATCH 06/10] Make Classifier-Free Guidance a sampling function --- examples/main/main.cpp | 33 +++---------------------- llama.cpp | 55 ++++++++++++++++++++++++++++++++++++++++++ llama.h | 12 +++++++++ 3 files changed, 71 insertions(+), 29 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8d3a3a3b6c543..8733d5febcf4b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -54,20 +54,6 @@ void sigint_handler(int signo) { } #endif -void inplace_log_softmax(float* logits, int n_vocab) { - float sum = 0.f; - for (int i = 0; i < n_vocab; ++i) { - float p = expf(logits[i]); - logits[i] = p; - sum += p; - } - - for (int i = 0; i < n_vocab; ++i) { - float p = logits[i]; - logits[i] = logf(p/ sum); - } -} - int main(int argc, char ** argv) { gpt_params params; @@ -554,21 +540,6 @@ int main(int argc, char ** argv) { logits[it->first] += it->second; } - if (guidance_ctx) { - inplace_log_softmax(logits, n_vocab); - auto* guidance_logits = llama_get_logits(guidance_ctx); - inplace_log_softmax(guidance_logits, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - guidance_logits[i] = params.cfg_scale * (logits[i] - guidance_logits[i]) + guidance_logits[i]; - } - inplace_log_softmax(guidance_logits, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - logits[i] = guidance_logits[i] * params.cfg_smooth_factor + logits[i] * (1 - params.cfg_smooth_factor); - } - } - std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { @@ -577,6 +548,10 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + if (guidance_ctx) { + llama_sample_context_free_guidance(ctx, &candidates_p, guidance_ctx, params.cfg_scale, params.cfg_smooth_factor); + } + // Apply penalties float nl_logit = logits[llama_token_nl()]; auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); diff --git a/llama.cpp b/llama.cpp index ee6ec0920fc9c..5a8c6cf3b5d97 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2141,6 +2141,61 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } } +template +void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) { + float sum = 0.f; + for (int i = 0; i < size; ++i) { + float& logit = logit_accessor(array[i]); + float p = expf(logit); + sum += p; + logit = p; + } + + for (int i = 0; i < size; ++i) { + float& logit = logit_accessor(array[i]); + logit = logf(logit / sum); + } +} + +void llama_sample_context_free_guidance( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_context * guidance_ctx, + float scale, + float smooth_factor) { + assert(ctx); + auto n_vocab = llama_n_vocab(ctx); + assert(n_vocab == (int)candidates->size); + assert(!candidates->sorted); + + auto logit_from_token_data = [](llama_token_data& data) -> float& { + return data.logit; + }; + + auto logit_from_float = [](float& item) -> float& { + return item; + }; + + llama_log_softmax(candidates->data, candidates->size, logit_from_token_data); + + auto* guidance_logits = llama_get_logits(guidance_ctx); + llama_log_softmax(guidance_logits, n_vocab, logit_from_float); + + for (int i = 0; i < n_vocab; ++i) { + float guidance_logit = guidance_logits[i]; + float base_logit = candidates->data[i].logit; + guidance_logits[i] = scale * (base_logit - guidance_logit) + guidance_logit; + } + + llama_log_softmax(guidance_logits, n_vocab, logit_from_float); + + for (int i = 0; i < n_vocab; ++i) { + float base_logit = candidates->data[i].logit; + float guidance_logit = guidance_logits[i]; + + candidates->data[i].logit = smooth_factor * guidance_logit + (1.f - smooth_factor) * base_logit; + } +} llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { assert(ctx); diff --git a/llama.h b/llama.h index c1e7dab9f5a9b..efac46ea824ad 100644 --- a/llama.h +++ b/llama.h @@ -307,6 +307,18 @@ extern "C" { /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. + /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. + /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits. + LLAMA_API void llama_sample_context_free_guidance( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_context * guidance_ctx, + float scale, + float smooth_factor); + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); From 66eb048470a5dbe5464770b90c830063ce656926 Mon Sep 17 00:00:00 2001 From: Bach Le Date: Fri, 7 Jul 2023 23:48:07 +0800 Subject: [PATCH 07/10] Correct typo. CFG already means context-free grammar. --- examples/main/main.cpp | 2 +- llama.cpp | 2 +- llama.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8733d5febcf4b..b53c16f513c21 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -549,7 +549,7 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; if (guidance_ctx) { - llama_sample_context_free_guidance(ctx, &candidates_p, guidance_ctx, params.cfg_scale, params.cfg_smooth_factor); + llama_sample_classifier_free_guidance(ctx, &candidates_p, guidance_ctx, params.cfg_scale, params.cfg_smooth_factor); } // Apply penalties diff --git a/llama.cpp b/llama.cpp index 5a8c6cf3b5d97..f96c9d143c0a4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2157,7 +2157,7 @@ void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) { } } -void llama_sample_context_free_guidance( +void llama_sample_classifier_free_guidance( struct llama_context * ctx, llama_token_data_array * candidates, struct llama_context * guidance_ctx, diff --git a/llama.h b/llama.h index efac46ea824ad..8789ce6ea6819 100644 --- a/llama.h +++ b/llama.h @@ -312,7 +312,7 @@ extern "C" { /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits. - LLAMA_API void llama_sample_context_free_guidance( + LLAMA_API void llama_sample_classifier_free_guidance( struct llama_context * ctx, llama_token_data_array * candidates, struct llama_context * guidance_ctx, From 8e66e59cdd994931dff59b5a14e5fe4edb6c5612 Mon Sep 17 00:00:00 2001 From: Bach Le Date: Sat, 8 Jul 2023 00:07:49 +0800 Subject: [PATCH 08/10] Record sampling time in llama_sample_classifier_free_guidance --- llama.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/llama.cpp b/llama.cpp index f96c9d143c0a4..cdfb1bbb66a21 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2163,6 +2163,8 @@ void llama_sample_classifier_free_guidance( struct llama_context * guidance_ctx, float scale, float smooth_factor) { + int64_t t_start_sample_us = t_start_sample_us = ggml_time_us(); + assert(ctx); auto n_vocab = llama_n_vocab(ctx); assert(n_vocab == (int)candidates->size); @@ -2195,6 +2197,10 @@ void llama_sample_classifier_free_guidance( candidates->data[i].logit = smooth_factor * guidance_logit + (1.f - smooth_factor) * base_logit; } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { From 325fc8814155a325bea879c92104e34e0145d7ea Mon Sep 17 00:00:00 2001 From: Bach Le Date: Sat, 8 Jul 2023 00:10:26 +0800 Subject: [PATCH 09/10] Shift all values by the max value before applying logsoftmax --- llama.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index cdfb1bbb66a21..6e0f96bf25d34 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2143,10 +2143,18 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l template void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) { + T* element = std::max_element( + array, array + size, + [&logit_accessor](T& lhs, T& rhs) { + return logit_accessor(lhs) < logit_accessor(rhs); + } + ); + + float max_l = logit_accessor(*element); float sum = 0.f; for (int i = 0; i < size; ++i) { float& logit = logit_accessor(array[i]); - float p = expf(logit); + float p = expf(logit - max_l); sum += p; logit = p; } From abf164d71ef0d2ee9180f05c7646c75b79a14f39 Mon Sep 17 00:00:00 2001 From: Bach Le Date: Mon, 10 Jul 2023 23:50:17 +0800 Subject: [PATCH 10/10] Fix styling based on review --- examples/common.cpp | 4 +-- examples/common.h | 2 +- examples/main/main.cpp | 48 +++++++++++++++++------------------ llama.cpp | 57 ++++++++++++++++-------------------------- 4 files changed, 49 insertions(+), 62 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 4ee1ea79a019d..836afebc7b1ce 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -556,7 +556,7 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s return res; } -struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params) { +struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; @@ -576,7 +576,7 @@ struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_p } std::tuple llama_init_from_gpt_params(const gpt_params & params) { - auto lparams = llama_get_context_params_from_gpt_params(params); + auto lparams = llama_context_params_from_gpt_params(params); llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); if (model == NULL) { diff --git a/examples/common.h b/examples/common.h index 00cb6888ac36c..6315df9613445 100644 --- a/examples/common.h +++ b/examples/common.h @@ -105,7 +105,7 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s // std::tuple llama_init_from_gpt_params(const gpt_params & params); -struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params); +struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); // // Console utils diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b53c16f513c21..19b5e07981e88 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -109,14 +109,14 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; - llama_context * guidance_ctx = NULL; + llama_context * ctx_guidance = NULL; g_ctx = &ctx; // load the model and apply lora adapter, if any std::tie(model, ctx) = llama_init_from_gpt_params(params); if (params.cfg_scale > 1.f) { - struct llama_context_params lparams = llama_get_context_params_from_gpt_params(params); - guidance_ctx = llama_new_context_with_model(model, lparams); + struct llama_context_params lparams = llama_context_params_from_gpt_params(params); + ctx_guidance = llama_new_context_with_model(model, lparams); } if (model == NULL) { @@ -202,9 +202,9 @@ int main(int argc, char ** argv) { std::vector guidance_inp; int guidance_offset = 0; int original_prompt_len = 0; - if (guidance_ctx) { + if (ctx_guidance) { params.cfg_negative_prompt.insert(0, 1, ' '); - guidance_inp = ::llama_tokenize(guidance_ctx, params.cfg_negative_prompt, true); + guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true); std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true); original_prompt_len = original_inp.size(); @@ -278,7 +278,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); } - if (guidance_ctx) { + if (ctx_guidance) { fprintf(stderr, "\n"); fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); @@ -363,13 +363,13 @@ int main(int argc, char ** argv) { int n_remain = params.n_predict; int n_consumed = 0; int n_session_consumed = 0; - int guidance_n_past = 0; + int n_past_guidance = 0; // the first thing we will do is to output the prompt, so set color accordingly console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; - std::vector guidance_embd; + std::vector embd_guidance; // do one empty run to warm up the model { @@ -403,7 +403,7 @@ int main(int argc, char ** argv) { // always keep the first token - BOS n_past = std::max(1, params.n_keep); - guidance_n_past = std::max(1, params.n_keep + guidance_offset); + n_past_guidance = std::max(1, params.n_keep + guidance_offset); // insert n_left/2 tokens at the start of embd from last_n_tokens embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); @@ -445,29 +445,29 @@ int main(int argc, char ** argv) { // evaluate tokens in batches // embd is typically prepared beforehand to fit within a batch, but not always - if (guidance_ctx) { + if (ctx_guidance) { int input_size = 0; llama_token* input_buf = NULL; - if (guidance_n_past < (int) guidance_inp.size()) { + if (n_past_guidance < (int) guidance_inp.size()) { // Guidance context should have the same data with these modifications: // // * Replace the initial prompt // * Shift everything by guidance_offset - guidance_embd = guidance_inp; + embd_guidance = guidance_inp; if (embd.begin() + original_prompt_len < embd.end()) { - guidance_embd.insert( - guidance_embd.end(), + embd_guidance.insert( + embd_guidance.end(), embd.begin() + original_prompt_len, embd.end() ); } - input_buf = guidance_embd.data(); - input_size = guidance_embd.size(); + input_buf = embd_guidance.data(); + input_size = embd_guidance.size(); //fprintf(stderr, "\n---------------------\n"); - //for (int i = 0; i < (int) guidance_embd.size(); i++) { - //fprintf(stderr, "%s", llama_token_to_str(ctx, guidance_embd[i])); + //for (int i = 0; i < (int) embd_guidance.size(); i++) { + //fprintf(stderr, "%s", llama_token_to_str(ctx, embd_guidance[i])); //} //fprintf(stderr, "\n---------------------\n"); } else { @@ -477,12 +477,12 @@ int main(int argc, char ** argv) { for (int i = 0; i < input_size; i += params.n_batch) { int n_eval = std::min(input_size - i, params.n_batch); - if (llama_eval(guidance_ctx, input_buf + i, n_eval, guidance_n_past, params.n_threads)) { + if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } - guidance_n_past += n_eval; + n_past_guidance += n_eval; } } @@ -505,7 +505,7 @@ int main(int argc, char ** argv) { } embd.clear(); - guidance_embd.clear(); + embd_guidance.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { // out of user input, sample next token @@ -548,8 +548,8 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - if (guidance_ctx) { - llama_sample_classifier_free_guidance(ctx, &candidates_p, guidance_ctx, params.cfg_scale, params.cfg_smooth_factor); + if (ctx_guidance) { + llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor); } // Apply penalties @@ -747,7 +747,7 @@ int main(int argc, char ** argv) { } llama_print_timings(ctx); - if (guidance_ctx) { llama_free(guidance_ctx); } + if (ctx_guidance) { llama_free(ctx_guidance); } llama_free(ctx); llama_free_model(model); diff --git a/llama.cpp b/llama.cpp index 6e0f96bf25d34..4895d8b0df052 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2141,27 +2141,17 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } } -template -void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) { - T* element = std::max_element( - array, array + size, - [&logit_accessor](T& lhs, T& rhs) { - return logit_accessor(lhs) < logit_accessor(rhs); - } - ); - - float max_l = logit_accessor(*element); +static void llama_log_softmax(float * array, size_t size) { + float max_l = *std::max_element(array, array + size); float sum = 0.f; - for (int i = 0; i < size; ++i) { - float& logit = logit_accessor(array[i]); - float p = expf(logit - max_l); + for (size_t i = 0; i < size; ++i) { + float p = expf(array[i] - max_l); sum += p; - logit = p; + array[i] = p; } - for (int i = 0; i < size; ++i) { - float& logit = logit_accessor(array[i]); - logit = logf(logit / sum); + for (size_t i = 0; i < size; ++i) { + array[i] = logf(array[i] / sum); } } @@ -2178,32 +2168,29 @@ void llama_sample_classifier_free_guidance( assert(n_vocab == (int)candidates->size); assert(!candidates->sorted); - auto logit_from_token_data = [](llama_token_data& data) -> float& { - return data.logit; - }; - - auto logit_from_float = [](float& item) -> float& { - return item; - }; - - llama_log_softmax(candidates->data, candidates->size, logit_from_token_data); + std::vector logits_base; + logits_base.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { + logits_base.push_back(candidates->data[i].logit); + } + llama_log_softmax(logits_base.data(), candidates->size); - auto* guidance_logits = llama_get_logits(guidance_ctx); - llama_log_softmax(guidance_logits, n_vocab, logit_from_float); + float* logits_guidance = llama_get_logits(guidance_ctx); + llama_log_softmax(logits_guidance, n_vocab); for (int i = 0; i < n_vocab; ++i) { - float guidance_logit = guidance_logits[i]; - float base_logit = candidates->data[i].logit; - guidance_logits[i] = scale * (base_logit - guidance_logit) + guidance_logit; + float logit_guidance = logits_guidance[i]; + float logit_base = logits_base[i]; + logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance; } - llama_log_softmax(guidance_logits, n_vocab, logit_from_float); + llama_log_softmax(logits_guidance, n_vocab); for (int i = 0; i < n_vocab; ++i) { - float base_logit = candidates->data[i].logit; - float guidance_logit = guidance_logits[i]; + float logit_base = logits_base[i]; + float logit_guidance = logits_guidance[i]; - candidates->data[i].logit = smooth_factor * guidance_logit + (1.f - smooth_factor) * base_logit; + candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base; } if (ctx) {