Skip to content

Commit

Permalink
fix: rewind text if does not match pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
sangjanai committed Jul 2, 2024
1 parent 3ac131a commit c7c8516
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
55 changes: 43 additions & 12 deletions cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,30 +79,61 @@ void RemoveId(std::vector<int>& vec, int id) {
vec.erase(std::remove(vec.begin(), vec.end(), id), vec.end());
}

bool HandleMatch(std::string const& rew_text, std::shared_ptr<InferenceState> infer_state, bool is_openhermes) {
bool HandleMatch(std::string const& rew_text,
std::shared_ptr<InferenceState> infer_state,
std::function<void(Json::Value&&, Json::Value&&)> cb,
bool is_openhermes) {
if (infer_state->IsComplete(is_openhermes)) {
infer_state->rewind_strs.clear();
return false;
}
if (infer_state->stop_word_match_len == 0) {
if ((is_openhermes && rew_text.find('<') != std::string::npos) ||
(!is_openhermes && rew_text.find('[') != std::string::npos)) {
infer_state->stop_word_match_len++; // Move to next state
infer_state->prev_text = rew_text;
infer_state->rewind_strs.push_back(rew_text);
return true;
}
}
else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) {
} else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) {
infer_state->stop_word_match_len++; // Move to next state
infer_state->prev_text = rew_text;
infer_state->rewind_strs.push_back(rew_text);
return true;
}
else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) {
} else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) {
infer_state->stop_word_match_len = 1; // Restart from first match if sequence breaks but matches start
infer_state->prev_text = rew_text;
// response cache data
for(auto const& s: infer_state->rewind_strs) {
// std::cout << s;
const std::string text_to_stream
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n";
Json::Value resp_data;
resp_data["data"] = text_to_stream;
Json::Value status;
status["is_done"] = false;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = k200OK;
cb(std::move(status), std::move(resp_data));
}
infer_state->rewind_strs.clear();
infer_state->rewind_strs.push_back(rew_text);
return true;
}
else {
} else {
infer_state->Reset();
// response cache data
for(auto const& s: infer_state->rewind_strs) {
// std::cout << s;
const std::string text_to_stream
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n";
Json::Value resp_data;
resp_data["data"] = text_to_stream;
Json::Value status;
status["is_done"] = false;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = k200OK;
cb(std::move(status), std::move(resp_data));
}
infer_state->rewind_strs.clear();
return false; // Reset to start if sequence breaks
}
return false;
Expand Down Expand Up @@ -313,7 +344,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
std::string rew_text = infer_state->texts_to_stream.front();
// res_str += rew_text;
infer_state->texts_to_stream.pop();
if (HandleMatch(rew_text, infer_state, is_openhermes_) && rew_text != "[DONE]") {
if (HandleMatch(rew_text, infer_state, cb, is_openhermes_) && rew_text != "[DONE]") {
continue;
};

Expand All @@ -338,7 +369,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), model_id_, rew_text) + "\n\n";

lock.unlock(); // Unlock as soon as possible
infer_state->prev_text = rew_text;
// std::cout << rew_text;

Json::Value resp_data;
resp_data["data"] = text_to_stream;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,17 @@ class Tokenizer {

struct InferenceState {
int prev_pos{0};
std::string prev_text;
bool is_finished;
std::queue<std::string> texts_to_stream;
std::mutex queue_mutex; // Mutex to protect access to textsToStream
size_t stop_word_match_len = 0;
std::vector<std::string> sequence_openhermes = {"<", "|", "im", "_", "end", "|", ">"};
std::vector<std::string> sequence_mistral = {"[", "INST", "]"};
int token_gen_count = 0;
std::vector<std::string> rewind_strs;

void Reset() {
stop_word_match_len = 0;
prev_text = "";
stop_word_match_len = 0;
}

bool IsComplete(bool is_openhermes) const {
Expand Down

0 comments on commit c7c8516

Please sign in to comment.