Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support Mistral v0.3 #52

Merged
merged 4 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ struct LoadModelRequest {
int ctx_len = 2048;
int n_parallel = 1;
std::string model_path;
std::string user_prompt = "<|im_end|>\n<|im_start|>user\n";
std::string ai_prompt = "<|im_end|>\n<|im_start|>user\n";
std::string system_prompt = "<|im_end|>\n<|im_start|>user\n";
std::string user_prompt = "";
std::string ai_prompt = "";
std::string system_prompt = "";
};

inline LoadModelRequest fromJson(std::shared_ptr<Json::Value> json_body) {
Expand All @@ -19,9 +19,9 @@ inline LoadModelRequest fromJson(std::shared_ptr<Json::Value> json_body) {
request.ctx_len = json_body->get("ctx_len", 2048).asInt();
request.n_parallel = json_body->get("n_parallel", 1).asInt();
request.model_path = json_body->get("model_path", "").asString();
request.user_prompt = json_body->get("user_prompt", "<|im_end|>\n<|im_start|>user\n").asString();
request.ai_prompt = json_body->get("ai_prompt", "<|im_end|>\n<|im_start|>assistant\n").asString();
request.system_prompt = json_body->get("system_prompt", "<|im_start|>system\n").asString();
request.user_prompt = json_body->get("user_prompt", "").asString();
request.ai_prompt = json_body->get("ai_prompt", "").asString();
request.system_prompt = json_body->get("system_prompt", "").asString();
}
return request;
}
Expand Down
180 changes: 124 additions & 56 deletions cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,40 +21,66 @@
using json = nlohmann::json;
using namespace tensorrtllm;

namespace {
constexpr const int k200OK = 200;
constexpr const int k400BadRequest = 400;
constexpr const int k409Conflict = 409;
constexpr const int k500InternalServerError = 500;

// https://nvidia.github.io/TensorRT-LLM/_cpp_gen/runtime.html#generationinput-h
// stopWordsList
// 'im', '_' , 'end', '</s>', '<|im_end|>'
const std::vector<int32_t> kOpenhermesStopWords = {321, 28730, 416, 2, 32000, 3, 4, 5, -1, -1};
const std::string kOhUserPrompt = "<|im_end|>\n<|im_start|>user\n";
const std::string kOhAiPrompt = "<|im_end|>\n<|im_start|>assistant\n";
const std::string kOhSystemPrompt = "<|im_start|>system\n";
const std::unordered_map<std::string, int> kOpenhermesTemplate = {{"<|im_end|>", 32000} , {"<|im_start|>", 32001}};

// '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '</s>'
const std::vector<int32_t> kMistral_V0_3_StopWords
= {29560, 17057, 29561, 3, 29560, 29516, 17057, 29561, 4, 2, 3, 4, 8, 9, 10, -1, -1, -1, -1, -1};

enum class MistralTemplate: int32_t {
kBos = 1,
kEos = 2,
kBeginInst = 3,
kEndInst = 4
};

constexpr const int k200OK = 200;
constexpr const int k400BadRequest = 400;
constexpr const int k409Conflict = 409;
constexpr const int k500InternalServerError = 500;

// TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc...
bool IsOpenhermes(const std::string& s) {
if (s.find("mistral") != std::string::npos || s.find("Mistral") != std::string::npos) {
return false;
}
return true;
}
}
TensorrtllmEngine::~TensorrtllmEngine() {}

void RemoveId(std::vector<int>& vec, int id) {
vec.erase(std::remove(vec.begin(), vec.end(), id), vec.end());
}

bool HandleMatch(std::string const& rew_text, std::shared_ptr<InferenceState> infer_state) {
if (infer_state->IsComplete()) {
bool HandleMatch(std::string const& rew_text,
std::shared_ptr<InferenceState> infer_state,
std::function<void(Json::Value&&, Json::Value&&)> cb,
bool is_openhermes) {
if (infer_state->IsComplete(is_openhermes)) {
return false;
}
if (infer_state->stop_word_match_len == 0) {
if (rew_text.find('<') != std::string::npos) { // Found "<" anywhere in the text
if ((is_openhermes && rew_text.find('<') != std::string::npos) ||
(!is_openhermes && rew_text.find('[') != std::string::npos)) {
infer_state->stop_word_match_len++; // Move to next state
infer_state->prev_text = rew_text;
return true;
}
}
else if (rew_text == infer_state->sequence[infer_state->stop_word_match_len]) {
} else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) {
infer_state->stop_word_match_len++; // Move to next state
infer_state->prev_text = rew_text;
return true;
}
else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->sequence[0]) {
} else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) {
infer_state->stop_word_match_len = 1; // Restart from first match if sequence breaks but matches start
infer_state->prev_text = rew_text;
return true;
}
else {
} else {
infer_state->Reset();
return false; // Reset to start if sequence breaks
}
Expand All @@ -67,19 +93,21 @@ GenerationInput::TensorPtr TensorrtllmEngine::GetTensorSingleStopWordList(int st
}

GenerationInput::TensorPtr TensorrtllmEngine::GetTensorChatMLStopWordList() {
std::vector<int32_t> stop_words_tokens
= {321, 28730, 416, 2, 32000, 3, 4, 5, -1, -1}; // Extend with -1 for increased length
return gpt_session->getBufferManager().copyFrom(stop_words_tokens, ITensor::makeShape({1, 2, 5}), MemoryType::kGPU);
if(is_openhermes_) {
return gpt_session->getBufferManager().copyFrom(kOpenhermesStopWords, ITensor::makeShape({1, 2, static_cast<int>(kOpenhermesStopWords.size()/2)}), MemoryType::kGPU);
} else {
return gpt_session->getBufferManager().copyFrom(kMistral_V0_3_StopWords, ITensor::makeShape({1, 2, static_cast<int>(kMistral_V0_3_StopWords.size()/2)}), MemoryType::kGPU);
}
}

GenerationInput TensorrtllmEngine::CreateGenerationInput(std::vector<int32_t> input_ids_host) {
int input_len = input_ids_host.size();
std::vector<int32_t> input_lengths_host(batchSize, input_len);
std::vector<int32_t> input_lengths_host(batch_size_, input_len);
GenerationInput::TensorPtr input_lengths
= gpt_session->getBufferManager().copyFrom(input_lengths_host, ITensor::makeShape({batchSize}), MemoryType::kGPU);
= gpt_session->getBufferManager().copyFrom(input_lengths_host, ITensor::makeShape({batch_size_}), MemoryType::kGPU);
GenerationInput::TensorPtr input_ids = gpt_session->getBufferManager().copyFrom(
input_ids_host, ITensor::makeShape({batchSize, input_len}), MemoryType::kGPU);
GenerationInput generation_input{0, 0, input_ids, input_lengths, model_config->usePackedInput()};
input_ids_host, ITensor::makeShape({batch_size_, input_len}), MemoryType::kGPU);
GenerationInput generation_input{0, 0, input_ids, input_lengths, model_config_->usePackedInput()};
generation_input.stopWordsList = GetTensorChatMLStopWordList();

LOG_INFO << "Create generation input successfully";
Expand All @@ -102,27 +130,34 @@ void InferenceThread(
TensorrtllmEngine* self,
SamplingConfig sampling_config,
int input_len,
int outputLen) {
int outputLen, bool is_openhermes) {

// Input preparation
LOG_INFO << "Inference thread started";
GenerationInput generation_input = self->CreateGenerationInput(input_ids_host);
GenerationOutput generation_output = self->CreateGenerationOutput();

// Define the callback to stream each generated token
generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output](
generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output, is_openhermes](
GenerationOutput::TensorPtr const& output_ids, SizeType step, bool finished) {
LOG_INFO << "Generating tokenizer in thread";
// LOG_INFO << "Generating tokenizer in thread";
// Assuming the shape of output_ids tensor is (1, 1, 160), where 160 is the number of tokens
int output_length = output_ids->getShape().d[2]; // Get the length of output IDs based on the tensor shape
// Copy output IDs from GPU to host for printing
std::vector<int32_t> output_idsHost(output_length);
self->gpt_session->getBufferManager().copy(*output_ids, output_idsHost.data(), MemoryType::kCPU);
// Find the last non-zero value in the output IDs starting from the end of the input sequence
std::vector<int> output_idsHostDecode(output_idsHost.begin() + input_len, output_idsHost.end());

RemoveId(output_idsHostDecode, 0);
RemoveId(output_idsHostDecode, 32000);
RemoveId(output_idsHostDecode, 32001);
if(is_openhermes) {
for(auto const& [_, v]: kOpenhermesTemplate) {
RemoveId(output_idsHostDecode, v);
}
} else {
RemoveId(output_idsHostDecode, static_cast<int32_t>(MistralTemplate::kBeginInst));
RemoveId(output_idsHostDecode, static_cast<int32_t>(MistralTemplate::kEndInst));
}
std::string text = self->cortex_tokenizer->Decode(output_idsHostDecode);

if (infer_state->prev_pos >= 0 && infer_state->prev_pos < text.size()) {
Expand Down Expand Up @@ -192,29 +227,47 @@ bool TensorrtllmEngine::CheckModelLoaded(std::function<void(Json::Value&&, Json:

void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_body, std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
inferences::ChatCompletionRequest request = inferences::fromJson(json_body);
std::string formatted_input = pre_prompt;
std::string formatted_input = pre_prompt_;
nlohmann::json data;
// data["stream"] = completion.stream;
// data["n_predict"] = completion.max_tokens;
data["presence_penalty"] = request.presence_penalty;
Json::Value const& messages = request.messages;

// tokens for Mistral v0.3
// TODO(sang): too much hard code here, need to refactor it soon
std::vector<int32_t> tokens = {static_cast<int32_t>(MistralTemplate::kBos)};

// Format the input from user
int msg_count = 0;
for (auto const& message : messages) {
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
role = user_prompt;
role = user_prompt_;
std::string content = message["content"].asString();
formatted_input += role + content;
if(!is_openhermes_) {
auto new_tokens = cortex_tokenizer->Encode(content);
new_tokens.insert(new_tokens.begin(), static_cast<int32_t>(MistralTemplate::kBeginInst));
new_tokens.push_back(static_cast<int32_t>(MistralTemplate::kEndInst));
tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end());
}
}
else if (input_role == "assistant") {
role = ai_prompt;
role = ai_prompt_;
std::string content = message["content"].asString();
formatted_input += role + content;
if(!is_openhermes_) {
auto new_tokens = cortex_tokenizer->Encode(content);
if(msg_count == messages.size() - 1) {
new_tokens.push_back(static_cast<int32_t>(MistralTemplate::kEos));
}
tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end());
}
}
else if (input_role == "system") {
role = system_prompt;
role = system_prompt_;
std::string content = message["content"].asString();
formatted_input = role + content + formatted_input;
}
Expand All @@ -223,13 +276,21 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
std::string content = message["content"].asString();
formatted_input += role + content;
}
msg_count++;
}
formatted_input += ai_prompt;
formatted_input += ai_prompt_;
// LOG_INFO << formatted_input;
// Format the input from user

std::shared_ptr<InferenceState> infer_state = std::make_shared<InferenceState>();

std::vector<int32_t> input_ids_host = cortex_tokenizer->Encode(formatted_input);
std::vector<int32_t> input_ids_host;
if(is_openhermes_) {
input_ids_host = cortex_tokenizer->Encode(formatted_input);
} else {
input_ids_host = tokens;
}

int const input_len = input_ids_host.size();
int const outputLen = request.max_tokens - input_len;

Expand All @@ -243,23 +304,25 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
sampling_config.repetitionPenalty = std::vector{request.frequency_penalty};
// Input preparation

std::thread inference_thread(InferenceThread, infer_state, input_ids_host, callback, this, sampling_config, input_len, outputLen);
std::thread inference_thread(InferenceThread, infer_state, input_ids_host, callback, this, sampling_config, input_len, outputLen, is_openhermes_);
inference_thread.detach(); // Detach the thread to allow it to run independently

q_->runTaskInQueue([cb = std::move(callback), infer_state]() {
q_->runTaskInQueue([this, cb = std::move(callback), infer_state]() {
// std::string res_str;
LOG_INFO << "Preparing to run inference task queue...";
while (true) { // Continuously check if the queue is not empty
std::unique_lock<std::mutex> lock(infer_state->queue_mutex); // Lock the queue for exclusive access
if (!infer_state->texts_to_stream.empty()) {
std::string rew_text = infer_state->texts_to_stream.front();
// res_str += rew_text;
infer_state->texts_to_stream.pop();
if (HandleMatch(rew_text, infer_state) && rew_text != "[DONE]") {
if (HandleMatch(rew_text, infer_state, cb, is_openhermes_) && rew_text != "[DONE]") {
continue;
};

if (rew_text == "[DONE]") {
const std::string str
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", "", "stop")
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), model_id_, "", "stop")
+ "\n\n" + "data: [DONE]" + "\n\n";

infer_state->is_finished = true;
Expand All @@ -275,10 +338,10 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
break;
}
const std::string text_to_stream
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", rew_text) + "\n\n";
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), model_id_, rew_text) + "\n\n";

lock.unlock(); // Unlock as soon as possible
infer_state->prev_text = rew_text;
// std::cout << rew_text;

Json::Value resp_data;
resp_data["data"] = text_to_stream;
Expand All @@ -293,6 +356,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
lock.unlock();
}
}
// LOG_INFO << res_str;
});

LOG_INFO << "Inference completed";
Expand All @@ -302,16 +366,20 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
model::LoadModelRequest request = model::fromJson(json_body);
std::filesystem::path model_dir = request.model_path;
is_openhermes_ = IsOpenhermes(request.model_path);

int ctx_len = request.ctx_len;
this->user_prompt = request.user_prompt;
this->ai_prompt = request.ai_prompt;
this->system_prompt = request.system_prompt;
this->model_id_ = GetModelId(*json_body);
// We only support 2 models for now, it is ugly but it works :(
if(is_openhermes_) {
user_prompt_ = request.user_prompt.empty() ? kOhUserPrompt : request.user_prompt;
ai_prompt_ = request.ai_prompt.empty() ? kOhAiPrompt : request.ai_prompt;
system_prompt_ = request.system_prompt.empty() ? kOhSystemPrompt : request.system_prompt;
}
model_id_ = GetModelId(*json_body);

logger = std::make_shared<TllmLogger>();
logger->setLevel(nvinfer1::ILogger::Severity::kINFO);
initTrtLlmPlugins(logger.get());
logger_ = std::make_shared<TllmLogger>();
logger_->setLevel(nvinfer1::ILogger::Severity::kINFO);
initTrtLlmPlugins(logger_.get());

std::filesystem::path tokenizer_model_name = model_dir / "tokenizer.model";
cortex_tokenizer = std::make_unique<Tokenizer>(tokenizer_model_name.string());
Expand All @@ -320,20 +388,20 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::f
std::filesystem::path json_file_name = model_dir / "config.json";
auto json = GptJsonConfig::parse(json_file_name);
auto config = json.getModelConfig();
model_config = std::make_unique<GptModelConfig>(config);
model_config_ = std::make_unique<GptModelConfig>(config);
auto world_config = WorldConfig::mpi(1, json.getTensorParallelism(), json.getPipelineParallelism());
LOG_INFO << "Loaded config from " << json_file_name.string();
// auto dtype = model_config->getDataType();

// Currently doing fixed session config
session_config.maxBatchSize = batchSize;
session_config.maxBeamWidth = 1; // Fixed for simplicity
session_config.maxSequenceLength = ctx_len;
session_config.cudaGraphMode = true; // Fixed for simplicity
session_config_.maxBatchSize = batch_size_;
session_config_.maxBeamWidth = 1; // Fixed for simplicity
session_config_.maxSequenceLength = ctx_len;
session_config_.cudaGraphMode = true; // Fixed for simplicity

// Init gpt_session
auto model_path = model_dir / json.engineFilename(world_config, model_id_);
gpt_session = std::make_unique<GptSession>(session_config, *model_config, world_config, model_path.string(), logger);
gpt_session = std::make_unique<GptSession>(session_config_, *model_config_, world_config, model_path.string(), logger_);

model_loaded_ = true;
if (q_ == nullptr) {
Expand Down Expand Up @@ -365,8 +433,8 @@ void TensorrtllmEngine::UnloadModel(std::shared_ptr<Json::Value> json_body, std:
gpt_session.reset();
cortex_tokenizer.reset();
q_.reset();
model_config.reset();
logger.reset();
model_config_.reset();
logger_.reset();
model_loaded_ = false;

Json::Value json_resp;
Expand Down
Loading
Loading