diff --git a/qa/L0_simple_ensemble/ensemble_test.py b/qa/L0_simple_ensemble/ensemble_test.py index 0b064c13e8..db516651df 100755 --- a/qa/L0_simple_ensemble/ensemble_test.py +++ b/qa/L0_simple_ensemble/ensemble_test.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -26,7 +26,13 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import random import sys +import time +from functools import partial + +import numpy as np +import tritonclient.grpc as grpcclient sys.path.append("../common") sys.path.append("../clients") @@ -40,6 +46,29 @@ import tritonhttpclient +# Utility function to Generate N requests with appropriate sequence flags +class RequestGenerator: + def __init__(self, init_value, num_requests) -> None: + self.count = 0 + self.init_value = init_value + self.num_requests = num_requests + + def __enter__(self): + return self + + def __iter__(self): + return self + + def __next__(self) -> bytes: + value = self.init_value + self.count + if self.count == self.num_requests: + raise StopIteration + start = True if self.count == 0 else False + end = True if self.count == self.num_requests - 1 else False + self.count = self.count + 1 + return start, end, self.count - 1, value + + class EnsembleTest(tu.TestResultCollector): def _get_infer_count_per_version(self, model_name): triton_client = tritonhttpclient.InferenceServerClient( @@ -102,6 +131,52 @@ def test_ensemble_add_sub_one_output(self): elif infer_count[1] == 0: self.assertTrue(False, "unexpeced zero infer count for 'simple' version 2") + def test_ensemble_sequence_flags(self): + request_generator = RequestGenerator(0, 3) + # 3 request made expect the START of 1st req to be true and + # END of last request to be true + expected_flags = [[True, False], [False, False], [False, True]] + response_flags = [] + + def callback(start_time, result, error): + response = result.get_response() + arr = [] + arr.append(response.parameters["sequence_start"].bool_param) + arr.append(response.parameters["sequence_end"].bool_param) + response_flags.append(arr) + + start_time = time.time() + triton_client = grpcclient.InferenceServerClient("localhost:8001") + triton_client.start_stream(callback=partial(callback, start_time)) + correlation_id = random.randint(1, 2**31 - 1) + # create input tensors + input0_data = np.random.randint(0, 100, size=(1, 16), dtype=np.int32) + input1_data = np.random.randint(0, 100, size=(1, 16), dtype=np.int32) + + inputs = [ + grpcclient.InferInput("INPUT0", input0_data.shape, "INT32"), + grpcclient.InferInput("INPUT1", input1_data.shape, "INT32"), + ] + + inputs[0].set_data_from_numpy(input0_data) + inputs[1].set_data_from_numpy(input1_data) + + # create output tensors + outputs = [grpcclient.InferRequestedOutput("OUTPUT0")] + for sequence_start, sequence_end, count, input_value in request_generator: + triton_client.async_stream_infer( + model_name="ensemble_add_sub_int32_int32_int32", + inputs=inputs, + outputs=outputs, + request_id=f"{correlation_id}_{count}", + sequence_id=correlation_id, + sequence_start=sequence_start, + sequence_end=sequence_end, + ) + time.sleep(2) + if expected_flags != response_flags: + self.assertTrue(False, "unexpeced sequence flags mismatch error") + if __name__ == "__main__": logging.basicConfig(stream=sys.stderr) diff --git a/qa/L0_simple_ensemble/test.sh b/qa/L0_simple_ensemble/test.sh index 705490dc3f..0a3c27a2f8 100755 --- a/qa/L0_simple_ensemble/test.sh +++ b/qa/L0_simple_ensemble/test.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# Copyright 2019-2024, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -69,6 +69,30 @@ set -e kill $SERVER_PID wait $SERVER_PID +# Run ensemble model with sequence flags and verify response sequence +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi + +set +e +python $SIMPLE_TEST_PY EnsembleTest.test_ensemble_sequence_flags >>$CLIENT_LOG 2>&1 +if [ $? -ne 0 ]; then + RET=1 +else + check_test_results $TEST_RESULT_FILE 1 + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification Failed\n***" + RET=1 + fi +fi +set -e + +kill $SERVER_PID +wait $SERVER_PID # Run ensemble model with only one output requested run_server @@ -78,8 +102,6 @@ if [ "$SERVER_PID" == "0" ]; then exit 1 fi -RET=0 - set +e python $SIMPLE_TEST_PY EnsembleTest.test_ensemble_add_sub_one_output >>$CLIENT_LOG 2>&1 if [ $? -ne 0 ]; then