Skip to content

Commit

Permalink
Add test for sequence model instance update
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui committed May 21, 2023
1 parent 4f487a0 commit 4278a26
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 19 deletions.
115 changes: 99 additions & 16 deletions qa/L0_model_update/instance_update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tritonclient.utils import InferenceServerException
from models.model_init_del.util import (get_count, reset_count, set_delay,
update_instance_group,
update_sequence_batching,
update_model_file, enable_batching,
disable_batching)

Expand All @@ -43,9 +44,21 @@ class TestInstanceUpdate(unittest.TestCase):
__model_name = "model_init_del"

def setUp(self):
# Initialize client
self.__reset_model()
self.__triton = grpcclient.InferenceServerClient("localhost:8001")

def __reset_model(self):
# Reset counters
reset_count("initialize")
reset_count("finalize")
# Reset batching
disable_batching()
# Reset delays
set_delay("initialize", 0)
set_delay("infer", 0)
# Reset sequence batching
update_sequence_batching("")

def __get_inputs(self, batching=False):
self.assertIsInstance(batching, bool)
if batching:
Expand Down Expand Up @@ -85,14 +98,8 @@ def __check_count(self, kind, expected_count, poll=False):
self.assertEqual(get_count(kind), expected_count)

def __load_model(self, instance_count, instance_config="", batching=False):
# Reset counters
reset_count("initialize")
reset_count("finalize")
# Set batching
enable_batching() if batching else disable_batching()
# Reset delays
set_delay("initialize", 0)
set_delay("infer", 0)
# Load model
self.__update_instance_count(instance_count,
0,
Expand Down Expand Up @@ -143,6 +150,7 @@ def test_add_rm_add_instance(self):
self.__update_instance_count(1, 0, batching=batching) # add
stop()
self.__unload_model(batching=batching)
self.__reset_model() # for next iteration

# Test remove -> add -> remove an instance
def test_rm_add_rm_instance(self):
Expand All @@ -154,6 +162,7 @@ def test_rm_add_rm_instance(self):
self.__update_instance_count(0, 1, batching=batching) # remove
stop()
self.__unload_model(batching=batching)
self.__reset_model() # for next iteration

# Test reduce instance count to zero
def test_rm_instance_to_zero(self):
Expand Down Expand Up @@ -341,15 +350,89 @@ def infer():
# Unload model
self.__unload_model()

# Test for instance update on direct sequence scheduling
@unittest.skip("Sequence will not continue after update [FIXME: DLIS-4820]")
def test_instance_update_on_direct_sequence_scheduling(self):
pass

# Test for instance update on oldest sequence scheduling
@unittest.skip("Sequence will not continue after update [FIXME: DLIS-4820]")
def test_instance_update_on_oldest_sequence_scheduling(self):
pass
# Test wait for in-flight sequence completion and block new sequence
def test_sequence_instance_update(self):
for sequence_batching_type in [
"direct { }\nmax_sequence_idle_microseconds: 10000000",
"oldest { max_candidate_sequences: 4 }\nmax_sequence_idle_microseconds: 10000000"
]:
# Load model
update_instance_group("{\ncount: 2\nkind: KIND_CPU\n}")
update_sequence_batching(sequence_batching_type)
self.__triton.load_model(self.__model_name)
self.__check_count("initialize", 2)
self.__check_count("finalize", 0)
# Basic sequence inference
self.__triton.infer(self.__model_name,
self.__get_inputs(),
sequence_id=1,
sequence_start=True)
self.__triton.infer(self.__model_name,
self.__get_inputs(),
sequence_id=1)
self.__triton.infer(self.__model_name,
self.__get_inputs(),
sequence_id=1,
sequence_end=True)
# Update instance
update_instance_group("{\ncount: 4\nkind: KIND_CPU\n}")
self.__triton.load_model(self.__model_name)
self.__check_count("initialize", 4)
self.__check_count("finalize", 0)
# Start an in-flight sequence
self.__triton.infer(self.__model_name,
self.__get_inputs(),
sequence_id=1,
sequence_start=True)
# Check update instance will wait for in-flight sequence completion
# and block new sequence from starting.
update_instance_group("{\ncount: 3\nkind: KIND_CPU\n}")
update_complete = [False]
def update():
self.__triton.load_model(self.__model_name)
update_complete[0] = True
self.__check_count("initialize", 4)
self.__check_count("finalize", 1)
infer_complete = [False]
def infer():
self.__triton.infer(self.__model_name,
self.__get_inputs(),
sequence_id=2,
sequence_start=True)
infer_complete[0] = True
with concurrent.futures.ThreadPoolExecutor() as pool:
# Update should wait until sequence 1 end
update_thread = pool.submit(update)
time.sleep(2) # make sure update has started
self.assertFalse(update_complete[0],
"Unexpected update completion")
# New sequence should wait until update complete
infer_thread = pool.submit(infer)
time.sleep(2) # make sure infer has started
self.assertFalse(infer_complete[0],
"Unexpected infer completion")
# End sequence 1 should unblock update
self.__triton.infer(self.__model_name,
self.__get_inputs(),
sequence_id=1,
sequence_end=True)
time.sleep(2) # make sure update has returned
self.assertTrue(update_complete[0], "Update possibly stuck")
update_thread.result()
# Update completion should unblock new sequence
time.sleep(2) # make sure infer has returned
self.assertTrue(infer_complete[0], "Infer possibly stuck")
infer_thread.result()
# End sequence 2
self.__triton.infer(self.__model_name,
self.__get_inputs(),
sequence_id=2,
sequence_end=True)
# Unload model
self.__triton.unload_model(self.__model_name)
self.__check_count("initialize", 4)
self.__check_count("finalize", 4, True)
self.__reset_model()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion qa/python_models/model_init_del/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ instance_group [
count: 1
kind: KIND_CPU
}
]
] # end instance_group
25 changes: 23 additions & 2 deletions qa/python_models/model_init_del/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,31 @@ def update_instance_group(instance_group_str):
full_path = os.path.join(os.path.dirname(__file__), "config.pbtxt")
with open(full_path, mode="r+", encoding="utf-8", errors="strict") as f:
txt = f.read()
txt = txt.split("instance_group [")[0]
txt, post_match = txt.split("instance_group [")
txt += "instance_group [\n"
txt += instance_group_str
txt += "\n]\n"
txt += "\n] # end instance_group\n"
txt += post_match.split("\n] # end instance_group\n")[1]
f.truncate(0)
f.seek(0)
f.write(txt)
return txt

def update_sequence_batching(sequence_batching_str):
full_path = os.path.join(os.path.dirname(__file__), "config.pbtxt")
with open(full_path, mode="r+", encoding="utf-8", errors="strict") as f:
txt = f.read()
if "sequence_batching {" in txt:
txt, post_match = txt.split("sequence_batching {")
if sequence_batching_str != "":
txt += "sequence_batching {\n"
txt += sequence_batching_str
txt += "\n} # end sequence_batching\n"
txt += post_match.split("\n} # end sequence_batching\n")[1]
elif sequence_batching_str != "":
txt += "\nsequence_batching {\n"
txt += sequence_batching_str
txt += "\n} # end sequence_batching\n"
f.truncate(0)
f.seek(0)
f.write(txt)
Expand Down

0 comments on commit 4278a26

Please sign in to comment.