Skip to content
This repository was archived by the owner on Dec 18, 2024. It is now read-only.

Commit

Permalink
bump llama, add stop
Browse files Browse the repository at this point in the history
  • Loading branch information
royshil committed Nov 5, 2023
1 parent 2721e70 commit 895022e
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 41 deletions.
8 changes: 5 additions & 3 deletions cmake/BuildLlamacpp.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ if(WIN32)
ExternalProject_Get_Property(OpenBLAS INSTALL_DIR)
set(OpenBLAS_DIR ${INSTALL_DIR})
set(LLAMA_ADDITIONAL_ENV "OPENBLAS_PATH=${OpenBLAS_DIR}")
set(LLAMA_ADDITIONAL_CMAKE_ARGS -DLLAMA_BLAS=ON -DLLAMA_CUBLAS=OFF)
set(LLAMA_ADDITIONAL_CMAKE_ARGS -DLLAMA_BLAS=ON -DLLAMA_CUBLAS=OFF
-DLLAMA_BLAS_VENDOR=OpenBLAS -DBLAS_LIBRARIES=${OpenBLAS_DIR}/lib/libopenblas.lib
-DBLAS_INCLUDE_DIRS=${OpenBLAS_DIR}/include)
endif()

ExternalProject_Add(
Llamacpp_Build
DOWNLOAD_EXTRACT_TIMESTAMP true
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
GIT_TAG 370359e5baf619f3a8d461023143d1494b1e8fde
GIT_TAG f28af0d81aa1010afa5de74cf627dcb04bea3157
BUILD_COMMAND ${CMAKE_COMMAND} --build <BINARY_DIR> --config ${Llamacpp_BUILD_TYPE}
BUILD_BYPRODUCTS
<INSTALL_DIR>/lib/static/${CMAKE_STATIC_LIBRARY_PREFIX}llama${CMAKE_STATIC_LIBRARY_SUFFIX}
Expand All @@ -76,7 +78,7 @@ else()
Llamacpp_Build
DOWNLOAD_EXTRACT_TIMESTAMP true
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
GIT_TAG 370359e5baf619f3a8d461023143d1494b1e8fde
GIT_TAG f28af0d81aa1010afa5de74cf627dcb04bea3157
BUILD_COMMAND ${CMAKE_COMMAND} --build <BINARY_DIR> --config ${Llamacpp_BUILD_TYPE}
BUILD_BYPRODUCTS <INSTALL_DIR>/lib/static/${CMAKE_STATIC_LIBRARY_PREFIX}llama${CMAKE_STATIC_LIBRARY_SUFFIX}
CMAKE_GENERATOR ${CMAKE_GENERATOR}
Expand Down
1 change: 1 addition & 0 deletions src/llm-dock/LLMSettingsDialog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ LLMSettingsDialog::LLMSettingsDialog(QWidget *parent) : QDialog(parent), ui(new
global_llm_config.cloud_api_key = this->ui->apiKey->text().toStdString();
global_llm_config.cloud_model_name = this->ui->apiModel->text().toStdString();
global_llm_config.system_prompt = this->ui->sysPrompt->toPlainText().toStdString();
global_llm_config.end_sequence = this->ui->endSeq->text().toStdString();
global_llm_config.max_output_tokens = this->ui->maxTokens->text().toUShort();
global_llm_config.temperature = this->ui->temperature->text().toFloat();

Expand Down
76 changes: 46 additions & 30 deletions src/llm-dock/llama-inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,47 @@ std::string get_system_info(const llama_context_params &params)
return os.str();
}

void llama_batch_clear(struct llama_batch &batch)
{
batch.n_tokens = 0;
}

void llama_batch_add(struct llama_batch &batch, llama_token id, llama_pos pos,
const std::vector<llama_seq_id> &seq_ids, bool logits)
{
batch.token[batch.n_tokens] = id;
batch.pos[batch.n_tokens] = pos, batch.n_seq_id[batch.n_tokens] = (int32_t)seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits[batch.n_tokens] = logits;

batch.n_tokens++;
}

std::vector<llama_token> llama_tokenize(const struct llama_model *model, const std::string &text,
bool add_bos)
bool add_bos, bool special = false)
{
// upper limit for the number of tokens
uint64_t n_tokens = text.length() + (uint64_t)add_bos;
int n_tokens = (int)text.length() + (add_bos ? 1 : 0);
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), (int)text.length(), result.data(),
(int)result.size(), add_bos);
(int)result.size(), add_bos, special);
if (n_tokens < 0) {
result.resize(std::abs((int)n_tokens));
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), (int)text.length(), result.data(),
(int)result.size(), add_bos);
GGML_ASSERT(check == (int)std::abs((int)n_tokens));
(int)result.size(), add_bos, special);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return result;
}

std::vector<llama_token> llama_tokenize(const struct llama_context *ctx, const std::string &text,
bool add_bos)
bool add_bos, bool special = false)
{
return ::llama_tokenize(llama_get_model(ctx), text, add_bos);
return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
}

std::string llama_token_to_piece(const struct llama_context *ctx, llama_token token)
Expand Down Expand Up @@ -118,15 +136,15 @@ struct llama_context *llama_init_context(const std::string &model_file_path)
obs_log(LOG_INFO, "warming up the model with an empty run");

std::vector<llama_token> tokens_list = {
llama_token_bos(ctx_llama),
llama_token_eos(ctx_llama),
llama_token_bos(llama_get_model(ctx_llama)),
llama_token_eos(llama_get_model(ctx_llama)),
};

llama_decode(ctx_llama, llama_batch_get_one(tokens_list.data(),
(int)std::min(tokens_list.size(),
(size_t)lparams.n_batch),
0, 0));
llama_kv_cache_tokens_rm(ctx_llama, -1, -1);
llama_kv_cache_clear(ctx_llama);
llama_reset_timings(ctx_llama);

obs_log(LOG_INFO, "warmed up the model");
Expand All @@ -137,7 +155,8 @@ struct llama_context *llama_init_context(const std::string &model_file_path)
}

std::string llama_inference(const std::string &promptIn, struct llama_context *ctx,
std::function<void(const std::string &)> partial_generation_callback)
std::function<void(const std::string &)> partial_generation_callback,
std::function<bool(const std::string &)> should_stop_callback)
{
std::string output = "";

Expand Down Expand Up @@ -170,16 +189,11 @@ std::string llama_inference(const std::string &promptIn, struct llama_context *c
// create a llama_batch with size 512
// we use this object to submit token data for decoding

llama_batch batch = llama_batch_init(512, 0);
llama_batch batch = llama_batch_init(512, 0, 1);

// evaluate the initial prompt
batch.n_tokens = (int)tokens_list.size();

for (int32_t i = 0; i < batch.n_tokens; i++) {
batch.token[i] = tokens_list[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
for (size_t i = 0; i < tokens_list.size(); i++) {
llama_batch_add(batch, tokens_list[i], (llama_pos)i, {0}, false);
}

// llama_decode will output logits only for the last token of the prompt
Expand Down Expand Up @@ -219,22 +233,20 @@ std::string llama_inference(const std::string &promptIn, struct llama_context *c
llama_sample_token_greedy(ctx, &candidates_p);

// is it an end of stream?
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
if (new_token_id == llama_token_eos(llama_get_model(ctx)) ||
n_cur == n_len) {
break;
}

partial_generation_callback(llama_token_to_piece(ctx, new_token_id));
std::string piece = llama_token_to_piece(ctx, new_token_id);
partial_generation_callback(piece);
output += piece;

// prepare the next batch
batch.n_tokens = 0;
llama_batch_clear(batch);

// push this new token for next evaluation
batch.token[batch.n_tokens] = new_token_id;
batch.pos[batch.n_tokens] = n_cur;
batch.seq_id[batch.n_tokens] = 0;
batch.logits[batch.n_tokens] = true;

batch.n_tokens += 1;
llama_batch_add(batch, new_token_id, n_cur, {0}, true);

n_decode += 1;
}
Expand All @@ -246,10 +258,14 @@ std::string llama_inference(const std::string &promptIn, struct llama_context *c
obs_log(LOG_ERROR, "%s : failed to eval, return code %d", __func__, 1);
return "";
}

if (should_stop_callback(output)) {
break;
}
}

// reset the KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);
llama_reset_timings(ctx);

const auto t_main_end = ggml_time_us();
Expand Down
3 changes: 2 additions & 1 deletion src/llm-dock/llama-inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
struct llama_context *llama_init_context(const std::string &model_file_path);

std::string llama_inference(const std::string &prompt, struct llama_context *ctx,
std::function<void(const std::string &)> partial_generation_callback);
std::function<void(const std::string &)> partial_generation_callback,
std::function<bool(const std::string &)> should_stop_callback);
3 changes: 3 additions & 0 deletions src/llm-dock/llm-config-data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Don't include harmful, unethical, racist, sexist, toxic, dangerous, socially bia
global_llm_config.temperature = 0.9f;
global_llm_config.max_output_tokens = 64;
global_llm_config.system_prompt = LLAMA_DEFAULT_SYSTEM_PROMPT;
global_llm_config.end_sequence = "";
global_llm_config.workflows = {};
}

Expand Down Expand Up @@ -117,6 +118,7 @@ std::string llm_config_data_to_json(const llm_config_data &data)
j["temperature"] = data.temperature;
j["max_output_tokens"] = data.max_output_tokens;
j["system_prompt"] = data.system_prompt;
j["end_sequence"] = data.end_sequence;
j["workflows"] = data.workflows;
return j.dump();
}
Expand All @@ -133,6 +135,7 @@ llm_config_data llm_config_data_from_json(const std::string &json)
data.temperature = j["temperature"];
data.max_output_tokens = j["max_output_tokens"];
data.system_prompt = j["system_prompt"];
data.end_sequence = j.value("end_sequence", "");
data.workflows = j["workflows"];
return data;
}
4 changes: 4 additions & 0 deletions src/llm-dock/llm-config-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <util/config-file.h>

#include <string>
#include <vector>

struct llm_config_data {
// local or cloud
Expand All @@ -27,6 +28,9 @@ struct llm_config_data {
// system prompt
std::string system_prompt;

// end sequence
std::string end_sequence;

// workflows
std::vector<std::string> workflows;
};
Expand Down
22 changes: 21 additions & 1 deletion src/llm-dock/llm-dock-ui.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ LLMDockWidgetUI::LLMDockWidgetUI(QWidget *parent) : QDockWidget(parent), ui(new

this->connect(this->ui->generate, &QPushButton::clicked, this, &LLMDockWidgetUI::generate);
this->connect(this->ui->clear, &QPushButton::clicked, this, &LLMDockWidgetUI::clear);
this->connect(this->ui->stop, &QPushButton::clicked, this, &LLMDockWidgetUI::stop);
this->connect(this, &LLMDockWidgetUI::update_text_signal, this,
&LLMDockWidgetUI::update_text);
// connect workflows
Expand All @@ -92,6 +93,7 @@ void LLMDockWidgetUI::generate()
if (input_text.isEmpty()) {
return;
}
this->stop_flag = false;

this->ui->generated->insertHtml(
QString("<p style=\"color:#ffffff;\">%1</p><br/>").arg(input_text));
Expand All @@ -107,7 +109,20 @@ void LLMDockWidgetUI::generate()
[this](const std::string &partial_generation) {
emit update_text_signal(QString::fromStdString(partial_generation),
true);
});
},
[this](const std::string &generation) {
// check if the stop button was pressed or the generation ends with the end sequence
if (this->stop_flag) {
return true;
}
if (!global_llm_config.end_sequence.empty()) {
std::regex end_sequence_regex(global_llm_config.end_sequence);
if (std::regex_search(generation, end_sequence_regex)) {
return true;
}
}
return false;
});
emit update_text_signal(QString("<br/>"), true);
});
t.detach();
Expand All @@ -119,6 +134,11 @@ void LLMDockWidgetUI::clear()
this->ui->generated->clear();
}

void LLMDockWidgetUI::stop()
{
this->stop_flag = true;
}

void LLMDockWidgetUI::update_text(const QString &text, bool partial_generation)
{
if (partial_generation) {
Expand Down
2 changes: 2 additions & 0 deletions src/llm-dock/llm-dock-ui.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ class LLMDockWidgetUI : public QDockWidget {
public slots:
void generate();
void clear();
void stop();
void update_text(const QString &text, bool partial_generation);

signals:
void update_text_signal(const QString &text, bool partial_generation);

private:
Ui::BrainDock *ui;
bool stop_flag = false;
};

#endif // LLMDOCKWIDGETUI_HPP
7 changes: 7 additions & 0 deletions src/llm-dock/ui/dockwidget.ui
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@
</property>
</widget>
</item>
<item>
<widget class="QPushButton" name="stop">
<property name="text">
<string>Stop</string>
</property>
</widget>
</item>
</layout>
</widget>
</item>
Expand Down
26 changes: 20 additions & 6 deletions src/llm-dock/ui/settingsdialog.ui
Original file line number Diff line number Diff line change
Expand Up @@ -81,40 +81,54 @@
<item row="2" column="0">
<widget class="QLabel" name="label_2">
<property name="text">
<string>System Prompt</string>
<string>Prompt Templ.</string>
</property>
</widget>
</item>
<item row="3" column="0">
<item row="4" column="0">
<widget class="QLabel" name="label_3">
<property name="text">
<string>Max. Tokens</string>
</property>
</widget>
</item>
<item row="3" column="1">
<item row="4" column="1">
<widget class="QLineEdit" name="maxTokens">
<property name="text">
<string>64</string>
</property>
</widget>
</item>
<item row="4" column="0">
<item row="5" column="0">
<widget class="QLabel" name="label_4">
<property name="text">
<string>Temperature</string>
</property>
</widget>
</item>
<item row="4" column="1">
<item row="5" column="1">
<widget class="QLineEdit" name="temperature">
<property name="text">
<string>0.9</string>
</property>
</widget>
</item>
<item row="2" column="1">
<widget class="QTextEdit" name="sysPrompt"/>
<widget class="QTextEdit" name="sysPrompt">
<property name="placeholderText">
<string>&lt;|im_start|&gt;system...</string>
</property>
</widget>
</item>
<item row="3" column="0">
<widget class="QLabel" name="label_8">
<property name="text">
<string>End sequence</string>
</property>
</widget>
</item>
<item row="3" column="1">
<widget class="QLineEdit" name="endSeq"/>
</item>
</layout>
</widget>
Expand Down

0 comments on commit 895022e

Please sign in to comment.