Skip to content

Commit

Permalink
test: Add test for sequence flags in ensemble streaming inference (#7344
Browse files Browse the repository at this point in the history
) (#7359)
indrajit96 authored Jun 17, 2024
1 parent e33ccfe commit 7766e0c
Showing 2 changed files with 101 additions and 4 deletions.
77 changes: 76 additions & 1 deletion qa/L0_simple_ensemble/ensemble_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
@@ -26,7 +26,13 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import random
import sys
import time
from functools import partial

import numpy as np
import tritonclient.grpc as grpcclient

sys.path.append("../common")
sys.path.append("../clients")
@@ -40,6 +46,29 @@
import tritonhttpclient


# Utility function to Generate N requests with appropriate sequence flags
class RequestGenerator:
def __init__(self, init_value, num_requests) -> None:
self.count = 0
self.init_value = init_value
self.num_requests = num_requests

def __enter__(self):
return self

def __iter__(self):
return self

def __next__(self) -> bytes:
value = self.init_value + self.count
if self.count == self.num_requests:
raise StopIteration
start = True if self.count == 0 else False
end = True if self.count == self.num_requests - 1 else False
self.count = self.count + 1
return start, end, self.count - 1, value


class EnsembleTest(tu.TestResultCollector):
def _get_infer_count_per_version(self, model_name):
triton_client = tritonhttpclient.InferenceServerClient(
@@ -102,6 +131,52 @@ def test_ensemble_add_sub_one_output(self):
elif infer_count[1] == 0:
self.assertTrue(False, "unexpeced zero infer count for 'simple' version 2")

def test_ensemble_sequence_flags(self):
request_generator = RequestGenerator(0, 3)
# 3 request made expect the START of 1st req to be true and
# END of last request to be true
expected_flags = [[True, False], [False, False], [False, True]]
response_flags = []

def callback(start_time, result, error):
response = result.get_response()
arr = []
arr.append(response.parameters["sequence_start"].bool_param)
arr.append(response.parameters["sequence_end"].bool_param)
response_flags.append(arr)

start_time = time.time()
triton_client = grpcclient.InferenceServerClient("localhost:8001")
triton_client.start_stream(callback=partial(callback, start_time))
correlation_id = random.randint(1, 2**31 - 1)
# create input tensors
input0_data = np.random.randint(0, 100, size=(1, 16), dtype=np.int32)
input1_data = np.random.randint(0, 100, size=(1, 16), dtype=np.int32)

inputs = [
grpcclient.InferInput("INPUT0", input0_data.shape, "INT32"),
grpcclient.InferInput("INPUT1", input1_data.shape, "INT32"),
]

inputs[0].set_data_from_numpy(input0_data)
inputs[1].set_data_from_numpy(input1_data)

# create output tensors
outputs = [grpcclient.InferRequestedOutput("OUTPUT0")]
for sequence_start, sequence_end, count, input_value in request_generator:
triton_client.async_stream_infer(
model_name="ensemble_add_sub_int32_int32_int32",
inputs=inputs,
outputs=outputs,
request_id=f"{correlation_id}_{count}",
sequence_id=correlation_id,
sequence_start=sequence_start,
sequence_end=sequence_end,
)
time.sleep(2)
if expected_flags != response_flags:
self.assertTrue(False, "unexpeced sequence flags mismatch error")


if __name__ == "__main__":
logging.basicConfig(stream=sys.stderr)
28 changes: 25 additions & 3 deletions qa/L0_simple_ensemble/test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright 2019-2024, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
@@ -69,6 +69,30 @@ set -e
kill $SERVER_PID
wait $SERVER_PID

# Run ensemble model with sequence flags and verify response sequence
run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

set +e
python $SIMPLE_TEST_PY EnsembleTest.test_ensemble_sequence_flags >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
RET=1
else
check_test_results $TEST_RESULT_FILE 1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi
set -e

kill $SERVER_PID
wait $SERVER_PID

# Run ensemble model with only one output requested
run_server
@@ -78,8 +102,6 @@ if [ "$SERVER_PID" == "0" ]; then
exit 1
fi

RET=0

set +e
python $SIMPLE_TEST_PY EnsembleTest.test_ensemble_add_sub_one_output >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then

0 comments on commit 7766e0c

Please sign in to comment.