From 697dfc03e1f9fa597a34bfbaa29a1f8583339487 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian Date: Tue, 24 Sep 2024 13:37:57 +0000 Subject: [PATCH] Fix response factory cleanup --- src/infer_request.h | 1 + src/python_be.cc | 28 +++++++++++++++++++++------- src/response_sender.h | 1 + 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/infer_request.h b/src/infer_request.h index c67e2fb0..f368d692 100644 --- a/src/infer_request.h +++ b/src/infer_request.h @@ -96,6 +96,7 @@ class InferRequest { InferenceTrace& GetTrace(); uint32_t ReleaseFlags(); void SetReleaseFlags(const uint32_t& flags); + intptr_t GetResponseFactoryAddress() { return response_factory_address_; } #ifdef TRITON_PB_STUB std::shared_ptr Exec(const bool is_decoupled); diff --git a/src/python_be.cc b/src/python_be.cc index b5334aa2..cc073a29 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -1089,6 +1089,17 @@ ModelInstanceState::ResponseSendDecoupled( ResponseSendMessage* send_message_payload = reinterpret_cast(send_message.data_.get()); std::unique_ptr error_message; + ScopedDefer response_factory_deleter([send_message_payload] { + if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { + TRITONBACKEND_ResponseFactory* response_factory = + reinterpret_cast( + send_message_payload->response_factory_address); + std::unique_ptr< + TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> + lresponse_factory(reinterpret_cast( + response_factory)); + } + }); ScopedDefer _([send_message_payload] { { bi::scoped_lock guard{send_message_payload->mu}; @@ -1214,13 +1225,6 @@ ModelInstanceState::ResponseSendDecoupled( SetErrorForResponseSendMessage( send_message_payload, WrapTritonErrorInSharedPtr(error), error_message); } - - if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { - std::unique_ptr< - TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> - lresponse_factory( - reinterpret_cast(response_factory)); - } } TRITONSERVER_Error* @@ -1357,6 +1361,16 @@ ModelInstanceState::ProcessRequests( (*responses)[r] = nullptr; continue; } + { + TRITONBACKEND_ResponseFactory* response_factory = + reinterpret_cast( + pb_infer_requests[r]->GetResponseFactoryAddress()); + std::unique_ptr< + TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> + lresponse_factory( + reinterpret_cast( + response_factory)); + } infer_response = InferResponse::LoadFromSharedMemory( Stub()->ShmPool(), response_shm_handle[r], false /* open_cuda_handle */); diff --git a/src/response_sender.h b/src/response_sender.h index 6ca7e997..7fce9dd2 100644 --- a/src/response_sender.h +++ b/src/response_sender.h @@ -43,6 +43,7 @@ class ResponseSender { const std::set& requested_output_names, std::unique_ptr& shm_pool, const std::shared_ptr& pb_cancel); + intptr_t ResponseFactory() { return response_factory_address_; } ~ResponseSender(); void Send(std::shared_ptr response, const uint32_t flags); bool IsCancelled();