Skip to content

Commit

Permalink
fix: Preserve Sequence Flags in Ensemble models (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
indrajit96 authored Jun 12, 2024
1 parent 8398c86 commit 6d00416
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/ensemble_scheduler/ensemble_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,9 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr<Step>& completed_step)
if (response != nullptr) {
RETURN_IF_TRITONSERVER_ERROR(TRITONSERVER_InferenceResponseError(response));
uint32_t count;
bool parameter_override = false;
InferenceRequest::SequenceId correlation_id = step_ptr->correlation_id_;
uint32_t flags = step_ptr->flags_;
uint32_t flags = 0;
RETURN_IF_TRITONSERVER_ERROR(
TRITONSERVER_InferenceResponseParameterCount(response, &count));
for (uint32_t idx = 0; idx < count; idx++) {
Expand All @@ -661,10 +662,12 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr<Step>& completed_step)
case TRITONSERVER_PARAMETER_INT:
correlation_id = InferenceRequest::SequenceId(
*reinterpret_cast<const uint64_t*>(vvalue));
parameter_override = true;
break;
case TRITONSERVER_PARAMETER_STRING:
correlation_id = InferenceRequest::SequenceId(
std::string(*reinterpret_cast<const char* const*>(vvalue)));
parameter_override = true;
break;
default:
RETURN_IF_TRITONSERVER_ERROR(TRITONSERVER_ErrorNew(
Expand All @@ -683,6 +686,7 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr<Step>& completed_step)
if (*reinterpret_cast<const bool*>(vvalue)) {
flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_START;
}
parameter_override = true;
}
} else if (!strcmp(name, "sequence_end")) {
if (type != TRITONSERVER_PARAMETER_BOOL) {
Expand All @@ -694,6 +698,7 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr<Step>& completed_step)
if (*reinterpret_cast<const bool*>(vvalue)) {
flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_END;
}
parameter_override = true;
}
}
}
Expand Down Expand Up @@ -740,9 +745,16 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr<Step>& completed_step)
}

auto& tensor_data = tensor_data_[it->second];
step_ptr->updated_tensors_.emplace(
it->second,
tensor_data.AddTensor(std::move(tensor), correlation_id, flags));
if (parameter_override) {
step_ptr->updated_tensors_.emplace(
it->second,
tensor_data.AddTensor(std::move(tensor), correlation_id, flags));
} else {
step_ptr->updated_tensors_.emplace(
it->second, tensor_data.AddTensor(
std::move(tensor), step_ptr->correlation_id_,
step_ptr->flags_));
}

output_to_tensor.erase(it);
} else {
Expand Down

0 comments on commit 6d00416

Please sign in to comment.