Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Preserve Sequence Flags in Ensemble models #369

Merged
merged 9 commits into from
Jun 12, 2024
20 changes: 16 additions & 4 deletions src/ensemble_scheduler/ensemble_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,9 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr<Step>& completed_step)
if (response != nullptr) {
RETURN_IF_TRITONSERVER_ERROR(TRITONSERVER_InferenceResponseError(response));
uint32_t count;
bool parameter_override = false;
InferenceRequest::SequenceId correlation_id = step_ptr->correlation_id_;
uint32_t flags = step_ptr->flags_;
uint32_t flags = 0;
RETURN_IF_TRITONSERVER_ERROR(
TRITONSERVER_InferenceResponseParameterCount(response, &count));
for (uint32_t idx = 0; idx < count; idx++) {
Expand All @@ -661,10 +662,12 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr<Step>& completed_step)
case TRITONSERVER_PARAMETER_INT:
correlation_id = InferenceRequest::SequenceId(
*reinterpret_cast<const uint64_t*>(vvalue));
parameter_override = true;
break;
case TRITONSERVER_PARAMETER_STRING:
correlation_id = InferenceRequest::SequenceId(
std::string(*reinterpret_cast<const char* const*>(vvalue)));
parameter_override = true;
break;
default:
RETURN_IF_TRITONSERVER_ERROR(TRITONSERVER_ErrorNew(
Expand All @@ -683,6 +686,7 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr<Step>& completed_step)
if (*reinterpret_cast<const bool*>(vvalue)) {
flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_START;
}
parameter_override = true;
}
} else if (!strcmp(name, "sequence_end")) {
if (type != TRITONSERVER_PARAMETER_BOOL) {
Expand All @@ -694,6 +698,7 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr<Step>& completed_step)
if (*reinterpret_cast<const bool*>(vvalue)) {
flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_END;
}
parameter_override = true;
}
}
}
Expand Down Expand Up @@ -740,9 +745,16 @@ EnsembleContext::ConsumeResponse(const std::unique_ptr<Step>& completed_step)
}

auto& tensor_data = tensor_data_[it->second];
step_ptr->updated_tensors_.emplace(
it->second,
tensor_data.AddTensor(std::move(tensor), correlation_id, flags));
if (parameter_override) {
step_ptr->updated_tensors_.emplace(
it->second,
tensor_data.AddTensor(std::move(tensor), correlation_id, flags));
} else {
step_ptr->updated_tensors_.emplace(
it->second, tensor_data.AddTensor(
std::move(tensor), step_ptr->correlation_id_,
step_ptr->flags_));
}

output_to_tensor.erase(it);
} else {
Expand Down
Loading