From 1282598aaa9b48c9f5758387070773217ebaa1c2 Mon Sep 17 00:00:00 2001 From: krishung5 Date: Thu, 3 Oct 2024 17:16:17 -0700 Subject: [PATCH] Fix up, add comments --- src/pb_stub.cc | 35 ++++++++++++----------------------- src/python_be.cc | 5 ++++- src/stub_launcher.cc | 1 - 3 files changed, 16 insertions(+), 25 deletions(-) diff --git a/src/pb_stub.cc b/src/pb_stub.cc index 6af07fb3..5bf2a5c2 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -719,12 +719,22 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) ResponseBatch* response_batch_shm_ptr = reinterpret_cast( response_batch.value().data_.get() + sizeof(IPCMessageShm)); - - // If the response sender is already closed, notify the backend NOT to + // The backend will clean up the response factory if there is an error in + // the response batch. It is necessary to handle cases where the response + // sender should have already cleaned up, ensuring the backend does not // delete the response factory again during error handling. if (err_message.find("Response sender has been closed") != std::string::npos) { response_batch_shm_ptr->is_response_factory_deleted = true; + } else if ( + err_message.find("is using the decoupled mode and the execute function " + "must return None") != std::string::npos) { + for (py::handle py_request : py_request_list) { + InferRequest* request = py_request.cast(); + if (request->GetResponseSender()->IsClosed()) { + response_batch_shm_ptr->is_response_factory_deleted = true; + } + } } response_batch_shm_ptr->has_error = true; @@ -788,27 +798,6 @@ Stub::ProcessReturnedResponses( } // Only non-decoupled may return responses. if (IsDecoupled()) { - // For decoupled mode, if before returning from this error, there was - // already a response sent from the response sender, along with the complete - // final flag, then use the `is_response_factory_deleted` flag to notify the - // backend to NOT to delete the response factory again during error - // handling. - for (py::handle py_request : py_requests) { - InferRequest* request = py_request.cast(); - if (request->GetResponseSender()->IsClosed()) { - // Notify the backend to NOT to delete the response factory again during - // error handling. - if (!response_batch) { - response_batch = std::move(shm_pool_->Construct( - sizeof(ResponseBatch) + sizeof(IPCMessageShm))); - } - ResponseBatch* response_batch_shm_ptr = - reinterpret_cast( - response_batch.value().data_.get() + sizeof(IPCMessageShm)); - response_batch_shm_ptr->is_response_factory_deleted = true; - } - } - throw PythonBackendException( "Python model '" + name_ + "' is using the decoupled mode and the execute function must return " diff --git a/src/python_be.cc b/src/python_be.cc index 6b8c2516..40909388 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -1362,6 +1362,9 @@ ModelInstanceState::ProcessRequests( reporter.SetBatchStatistics(total_batch_size); if (response_batch_shm_ptr->has_error) { + // Clean up the response factory if an error occurred. The + // `is_response_factory_deleted` flag indicates whether the response factory + // has been deleted for some corner cases. if (!response_batch_shm_ptr->is_response_factory_deleted) { for (uint32_t r = 0; r < request_count; r++) { TRITONBACKEND_ResponseFactory* response_factory = @@ -1396,7 +1399,7 @@ ModelInstanceState::ProcessRequests( // It is possible to have multiple responses batched together in a single // response batch shm, where some of the responses are None due to the // usage of response sender, so only create a TRITONBACKEND_Response - // object for the valid responses, and skip the None responses later. + // object for the valid responses. if (response_shm_handle[i] == 0) { responses->emplace_back(nullptr); } else { diff --git a/src/stub_launcher.cc b/src/stub_launcher.cc index e8d2430f..828228e6 100644 --- a/src/stub_launcher.cc +++ b/src/stub_launcher.cc @@ -728,7 +728,6 @@ TRITONSERVER_Error* StubLauncher::ReceiveMessageFromStub( bi::managed_external_buffer::handle_t& message) { - // message = parent_message_queue_->Pop(); bool success = false; while (!success) { uint64_t timeout_miliseconds = 1000;