diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index dbc0f85559..c62ea12650 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -186,7 +186,7 @@ CommonCallData::Process(bool rpc_ok) step_ = Steps::FINISH; } - return step_ != Steps::FINISH; + return step_ == Steps::FINISH; } template @@ -338,7 +338,8 @@ CommonHandler::Start() while (cq_->Next(&tag, &ok)) { ICallData* call_data = static_cast(tag); - if (!call_data->Process(ok)) { + const bool tag_finished = call_data->Process(ok); + if (tag_finished) { LOG_VERBOSE(1) << "Done for " << call_data->Name() << ", " << call_data->Id(); delete call_data; diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index 37a921fa75..d9b6bcb559 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -767,7 +767,7 @@ ModelInferHandler::Process(InferHandler::State* state, bool rpc_ok) finished = true; } - return !finished; + return finished; } TRITONSERVER_Error* diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index b2ce3f13e2..749fc157c0 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -958,6 +958,7 @@ class InferHandler : public HandlerBase { } virtual void StartNewRequest() = 0; + // Returns whether the request is finished being processed or not. virtual bool Process(State* state, bool rpc_ok) = 0; bool ExecutePrecondition(InferHandler::State* state); @@ -1037,7 +1038,8 @@ InferHandler< while (cq_->Next(&tag, &ok)) { State* state = static_cast(tag); - if (!Process(state, ok)) { + const bool tag_finished = Process(state, ok); + if (tag_finished) { LOG_VERBOSE(1) << "Done for " << Name() << ", " << state->unique_id_; StateRelease(state); } diff --git a/src/grpc/stream_infer_handler.cc b/src/grpc/stream_infer_handler.cc index 8877694284..b6accd416f 100644 --- a/src/grpc/stream_infer_handler.cc +++ b/src/grpc/stream_infer_handler.cc @@ -107,6 +107,68 @@ StreamOutputBufferAttributes( // RPC. This implementation is tuned towards performance and reducing latency. //============================================================================= +bool +ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) +{ + LOG_VERBOSE(1) << "Process for " << Name() << ", rpc_ok=" << rpc_ok + << ", context " << state->context_->unique_id_ << ", " + << state->unique_id_ << " step " << state->step_; + + // We need an explicit finish indicator. Can't use 'state->step_' + // because we launch an async thread that could update 'state's + // step_ to be FINISH before this thread exits this function. + bool finished = false; + + if (state->step_ == Steps::START) { + // Transitions to READ state on success, or COMPLETE/FINISH on errors. + finished = RequestStartStep(state, rpc_ok); + } else if (state->step_ == Steps::READ) { + // Transitions to ISSUED state on successfully sending inference request to + // Triton, or COMPLETE/FINISH on errors. The ISSUED state is checked in the + // request's response callback to handle transitioning to writing responses. + finished = RequestReadStep(state, rpc_ok); + } + // We handle the WRITTEN and WRITEREADY states little + // differently depending whether the inference request + // is for a decoupled model or not. This is because the + // grpc contract requires us to call Write() only once + // on a task. Hence, for decoupled writes, we call only + // one write and then wait for another notification from + // the completion queue to execute pending Write()'s, if + // any. + else if (state->step_ == Steps::WRITEREADY) { + // The non-decoupled transition to WRITEREADY state immediately attempts to + // WriteResponseIfReady() and go to WRITTEN state, so only handle the + // decoupled case here. + if (state->is_decoupled_) { + // Transitions to WRITTEN state if no other writes are ongoing, otherwise + // remains in WRITEREADY state and is moved to the back of the task queue. + // If there are no responses left to write, then this transitions to + // COMPLETE/FINISH states. + finished = RequestWriteReadyStepDecoupled(state); + } + } else if (state->step_ == Steps::WRITTEN) { + if (state->is_decoupled_) { + // Transitions to COMPLETE/FINISH state if all responses have been + // written. Otherwise, transitions to WRITEREADY or ISSUED depending on + // whether additional responses are ready to write or not. + finished = RequestWrittenStepDecoupled(state, rpc_ok); + } else { + // Transitions to COMPLETE/FINISH state from WRITTEN state here, because + // there is only one response per-request in the non-decoupled case. + finished = RequestWrittenStepNonDecoupled(state, rpc_ok); + } + } + // COMPLETE step simply marks that we're finished with the request. + else if (state->step_ == Steps::COMPLETE) { + finished = RequestCompleteStep(state); + } + // No special handling currently needed here for remaining states like + // ISSUED and FINISH. + + return finished; +} + void ModelStreamInferHandler::StartNewRequest() { @@ -130,362 +192,367 @@ ModelStreamInferHandler::StartNewRequest() } bool -ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) +ModelStreamInferHandler::RequestStartStep( + InferHandler::State* state, bool rpc_ok) { - LOG_VERBOSE(1) << "Process for " << Name() << ", rpc_ok=" << rpc_ok - << ", context " << state->context_->unique_id_ << ", " - << state->unique_id_ << " step " << state->step_; + // A new stream connection... If RPC failed on a new request then + // the server is shutting down and so we should do nothing. + if (!rpc_ok) { + state->step_ = Steps::FINISH; + return true; + } - // We need an explicit finish indicator. Can't use 'state->step_' - // because we launch an async thread that could update 'state's - // step_ to be FINISH before this thread exits this function. - bool finished = false; + // Start a new request to replace this one... + StartNewRequest(); - if (state->step_ == Steps::START) { - // A new stream connection... If RPC failed on a new request then - // the server is shutting down and so we should do nothing. - if (!rpc_ok) { - state->step_ = Steps::FINISH; - return false; - } + if (ExecutePrecondition(state)) { + // Since this is the start of a connection, 'state' hasn't been + // used yet so use it to read a request off the connection. + state->context_->step_ = Steps::READ; + state->step_ = Steps::READ; + state->context_->responder_->Read(&state->request_, state); + } else { + // Precondition is not satisfied, cancel the stream + state->context_->step_ = Steps::COMPLETE; + state->step_ = Steps::COMPLETE; + ::grpc::Status status = ::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + std::string("This protocol is restricted, expecting header '") + + restricted_kv_.first + "'"); + state->context_->responder_->Finish(status, state); + } - // Start a new request to replace this one... - StartNewRequest(); + // Not finished with a request on the start step unless an error occurs above. + return false; +} - if (ExecutePrecondition(state)) { - // Since this is the start of a connection, 'state' hasn't been - // used yet so use it to read a request off the connection. - state->context_->step_ = Steps::READ; - state->step_ = Steps::READ; - state->context_->responder_->Read(&state->request_, state); - } else { - // Precondition is not satisfied, cancel the stream - state->context_->step_ = Steps::COMPLETE; - state->step_ = Steps::COMPLETE; - ::grpc::Status status = ::grpc::Status( - ::grpc::StatusCode::UNAVAILABLE, - std::string("This protocol is restricted, expecting header '") + - restricted_kv_.first + "'"); - state->context_->responder_->Finish(status, state); - return !finished; +void +ModelStreamInferHandler::PrepareAndSendTritonRequest(InferHandler::State* state) +{ + TRITONSERVER_Error* err = nullptr; + const inference::ModelInferRequest& request = state->request_; + int64_t requested_model_version; + err = GetModelVersionFromString( + request.model_version(), &requested_model_version); + + // Record the transaction policy of the model into the current state + // object. + if (err == nullptr) { + uint32_t txn_flags; + err = TRITONSERVER_ServerModelTransactionProperties( + tritonserver_.get(), request.model_name().c_str(), + requested_model_version, &txn_flags, nullptr /* voidp */); + if (err == nullptr) { + state->is_decoupled_ = ((txn_flags & TRITONSERVER_TXN_DECOUPLED) != 0); } + } - } else if (state->step_ == Steps::READ) { - TRITONSERVER_Error* err = nullptr; - const inference::ModelInferRequest& request = state->request_; -#ifdef TRITON_ENABLE_TRACING - state->trace_timestamps_.emplace_back( - std::make_pair("GRPC_WAITREAD_END", TraceManager::CaptureTimestamp())); -#endif // TRITON_ENABLE_TRACING - - // If done reading and no in-flight requests then can finish the - // entire stream. Otherwise just finish this state. - if (!rpc_ok) { - state->context_->step_ = Steps::WRITEREADY; - if (state->context_->IsRequestsCompleted()) { - state->context_->step_ = Steps::COMPLETE; - state->step_ = Steps::COMPLETE; - state->context_->responder_->Finish( - state->context_->finish_ok_ ? ::grpc::Status::OK - : ::grpc::Status::CANCELLED, - state); - } else { - state->step_ = Steps::FINISH; - finished = true; - } - - return !finished; - } + // Request has been successfully read, increment the context request + // counter. + state->context_->IncrementRequestCounter(); - int64_t requested_model_version; - err = GetModelVersionFromString( - request.model_version(), &requested_model_version); + // If the request is not for a model with decoupled transaction policy + // then put it in the context queue so that its response is sent in + // the same order as the request was received. + if (!state->is_decoupled_) { + state->context_->EnqueueForResponse(state); + } - // Record the transaction policy of the model into the current state - // object. - if (err == nullptr) { - uint32_t txn_flags; - err = TRITONSERVER_ServerModelTransactionProperties( - tritonserver_.get(), request.model_name().c_str(), - requested_model_version, &txn_flags, nullptr /* voidp */); - if (err == nullptr) { - state->is_decoupled_ = ((txn_flags & TRITONSERVER_TXN_DECOUPLED) != 0); - } - } + // Create the inference request which contains all the + // input information needed for an inference. + TRITONSERVER_InferenceRequest* irequest = nullptr; + if (err == nullptr) { + err = TRITONSERVER_InferenceRequestNew( + &irequest, tritonserver_.get(), request.model_name().c_str(), + requested_model_version); + } - // Request has been successfully read, increment the context request - // counter. - state->context_->IncrementRequestCounter(); + if (err == nullptr) { + err = SetInferenceRequestMetadata(irequest, request, state->parameters_); + } - // If the request is not for a model with decoupled transaction policy - // then put it in the context queue so that its response is sent in - // the same order as the request was received. - if (!state->is_decoupled_) { - state->context_->EnqueueForResponse(state); - } + if (err == nullptr) { + err = ForwardHeadersAsParameters(irequest, state); + } - // Need to get context here as it is needed below. 'state' can - // complete inference, write response, and finish (which releases - // context) before we make any forward progress.... so need to - // hold onto context here while we know it is good. - std::shared_ptr context = state->context_; + // Will be used to hold the serialized data in case explicit string + // tensors are present in the request. + std::list serialized_data; - // Issue the inference request into server... - auto response_queue_ = state->response_queue_; + if (err == nullptr) { + err = InferGRPCToInput( + tritonserver_, shm_manager_, request, &serialized_data, irequest); + } + if (err == nullptr) { + err = InferAllocatorPayload( + tritonserver_, shm_manager_, request, std::move(serialized_data), + state->response_queue_, &state->alloc_payload_); + } + if (err == nullptr) { + err = TRITONSERVER_InferenceRequestSetReleaseCallback( + irequest, InferRequestComplete, nullptr /* request_release_userp */); + } + if (err == nullptr) { + err = TRITONSERVER_InferenceRequestSetResponseCallback( + irequest, allocator_, + &state->alloc_payload_ /* response_allocator_userp */, + StreamInferResponseComplete, reinterpret_cast(state)); + } - // Create the inference request which contains all the - // input information needed for an inference. - TRITONSERVER_InferenceRequest* irequest = nullptr; - if (err == nullptr) { - err = TRITONSERVER_InferenceRequestNew( - &irequest, tritonserver_.get(), request.model_name().c_str(), - requested_model_version); + if (err == nullptr) { + TRITONSERVER_InferenceTrace* triton_trace = nullptr; +#ifdef TRITON_ENABLE_TRACING + state->trace_ = + std::move(trace_manager_->SampleTrace(request.model_name())); + if (state->trace_ != nullptr) { + triton_trace = state->trace_->trace_; } +#endif // TRITON_ENABLE_TRACING - if (err == nullptr) { - err = SetInferenceRequestMetadata(irequest, request, state->parameters_); - } + state->step_ = ISSUED; + err = TRITONSERVER_ServerInferAsync( + tritonserver_.get(), irequest, triton_trace); + } - if (err == nullptr) { - err = ForwardHeadersAsParameters(irequest, state); - } + // If there was not an error in issuing the 'state' request then + // state->step_ == ISSUED and inference request has + // initiated... the completion callback will transition to + // WRITEREADY or WRITTEN. If there was an error then enqueue the + // error response and show it to be ready for writing. + if (err != nullptr) { + EnqueueErrorResponse(state, irequest, err); + } +} - // Will be used to hold the serialized data in case explicit string - // tensors are present in the request. - std::list serialized_data; +void +ModelStreamInferHandler::EnqueueErrorResponse( + InferHandler::State* state, TRITONSERVER_InferenceRequest* irequest, + TRITONSERVER_Error* err) +{ + const inference::ModelInferRequest& request = state->request_; + inference::ModelStreamInferResponse* response; + if (state->is_decoupled_) { + state->response_queue_->AllocateResponse(); + response = state->response_queue_->GetLastAllocatedResponse(); + } else { + response = state->response_queue_->GetNonDecoupledResponse(); + } - if (err == nullptr) { - err = InferGRPCToInput( - tritonserver_, shm_manager_, request, &serialized_data, irequest); - } - if (err == nullptr) { - err = InferAllocatorPayload( - tritonserver_, shm_manager_, request, std::move(serialized_data), - response_queue_, &state->alloc_payload_); - } - if (err == nullptr) { - err = TRITONSERVER_InferenceRequestSetReleaseCallback( - irequest, InferRequestComplete, nullptr /* request_release_userp */); - } - if (err == nullptr) { - err = TRITONSERVER_InferenceRequestSetResponseCallback( - irequest, allocator_, - &state->alloc_payload_ /* response_allocator_userp */, - StreamInferResponseComplete, reinterpret_cast(state)); - } + // Get request ID for logging in case of error. + std::string log_request_id = request.id(); + if (log_request_id.empty()) { + log_request_id = ""; + } + LOG_VERBOSE(1) << "[request id: " << log_request_id << "] " + << "Infer failed: " << TRITONSERVER_ErrorMessage(err); + + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(irequest), + "deleting GRPC inference request"); + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + response->set_error_message(status.error_message()); + + response->mutable_infer_response()->Clear(); + // repopulate the id so that client knows which request failed. + response->mutable_infer_response()->set_id(request.id()); + state->step_ = Steps::WRITEREADY; + if (!state->is_decoupled_) { + state->context_->WriteResponseIfReady(state); + } else { + state->response_queue_->MarkNextResponseComplete(); + state->complete_ = true; + state->context_->PutTaskBackToQueue(state); + } +} - if (err == nullptr) { - TRITONSERVER_InferenceTrace* triton_trace = nullptr; +bool +ModelStreamInferHandler::RequestReadStep( + InferHandler::State* state, bool rpc_ok) +{ + bool finished = false; #ifdef TRITON_ENABLE_TRACING - state->trace_ = - std::move(trace_manager_->SampleTrace(request.model_name())); - if (state->trace_ != nullptr) { - triton_trace = state->trace_->trace_; - } + state->trace_timestamps_.emplace_back( + std::make_pair("GRPC_WAITREAD_END", TraceManager::CaptureTimestamp())); #endif // TRITON_ENABLE_TRACING - state->step_ = ISSUED; - err = TRITONSERVER_ServerInferAsync( - tritonserver_.get(), irequest, triton_trace); - } - - // If there was not an error in issuing the 'state' request then - // state->step_ == ISSUED and inference request has - // initiated... the completion callback will transition to - // WRITEREADY or WRITTEN. If there was an error then enqueue the - // error response and show it to be ready for writing. - if (err != nullptr) { - inference::ModelStreamInferResponse* response; - if (state->is_decoupled_) { - state->response_queue_->AllocateResponse(); - response = state->response_queue_->GetLastAllocatedResponse(); - } else { - response = state->response_queue_->GetNonDecoupledResponse(); - } - - // Get request ID for logging in case of error. - std::string log_request_id = request.id(); - if (log_request_id.empty()) { - log_request_id = ""; - } - LOG_VERBOSE(1) << "[request id: " << log_request_id << "] " - << "Infer failed: " << TRITONSERVER_ErrorMessage(err); - - LOG_TRITONSERVER_ERROR( - TRITONSERVER_InferenceRequestDelete(irequest), - "deleting GRPC inference request"); + // If done reading and no in-flight requests then can finish the + // entire stream. Otherwise just finish this state. + if (!rpc_ok) { + // Mark as WRITEREADY to indicate that there are no reads being processed + // for checking IsRequestsCompleted(). + state->context_->step_ = Steps::WRITEREADY; + finished = Finish(state); + return finished; + } - ::grpc::Status status; - GrpcStatusUtil::Create(&status, err); - TRITONSERVER_ErrorDelete(err); - response->set_error_message(status.error_message()); + // Need to get context here as it is needed below. 'state' can + // complete inference, write response, and finish (which releases + // context) before we make any forward progress.... so need to + // hold onto context here while we know it is good. + std::shared_ptr context = state->context_; - response->mutable_infer_response()->Clear(); - // repopulate the id so that client knows which request failed. - response->mutable_infer_response()->set_id(request.id()); - state->step_ = Steps::WRITEREADY; - if (!state->is_decoupled_) { - state->context_->WriteResponseIfReady(state); - } else { - state->response_queue_->MarkNextResponseComplete(); - state->complete_ = true; - state->context_->PutTaskBackToQueue(state); - } - } + // Read protobuf request, convert it into a Triton request, and execute it. + // Update any related state metadata, such as if model is decoupled or not. + // Send an error response if any error occurs. + PrepareAndSendTritonRequest(state); - // Now that the inference request is in flight, create a copy of - // 'state' and use it to attempt another read from the connection - // (i.e the next request in the stream). - State* next_read_state = - StateNew(tritonserver_.get(), context, Steps::READ); + // Now that the inference request is in flight, create a copy of + // 'state' and use it to attempt another read from the connection + // (i.e the next request in the stream). + State* next_read_state = StateNew(tritonserver_.get(), context, Steps::READ); #ifdef TRITON_ENABLE_TRACING - // Capture a timestamp for the time when we start waiting for this - // next request to read. - // Can't create trace as we don't know the model to be requested, - // track timestamps in 'state' - next_read_state->trace_timestamps_.emplace_back(std::make_pair( - "GRPC_WAITREAD_START", TraceManager::CaptureTimestamp())); + // Capture a timestamp for the time when we start waiting for this + // next request to read. + // Can't create trace as we don't know the model to be requested, + // track timestamps in 'state' + next_read_state->trace_timestamps_.emplace_back( + std::make_pair("GRPC_WAITREAD_START", TraceManager::CaptureTimestamp())); #endif // TRITON_ENABLE_TRACING - next_read_state->context_->responder_->Read( - &next_read_state->request_, next_read_state); + next_read_state->context_->responder_->Read( + &next_read_state->request_, next_read_state); - } else if (state->step_ == Steps::COMPLETE) { - state->step_ = Steps::FINISH; - finished = true; - } else if (!state->is_decoupled_) { - // We handle the WRITTEN and WRITEREADY states little - // differently depending whether the inference request - // is for a decoupled model or not. This is because the - // grpc contract requires us to call Write() only once - // on a task. Hence, for decoupled writes, we call only - // one write and then wait for another notification from - // the completion queue to execute pending Write()'s, if - // any. - - // - // Non-Decoupled state transitions - // - if (state->step_ == Steps::WRITTEN) { - state->context_->ongoing_write_ = false; + // Not finished with a request on the read step unless an error occurs above. + return false; +} + +bool +ModelStreamInferHandler::RequestWrittenStepDecoupled( + InferHandler::State* state, bool rpc_ok) +{ + bool finished = false; + state->context_->ongoing_write_ = false; #ifdef TRITON_ENABLE_TRACING - state->trace_timestamps_.emplace_back( - std::make_pair("GRPC_SEND_END", TraceManager::CaptureTimestamp())); + state->trace_timestamps_.emplace_back( + std::make_pair("GRPC_SEND_END", TraceManager::CaptureTimestamp())); #endif // TRITON_ENABLE_TRACING - // If the write failed (for example, client closed the stream) - // mark that the stream did not complete successfully but don't - // cancel right away... need to wait for any pending reads, - // inferences and writes to complete. - if (!rpc_ok) { - LOG_VERBOSE(1) << "Write for " << Name() << ", rpc_ok=" << rpc_ok - << ", context " << state->context_->unique_id_ << ", " - << state->unique_id_ << " step " << state->step_ - << ", failed"; - state->context_->finish_ok_ = false; - } - - // Log an error if 'state' is not the expected next response. Mark - // that the stream did not complete successfully but don't cancel - // right away... need to wait for any pending reads, inferences - // and writes to complete. - if (!state->context_->PopCompletedResponse(state)) { - LOG_ERROR << "Unexpected response for " << Name() - << ", rpc_ok=" << rpc_ok << ", context " - << state->context_->unique_id_ << ", " << state->unique_id_ - << " step " << state->step_; - state->context_->finish_ok_ = false; - } - - // Write the next response if it is ready... - state->context_->WriteResponseIfReady(nullptr); - - // The response for the request has been written completely. - // The counter can be safely decremented. - state->context_->DecrementRequestCounter(); - finished = Finish(state); - - } else if (state->step_ == Steps::COMPLETE) { - state->step_ = Steps::FINISH; - finished = true; - } + // If the write failed (for example, client closed the stream) + // mark that the stream did not complete successfully but don't + // cancel right away... need to wait for any pending reads, + // inferences and writes to complete. + if (!rpc_ok) { + LOG_VERBOSE(1) << "Write for " << Name() << ", rpc_ok=" << rpc_ok + << ", context " << state->context_->unique_id_ << ", " + << state->unique_id_ << " step " << state->step_ + << ", failed"; + state->context_->finish_ok_ = false; + } + + // Finish the state if all the transactions associated with + // the state have completed. + if (state->IsComplete()) { + state->context_->DecrementRequestCounter(); + finished = Finish(state); } else { - // - // Decoupled state transitions - // - if (state->step_ == Steps::WRITTEN) { - state->context_->ongoing_write_ = false; + std::lock_guard lock(state->step_mtx_); + + // If there is an available response to be written + // to the stream, then transition directly to WRITEREADY + // state and enqueue itself to the completion queue to be + // taken up later. Otherwise, go to ISSUED state and wait + // for the callback to make a response available. + if (state->response_queue_->HasReadyResponse()) { + state->step_ = Steps::WRITEREADY; + state->context_->PutTaskBackToQueue(state); + } else { + state->step_ = Steps::ISSUED; + } + } + + return finished; +} + +bool +ModelStreamInferHandler::RequestWrittenStepNonDecoupled( + InferHandler::State* state, bool rpc_ok) +{ + bool finished = false; + state->context_->ongoing_write_ = false; #ifdef TRITON_ENABLE_TRACING - state->trace_timestamps_.emplace_back( - std::make_pair("GRPC_SEND_END", TraceManager::CaptureTimestamp())); + state->trace_timestamps_.emplace_back( + std::make_pair("GRPC_SEND_END", TraceManager::CaptureTimestamp())); #endif // TRITON_ENABLE_TRACING - // If the write failed (for example, client closed the stream) - // mark that the stream did not complete successfully but don't - // cancel right away... need to wait for any pending reads, - // inferences and writes to complete. - if (!rpc_ok) { - LOG_VERBOSE(1) << "Write for " << Name() << ", rpc_ok=" << rpc_ok - << ", context " << state->context_->unique_id_ << ", " - << state->unique_id_ << " step " << state->step_ - << ", failed"; - state->context_->finish_ok_ = false; - } - - // Finish the state if all the transactions associated with - // the state have completed. - if (state->IsComplete()) { - state->context_->DecrementRequestCounter(); - finished = Finish(state); - } else { - std::lock_guard lock(state->step_mtx_); - - // If there is an available response to be written - // to the stream, then transition directly to WRITEREADY - // state and enqueue itself to the completion queue to be - // taken up later. Otherwise, go to ISSUED state and wait - // for the callback to make a response available. - if (state->response_queue_->HasReadyResponse()) { - state->step_ = Steps::WRITEREADY; - state->context_->PutTaskBackToQueue(state); - } else { - state->step_ = Steps::ISSUED; - } - } - } else if (state->step_ == Steps::WRITEREADY) { - if (state->delay_response_ms_ != 0) { - // Will delay the write of the response by the specified time. - // This can be used to test the flow where there are other - // responses available to be written. - LOG_INFO << "Delaying the write of the response by " - << state->delay_response_ms_ << " ms..."; - std::this_thread::sleep_for( - std::chrono::milliseconds(state->delay_response_ms_)); - } - - // Finish the state if all the transactions associated with - // the state have completed. - if (state->IsComplete()) { - state->context_->DecrementRequestCounter(); - finished = Finish(state); - } else { - // GRPC doesn't allow to issue another write till - // the notification from previous write has been - // delivered. If there is an ongoing write then - // defer writing and place the task at the back - // of the completion queue to be taken up later. - if (!state->context_->ongoing_write_) { - state->context_->ongoing_write_ = true; - state->context_->DecoupledWriteResponse(state); - } else { - state->context_->PutTaskBackToQueue(state); - } - } + // If the write failed (for example, client closed the stream) + // mark that the stream did not complete successfully but don't + // cancel right away... need to wait for any pending reads, + // inferences and writes to complete. + if (!rpc_ok) { + LOG_VERBOSE(1) << "Write for " << Name() << ", rpc_ok=" << rpc_ok + << ", context " << state->context_->unique_id_ << ", " + << state->unique_id_ << " step " << state->step_ + << ", failed"; + state->context_->finish_ok_ = false; + } + + // Log an error if 'state' is not the expected next response. Mark + // that the stream did not complete successfully but don't cancel + // right away... need to wait for any pending reads, inferences + // and writes to complete. + if (!state->context_->PopCompletedResponse(state)) { + LOG_ERROR << "Unexpected response for " << Name() << ", rpc_ok=" << rpc_ok + << ", context " << state->context_->unique_id_ << ", " + << state->unique_id_ << " step " << state->step_; + state->context_->finish_ok_ = false; + } + + // Write the next response if it is ready... + state->context_->WriteResponseIfReady(nullptr); + + // The response for the request has been written completely. + // The counter can be safely decremented. + state->context_->DecrementRequestCounter(); + finished = Finish(state); + return finished; +} + +bool +ModelStreamInferHandler::RequestWriteReadyStepDecoupled( + InferHandler::State* state) +{ + bool finished = false; + if (state->delay_response_ms_ != 0) { + // Will delay the write of the response by the specified time. + // This can be used to test the flow where there are other + // responses available to be written. + LOG_INFO << "Delaying the write of the response by " + << state->delay_response_ms_ << " ms..."; + std::this_thread::sleep_for( + std::chrono::milliseconds(state->delay_response_ms_)); + } + + // Finish the state if all the transactions associated with + // the state have completed. + if (state->IsComplete()) { + state->context_->DecrementRequestCounter(); + finished = Finish(state); + } else { + // GRPC doesn't allow to issue another write till + // the notification from previous write has been + // delivered. If there is an ongoing write then + // defer writing and place the task at the back + // of the completion queue to be taken up later. + if (!state->context_->ongoing_write_) { + state->context_->ongoing_write_ = true; + state->context_->DecoupledWriteResponse(state); + } else { + state->context_->PutTaskBackToQueue(state); } } - return !finished; + return finished; +} + +bool +ModelStreamInferHandler::RequestCompleteStep(InferHandler::State* state) +{ + state->step_ = Steps::FINISH; + return true; } bool @@ -558,7 +625,8 @@ ModelStreamInferHandler::StreamInferResponseComplete( if (response) { inference::ModelInferResponse& infer_response = *(response->mutable_infer_response()); - // Validate Triton iresponse and set grpc/protobuf response fields from it + // Validate Triton iresponse and set grpc/protobuf response fields from + // it err = InferResponseCompleteCommon( state->tritonserver_, iresponse, infer_response, state->alloc_payload_); @@ -633,5 +701,4 @@ ModelStreamInferHandler::StreamInferResponseComplete( state->context_->WriteResponseIfReady(state); } } - }}} // namespace triton::server::grpc diff --git a/src/grpc/stream_infer_handler.h b/src/grpc/stream_infer_handler.h index 60c4530227..b67c0f2bfa 100644 --- a/src/grpc/stream_infer_handler.h +++ b/src/grpc/stream_infer_handler.h @@ -106,6 +106,7 @@ class ModelStreamInferHandler protected: void StartNewRequest() override; + // Returns whether the request is finished being processed or not. bool Process(State* state, bool rpc_ok) override; private: @@ -114,6 +115,18 @@ class ModelStreamInferHandler void* userp); bool Finish(State* state); + // Step helpers + bool RequestStartStep(State* state, bool rpc_ok); + bool RequestReadStep(State* state, bool rpc_ok); + bool RequestCompleteStep(State* state); + bool RequestWrittenStepDecoupled(State* state, bool rpc_ok); + bool RequestWrittenStepNonDecoupled(State* state, bool rpc_ok); + bool RequestWriteReadyStepDecoupled(State* state); + void PrepareAndSendTritonRequest(State* state); + void EnqueueErrorResponse( + State* state, TRITONSERVER_InferenceRequest* irequest, + TRITONSERVER_Error* err); + TraceManager* trace_manager_; std::shared_ptr shm_manager_; TRITONSERVER_ResponseAllocator* allocator_;