From a4068935837bc5ce7fa21b3f59d2bea48b82d2a1 Mon Sep 17 00:00:00 2001
From: krishung5 <krish@nvidia.com>
Date: Wed, 1 Nov 2023 19:39:02 -0700
Subject: [PATCH 01/10] Add testing

---
 .../request_rescheduling/test.sh              | 102 ++++++++++++++
 qa/L0_backend_python/test.sh                  |   2 +-
 qa/python_models/bls/model.py                 |  24 ++--
 .../generate_sequence/config.pbtxt            |  51 +++++++
 qa/python_models/generate_sequence/model.py   | 132 ++++++++++++++++++
 .../request_rescheduling/config.pbtxt         |  38 +++++
 .../request_rescheduling/model.py             | 123 ++++++++++++++++
 .../request_rescheduling_addsub/config.pbtxt  |  61 ++++++++
 .../request_rescheduling_addsub/model.py      |  83 +++++++++++
 .../request_rescheduling_error/config.pbtxt   |  49 +++++++
 .../request_rescheduling_error/model.py       |  68 +++++++++
 11 files changed, 720 insertions(+), 13 deletions(-)
 create mode 100755 qa/L0_backend_python/request_rescheduling/test.sh
 create mode 100644 qa/python_models/generate_sequence/config.pbtxt
 create mode 100644 qa/python_models/generate_sequence/model.py
 create mode 100644 qa/python_models/request_rescheduling/config.pbtxt
 create mode 100644 qa/python_models/request_rescheduling/model.py
 create mode 100644 qa/python_models/request_rescheduling_addsub/config.pbtxt
 create mode 100644 qa/python_models/request_rescheduling_addsub/model.py
 create mode 100644 qa/python_models/request_rescheduling_error/config.pbtxt
 create mode 100644 qa/python_models/request_rescheduling_error/model.py

diff --git a/qa/L0_backend_python/request_rescheduling/test.sh b/qa/L0_backend_python/request_rescheduling/test.sh
new file mode 100755
index 0000000000..f819c78e0c
--- /dev/null
+++ b/qa/L0_backend_python/request_rescheduling/test.sh
@@ -0,0 +1,102 @@
+#!/bin/bash
+# Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+CLIENT_PY=../python_unittest.py
+CLIENT_LOG="./request_rescheduling_client.log"
+EXPECTED_NUM_TESTS="1"
+TEST_RESULT_FILE='test_results.txt'
+source ../../common/util.sh
+
+TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"}
+SERVER=${TRITON_DIR}/bin/tritonserver
+BACKEND_DIR=${TRITON_DIR}/backends
+
+RET=0
+# This variable is used to print out the correct server log for each sub-test.
+SUB_TEST_RET=0
+rm -fr *.log ./models *.txt
+
+# pip3 uninstall -y torch
+# pip3 install torch==1.13.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html
+
+mkdir -p models/request_rescheduling/1/
+cp ../../python_models/request_rescheduling/model.py models/request_rescheduling/1/
+cp ../../python_models/request_rescheduling/config.pbtxt models/request_rescheduling
+
+mkdir -p models/request_rescheduling_error/1/
+cp ../../python_models/request_rescheduling_error/model.py models/request_rescheduling_error/1/
+cp ../../python_models/request_rescheduling_error/config.pbtxt models/request_rescheduling_error
+
+mkdir -p models/request_rescheduling_addsub/1/
+cp ../../python_models/request_rescheduling_addsub/model.py models/request_rescheduling_addsub/1/
+cp ../../python_models/request_rescheduling_addsub/config.pbtxt models/request_rescheduling_addsub
+
+mkdir -p models/generate_sequence/1/
+cp ../../python_models/generate_sequence/model.py models/generate_sequence/1/
+cp ../../python_models/generate_sequence/config.pbtxt models/generate_sequence
+
+SERVER_LOG="./request_rescheduling_server.log"
+# SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --log-verbose=1 --model-control-mode=explicit --load-model=request_rescheduling"
+SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --log-verbose=1"
+
+run_server
+if [ "$SERVER_PID" == "0" ]; then
+    echo -e "\n***\n*** Failed to start $SERVER\n***"
+    cat $SERVER_LOG
+    exit 1
+fi
+
+export MODEL_NAME='request_rescheduling'
+
+set +e
+python3 $CLIENT_PY >> $CLIENT_LOG 2>&1
+if [ $? -ne 0 ]; then
+    echo -e "\n***\n*** request_rescheduling unit test FAILED. \n***"
+    cat $CLIENT_LOG
+    RET=1
+else
+    check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
+    if [ $? -ne 0 ]; then
+        cat $CLIENT_LOG
+        echo -e "\n***\n*** Test Result Verification Failed\n***"
+        RET=1
+    fi
+fi
+set -e
+
+kill $SERVER_PID
+wait $SERVER_PID
+
+
+if [ $RET -eq 1 ]; then
+    cat $SERVER_LOG
+    echo -e "\n***\n*** Request Rescheduling test FAILED. \n***"
+else
+    echo -e "\n***\n*** Request Rescheduling test PASSED. \n***"
+fi
+
+exit $RET
diff --git a/qa/L0_backend_python/test.sh b/qa/L0_backend_python/test.sh
index 23c2ce75b4..755d6b9ed5 100755
--- a/qa/L0_backend_python/test.sh
+++ b/qa/L0_backend_python/test.sh
@@ -423,7 +423,7 @@ if [ "$TEST_JETSON" == "0" ]; then
     fi
 fi
 
-SUBTESTS="lifecycle restart model_control examples argument_validation logging custom_metrics"
+SUBTESTS="lifecycle restart model_control examples argument_validation logging custom_metrics request_rescheduling"
 for TEST in ${SUBTESTS}; do
     # Run each subtest in a separate virtual environment to avoid conflicts
     # between dependencies.
diff --git a/qa/python_models/bls/model.py b/qa/python_models/bls/model.py
index 024bbbe550..dbd2b822db 100644
--- a/qa/python_models/bls/model.py
+++ b/qa/python_models/bls/model.py
@@ -220,7 +220,7 @@ def _send_bls_sequence_requests(self, correlation_id, is_decoupled):
                 infer_request.flags(), pb_utils.TRITONSERVER_REQUEST_FLAG_SEQUENCE_START
             )
             infer_response = infer_request.exec()
-            self.assertFalse(infer_response.has_error(), infer_response.error())
+            self.assertFalse(infer_response.has_error())
             output = pb_utils.get_output_tensor_by_name(infer_response, "OUTPUT")
             self.assertFalse(output.is_cpu())
             output = from_dlpack(output.to_dlpack()).to("cpu").cpu().detach().numpy()
@@ -242,7 +242,7 @@ def _send_bls_sequence_requests(self, correlation_id, is_decoupled):
                         next(infer_responses)
                 else:
                     infer_response = infer_request.exec()
-                self.assertFalse(infer_response.has_error(), infer_response.error())
+                self.assertFalse(infer_response.has_error())
 
                 # The new output is the previous output + the current input
                 expected_output = output[0] + i
@@ -275,7 +275,7 @@ def _send_bls_sequence_requests(self, correlation_id, is_decoupled):
             else:
                 infer_response = infer_request.exec()
 
-            self.assertFalse(infer_response.has_error(), infer_response.error())
+            self.assertFalse(infer_response.has_error())
             expected_output = output[0] + input.as_numpy()[0]
             output = pb_utils.get_output_tensor_by_name(infer_response, "OUTPUT")
             self.assertFalse(output.is_cpu())
@@ -345,7 +345,7 @@ def _get_gpu_bls_outputs(self, input0_pb, input1_pb, is_decoupled):
         else:
             infer_response = infer_request.exec()
 
-        self.assertFalse(infer_response.has_error(), infer_response.error())
+        self.assertFalse(infer_response.has_error())
 
         output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUTPUT0")
         output1 = pb_utils.get_output_tensor_by_name(infer_response, "OUTPUT1")
@@ -401,7 +401,7 @@ def test_zero_length_io(self):
         else:
             infer_response = infer_request.exec()
 
-        self.assertFalse(infer_response.has_error(), infer_response.error())
+        self.assertFalse(infer_response.has_error())
 
         output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUTPUT0")
         self.assertTrue(np.all(output0 == input0))
@@ -439,7 +439,7 @@ def bls_tensor_lifecycle_helper(self):
                     next(infer_responses)
             else:
                 infer_response = infer_request.exec()
-            self.assertFalse(infer_response.has_error(), infer_response.error())
+            self.assertFalse(infer_response.has_error())
 
             output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUTPUT0")
             np.testing.assert_equal(
@@ -497,7 +497,7 @@ def bls_tensor_lifecycle_helper(self):
             else:
                 infer_response = infer_request.exec()
 
-            self.assertFalse(infer_response.has_error(), infer_response.error())
+            self.assertFalse(infer_response.has_error())
 
             output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUTPUT0")
             output0_pytorch = from_dlpack(output0.to_dlpack())
@@ -677,7 +677,7 @@ def _test_response_iterator_square(
         expected_output_cnt = np.array([expected_output_cnt], dtype=np.int32)
 
         for infer_response in response_iterator:
-            self.assertFalse(infer_response.has_error(), infer_response.error())
+            self.assertFalse(infer_response.has_error())
             if len(infer_response.output_tensors()) > 0:
                 output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUT")
                 self.assertIsNotNone(output0)
@@ -710,7 +710,7 @@ def test_response_iterator(self):
             # case 1. Use Next() to get the next response first, then use
             # for-loop to get the remaining responses.
             infer_response = next(infer_responses)
-            self.assertFalse(infer_response.has_error(), infer_response.error())
+            self.assertFalse(infer_response.has_error())
             output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUT")
             self.assertIsNotNone(output0)
             self.assertEqual(response_value, output0.as_numpy())
@@ -734,7 +734,7 @@ def test_response_iterator(self):
             # get the remaining responses.
             response_count = 0
             for infer_response in infer_responses:
-                self.assertFalse(infer_response.has_error(), infer_response.error())
+                self.assertFalse(infer_response.has_error())
                 output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUT")
                 self.assertIsNotNone(output0)
                 self.assertEqual(response_value, output0.as_numpy())
@@ -744,7 +744,7 @@ def test_response_iterator(self):
                     break
 
             infer_response = next(infer_responses)
-            self.assertFalse(infer_response.has_error(), infer_response.error())
+            self.assertFalse(infer_response.has_error())
             output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUT")
             self.assertIsNotNone(output0)
             self.assertEqual(response_value, output0.as_numpy())
@@ -759,7 +759,7 @@ def test_response_iterator(self):
             infer_responses = infer_request.exec(decoupled=True)
 
             infer_response = next(infer_responses)
-            self.assertFalse(infer_response.has_error(), infer_response.error())
+            self.assertFalse(infer_response.has_error())
             output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUT")
             self.assertIsNotNone(output0)
             self.assertEqual(response_value, output0.as_numpy())
diff --git a/qa/python_models/generate_sequence/config.pbtxt b/qa/python_models/generate_sequence/config.pbtxt
new file mode 100644
index 0000000000..9f2e08ada3
--- /dev/null
+++ b/qa/python_models/generate_sequence/config.pbtxt
@@ -0,0 +1,51 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+name: "generate_sequence"
+backend: "python"
+max_batch_size: 0
+model_transaction_policy {
+  decoupled: True
+}
+input [
+  {
+    name: "IN"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+  }
+]
+output [
+  {
+    name: "OUT"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+  }
+]
+sequence_batching {
+  generative_sequence : true
+}
+
+instance_group [{ kind: KIND_CPU }]
diff --git a/qa/python_models/generate_sequence/model.py b/qa/python_models/generate_sequence/model.py
new file mode 100644
index 0000000000..f07b14e6ec
--- /dev/null
+++ b/qa/python_models/generate_sequence/model.py
@@ -0,0 +1,132 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+
+import numpy as np
+import triton_python_backend_utils as pb_utils
+
+
+class TritonPythonModel:
+    """
+    This model takes 1 input tensor, an INT32 [ 1 ] input named "INPUT", and
+    produces an output tensor "OUTPUT" with the same shape as the input tensor.
+    The input value indicates the total number of responses to be generated and
+    the output value indicates the number of remaining responses. For example,
+    if the request input has value 2, the model will:
+        - Send a response with value 1.
+        - Release request with RESCHEDULE flag.
+        - When execute on the same request, send the last response with value 0.
+        - Release request with ALL flag.
+    """
+
+    def initialize(self, args):
+        self.model_config = model_config = json.loads(args["model_config"])
+
+        using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
+            model_config
+        )
+        if not using_decoupled:
+            raise pb_utils.TritonModelException(
+                """the model `{}` can generate any number of responses per request,
+                enable decoupled transaction policy in model configuration to
+                serve this model""".format(
+                    args["model_name"]
+                )
+            )
+
+        # Get IN configuration
+        in_config = pb_utils.get_input_config_by_name(model_config, "IN")
+
+        # Validate the shape and data type of IN
+        in_shape = in_config["dims"]
+        if (len(in_shape) != 1) or (in_shape[0] != 1):
+            raise pb_utils.TritonModelException(
+                """the model `{}` requires the shape of 'IN' to be
+                [1], got {}""".format(
+                    args["model_name"], in_shape
+                )
+            )
+        if in_config["data_type"] != "TYPE_INT32":
+            raise pb_utils.TritonModelException(
+                """the model `{}` requires the data_type of 'IN' to be
+                'TYPE_INT32', got {}""".format(
+                    args["model_name"], in_config["data_type"]
+                )
+            )
+
+        # Get OUT configuration
+        out_config = pb_utils.get_output_config_by_name(model_config, "OUT")
+
+        # Validate the shape and data type of OUT
+        out_shape = out_config["dims"]
+        if (len(out_shape) != 1) or (out_shape[0] != 1):
+            raise pb_utils.TritonModelException(
+                """the model `{}` requires the shape of 'OUT' to be
+                [1], got {}""".format(
+                    args["model_name"], out_shape
+                )
+            )
+        if out_config["data_type"] != "TYPE_INT32":
+            raise pb_utils.TritonModelException(
+                """the model `{}` requires the data_type of 'OUT' to be
+                'TYPE_INT32', got {}""".format(
+                    args["model_name"], out_config["data_type"]
+                )
+            )
+
+        self.remaining_response = 0
+        self.reset_flag = True
+
+    def execute(self, requests):
+        for request in requests:
+            in_input = pb_utils.get_input_tensor_by_name(request, "IN").as_numpy()
+
+            if self.reset_flag:
+                self.remaining_response = in_input[0]
+                self.reset_flag = False
+
+            response_sender = request.get_response_sender()
+
+            self.remaining_response -= 1
+
+            out_output = pb_utils.Tensor(
+                "OUT", np.array([self.remaining_response], np.int32)
+            )
+            response = pb_utils.InferenceResponse(output_tensors=[out_output])
+
+            if self.remaining_response <= 0:
+                response_sender.send(
+                    response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
+                )
+                self.reset_flag = True
+            else:
+                request.set_release_flags(
+                    pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE
+                )
+                response_sender.send(response)
+
+        return None
diff --git a/qa/python_models/request_rescheduling/config.pbtxt b/qa/python_models/request_rescheduling/config.pbtxt
new file mode 100644
index 0000000000..a962a0595a
--- /dev/null
+++ b/qa/python_models/request_rescheduling/config.pbtxt
@@ -0,0 +1,38 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+name: "request_rescheduling"
+backend: "python"
+
+output [
+  {
+    name: "OUTPUT0"
+    data_type: TYPE_FP32
+    dims: [ 16 ]
+  }
+]
+
+instance_group [{ kind: KIND_CPU }]
diff --git a/qa/python_models/request_rescheduling/model.py b/qa/python_models/request_rescheduling/model.py
new file mode 100644
index 0000000000..91eb58f3ed
--- /dev/null
+++ b/qa/python_models/request_rescheduling/model.py
@@ -0,0 +1,123 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import sys
+import unittest
+
+import numpy as np
+import triton_python_backend_utils as pb_utils
+
+
+class REQUESTRESCHEDULETest(unittest.TestCase):
+    # def setUp(self):
+    #     pb_utils.unload_model("generate_sequence")
+    #     pb_utils.load_model("generate_sequence")
+
+    def test_wrong_return_type(self):
+        input0 = pb_utils.Tensor("INPUT0", (np.random.randn(*[4])).astype(np.float32))
+        infer_request = pb_utils.InferenceRequest(
+            model_name="request_rescheduling_error",
+            inputs=[input0],
+            requested_output_names=["OUTPUT0"],
+        )
+
+        infer_response = infer_request.exec()
+        self.assertTrue(infer_response.has_error())
+        self.assertIn(
+            "Expected a None object in the execute function return list for reschduled request",
+            infer_response.error().message(),
+        )
+
+    def test_non_decoupled_e2e(self):
+        input0_np = np.random.randn(*[16])
+        input0_np = input0_np.astype(np.float32)
+        input1_np = np.random.randn(*[16])
+        input1_np = input1_np.astype(np.float32)
+        input0 = pb_utils.Tensor("INPUT0", input0_np)
+        input1 = pb_utils.Tensor("INPUT1", input1_np)
+        infer_request = pb_utils.InferenceRequest(
+            model_name="request_rescheduling_addsub",
+            inputs=[input0, input1],
+            requested_output_names=["OUTPUT0", "OUTPUT1"],
+        )
+        infer_response = infer_request.exec()
+
+        self.assertFalse(infer_response.has_error())
+
+        output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUTPUT0")
+        output1 = pb_utils.get_output_tensor_by_name(infer_response, "OUTPUT1")
+
+        self.assertIsNotNone(output0)
+        self.assertIsNotNone(output1)
+
+        expected_output_0 = input0.as_numpy() + input1.as_numpy()
+        expected_output_1 = input0.as_numpy() - input1.as_numpy()
+
+        self.assertEqual(expected_output_0[0], output0.as_numpy()[0])
+        self.assertEqual(expected_output_1[0], output1.as_numpy()[0])
+
+    def test_decoupled_e2e(self):
+        input_value = 3
+        input0 = pb_utils.Tensor("IN", np.array([input_value], dtype=np.int32))
+        infer_request = pb_utils.InferenceRequest(
+            model_name="generate_sequence",
+            inputs=[input0],
+            requested_output_names=["OUT"],
+        )
+        infer_responses = infer_request.exec(decoupled=True)
+
+        expected_output = input_value - 1
+
+        if infer_responses:
+            for infer_response in infer_responses:
+                self.assertFalse(infer_response.has_error())
+
+                if len(infer_response.output_tensors()) > 0:
+                    output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUT")
+                    self.assertIsNotNone(output0)
+                    print("output0.as_numpy()[0]: ", output0.as_numpy()[0], flush=True)
+
+                    self.assertEqual(expected_output, output0.as_numpy()[0])
+                    expected_output -= 1
+
+
+class TritonPythonModel:
+    def execute(self, requests):
+        responses = []
+        for _ in requests:
+            # Run the unittest and store the results in InferenceResponse.
+            test = unittest.main("model", exit=False)
+            responses.append(
+                pb_utils.InferenceResponse(
+                    [
+                        pb_utils.Tensor(
+                            "OUTPUT0",
+                            np.array([test.result.wasSuccessful()], dtype=np.float16),
+                        )
+                    ]
+                )
+            )
+        return responses
diff --git a/qa/python_models/request_rescheduling_addsub/config.pbtxt b/qa/python_models/request_rescheduling_addsub/config.pbtxt
new file mode 100644
index 0000000000..e07e603f07
--- /dev/null
+++ b/qa/python_models/request_rescheduling_addsub/config.pbtxt
@@ -0,0 +1,61 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+name: "request_rescheduling_addsub"
+backend: "python"
+
+input [
+  {
+    name: "INPUT0"
+    data_type: TYPE_FP32
+    dims: [ 16 ]
+  }
+]
+input [
+  {
+    name: "INPUT1"
+    data_type: TYPE_FP32
+    dims: [ 16 ]
+  }
+]
+output [
+  {
+    name: "OUTPUT0"
+    data_type: TYPE_FP32
+    dims: [ 16 ]
+  }
+]
+output [
+  {
+    name: "OUTPUT1"
+    data_type: TYPE_FP32
+    dims: [ 16 ]
+  }
+]
+sequence_batching {
+  generative_sequence : true
+}
+instance_group [{ kind: KIND_CPU }]
diff --git a/qa/python_models/request_rescheduling_addsub/model.py b/qa/python_models/request_rescheduling_addsub/model.py
new file mode 100644
index 0000000000..099dd83e3c
--- /dev/null
+++ b/qa/python_models/request_rescheduling_addsub/model.py
@@ -0,0 +1,83 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+
+import numpy as np
+import triton_python_backend_utils as pb_utils
+
+
+class TritonPythonModel:
+    def initialize(self, args):
+        self.model_config = model_config = json.loads(args["model_config"])
+
+        output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT0")
+        output1_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT1")
+
+        self.output0_dtype = pb_utils.triton_string_to_numpy(
+            output0_config["data_type"]
+        )
+        self.output1_dtype = pb_utils.triton_string_to_numpy(
+            output1_config["data_type"]
+        )
+
+        self.idx = 0
+
+    def execute(self, requests):
+        """This function is called on inference request."""
+
+        output0_dtype = self.output0_dtype
+        output1_dtype = self.output1_dtype
+
+        responses = []
+
+        for request in requests:
+            in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0")
+            in_1 = pb_utils.get_input_tensor_by_name(request, "INPUT1")
+
+            out_0, out_1 = (
+                in_0.as_numpy() + in_1.as_numpy(),
+                in_0.as_numpy() - in_1.as_numpy(),
+            )
+
+            out_tensor_0 = pb_utils.Tensor("OUTPUT0", out_0.astype(output0_dtype))
+            out_tensor_1 = pb_utils.Tensor("OUTPUT1", out_1.astype(output1_dtype))
+
+            inference_response = pb_utils.InferenceResponse(
+                output_tensors=[out_tensor_0, out_tensor_1]
+            )
+
+            # Explicitly reschedule the first request
+            if self.idx == 0:
+                request.set_release_flags(
+                    pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE
+                )
+                responses.append(None)
+                self.idx += 1
+            else:
+                responses.append(inference_response)
+
+        return responses
diff --git a/qa/python_models/request_rescheduling_error/config.pbtxt b/qa/python_models/request_rescheduling_error/config.pbtxt
new file mode 100644
index 0000000000..7d50b261c3
--- /dev/null
+++ b/qa/python_models/request_rescheduling_error/config.pbtxt
@@ -0,0 +1,49 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+name: "request_rescheduling_error"
+backend: "python"
+
+input [
+  {
+    name: "INPUT0"
+    data_type: TYPE_FP32
+    dims: [ 4 ]
+  }
+]
+output [
+  {
+    name: "OUTPUT0"
+    data_type: TYPE_FP32
+    dims: [ 4 ]
+  }
+]
+
+sequence_batching {
+  generative_sequence : true
+}
+
+instance_group [{ kind: KIND_CPU }]
diff --git a/qa/python_models/request_rescheduling_error/model.py b/qa/python_models/request_rescheduling_error/model.py
new file mode 100644
index 0000000000..68a9ab8a6c
--- /dev/null
+++ b/qa/python_models/request_rescheduling_error/model.py
@@ -0,0 +1,68 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+
+import numpy as np
+import triton_python_backend_utils as pb_utils
+
+
+class TritonPythonModel:
+    def initialize(self, args):
+        self.model_config = model_config = json.loads(args["model_config"])
+
+        output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT0")
+
+        self.output0_dtype = pb_utils.triton_string_to_numpy(
+            output0_config["data_type"]
+        )
+
+    def execute(self, requests):
+        output0_dtype = self.output0_dtype
+
+        responses = []
+
+        for request in requests:
+            in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0")
+
+            out_0 = in_0.as_numpy()
+
+            # Create output tensors. You need pb_utils.Tensor
+            # objects to create pb_utils.InferenceResponse.
+            out_tensor_0 = pb_utils.Tensor("OUTPUT0", out_0.astype(output0_dtype))
+
+            inference_response = pb_utils.InferenceResponse(
+                output_tensors=[out_tensor_0]
+            )
+
+            request.set_release_flags(pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE)
+            # Should append `None` for rescheduled requests.
+            responses.append(inference_response)
+
+        return responses
+
+    def finalize(self):
+        pass

From b8e679d04e007c4c7024ba9e6fdf0a329c23f95e Mon Sep 17 00:00:00 2001
From: krishung5 <krish@nvidia.com>
Date: Thu, 2 Nov 2023 02:19:02 -0700
Subject: [PATCH 02/10] Fix up

---
 qa/L0_backend_python/request_rescheduling/test.sh | 1 -
 qa/python_models/request_rescheduling/model.py    | 4 ----
 2 files changed, 5 deletions(-)

diff --git a/qa/L0_backend_python/request_rescheduling/test.sh b/qa/L0_backend_python/request_rescheduling/test.sh
index f819c78e0c..7e2829abc8 100755
--- a/qa/L0_backend_python/request_rescheduling/test.sh
+++ b/qa/L0_backend_python/request_rescheduling/test.sh
@@ -60,7 +60,6 @@ cp ../../python_models/generate_sequence/model.py models/generate_sequence/1/
 cp ../../python_models/generate_sequence/config.pbtxt models/generate_sequence
 
 SERVER_LOG="./request_rescheduling_server.log"
-# SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --log-verbose=1 --model-control-mode=explicit --load-model=request_rescheduling"
 SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --log-verbose=1"
 
 run_server
diff --git a/qa/python_models/request_rescheduling/model.py b/qa/python_models/request_rescheduling/model.py
index 91eb58f3ed..9749e93de4 100644
--- a/qa/python_models/request_rescheduling/model.py
+++ b/qa/python_models/request_rescheduling/model.py
@@ -32,10 +32,6 @@
 
 
 class REQUESTRESCHEDULETest(unittest.TestCase):
-    # def setUp(self):
-    #     pb_utils.unload_model("generate_sequence")
-    #     pb_utils.load_model("generate_sequence")
-
     def test_wrong_return_type(self):
         input0 = pb_utils.Tensor("INPUT0", (np.random.randn(*[4])).astype(np.float32))
         infer_request = pb_utils.InferenceRequest(

From 58c2f7b24a6f0437d39d4fe887094e09d073b031 Mon Sep 17 00:00:00 2001
From: krishung5 <krish@nvidia.com>
Date: Mon, 6 Nov 2023 18:24:47 -0800
Subject: [PATCH 03/10] Enhance testing

---
 qa/L0_backend_python/python_unittest.py       |   3 +-
 .../request_rescheduling/test.sh              |  39 ++---
 .../config.pbtxt                              |   2 +-
 .../model.py                                  |  60 ++++++-
 .../config.pbtxt                              |   2 +-
 .../model.py                                  |   5 +-
 .../request_rescheduling_addsub/model.py      |   1 -
 .../request_rescheduling_cases/config.pbtxt   |  51 ++++++
 .../request_rescheduling_cases/model.py       | 148 ++++++++++++++++++
 .../config.pbtxt                              |   2 +-
 .../model.py                                  |   1 -
 11 files changed, 284 insertions(+), 30 deletions(-)
 rename qa/python_models/{request_rescheduling => bls_request_rescheduling}/config.pbtxt (98%)
 rename qa/python_models/{request_rescheduling => bls_request_rescheduling}/model.py (67%)
 rename qa/python_models/{generate_sequence => generative_sequence}/config.pbtxt (98%)
 rename qa/python_models/{generate_sequence => generative_sequence}/model.py (97%)
 create mode 100644 qa/python_models/request_rescheduling_cases/config.pbtxt
 create mode 100644 qa/python_models/request_rescheduling_cases/model.py
 rename qa/python_models/{request_rescheduling_error => wrong_return_type}/config.pbtxt (98%)
 rename qa/python_models/{request_rescheduling_error => wrong_return_type}/model.py (99%)

diff --git a/qa/L0_backend_python/python_unittest.py b/qa/L0_backend_python/python_unittest.py
index bff4dd57da..a00ee1cb99 100755
--- a/qa/L0_backend_python/python_unittest.py
+++ b/qa/L0_backend_python/python_unittest.py
@@ -68,6 +68,7 @@ def test_python_unittest(self):
                 model_name == "bls"
                 or model_name == "bls_memory"
                 or model_name == "bls_memory_async"
+                or model_name == "bls_request_rescheduling"
             ):
                 # For these tests, the memory region size will be grown. Because of
                 # this we need to use the shared memory probe only on the later
@@ -75,7 +76,7 @@ def test_python_unittest(self):
                 self._run_unittest(model_name)
 
                 # [FIXME] See DLIS-3684
-                self._run_unittest(model_name)
+                # self._run_unittest(model_name)
                 with self._shm_leak_detector.Probe() as shm_probe:
                     self._run_unittest(model_name)
             else:
diff --git a/qa/L0_backend_python/request_rescheduling/test.sh b/qa/L0_backend_python/request_rescheduling/test.sh
index 7e2829abc8..b290c90bb1 100755
--- a/qa/L0_backend_python/request_rescheduling/test.sh
+++ b/qa/L0_backend_python/request_rescheduling/test.sh
@@ -1,5 +1,5 @@
 #!/bin/bash
-# Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 #
 # Redistribution and use in source and binary forms, with or without
 # modification, are permitted provided that the following conditions
@@ -36,31 +36,34 @@ SERVER=${TRITON_DIR}/bin/tritonserver
 BACKEND_DIR=${TRITON_DIR}/backends
 
 RET=0
-# This variable is used to print out the correct server log for each sub-test.
-SUB_TEST_RET=0
-rm -fr *.log ./models *.txt
 
-# pip3 uninstall -y torch
-# pip3 install torch==1.13.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html
+rm -fr *.log ./models *.txt
 
-mkdir -p models/request_rescheduling/1/
-cp ../../python_models/request_rescheduling/model.py models/request_rescheduling/1/
-cp ../../python_models/request_rescheduling/config.pbtxt models/request_rescheduling
+pip3 uninstall -y torch
+pip3 install torch==1.13.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html
 
-mkdir -p models/request_rescheduling_error/1/
-cp ../../python_models/request_rescheduling_error/model.py models/request_rescheduling_error/1/
-cp ../../python_models/request_rescheduling_error/config.pbtxt models/request_rescheduling_error
+mkdir -p models/bls_request_rescheduling/1/
+cp ../../python_models/bls_request_rescheduling/model.py models/bls_request_rescheduling/1/
+cp ../../python_models/bls_request_rescheduling/config.pbtxt models/bls_request_rescheduling
 
 mkdir -p models/request_rescheduling_addsub/1/
 cp ../../python_models/request_rescheduling_addsub/model.py models/request_rescheduling_addsub/1/
 cp ../../python_models/request_rescheduling_addsub/config.pbtxt models/request_rescheduling_addsub
 
-mkdir -p models/generate_sequence/1/
-cp ../../python_models/generate_sequence/model.py models/generate_sequence/1/
-cp ../../python_models/generate_sequence/config.pbtxt models/generate_sequence
+mkdir -p models/generative_sequence/1/
+cp ../../python_models/generative_sequence/model.py models/generative_sequence/1/
+cp ../../python_models/generative_sequence/config.pbtxt models/generative_sequence
+
+mkdir -p models/request_rescheduling_cases/1/
+cp ../../python_models/request_rescheduling_cases/model.py models/request_rescheduling_cases/1/
+cp ../../python_models/request_rescheduling_cases/config.pbtxt models/request_rescheduling_cases
+
+mkdir -p models/wrong_return_type/1/
+cp ../../python_models/wrong_return_type/model.py models/wrong_return_type/1/
+cp ../../python_models/wrong_return_type/config.pbtxt models/wrong_return_type
 
 SERVER_LOG="./request_rescheduling_server.log"
-SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --log-verbose=1"
+SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --model-control-mode=explicit --load-model=* --log-verbose=1"
 
 run_server
 if [ "$SERVER_PID" == "0" ]; then
@@ -69,12 +72,12 @@ if [ "$SERVER_PID" == "0" ]; then
     exit 1
 fi
 
-export MODEL_NAME='request_rescheduling'
+export MODEL_NAME='bls_request_rescheduling'
 
 set +e
 python3 $CLIENT_PY >> $CLIENT_LOG 2>&1
 if [ $? -ne 0 ]; then
-    echo -e "\n***\n*** request_rescheduling unit test FAILED. \n***"
+    echo -e "\n***\n*** bls_request_rescheduling test FAILED. \n***"
     cat $CLIENT_LOG
     RET=1
 else
diff --git a/qa/python_models/request_rescheduling/config.pbtxt b/qa/python_models/bls_request_rescheduling/config.pbtxt
similarity index 98%
rename from qa/python_models/request_rescheduling/config.pbtxt
rename to qa/python_models/bls_request_rescheduling/config.pbtxt
index a962a0595a..84f8658f7f 100644
--- a/qa/python_models/request_rescheduling/config.pbtxt
+++ b/qa/python_models/bls_request_rescheduling/config.pbtxt
@@ -24,7 +24,7 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-name: "request_rescheduling"
+name: "bls_request_rescheduling"
 backend: "python"
 
 output [
diff --git a/qa/python_models/request_rescheduling/model.py b/qa/python_models/bls_request_rescheduling/model.py
similarity index 67%
rename from qa/python_models/request_rescheduling/model.py
rename to qa/python_models/bls_request_rescheduling/model.py
index 9749e93de4..c219aeb58b 100644
--- a/qa/python_models/request_rescheduling/model.py
+++ b/qa/python_models/bls_request_rescheduling/model.py
@@ -25,6 +25,7 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import sys
+import time
 import unittest
 
 import numpy as np
@@ -35,7 +36,7 @@ class REQUESTRESCHEDULETest(unittest.TestCase):
     def test_wrong_return_type(self):
         input0 = pb_utils.Tensor("INPUT0", (np.random.randn(*[4])).astype(np.float32))
         infer_request = pb_utils.InferenceRequest(
-            model_name="request_rescheduling_error",
+            model_name="wrong_return_type",
             inputs=[input0],
             requested_output_names=["OUTPUT0"],
         )
@@ -76,10 +77,15 @@ def test_non_decoupled_e2e(self):
         self.assertEqual(expected_output_1[0], output1.as_numpy()[0])
 
     def test_decoupled_e2e(self):
+        model_name = "generative_sequence"
+        # Reload the model to reset the flag for multiple iterations
+        pb_utils.unload_model(model_name)
+        pb_utils.load_model(model_name)
+
         input_value = 3
         input0 = pb_utils.Tensor("IN", np.array([input_value], dtype=np.int32))
         infer_request = pb_utils.InferenceRequest(
-            model_name="generate_sequence",
+            model_name=model_name,
             inputs=[input0],
             requested_output_names=["OUT"],
         )
@@ -94,11 +100,59 @@ def test_decoupled_e2e(self):
                 if len(infer_response.output_tensors()) > 0:
                     output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUT")
                     self.assertIsNotNone(output0)
-                    print("output0.as_numpy()[0]: ", output0.as_numpy()[0], flush=True)
 
                     self.assertEqual(expected_output, output0.as_numpy()[0])
                     expected_output -= 1
 
+    def test_send_final_flag_before_rescheduling_request(self):
+        model_name = "request_rescheduling_cases"
+        # Reload the model to reset the flag for multiple iterations
+        pb_utils.unload_model(model_name)
+        pb_utils.load_model(model_name)
+
+        case_value = 0
+        input0 = pb_utils.Tensor("IN", np.array([case_value], dtype=np.int32))
+        infer_request = pb_utils.InferenceRequest(
+            model_name=model_name,
+            inputs=[input0],
+            requested_output_names=["OUT"],
+        )
+        infer_responses = infer_request.exec(decoupled=True)
+        for infer_response in infer_responses:
+            self.assertFalse(infer_response.has_error())
+
+            if len(infer_response.output_tensors()) > 0:
+                output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUT")
+                self.assertIsNotNone(output0)
+
+                self.assertEqual(case_value, output0.as_numpy()[0])
+
+    def test_process_request_in_different_thread(self):
+        model_name = "request_rescheduling_cases"
+        # Reload the model to reset the flag for multiple iterations
+        pb_utils.unload_model(model_name)
+        pb_utils.load_model(model_name)
+
+        case_value = 1
+        input0 = pb_utils.Tensor("IN", np.array([case_value], dtype=np.int32))
+        infer_request = pb_utils.InferenceRequest(
+            model_name=model_name,
+            inputs=[input0],
+            requested_output_names=["OUT"],
+        )
+        infer_responses = infer_request.exec(decoupled=True)
+
+        expected_output = case_value
+        for infer_response in infer_responses:
+            self.assertFalse(infer_response.has_error())
+
+            if len(infer_response.output_tensors()) > 0:
+                output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUT")
+                self.assertIsNotNone(output0)
+
+                self.assertEqual(expected_output, output0.as_numpy()[0])
+                expected_output -= 1
+
 
 class TritonPythonModel:
     def execute(self, requests):
diff --git a/qa/python_models/generate_sequence/config.pbtxt b/qa/python_models/generative_sequence/config.pbtxt
similarity index 98%
rename from qa/python_models/generate_sequence/config.pbtxt
rename to qa/python_models/generative_sequence/config.pbtxt
index 9f2e08ada3..46eca99f61 100644
--- a/qa/python_models/generate_sequence/config.pbtxt
+++ b/qa/python_models/generative_sequence/config.pbtxt
@@ -24,7 +24,7 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-name: "generate_sequence"
+name: "generative_sequence"
 backend: "python"
 max_batch_size: 0
 model_transaction_policy {
diff --git a/qa/python_models/generate_sequence/model.py b/qa/python_models/generative_sequence/model.py
similarity index 97%
rename from qa/python_models/generate_sequence/model.py
rename to qa/python_models/generative_sequence/model.py
index f07b14e6ec..c45f82a607 100644
--- a/qa/python_models/generate_sequence/model.py
+++ b/qa/python_models/generative_sequence/model.py
@@ -32,8 +32,8 @@
 
 class TritonPythonModel:
     """
-    This model takes 1 input tensor, an INT32 [ 1 ] input named "INPUT", and
-    produces an output tensor "OUTPUT" with the same shape as the input tensor.
+    This model takes 1 input tensor, an INT32 [ 1 ] input named "IN", and
+    produces an output tensor "OUT" with the same shape as the input tensor.
     The input value indicates the total number of responses to be generated and
     the output value indicates the number of remaining responses. For example,
     if the request input has value 2, the model will:
@@ -122,7 +122,6 @@ def execute(self, requests):
                 response_sender.send(
                     response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
                 )
-                self.reset_flag = True
             else:
                 request.set_release_flags(
                     pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE
diff --git a/qa/python_models/request_rescheduling_addsub/model.py b/qa/python_models/request_rescheduling_addsub/model.py
index 099dd83e3c..fb7b0ac9c7 100644
--- a/qa/python_models/request_rescheduling_addsub/model.py
+++ b/qa/python_models/request_rescheduling_addsub/model.py
@@ -26,7 +26,6 @@
 
 import json
 
-import numpy as np
 import triton_python_backend_utils as pb_utils
 
 
diff --git a/qa/python_models/request_rescheduling_cases/config.pbtxt b/qa/python_models/request_rescheduling_cases/config.pbtxt
new file mode 100644
index 0000000000..19b6db68f3
--- /dev/null
+++ b/qa/python_models/request_rescheduling_cases/config.pbtxt
@@ -0,0 +1,51 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+name: "request_rescheduling_cases"
+backend: "python"
+max_batch_size: 0
+model_transaction_policy {
+  decoupled: True
+}
+input [
+  {
+    name: "IN"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+  }
+]
+output [
+  {
+    name: "OUT"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+  }
+]
+sequence_batching {
+  generative_sequence : true
+}
+
+instance_group [{ kind: KIND_CPU }]
diff --git a/qa/python_models/request_rescheduling_cases/model.py b/qa/python_models/request_rescheduling_cases/model.py
new file mode 100644
index 0000000000..615996c8f2
--- /dev/null
+++ b/qa/python_models/request_rescheduling_cases/model.py
@@ -0,0 +1,148 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import threading
+import time
+
+import numpy as np
+import triton_python_backend_utils as pb_utils
+
+
+class TritonPythonModel:
+    def initialize(self, args):
+        self.model_config = model_config = json.loads(args["model_config"])
+
+        using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
+            model_config
+        )
+        if not using_decoupled:
+            raise pb_utils.TritonModelException(
+                """the model `{}` can generate any number of responses per request,
+                enable decoupled transaction policy in model configuration to
+                serve this model""".format(
+                    args["model_name"]
+                )
+            )
+
+        # Get IN configuration
+        in_config = pb_utils.get_input_config_by_name(model_config, "IN")
+
+        # Validate the shape and data type of IN
+        in_shape = in_config["dims"]
+        if (len(in_shape) != 1) or (in_shape[0] != 1):
+            raise pb_utils.TritonModelException(
+                """the model `{}` requires the shape of 'IN' to be
+                [1], got {}""".format(
+                    args["model_name"], in_shape
+                )
+            )
+        if in_config["data_type"] != "TYPE_INT32":
+            raise pb_utils.TritonModelException(
+                """the model `{}` requires the data_type of 'IN' to be
+                'TYPE_INT32', got {}""".format(
+                    args["model_name"], in_config["data_type"]
+                )
+            )
+
+        # Get OUT configuration
+        out_config = pb_utils.get_output_config_by_name(model_config, "OUT")
+
+        # Validate the shape and data type of OUT
+        out_shape = out_config["dims"]
+        if (len(out_shape) != 1) or (out_shape[0] != 1):
+            raise pb_utils.TritonModelException(
+                """the model `{}` requires the shape of 'OUT' to be
+                [1], got {}""".format(
+                    args["model_name"], out_shape
+                )
+            )
+        if out_config["data_type"] != "TYPE_INT32":
+            raise pb_utils.TritonModelException(
+                """the model `{}` requires the data_type of 'OUT' to be
+                'TYPE_INT32', got {}""".format(
+                    args["model_name"], out_config["data_type"]
+                )
+            )
+
+        self.idx = 0
+        self.inflight_thread_count = 0
+        self.inflight_thread_count_lck = threading.Lock()
+
+    def execute(self, requests):
+        for request in requests:
+            case = pb_utils.get_input_tensor_by_name(request, "IN").as_numpy()
+
+            if case[0] == 0:
+                self.send_final_flag_before_rescheduling_request(request)
+            elif case[0] == 1:
+                self.process_request_thread(request)
+            else:
+                raise pb_utils.TritonModelException("Unknown test case.")
+
+        return None
+
+    def send_final_flag_before_rescheduling_request(self, request):
+        response_sender = request.get_response_sender()
+        if self.idx == 0:
+            out_output = pb_utils.Tensor("OUT", np.array([0], np.int32))
+            response = pb_utils.InferenceResponse(output_tensors=[out_output])
+            response_sender.send(response)
+            response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
+            request.set_release_flags(pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE)
+            self.idx = 1
+
+    def process_request_thread(self, request):
+        thread = threading.Thread(
+            target=self.response_thread,
+            args=(
+                request.get_response_sender(),
+                pb_utils.get_input_tensor_by_name(request, "IN").as_numpy(),
+            ),
+        )
+
+        thread.daemon = True
+
+        with self.inflight_thread_count_lck:
+            self.inflight_thread_count += 1
+
+        if self.idx == 0:
+            request.set_release_flags(pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE)
+            thread.start()
+            self.idx = 1
+
+    def response_thread(self, response_sender, in_input):
+        output_value = in_input[0]
+        while output_value >= 0:
+            out_output = pb_utils.Tensor("OUT", np.array([output_value], np.int32))
+            response = pb_utils.InferenceResponse(output_tensors=[out_output])
+            response_sender.send(response)
+            output_value -= 1
+
+        response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
+
+        with self.inflight_thread_count_lck:
+            self.inflight_thread_count -= 1
diff --git a/qa/python_models/request_rescheduling_error/config.pbtxt b/qa/python_models/wrong_return_type/config.pbtxt
similarity index 98%
rename from qa/python_models/request_rescheduling_error/config.pbtxt
rename to qa/python_models/wrong_return_type/config.pbtxt
index 7d50b261c3..2405d66e82 100644
--- a/qa/python_models/request_rescheduling_error/config.pbtxt
+++ b/qa/python_models/wrong_return_type/config.pbtxt
@@ -24,7 +24,7 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-name: "request_rescheduling_error"
+name: "wrong_return_type"
 backend: "python"
 
 input [
diff --git a/qa/python_models/request_rescheduling_error/model.py b/qa/python_models/wrong_return_type/model.py
similarity index 99%
rename from qa/python_models/request_rescheduling_error/model.py
rename to qa/python_models/wrong_return_type/model.py
index 68a9ab8a6c..c5e6f660fc 100644
--- a/qa/python_models/request_rescheduling_error/model.py
+++ b/qa/python_models/wrong_return_type/model.py
@@ -26,7 +26,6 @@
 
 import json
 
-import numpy as np
 import triton_python_backend_utils as pb_utils
 
 

From c6e3ebf2c80832543825fb9d98c7d8b18723c5b6 Mon Sep 17 00:00:00 2001
From: krishung5 <krish@nvidia.com>
Date: Mon, 6 Nov 2023 18:41:16 -0800
Subject: [PATCH 04/10] Fix up

---
 qa/python_models/bls_request_rescheduling/model.py   | 2 --
 qa/python_models/request_rescheduling_cases/model.py | 1 -
 2 files changed, 3 deletions(-)

diff --git a/qa/python_models/bls_request_rescheduling/model.py b/qa/python_models/bls_request_rescheduling/model.py
index c219aeb58b..af630a4e83 100644
--- a/qa/python_models/bls_request_rescheduling/model.py
+++ b/qa/python_models/bls_request_rescheduling/model.py
@@ -24,8 +24,6 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-import sys
-import time
 import unittest
 
 import numpy as np
diff --git a/qa/python_models/request_rescheduling_cases/model.py b/qa/python_models/request_rescheduling_cases/model.py
index 615996c8f2..c23d889fd7 100644
--- a/qa/python_models/request_rescheduling_cases/model.py
+++ b/qa/python_models/request_rescheduling_cases/model.py
@@ -26,7 +26,6 @@
 
 import json
 import threading
-import time
 
 import numpy as np
 import triton_python_backend_utils as pb_utils

From a6b5be16d9dd77ada89f654e10f269c893708be4 Mon Sep 17 00:00:00 2001
From: krishung5 <krish@nvidia.com>
Date: Wed, 8 Nov 2023 01:47:09 -0800
Subject: [PATCH 05/10] Revert test changes

---
 qa/L0_backend_python/python_unittest.py       |   2 +-
 .../request_rescheduling/test.sh              |   7 -
 .../bls_request_rescheduling/model.py         |  54 +------
 .../request_rescheduling_cases/config.pbtxt   |  51 ------
 .../request_rescheduling_cases/model.py       | 147 ------------------
 5 files changed, 6 insertions(+), 255 deletions(-)
 delete mode 100644 qa/python_models/request_rescheduling_cases/config.pbtxt
 delete mode 100644 qa/python_models/request_rescheduling_cases/model.py

diff --git a/qa/L0_backend_python/python_unittest.py b/qa/L0_backend_python/python_unittest.py
index a00ee1cb99..c956412f9d 100755
--- a/qa/L0_backend_python/python_unittest.py
+++ b/qa/L0_backend_python/python_unittest.py
@@ -76,7 +76,7 @@ def test_python_unittest(self):
                 self._run_unittest(model_name)
 
                 # [FIXME] See DLIS-3684
-                # self._run_unittest(model_name)
+                self._run_unittest(model_name)
                 with self._shm_leak_detector.Probe() as shm_probe:
                     self._run_unittest(model_name)
             else:
diff --git a/qa/L0_backend_python/request_rescheduling/test.sh b/qa/L0_backend_python/request_rescheduling/test.sh
index b290c90bb1..cecf2b2812 100755
--- a/qa/L0_backend_python/request_rescheduling/test.sh
+++ b/qa/L0_backend_python/request_rescheduling/test.sh
@@ -39,9 +39,6 @@ RET=0
 
 rm -fr *.log ./models *.txt
 
-pip3 uninstall -y torch
-pip3 install torch==1.13.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html
-
 mkdir -p models/bls_request_rescheduling/1/
 cp ../../python_models/bls_request_rescheduling/model.py models/bls_request_rescheduling/1/
 cp ../../python_models/bls_request_rescheduling/config.pbtxt models/bls_request_rescheduling
@@ -54,10 +51,6 @@ mkdir -p models/generative_sequence/1/
 cp ../../python_models/generative_sequence/model.py models/generative_sequence/1/
 cp ../../python_models/generative_sequence/config.pbtxt models/generative_sequence
 
-mkdir -p models/request_rescheduling_cases/1/
-cp ../../python_models/request_rescheduling_cases/model.py models/request_rescheduling_cases/1/
-cp ../../python_models/request_rescheduling_cases/config.pbtxt models/request_rescheduling_cases
-
 mkdir -p models/wrong_return_type/1/
 cp ../../python_models/wrong_return_type/model.py models/wrong_return_type/1/
 cp ../../python_models/wrong_return_type/config.pbtxt models/wrong_return_type
diff --git a/qa/python_models/bls_request_rescheduling/model.py b/qa/python_models/bls_request_rescheduling/model.py
index af630a4e83..5599618c71 100644
--- a/qa/python_models/bls_request_rescheduling/model.py
+++ b/qa/python_models/bls_request_rescheduling/model.py
@@ -24,6 +24,7 @@
 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
+import time
 import unittest
 
 import numpy as np
@@ -78,6 +79,10 @@ def test_decoupled_e2e(self):
         model_name = "generative_sequence"
         # Reload the model to reset the flag for multiple iterations
         pb_utils.unload_model(model_name)
+        # TODO: Make this more robust to wait until fully unloaded
+        print("Sleep 10 seconds to make sure model finishes unloading...", flush=True)
+        time.sleep(10)
+        print("Done sleeping.", flush=True)
         pb_utils.load_model(model_name)
 
         input_value = 3
@@ -102,55 +107,6 @@ def test_decoupled_e2e(self):
                     self.assertEqual(expected_output, output0.as_numpy()[0])
                     expected_output -= 1
 
-    def test_send_final_flag_before_rescheduling_request(self):
-        model_name = "request_rescheduling_cases"
-        # Reload the model to reset the flag for multiple iterations
-        pb_utils.unload_model(model_name)
-        pb_utils.load_model(model_name)
-
-        case_value = 0
-        input0 = pb_utils.Tensor("IN", np.array([case_value], dtype=np.int32))
-        infer_request = pb_utils.InferenceRequest(
-            model_name=model_name,
-            inputs=[input0],
-            requested_output_names=["OUT"],
-        )
-        infer_responses = infer_request.exec(decoupled=True)
-        for infer_response in infer_responses:
-            self.assertFalse(infer_response.has_error())
-
-            if len(infer_response.output_tensors()) > 0:
-                output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUT")
-                self.assertIsNotNone(output0)
-
-                self.assertEqual(case_value, output0.as_numpy()[0])
-
-    def test_process_request_in_different_thread(self):
-        model_name = "request_rescheduling_cases"
-        # Reload the model to reset the flag for multiple iterations
-        pb_utils.unload_model(model_name)
-        pb_utils.load_model(model_name)
-
-        case_value = 1
-        input0 = pb_utils.Tensor("IN", np.array([case_value], dtype=np.int32))
-        infer_request = pb_utils.InferenceRequest(
-            model_name=model_name,
-            inputs=[input0],
-            requested_output_names=["OUT"],
-        )
-        infer_responses = infer_request.exec(decoupled=True)
-
-        expected_output = case_value
-        for infer_response in infer_responses:
-            self.assertFalse(infer_response.has_error())
-
-            if len(infer_response.output_tensors()) > 0:
-                output0 = pb_utils.get_output_tensor_by_name(infer_response, "OUT")
-                self.assertIsNotNone(output0)
-
-                self.assertEqual(expected_output, output0.as_numpy()[0])
-                expected_output -= 1
-
 
 class TritonPythonModel:
     def execute(self, requests):
diff --git a/qa/python_models/request_rescheduling_cases/config.pbtxt b/qa/python_models/request_rescheduling_cases/config.pbtxt
deleted file mode 100644
index 19b6db68f3..0000000000
--- a/qa/python_models/request_rescheduling_cases/config.pbtxt
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-#  * Redistributions of source code must retain the above copyright
-#    notice, this list of conditions and the following disclaimer.
-#  * Redistributions in binary form must reproduce the above copyright
-#    notice, this list of conditions and the following disclaimer in the
-#    documentation and/or other materials provided with the distribution.
-#  * Neither the name of NVIDIA CORPORATION nor the names of its
-#    contributors may be used to endorse or promote products derived
-#    from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-name: "request_rescheduling_cases"
-backend: "python"
-max_batch_size: 0
-model_transaction_policy {
-  decoupled: True
-}
-input [
-  {
-    name: "IN"
-    data_type: TYPE_INT32
-    dims: [ 1 ]
-  }
-]
-output [
-  {
-    name: "OUT"
-    data_type: TYPE_INT32
-    dims: [ 1 ]
-  }
-]
-sequence_batching {
-  generative_sequence : true
-}
-
-instance_group [{ kind: KIND_CPU }]
diff --git a/qa/python_models/request_rescheduling_cases/model.py b/qa/python_models/request_rescheduling_cases/model.py
deleted file mode 100644
index c23d889fd7..0000000000
--- a/qa/python_models/request_rescheduling_cases/model.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-#  * Redistributions of source code must retain the above copyright
-#    notice, this list of conditions and the following disclaimer.
-#  * Redistributions in binary form must reproduce the above copyright
-#    notice, this list of conditions and the following disclaimer in the
-#    documentation and/or other materials provided with the distribution.
-#  * Neither the name of NVIDIA CORPORATION nor the names of its
-#    contributors may be used to endorse or promote products derived
-#    from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-import json
-import threading
-
-import numpy as np
-import triton_python_backend_utils as pb_utils
-
-
-class TritonPythonModel:
-    def initialize(self, args):
-        self.model_config = model_config = json.loads(args["model_config"])
-
-        using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
-            model_config
-        )
-        if not using_decoupled:
-            raise pb_utils.TritonModelException(
-                """the model `{}` can generate any number of responses per request,
-                enable decoupled transaction policy in model configuration to
-                serve this model""".format(
-                    args["model_name"]
-                )
-            )
-
-        # Get IN configuration
-        in_config = pb_utils.get_input_config_by_name(model_config, "IN")
-
-        # Validate the shape and data type of IN
-        in_shape = in_config["dims"]
-        if (len(in_shape) != 1) or (in_shape[0] != 1):
-            raise pb_utils.TritonModelException(
-                """the model `{}` requires the shape of 'IN' to be
-                [1], got {}""".format(
-                    args["model_name"], in_shape
-                )
-            )
-        if in_config["data_type"] != "TYPE_INT32":
-            raise pb_utils.TritonModelException(
-                """the model `{}` requires the data_type of 'IN' to be
-                'TYPE_INT32', got {}""".format(
-                    args["model_name"], in_config["data_type"]
-                )
-            )
-
-        # Get OUT configuration
-        out_config = pb_utils.get_output_config_by_name(model_config, "OUT")
-
-        # Validate the shape and data type of OUT
-        out_shape = out_config["dims"]
-        if (len(out_shape) != 1) or (out_shape[0] != 1):
-            raise pb_utils.TritonModelException(
-                """the model `{}` requires the shape of 'OUT' to be
-                [1], got {}""".format(
-                    args["model_name"], out_shape
-                )
-            )
-        if out_config["data_type"] != "TYPE_INT32":
-            raise pb_utils.TritonModelException(
-                """the model `{}` requires the data_type of 'OUT' to be
-                'TYPE_INT32', got {}""".format(
-                    args["model_name"], out_config["data_type"]
-                )
-            )
-
-        self.idx = 0
-        self.inflight_thread_count = 0
-        self.inflight_thread_count_lck = threading.Lock()
-
-    def execute(self, requests):
-        for request in requests:
-            case = pb_utils.get_input_tensor_by_name(request, "IN").as_numpy()
-
-            if case[0] == 0:
-                self.send_final_flag_before_rescheduling_request(request)
-            elif case[0] == 1:
-                self.process_request_thread(request)
-            else:
-                raise pb_utils.TritonModelException("Unknown test case.")
-
-        return None
-
-    def send_final_flag_before_rescheduling_request(self, request):
-        response_sender = request.get_response_sender()
-        if self.idx == 0:
-            out_output = pb_utils.Tensor("OUT", np.array([0], np.int32))
-            response = pb_utils.InferenceResponse(output_tensors=[out_output])
-            response_sender.send(response)
-            response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
-            request.set_release_flags(pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE)
-            self.idx = 1
-
-    def process_request_thread(self, request):
-        thread = threading.Thread(
-            target=self.response_thread,
-            args=(
-                request.get_response_sender(),
-                pb_utils.get_input_tensor_by_name(request, "IN").as_numpy(),
-            ),
-        )
-
-        thread.daemon = True
-
-        with self.inflight_thread_count_lck:
-            self.inflight_thread_count += 1
-
-        if self.idx == 0:
-            request.set_release_flags(pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE)
-            thread.start()
-            self.idx = 1
-
-    def response_thread(self, response_sender, in_input):
-        output_value = in_input[0]
-        while output_value >= 0:
-            out_output = pb_utils.Tensor("OUT", np.array([output_value], np.int32))
-            response = pb_utils.InferenceResponse(output_tensors=[out_output])
-            response_sender.send(response)
-            output_value -= 1
-
-        response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
-
-        with self.inflight_thread_count_lck:
-            self.inflight_thread_count -= 1

From a426da33f80e5fc6300c6d95fc7ec88a5b148f3a Mon Sep 17 00:00:00 2001
From: krishung5 <krish@nvidia.com>
Date: Wed, 8 Nov 2023 15:56:17 -0800
Subject: [PATCH 06/10] Add grpc endpoint test

---
 .../grpc_endpoint_test.py                     | 114 ++++++++++++++++++
 .../request_rescheduling/test.sh              |  19 +++
 .../bls_request_rescheduling/model.py         |  24 ++--
 3 files changed, 148 insertions(+), 9 deletions(-)
 create mode 100755 qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py

diff --git a/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py b/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py
new file mode 100755
index 0000000000..3e75aec8d3
--- /dev/null
+++ b/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import sys
+
+sys.path.append("../../common")
+
+import json
+
+# GRPC streaming helpers..
+import queue
+import unittest
+from functools import partial
+
+import numpy as np
+import requests
+import test_util as tu
+import tritonclient.grpc as grpcclient
+from tritonclient.utils import InferenceServerException
+
+
+class UserData:
+    def __init__(self):
+        self._completed_requests = queue.Queue()
+
+
+def callback(user_data, result, error):
+    if error:
+        user_data._completed_requests.put(error)
+    else:
+        user_data._completed_requests.put(result)
+
+
+class GRPCENDPOINTTest(tu.TestResultCollector):
+    def test_grpc_decoupled(self, sequence_id=0, sequence_start=False):
+        user_data = UserData()
+        with grpcclient.InferenceServerClient("localhost:8001") as triton_client:
+            # Reload the model to reset the flag
+            triton_client.unload_model("generative_sequence")
+            triton_client.load_model("generative_sequence")
+
+            triton_client.start_stream(callback=partial(callback, user_data))
+            inputs = []
+            inputs.append(grpcclient.InferInput("IN", [1], "INT32"))
+            inputs[0].set_data_from_numpy(np.array([3], dtype=np.int32))
+
+            triton_client.async_stream_infer(
+                model_name="generative_sequence",
+                inputs=inputs,
+                sequence_id=sequence_id,
+                sequence_start=sequence_start,
+            )
+            res_count = 3
+            while res_count > 0:
+                data_item = user_data._completed_requests.get()
+                res_count -= 1
+                if type(data_item) == InferenceServerException:
+                    raise data_item
+                else:
+                    self.assertEqual(res_count, data_item.as_numpy("OUT")[0])
+            self.assertEqual(0, res_count)
+
+    def test_grpc_non_decoupled(self, sequence_id=0, sequence_start=False):
+        with grpcclient.InferenceServerClient("localhost:8001") as triton_client:
+            # Reload the model to reset the flag
+            triton_client.unload_model("request_rescheduling_addsub")
+            triton_client.load_model("request_rescheduling_addsub")
+
+            inputs = []
+            inputs.append(grpcclient.InferInput("INPUT0", [16], "FP32"))
+            inputs.append(grpcclient.InferInput("INPUT1", [16], "FP32"))
+            input0_val = np.random.randn(*[16]).astype(np.float32)
+            input1_val = np.random.randn(*[16]).astype(np.float32)
+            inputs[0].set_data_from_numpy(input0_val)
+            inputs[1].set_data_from_numpy(input1_val)
+
+            results = triton_client.infer(
+                model_name="request_rescheduling_addsub",
+                inputs=inputs,
+            )
+
+            output0_data = results.as_numpy("OUTPUT0")
+            output1_data = results.as_numpy("OUTPUT1")
+
+            self.assertTrue(np.array_equal(output0_data, input0_val + input1_val))
+            self.assertTrue(np.array_equal(output1_data, input0_val - input1_val))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/qa/L0_backend_python/request_rescheduling/test.sh b/qa/L0_backend_python/request_rescheduling/test.sh
index cecf2b2812..b181e58d2e 100755
--- a/qa/L0_backend_python/request_rescheduling/test.sh
+++ b/qa/L0_backend_python/request_rescheduling/test.sh
@@ -83,6 +83,25 @@ else
 fi
 set -e
 
+GRPC_TEST_PY=./grpc_endpoint_test.py
+EXPECTED_NUM_TESTS="2"
+
+set +e
+python3 $GRPC_TEST_PY >> $CLIENT_LOG 2>&1
+if [ $? -ne 0 ]; then
+    echo -e "\n***\n*** GRPC Endpoint test FAILED. \n***"
+    cat $CLIENT_LOG
+    RET=1
+else
+    check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
+    if [ $? -ne 0 ]; then
+        cat $CLIENT_LOG
+        echo -e "\n***\n*** Test Result Verification Failed\n***"
+        RET=1
+    fi
+fi
+set -e
+
 kill $SERVER_PID
 wait $SERVER_PID
 
diff --git a/qa/python_models/bls_request_rescheduling/model.py b/qa/python_models/bls_request_rescheduling/model.py
index 5599618c71..3d2a950dfa 100644
--- a/qa/python_models/bls_request_rescheduling/model.py
+++ b/qa/python_models/bls_request_rescheduling/model.py
@@ -31,7 +31,16 @@
 import triton_python_backend_utils as pb_utils
 
 
-class REQUESTRESCHEDULETest(unittest.TestCase):
+class REQUESTRESCHEDULINGTest(unittest.TestCase):
+    def _reload_model(self, model_name):
+        # Reload the model to reset the flag for multiple iterations
+        pb_utils.unload_model(model_name)
+        # TODO: Make this more robust to wait until fully unloaded
+        print("Sleep 10 seconds to make sure model finishes unloading...", flush=True)
+        time.sleep(10)
+        print("Done sleeping.", flush=True)
+        pb_utils.load_model(model_name)
+
     def test_wrong_return_type(self):
         input0 = pb_utils.Tensor("INPUT0", (np.random.randn(*[4])).astype(np.float32))
         infer_request = pb_utils.InferenceRequest(
@@ -48,6 +57,9 @@ def test_wrong_return_type(self):
         )
 
     def test_non_decoupled_e2e(self):
+        model_name = "request_rescheduling_addsub"
+        self._reload_model(model_name)
+
         input0_np = np.random.randn(*[16])
         input0_np = input0_np.astype(np.float32)
         input1_np = np.random.randn(*[16])
@@ -55,7 +67,7 @@ def test_non_decoupled_e2e(self):
         input0 = pb_utils.Tensor("INPUT0", input0_np)
         input1 = pb_utils.Tensor("INPUT1", input1_np)
         infer_request = pb_utils.InferenceRequest(
-            model_name="request_rescheduling_addsub",
+            model_name=model_name,
             inputs=[input0, input1],
             requested_output_names=["OUTPUT0", "OUTPUT1"],
         )
@@ -77,13 +89,7 @@ def test_non_decoupled_e2e(self):
 
     def test_decoupled_e2e(self):
         model_name = "generative_sequence"
-        # Reload the model to reset the flag for multiple iterations
-        pb_utils.unload_model(model_name)
-        # TODO: Make this more robust to wait until fully unloaded
-        print("Sleep 10 seconds to make sure model finishes unloading...", flush=True)
-        time.sleep(10)
-        print("Done sleeping.", flush=True)
-        pb_utils.load_model(model_name)
+        self._reload_model(model_name)
 
         input_value = 3
         input0 = pb_utils.Tensor("IN", np.array([input_value], dtype=np.int32))

From 1d44de2ad1b5cbd0654d5c655c7e0c0bba2ef23c Mon Sep 17 00:00:00 2001
From: krishung5 <krish@nvidia.com>
Date: Wed, 8 Nov 2023 15:59:38 -0800
Subject: [PATCH 07/10] Remove unused import

---
 qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py b/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py
index 3e75aec8d3..4b02fe536d 100755
--- a/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py
+++ b/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py
@@ -29,8 +29,6 @@
 
 sys.path.append("../../common")
 
-import json
-
 # GRPC streaming helpers..
 import queue
 import unittest

From 937bac13e195f215ec933ab1110a07b2b82b5d2c Mon Sep 17 00:00:00 2001
From: krishung5 <krish@nvidia.com>
Date: Wed, 8 Nov 2023 17:39:47 -0800
Subject: [PATCH 08/10] Remove unused import

---
 qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py b/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py
index 4b02fe536d..6e75db4912 100755
--- a/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py
+++ b/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py
@@ -35,7 +35,6 @@
 from functools import partial
 
 import numpy as np
-import requests
 import test_util as tu
 import tritonclient.grpc as grpcclient
 from tritonclient.utils import InferenceServerException

From 025393dbfa2fc787580c46667b90603ba60dd1e0 Mon Sep 17 00:00:00 2001
From: Kris Hung <krish@nvidia.com>
Date: Thu, 9 Nov 2023 10:53:43 -0800
Subject: [PATCH 09/10] Update
 qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py

Co-authored-by: Iman Tabrizian <iman.tabrizian@gmail.com>
---
 qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py b/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py
index 6e75db4912..fb6e4a8cf2 100755
--- a/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py
+++ b/qa/L0_backend_python/request_rescheduling/grpc_endpoint_test.py
@@ -52,7 +52,7 @@ def callback(user_data, result, error):
         user_data._completed_requests.put(result)
 
 
-class GRPCENDPOINTTest(tu.TestResultCollector):
+class GrpcEndpointTest(tu.TestResultCollector):
     def test_grpc_decoupled(self, sequence_id=0, sequence_start=False):
         user_data = UserData()
         with grpcclient.InferenceServerClient("localhost:8001") as triton_client:

From 9825b9649cc1c38d9f4d3b65537fde98f822ea29 Mon Sep 17 00:00:00 2001
From: Kris Hung <krish@nvidia.com>
Date: Thu, 9 Nov 2023 10:53:49 -0800
Subject: [PATCH 10/10] Update
 qa/python_models/bls_request_rescheduling/model.py

Co-authored-by: Iman Tabrizian <iman.tabrizian@gmail.com>
---
 qa/python_models/bls_request_rescheduling/model.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/qa/python_models/bls_request_rescheduling/model.py b/qa/python_models/bls_request_rescheduling/model.py
index 3d2a950dfa..28aa3bd44c 100644
--- a/qa/python_models/bls_request_rescheduling/model.py
+++ b/qa/python_models/bls_request_rescheduling/model.py
@@ -31,7 +31,7 @@
 import triton_python_backend_utils as pb_utils
 
 
-class REQUESTRESCHEDULINGTest(unittest.TestCase):
+class RequestReschedulingTest(unittest.TestCase):
     def _reload_model(self, model_name):
         # Reload the model to reset the flag for multiple iterations
         pb_utils.unload_model(model_name)