Skip to content

Commit

Permalink
Extend request objects lifetime and fixes possible segmentation fault (
Browse files Browse the repository at this point in the history
…#6620)

* Extend request objects lifetime

* Remove explicit TRITONSERVER_InferenceRequestDelete

* Format fix

* Include the inference_request_ initialization to cover RequestNew

---------

Co-authored-by: Neelay Shah <[email protected]>
  • Loading branch information
tanmayv25 and nnshah1 authored Nov 22, 2023
1 parent 738996f commit 4b34a48
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 26 deletions.
27 changes: 18 additions & 9 deletions src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -640,10 +640,11 @@ InferRequestComplete(
{
LOG_VERBOSE(1) << "ModelInferHandler::InferRequestComplete";

RequestReleasePayload* request_release_payload =
static_cast<RequestReleasePayload*>(userp);

if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting GRPC inference request");
delete request_release_payload;
}
}

Expand Down Expand Up @@ -861,6 +862,12 @@ ModelInferHandler::Execute(InferHandler::State* state)
}

if (err == nullptr) {
state->inference_request_ = {
irequest, [](TRITONSERVER_InferenceRequest* request) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting gRPC inference request");
}};
err = SetInferenceRequestMetadata(irequest, request, state->parameters_);
}

Expand All @@ -881,9 +888,13 @@ ModelInferHandler::Execute(InferHandler::State* state)
tritonserver_, shm_manager_, request, std::move(serialized_data),
response_queue, &state->alloc_payload_);
}

auto request_release_payload =
std::make_unique<RequestReleasePayload>(state->inference_request_);
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest, InferRequestComplete, nullptr /* request_release_userp */);
irequest, InferRequestComplete,
request_release_payload.get() /* request_release_userp */);
}
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetResponseCallback(
Expand Down Expand Up @@ -922,16 +933,14 @@ ModelInferHandler::Execute(InferHandler::State* state)
// COMPLETE or CANCELLED. Recording the state and the irequest
// to handle gRPC stream cancellation.
if (err == nullptr) {
state->context_->InsertInflightState(state, irequest);
state->context_->InsertInflightState(state);
// The payload will be cleaned in request release callback.
request_release_payload.release();
} else {
// If error go immediately to COMPLETE.
LOG_VERBOSE(1) << "[request id: " << request_id << "] "
<< "Infer failed: " << TRITONSERVER_ErrorMessage(err);

LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(irequest),
"deleting GRPC inference request");

::grpc::Status status;
GrpcStatusUtil::Create(&status, err);
TRITONSERVER_ErrorDelete(err);
Expand Down
31 changes: 20 additions & 11 deletions src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ class Barrier {
size_t generation_;
};

// Simple structure that carries the userp payload needed for
// request release callback.
struct RequestReleasePayload final {
explicit RequestReleasePayload(
const std::shared_ptr<TRITONSERVER_InferenceRequest>& inference_request)
: inference_request_(inference_request){};

private:
std::shared_ptr<TRITONSERVER_InferenceRequest> inference_request_ = nullptr;
};

//
// ResponseQueue
//
Expand Down Expand Up @@ -715,15 +726,9 @@ class InferHandlerState {
// Inserts the state to a set tracking active requests
// within the server core. Should only be called when
// the request was successfully enqueued on Triton.
void InsertInflightState(
InferHandlerStateType* state, TRITONSERVER_InferenceRequest* irequest)
void InsertInflightState(InferHandlerStateType* state)
{
std::lock_guard<std::recursive_mutex> lock(mu_);
// The irequest_ptr_ will get populated when it is
// marked as active which means the request has been
// successfully enqueued to Triton core using
// TRITONSERVER_ServerInferAsync.
state->irequest_ptr_ = irequest;
inflight_states_.insert(state);
}

Expand All @@ -748,7 +753,7 @@ class InferHandlerState {
if (state->step_ != Steps::CANCELLED &&
state->step_ != Steps::COMPLETE) {
LOG_VERBOSE(1) << "Issuing cancellation for " << state->unique_id_;
if (state->irequest_ptr_ == nullptr) {
if (state->inference_request_.get() == nullptr) {
// The context might be holding some states that have
// not been issued to Triton core. Need to skip calling
// issuing cancellation for such requests.
Expand All @@ -758,7 +763,8 @@ class InferHandlerState {
// Assuming if RequestComplete callback is run asynchronously
// before this point.
TRITONSERVER_Error* err = nullptr;
err = TRITONSERVER_InferenceRequestCancel(state->irequest_ptr_);
err = TRITONSERVER_InferenceRequestCancel(
state->inference_request_.get());
// TODO: Add request id to the message
if (err != nullptr) {
LOG_INFO << "Failed to cancel the request: "
Expand Down Expand Up @@ -1023,7 +1029,6 @@ class InferHandlerState {
unique_id_ = NEXT_UNIQUE_ID;
context_ = context;
step_ = start_step;
irequest_ptr_ = nullptr;
cb_count_ = 0;
is_decoupled_ = false;
complete_ = false;
Expand All @@ -1042,6 +1047,7 @@ class InferHandlerState {
void Release()
{
context_ = nullptr;
inference_request_.reset();
ClearTraceTimestamps();
}

Expand Down Expand Up @@ -1077,7 +1083,10 @@ class InferHandlerState {
Steps step_;
std::recursive_mutex step_mtx_;

TRITONSERVER_InferenceRequest* irequest_ptr_;
// Shared pointer to the inference request object. The lifetime of
// inference request object is extended till all the responses from
// the request are processed and the request is released.
std::shared_ptr<TRITONSERVER_InferenceRequest> inference_request_;

#ifdef TRITON_ENABLE_TRACING
std::shared_ptr<TraceManager::Trace> trace_;
Expand Down
20 changes: 14 additions & 6 deletions src/grpc/stream_infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
}

if (err == nullptr) {
state->inference_request_ = {
irequest, [](TRITONSERVER_InferenceRequest* request) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting gRPC inference request");
}};
err = SetInferenceRequestMetadata(irequest, request, state->parameters_);
}

Expand All @@ -285,9 +291,13 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
tritonserver_, shm_manager_, request, std::move(serialized_data),
response_queue_, &state->alloc_payload_);
}

auto request_release_payload =
std::make_unique<RequestReleasePayload>(state->inference_request_);
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest, InferRequestComplete, nullptr /* request_release_userp */);
irequest, InferRequestComplete,
request_release_payload.get() /* request_release_userp */);
}
if (err == nullptr) {
err = TRITONSERVER_InferenceRequestSetResponseCallback(
Expand Down Expand Up @@ -317,7 +327,9 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
// WRITEREADY or WRITTEN or CANCELLED. Recording the state and the
// irequest to handle gRPC stream cancellation.
if (err == nullptr) {
state->context_->InsertInflightState(state, irequest);
state->context_->InsertInflightState(state);
// The payload will be cleaned in request release callback.
request_release_payload.release();
} else {
// If there was an error then enqueue the error response and show
// it to be ready for writing.
Expand All @@ -337,10 +349,6 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok)
LOG_VERBOSE(1) << "[request id: " << log_request_id << "] "
<< "Infer failed: " << TRITONSERVER_ErrorMessage(err);

LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(irequest),
"deleting GRPC inference request");

::grpc::Status status;
GrpcStatusUtil::Create(&status, err);
TRITONSERVER_ErrorDelete(err);
Expand Down

0 comments on commit 4b34a48

Please sign in to comment.