Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interactive mode #61

Merged
merged 4 commits into from
Mar 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,29 @@ The number of files generated for each model is as follows:

When running the larger models, make sure you have enough disk space to store all the intermediate files.

### Interactive mode

If you want a more ChatGPT-like experience, you can run in interactive mode by passing `-i` as a parameter.
In this mode, you can always interrupt generation by pressing Ctrl+C and enter one or more lines of text which will be converted into tokens and appended to the current context. You can also specify a *reverse prompt* with the parameter `-r "reverse prompt string"`. This will result in user input being prompted whenever the exact tokens of the reverse prompt string are encountered in the generation. A typical use is to use a prompt which makes LLaMa emulate a chat between multiple users, say Alice and Bob, and pass `-r "Alice:"`.

Here is an example few-shot interaction, invoked with the command
```
./main -m ./models/13B/ggml-model-q4_0.bin -t 8 --repeat_penalty 1.2 --temp 0.9 --top_p 0.9 -n 256 \
--color -i -r "User:" \
-p \
"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.

User: Hello, Bob.
Bob: Hello. How may I help you today?
User: Please tell me the largest city in Europe.
Bob: Sure. The largest city in Europe is London, the capital of the United Kingdom.
User:"
```
Note the use of `--color` to distinguish between user input and generated text.

![image](https://user-images.githubusercontent.com/401380/224572787-d418782f-47b2-49c4-a04e-65bfa7ad4ec0.png)


## Limitations

- Not sure if my tokenizer is correct. There are a few places where we might have a mistake:
Expand Down
137 changes: 127 additions & 10 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@
#include <string>
#include <vector>

#include <signal.h>
#include <unistd.h>

#define ANSI_COLOR_RED "\x1b[31m"
#define ANSI_COLOR_GREEN "\x1b[32m"
#define ANSI_COLOR_YELLOW "\x1b[33m"
#define ANSI_COLOR_BLUE "\x1b[34m"
#define ANSI_COLOR_MAGENTA "\x1b[35m"
#define ANSI_COLOR_CYAN "\x1b[36m"
#define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_BOLD "\x1b[1m"

// determine number of model parts based on the dimension
static const std::map<int, int> LLAMA_N_PARTS = {
{ 4096, 1 },
Expand Down Expand Up @@ -733,6 +745,18 @@ bool llama_eval(
return true;
}

static bool is_interacting = false;

void sigint_handler(int signo) {
if (signo == SIGINT) {
if (!is_interacting) {
is_interacting=true;
} else {
_exit(130);
}
}
}

int main(int argc, char ** argv) {
ggml_time_init();
const int64_t t_main_start_us = ggml_time_us();
Expand Down Expand Up @@ -787,13 +811,34 @@ int main(int argc, char ** argv) {

params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());

// tokenize the reverse prompt
std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false);

printf("\n");
printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
for (int i = 0; i < (int) embd_inp.size(); i++) {
printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
}
printf("\n");
if (params.interactive) {
struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);

printf("%s: interactive mode on.\n", __func__);

if(antiprompt_inp.size()) {
printf("%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str());
printf("%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
printf("%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
}
printf("\n");
}
}
printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
printf("\n\n");

Expand All @@ -807,7 +852,28 @@ int main(int argc, char ** argv) {
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);

for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {

if (params.interactive) {
printf("== Running in interactive mode. ==\n"
" - Press Ctrl+C to interject at any time.\n"
" - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n");
}

int remaining_tokens = params.n_predict;
int input_consumed = 0;
bool input_noecho = false;

// prompt user immediately after the starting prompt has been loaded
if (params.interactive_start) {
is_interacting = true;
}

if (params.use_color) {
printf(ANSI_COLOR_YELLOW);
}

while (remaining_tokens > 0) {
// predict
if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us();
Expand All @@ -823,8 +889,8 @@ int main(int argc, char ** argv) {
n_past += embd.size();
embd.clear();

if (i >= embd_inp.size()) {
// sample next token
if (embd_inp.size() <= input_consumed) {
// out of input, sample next token
const float top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
Expand All @@ -847,24 +913,74 @@ int main(int argc, char ** argv) {

// add it to the context
embd.push_back(id);

// echo this to console
input_noecho = false;

// decrement remaining sampling budget
--remaining_tokens;
} else {
// if here, it means we are still processing the input prompt
for (int k = i; k < embd_inp.size(); k++) {
embd.push_back(embd_inp[k]);
while (embd_inp.size() > input_consumed) {
embd.push_back(embd_inp[input_consumed]);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[k]);
last_n_tokens.push_back(embd_inp[input_consumed]);
++input_consumed;
if (embd.size() > params.n_batch) {
break;
}
}
i += embd.size() - 1;

if (params.use_color && embd_inp.size() <= input_consumed) {
printf(ANSI_COLOR_RESET);
}
}

// display text
for (auto id : embd) {
printf("%s", vocab.id_to_token[id].c_str());
if (!input_noecho) {
for (auto id : embd) {
printf("%s", vocab.id_to_token[id].c_str());
}
fflush(stdout);
}

// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && embd_inp.size() <= input_consumed) {
// check for reverse prompt
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
// reverse prompt found
is_interacting = true;
}
if (is_interacting) {
// currently being interactive
bool another_line=true;
while (another_line) {
char buf[256] = {0};
int n_read;
if(params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
scanf("%255[^\n]%n%*c", buf, &n_read);
if(params.use_color) printf(ANSI_COLOR_RESET);

if (n_read > 0 && buf[n_read-1]=='\\') {
another_line = true;
buf[n_read-1] = '\n';
buf[n_read] = 0;
} else {
another_line = false;
buf[n_read] = '\n';
buf[n_read+1] = 0;
}

std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buf, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());

input_noecho = true; // do not echo this again
}

is_interacting = false;
}
}
fflush(stdout);

// end of text token
if (embd.back() == 2) {
Expand All @@ -873,6 +989,7 @@ int main(int argc, char ** argv) {
}
}


// report timing
{
const int64_t t_main_end_us = ggml_time_us();
Expand Down
14 changes: 14 additions & 0 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.n_batch = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "--interactive-start") {
params.interactive = true;
params.interactive_start = true;
} else if (arg == "--color") {
params.use_color = true;
} else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt = argv[++i];
} else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params);
exit(0);
Expand All @@ -67,6 +76,11 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -i, --interactive run in interactive mode\n");
fprintf(stderr, " --interactive-start run in interactive mode and poll user input at startup\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
Expand Down
6 changes: 6 additions & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ struct gpt_params {

std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt;

bool use_color = false; // use color to distinguish generations and inputs

bool interactive = false; // interactive mode
bool interactive_start = false; // reverse prompt immediately
std::string antiprompt = ""; // string upon seeing which more user input is prompted
};

bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
Expand Down