Skip to content

Commit

Permalink
fix: support batching
Browse files Browse the repository at this point in the history
  • Loading branch information
sangjanai committed Jul 31, 2024
1 parent a3a6f15 commit 9e777e3
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ int main(int argc, char** argv) {
};
resp.set_chunked_content_provider("text/event-stream",
chunked_content_provider,
[](bool) { LOG_INFO << "Done"; });
[](bool) { });
};

const auto handle_load_model = [&](const httplib::Request& req,
Expand Down
152 changes: 89 additions & 63 deletions cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ void RemoveSpecialTokens(std::vector<int32_t>& v, ModelType model_type) {
remove_id(v, static_cast<int32_t>(Llama3Template::kEndOfTurn));
break;
case ModelType::kMistral:
remove_id(v, static_cast<int32_t>(MistralTemplate::kBeginInst));
remove_id(v, static_cast<int32_t>(MistralTemplate::kEndInst));
remove_id(v, static_cast<int32_t>(MistralTemplate::kEos));
break;
default:
remove_id(v, static_cast<int32_t>(OpenhermesTemplate::kImEnd));
Expand All @@ -114,8 +113,12 @@ void RemoveSpecialTokens(std::vector<int32_t>& v, ModelType model_type) {
} // namespace
TensorrtllmEngine::~TensorrtllmEngine() {
model_loaded_ = false;
if (thread_.joinable()) {
thread_.join();
if (res_thread_.joinable()) {
res_thread_.join();
}

if (req_thread_.joinable()) {
req_thread_.join();
}
}

Expand Down Expand Up @@ -296,21 +299,26 @@ void TensorrtllmEngine::HandleChatCompletion(

runtime_opts_.streaming = true;
runtime_opts_.timeoutMs = 2000;
auto request_id = 0u;
if (executor_->canEnqueueRequests()) {
request_id = EnqueueRequest(runtime_opts_, input_ids_host);
} else {
LOG_WARN << "Could not enqueue requests";
return;
auto request_id = req_id_++;
{
std::lock_guard<std::mutex> l(req_mtx_);
reqs_.emplace(request_id, std::move(input_ids_host));
req_cv_.notify_one();
}

q_->runTaskInQueue([this, cb = std::move(callback), request_id]() {
auto& infer_state = responses_[request_id];
std::chrono::time_point<std::chrono::system_clock> start;
bool first = true;
auto& infer_state = responses_.Get(request_id);
LOG_INFO << "Preparing to run inference task queue...";
while (true) { // Continuously check if the queue is not empty
std::unique_lock<std::mutex> lock(
infer_state.queue_mutex); // Lock the queue for exclusive access
if (!infer_state.texts_to_stream.empty()) {
if (std::exchange(first, false)) {
first = false;
start = std::chrono::system_clock::now();
}
std::string rew_text = infer_state.texts_to_stream.front();
// std::cout << rew_text << std::endl;
infer_state.texts_to_stream.pop();
Expand Down Expand Up @@ -363,10 +371,17 @@ void TensorrtllmEngine::HandleChatCompletion(
}
}
// LOG_INFO << res_str;
responses_.erase(request_id);
auto end = std::chrono::system_clock::now();
auto duration_ms =
std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
.count();
LOG_INFO << "Inference completed, generated tokens per second: "
<< static_cast<double>(infer_state.token_gen_count) / duration_ms *
1000;
responses_.Erase(request_id);
});

LOG_INFO << "Inference completed";
LOG_TRACE << "Done";
return;
};

Expand All @@ -376,9 +391,11 @@ void TensorrtllmEngine::LoadModel(
model::LoadModelRequest request = model::fromJson(json_body);
std::filesystem::path model_dir = request.model_path;
model_type_ = GetModelType(request.model_path);
n_parallel_ = request.n_parallel;
LOG_DEBUG << "n_parallel: " << n_parallel_;

int ctx_len = request.ctx_len;
// We only support 2 models for now, it is ugly but it works :(
// We only support 3 models for now, it is ugly but it works :(
if (model_type_ == ModelType::kOpenHermes) {
user_prompt_ =
request.user_prompt.empty() ? kOhUserPrompt : request.user_prompt;
Expand Down Expand Up @@ -414,7 +431,7 @@ void TensorrtllmEngine::LoadModel(

runtime_opts_.beamWidth = 1;
runtime_opts_.trtEnginePath = request.model_path;
runtime_opts_.maxNewTokens = ctx_len;
runtime_opts_.maxNewTokens = 1024;

auto executor_config = tle::ExecutorConfig(runtime_opts_.beamWidth);
// TODO(sang) try catch?
Expand All @@ -424,9 +441,10 @@ void TensorrtllmEngine::LoadModel(

model_loaded_ = true;
if (q_ == nullptr) {
q_ = std::make_unique<trantor::ConcurrentTaskQueue>(1, model_id_);
q_ = std::make_unique<trantor::ConcurrentTaskQueue>(n_parallel_, model_id_);
}
thread_ = std::thread(&TensorrtllmEngine::WaitForResponses, this);
res_thread_ = std::thread(&TensorrtllmEngine::WaitForResponses, this);
req_thread_ = std::thread(&TensorrtllmEngine::HandleRequests, this);

// Model loaded successfully
LOG_INFO << "Model " << model_id_ << " loaded successfully from path "
Expand Down Expand Up @@ -521,79 +539,87 @@ void TensorrtllmEngine::GetModels(
LOG_INFO << "Running models responded";
}

tle::IdType TensorrtllmEngine::EnqueueRequest(
RuntimeOptions const& runtimeOpts, const tle::VecTokens& input_tokens) {
// LOG_INFO << "EnqueueRequests 1";
void TensorrtllmEngine::HandleRequests() {
tle::OutputConfig outputConfig;
// LOG_INFO << "EnqueueRequests 2";
outputConfig.excludeInputFromOutput = runtimeOpts.excludeInputFromOutput;
// LOG_INFO << "EnqueueRequests 3 " << runtimeOpts.beamWidth;
tle::SamplingConfig samplingConfig(runtimeOpts.beamWidth);

std::vector<tle::Request> requests;
// for (auto& tokens : input_tokens) {
// for (auto t : tokens) {
// std::cout << t << " ";
// }
// std::cout << std::endl;
LOG_INFO << "Creating request with " << input_tokens.size()
<< " input tokens",
input_tokens.size();
auto req = tle::Request(std::move(input_tokens), runtimeOpts.maxNewTokens,
runtimeOpts.streaming, samplingConfig, outputConfig);
req.setStopWords(GetStopWords(model_type_));
// requests.emplace_back(req);
// }
// LOG_INFO << "EnqueueRequests 5";

// Enqueue the requests
auto requestId = executor_->enqueueRequest(req);

return requestId;
outputConfig.excludeInputFromOutput = runtime_opts_.excludeInputFromOutput;
tle::SamplingConfig samplingConfig(runtime_opts_.beamWidth);

while (model_loaded_) {
// process with batch of n_parallel_ or timeout
std::unique_lock lk(req_mtx_);
req_cv_.wait_for(lk, std::chrono::milliseconds(10),
[this] { return reqs_.size() >= n_parallel_; });
// TODO(sang) Better way to do this?
std::vector<tle::Request> requests;
std::vector<int> req_ids;
while (!reqs_.empty()) {
auto req = tle::Request(
std::move(reqs_.front().second), runtime_opts_.maxNewTokens,
runtime_opts_.streaming, samplingConfig, outputConfig);
req.setStopWords(GetStopWords(model_type_));
requests.push_back(req);
req_ids.push_back(reqs_.front().first);
reqs_.pop();
}
if (!requests.empty()) {
// LOG_DEBUG << "Enqueue: " << requests.size() << " " << req_ids.size();
if (executor_->canEnqueueRequests()) {
auto res = executor_->enqueueRequests(requests);
if (res.size() == req_ids.size()) {
for (size_t i = 0; i < res.size(); i++) {
trt2c_ids_[res[i]] = req_ids[i];
}
} else {
LOG_WARN << "Something wrong happened, two sizes should always be "
"the same";
}
}
}
}
}

bool TensorrtllmEngine::WaitForResponses() {
tle::SizeType32 numFinished{0};

// Get the new tokens for each request
// TODO(sang) only works with beamWidth = 1 now
while (model_loaded_) {
std::chrono::milliseconds waitTime(1);
std::chrono::milliseconds waitTime(10);
// Wait for any response
auto responses = executor_->awaitResponses(waitTime);
// Loop over the responses
for (auto const& response : responses) {
auto requestId = response.getRequestId();
// Map back to our request id
auto request_id = trt2c_ids_[response.getRequestId()];

if (!response.hasError()) {
auto result = response.getResult();
numFinished += result.isFinal;
std::lock_guard<std::mutex> guard(responses_[requestId].queue_mutex);
std::lock_guard<std::mutex> guard(
responses_.Get(request_id).queue_mutex);
for (tle::SizeType32 beam = 0; beam < runtime_opts_.beamWidth; ++beam) {
auto& respTokens = result.outputTokenIds.at(beam);
RemoveSpecialTokens(respTokens, model_type_);
auto& resp_tokens = result.outputTokenIds.at(beam);
responses_.Get(request_id).token_gen_count += resp_tokens.size();
RemoveSpecialTokens(resp_tokens, model_type_);
if (model_type_ == ModelType::kLlama3) {
responses_[requestId].texts_to_stream.push(
cortex_tokenizer_->Decode(respTokens));
responses_[requestId].token_gen_count;
responses_.Get(request_id)
.texts_to_stream.push(cortex_tokenizer_->Decode(resp_tokens));
} else {
for (auto res : respTokens) {
responses_[requestId].texts_to_stream.push(
cortex_tokenizer_->DecodeWithSpace(res));
// LOG_INFO << responses_[requestId].texts_to_stream.back();
for (auto res : resp_tokens) {
responses_.Get(request_id)
.texts_to_stream.push(
cortex_tokenizer_->DecodeWithSpace(res));
// LOG_INFO << responses_[request_id].texts_to_stream.back();
}
}
}
if (result.isFinal) {
LOG_INFO << "Request id " << requestId << " is completed.";
responses_[requestId].texts_to_stream.push("[DONE]");
LOG_INFO << "Request id " << request_id << " is completed.";
responses_.Get(request_id).texts_to_stream.push("[DONE]");
}
} else {
// Allow response with error only if awaitResponse processed a terminated request id
std::string err = "ReqId " + std::to_string(response.getRequestId()) +
" has already been processed and was terminated.";
if (response.getErrorMsg() != err) {
TLLM_THROW("Request id %lu encountered error: %s", requestId,
TLLM_THROW("Request id %lu encountered error: %s", request_id,
response.getErrorMsg().c_str());
return false;
}
Expand Down
56 changes: 44 additions & 12 deletions cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <condition_variable>
#include <cstdint>
#include <iostream>
#include <memory>
Expand Down Expand Up @@ -105,21 +106,27 @@ class SentencePieceTokenizer : public Tokenizer {
}

std::string DecodeWithSpace(const int id) override {
std::lock_guard<std::mutex> l(m_);
std::string text = processor.IdToPiece(id);
ReplaceSubstring(text, "", " ");
return text;
}

std::string Decode(const std::vector<int32_t> ids) override {
std::lock_guard<std::mutex> l(m_);
std::string text = processor.DecodeIds(ids);
return text;
}

std::vector<int> Encode(const std::string& input) override {
std::lock_guard<std::mutex> l(m_);
std::vector<int> ids;
processor.Encode(input, &ids);
return ids;
}

private:
std::mutex m_;
};

class TiktokenTokenizer : public Tokenizer {
Expand All @@ -135,14 +142,19 @@ class TiktokenTokenizer : public Tokenizer {
}

std::string Decode(const std::vector<int32_t> ids) override {
std::lock_guard<std::mutex> l(m_);
std::string text = encoder->decode(ids);
return text;
}

std::vector<int> Encode(const std::string& input) override {
std::lock_guard<std::mutex> l(m_);
std::vector<int> ids = encoder->encode(input);
return ids;
}

private:
std::mutex m_;
};
enum class ModelType { kOpenHermes, kLlama3, kMistral };

Expand Down Expand Up @@ -207,28 +219,48 @@ class TensorrtllmEngine : public EngineI {
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) final;

GenerationInput::TensorPtr GetTensorSingleStopWordList(int stopToken);
GenerationInput CreateGenerationInput(std::vector<int32_t> inputIds);
GenerationOutput CreateGenerationOutput();
GenerationInput::TensorPtr GetTensorChatMLStopWordList();

private:
bool CheckModelLoaded(
std::function<void(Json::Value&&, Json::Value&&)>& callback);

// Function that enqueues requests
tle::IdType EnqueueRequest(
RuntimeOptions const& runtimeOpts,
const tle::VecTokens& input_tokens);
// Function that hanlde incoming requests
void HandleRequests();

// Function that waits for responses and stores output tokens
bool WaitForResponses();

std::unique_ptr<Tokenizer> cortex_tokenizer_;
RuntimeOptions runtime_opts_;
std::unique_ptr<tle::Executor> executor_;
std::thread thread_; // worker thread to handle responses
std::unordered_map<tle::IdType, InferenceState> responses_;

// TODO(sang) use taskqueue
// We are using 2 data structures to hold requests and responses
// We also need an unordered_map to map between tensorrt-llm request id to our request id
std::thread res_thread_; // worker thread to handle responses
// TODO(sang) template
struct InfSyncMap {
InferenceState& Get(uint64_t k) {
std::lock_guard<std::mutex> l(m);
return data[k];
}

void Erase(uint64_t k) {
std::lock_guard<std::mutex> l(m);
data.erase(k);
}
std::mutex m;
std::unordered_map<tle::IdType, InferenceState> data;
};
InfSyncMap responses_;

std::thread req_thread_; // worker thread to handle requests
std::queue<std::pair<int, tle::VecTokens>> reqs_;
std::condition_variable req_cv_;
std::mutex req_mtx_;
// map tensorrt request id to our request id
std::unordered_map<uint64_t, uint64_t> trt2c_ids_;

std::atomic<uint64_t> req_id_ = 0;

std::shared_ptr<TllmLogger> logger_;
std::string user_prompt_;
Expand All @@ -241,7 +273,7 @@ class TensorrtllmEngine : public EngineI {
std::atomic<bool> model_loaded_;
std::unique_ptr<trantor::ConcurrentTaskQueue> q_;
ModelType model_type_ = ModelType::kOpenHermes;

int n_parallel_ = 1;
};

} // namespace tensorrtllm

0 comments on commit 9e777e3

Please sign in to comment.