diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index bfd3c9b2e2..30d93fa4f9 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -640,10 +640,11 @@ InferRequestComplete( { LOG_VERBOSE(1) << "ModelInferHandler::InferRequestComplete"; + RequestReleasePayload* request_release_payload = + static_cast(userp); + if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) { - LOG_TRITONSERVER_ERROR( - TRITONSERVER_InferenceRequestDelete(request), - "deleting GRPC inference request"); + delete request_release_payload; } } @@ -861,6 +862,12 @@ ModelInferHandler::Execute(InferHandler::State* state) } if (err == nullptr) { + state->inference_request_ = { + irequest, [](TRITONSERVER_InferenceRequest* request) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(request), + "deleting gRPC inference request"); + }}; err = SetInferenceRequestMetadata(irequest, request, state->parameters_); } @@ -881,9 +888,13 @@ ModelInferHandler::Execute(InferHandler::State* state) tritonserver_, shm_manager_, request, std::move(serialized_data), response_queue, &state->alloc_payload_); } + + auto request_release_payload = + std::make_unique(state->inference_request_); if (err == nullptr) { err = TRITONSERVER_InferenceRequestSetReleaseCallback( - irequest, InferRequestComplete, nullptr /* request_release_userp */); + irequest, InferRequestComplete, + request_release_payload.get() /* request_release_userp */); } if (err == nullptr) { err = TRITONSERVER_InferenceRequestSetResponseCallback( @@ -922,16 +933,14 @@ ModelInferHandler::Execute(InferHandler::State* state) // COMPLETE or CANCELLED. Recording the state and the irequest // to handle gRPC stream cancellation. if (err == nullptr) { - state->context_->InsertInflightState(state, irequest); + state->context_->InsertInflightState(state); + // The payload will be cleaned in request release callback. + request_release_payload.release(); } else { // If error go immediately to COMPLETE. LOG_VERBOSE(1) << "[request id: " << request_id << "] " << "Infer failed: " << TRITONSERVER_ErrorMessage(err); - LOG_TRITONSERVER_ERROR( - TRITONSERVER_InferenceRequestDelete(irequest), - "deleting GRPC inference request"); - ::grpc::Status status; GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index d9bf17a068..36783e5912 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -88,6 +88,17 @@ class Barrier { size_t generation_; }; +// Simple structure that carries the userp payload needed for +// request release callback. +struct RequestReleasePayload final { + explicit RequestReleasePayload( + const std::shared_ptr& inference_request) + : inference_request_(inference_request){}; + + private: + std::shared_ptr inference_request_ = nullptr; +}; + // // ResponseQueue // @@ -715,15 +726,9 @@ class InferHandlerState { // Inserts the state to a set tracking active requests // within the server core. Should only be called when // the request was successfully enqueued on Triton. - void InsertInflightState( - InferHandlerStateType* state, TRITONSERVER_InferenceRequest* irequest) + void InsertInflightState(InferHandlerStateType* state) { std::lock_guard lock(mu_); - // The irequest_ptr_ will get populated when it is - // marked as active which means the request has been - // successfully enqueued to Triton core using - // TRITONSERVER_ServerInferAsync. - state->irequest_ptr_ = irequest; inflight_states_.insert(state); } @@ -748,7 +753,7 @@ class InferHandlerState { if (state->step_ != Steps::CANCELLED && state->step_ != Steps::COMPLETE) { LOG_VERBOSE(1) << "Issuing cancellation for " << state->unique_id_; - if (state->irequest_ptr_ == nullptr) { + if (state->inference_request_.get() == nullptr) { // The context might be holding some states that have // not been issued to Triton core. Need to skip calling // issuing cancellation for such requests. @@ -758,7 +763,8 @@ class InferHandlerState { // Assuming if RequestComplete callback is run asynchronously // before this point. TRITONSERVER_Error* err = nullptr; - err = TRITONSERVER_InferenceRequestCancel(state->irequest_ptr_); + err = TRITONSERVER_InferenceRequestCancel( + state->inference_request_.get()); // TODO: Add request id to the message if (err != nullptr) { LOG_INFO << "Failed to cancel the request: " @@ -1023,7 +1029,6 @@ class InferHandlerState { unique_id_ = NEXT_UNIQUE_ID; context_ = context; step_ = start_step; - irequest_ptr_ = nullptr; cb_count_ = 0; is_decoupled_ = false; complete_ = false; @@ -1042,6 +1047,7 @@ class InferHandlerState { void Release() { context_ = nullptr; + inference_request_.reset(); ClearTraceTimestamps(); } @@ -1077,7 +1083,10 @@ class InferHandlerState { Steps step_; std::recursive_mutex step_mtx_; - TRITONSERVER_InferenceRequest* irequest_ptr_; + // Shared pointer to the inference request object. The lifetime of + // inference request object is extended till all the responses from + // the request are processed and the request is released. + std::shared_ptr inference_request_; #ifdef TRITON_ENABLE_TRACING std::shared_ptr trace_; diff --git a/src/grpc/stream_infer_handler.cc b/src/grpc/stream_infer_handler.cc index 9e564d8322..9c162ad644 100644 --- a/src/grpc/stream_infer_handler.cc +++ b/src/grpc/stream_infer_handler.cc @@ -265,6 +265,12 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) } if (err == nullptr) { + state->inference_request_ = { + irequest, [](TRITONSERVER_InferenceRequest* request) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(request), + "deleting gRPC inference request"); + }}; err = SetInferenceRequestMetadata(irequest, request, state->parameters_); } @@ -285,9 +291,13 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) tritonserver_, shm_manager_, request, std::move(serialized_data), response_queue_, &state->alloc_payload_); } + + auto request_release_payload = + std::make_unique(state->inference_request_); if (err == nullptr) { err = TRITONSERVER_InferenceRequestSetReleaseCallback( - irequest, InferRequestComplete, nullptr /* request_release_userp */); + irequest, InferRequestComplete, + request_release_payload.get() /* request_release_userp */); } if (err == nullptr) { err = TRITONSERVER_InferenceRequestSetResponseCallback( @@ -317,7 +327,9 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) // WRITEREADY or WRITTEN or CANCELLED. Recording the state and the // irequest to handle gRPC stream cancellation. if (err == nullptr) { - state->context_->InsertInflightState(state, irequest); + state->context_->InsertInflightState(state); + // The payload will be cleaned in request release callback. + request_release_payload.release(); } else { // If there was an error then enqueue the error response and show // it to be ready for writing. @@ -337,10 +349,6 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) LOG_VERBOSE(1) << "[request id: " << log_request_id << "] " << "Infer failed: " << TRITONSERVER_ErrorMessage(err); - LOG_TRITONSERVER_ERROR( - TRITONSERVER_InferenceRequestDelete(irequest), - "deleting GRPC inference request"); - ::grpc::Status status; GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err);