Skip to content

Commit

Permalink
Fix up, add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
krishung5 committed Oct 4, 2024
1 parent 95519a1 commit 1282598
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 25 deletions.
35 changes: 12 additions & 23 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -719,12 +719,22 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
ResponseBatch* response_batch_shm_ptr = reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));


// If the response sender is already closed, notify the backend NOT to
// The backend will clean up the response factory if there is an error in
// the response batch. It is necessary to handle cases where the response
// sender should have already cleaned up, ensuring the backend does not
// delete the response factory again during error handling.
if (err_message.find("Response sender has been closed") !=
std::string::npos) {
response_batch_shm_ptr->is_response_factory_deleted = true;
} else if (
err_message.find("is using the decoupled mode and the execute function "
"must return None") != std::string::npos) {
for (py::handle py_request : py_request_list) {
InferRequest* request = py_request.cast<InferRequest*>();
if (request->GetResponseSender()->IsClosed()) {
response_batch_shm_ptr->is_response_factory_deleted = true;
}
}
}

response_batch_shm_ptr->has_error = true;
Expand Down Expand Up @@ -788,27 +798,6 @@ Stub::ProcessReturnedResponses(
}
// Only non-decoupled may return responses.
if (IsDecoupled()) {
// For decoupled mode, if before returning from this error, there was
// already a response sent from the response sender, along with the complete
// final flag, then use the `is_response_factory_deleted` flag to notify the
// backend to NOT to delete the response factory again during error
// handling.
for (py::handle py_request : py_requests) {
InferRequest* request = py_request.cast<InferRequest*>();
if (request->GetResponseSender()->IsClosed()) {
// Notify the backend to NOT to delete the response factory again during
// error handling.
if (!response_batch) {
response_batch = std::move(shm_pool_->Construct<char>(
sizeof(ResponseBatch) + sizeof(IPCMessageShm)));
}
ResponseBatch* response_batch_shm_ptr =
reinterpret_cast<ResponseBatch*>(
response_batch.value().data_.get() + sizeof(IPCMessageShm));
response_batch_shm_ptr->is_response_factory_deleted = true;
}
}

throw PythonBackendException(
"Python model '" + name_ +
"' is using the decoupled mode and the execute function must return "
Expand Down
5 changes: 4 additions & 1 deletion src/python_be.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,9 @@ ModelInstanceState::ProcessRequests(
reporter.SetBatchStatistics(total_batch_size);

if (response_batch_shm_ptr->has_error) {
// Clean up the response factory if an error occurred. The
// `is_response_factory_deleted` flag indicates whether the response factory
// has been deleted for some corner cases.
if (!response_batch_shm_ptr->is_response_factory_deleted) {
for (uint32_t r = 0; r < request_count; r++) {
TRITONBACKEND_ResponseFactory* response_factory =
Expand Down Expand Up @@ -1396,7 +1399,7 @@ ModelInstanceState::ProcessRequests(
// It is possible to have multiple responses batched together in a single
// response batch shm, where some of the responses are None due to the
// usage of response sender, so only create a TRITONBACKEND_Response
// object for the valid responses, and skip the None responses later.
// object for the valid responses.
if (response_shm_handle[i] == 0) {
responses->emplace_back(nullptr);
} else {
Expand Down
1 change: 0 additions & 1 deletion src/stub_launcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,6 @@ TRITONSERVER_Error*
StubLauncher::ReceiveMessageFromStub(
bi::managed_external_buffer::handle_t& message)
{
// message = parent_message_queue_->Pop();
bool success = false;
while (!success) {
uint64_t timeout_miliseconds = 1000;
Expand Down

0 comments on commit 1282598

Please sign in to comment.