Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tanmayv25 committed Sep 13, 2023
1 parent 3fbbe52 commit 8771ed0
Showing 1 changed file with 45 additions and 34 deletions.
79 changes: 45 additions & 34 deletions qa/L0_request_cancellation/client_cancellation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,20 @@ def test_grpc_async_infer(self):

self._record_start_time_ms()

# Expect inference to pass successfully for a large timeout
# value
future = triton_client.async_infer(
model_name=self.model_name_,
inputs=self.inputs_,
callback=partial(callback, user_data),
outputs=self.outputs_,
)
time.sleep(2)
future.cancel()
with self.assertRaises(InferenceServerException) as cm:
future = triton_client.async_infer(
model_name=self.model_name_,
inputs=self.inputs_,
callback=partial(callback, user_data),
outputs=self.outputs_,
)
time.sleep(2)
future.cancel()

# Wait until the results is captured via callback
data_item = user_data._completed_requests.get()
self.assertEqual(type(data_item), grpcclient.CancelledError)
data_item = user_data._completed_requests.get()
if type(data_item) == InferenceServerException:
raise data_item
self.assertIn("Locally cancelled by application!", str(cm.exception))

self._record_end_time_ms()
self._test_runtime_duration(5000)
Expand All @@ -132,20 +132,22 @@ def test_grpc_stream_infer(self):
self._prepare_request()
user_data = UserData()

# The model is configured to take three seconds to send the
# response. Expect an exception for small timeout values.
triton_client.start_stream(callback=partial(callback, user_data))
self._record_start_time_ms()
for i in range(1):
triton_client.async_stream_infer(
model_name=self.model_name_, inputs=self.inputs_, outputs=self.outputs_
)

time.sleep(2)
triton_client.stop_stream(cancel_requests=True)

data_item = user_data._completed_requests.get()
self.assertEqual(type(data_item), grpcclient.CancelledError)
with self.assertRaises(InferenceServerException) as cm:
for i in range(1):
triton_client.async_stream_infer(
model_name=self.model_name_,
inputs=self.inputs_,
outputs=self.outputs_,
)
time.sleep(2)
triton_client.stop_stream(cancel_requests=True)
data_item = user_data._completed_requests.get()
if type(data_item) == InferenceServerException:
raise data_item
self.assertIn("Locally cancelled by application!", str(cm.exception))

self._record_end_time_ms()
self._test_runtime_duration(5000)
Expand All @@ -160,26 +162,31 @@ async def cancel_request(call):
await asyncio.sleep(2)
self.assertTrue(call.cancel())

async def handle_response(call):
async def handle_response(generator):
with self.assertRaises(asyncio.exceptions.CancelledError) as cm:
await call
_ = await anext(generator)

async def test_aio_infer(self):
triton_client = aiogrpcclient.InferenceServerClient(
url="localhost:8001", verbose=True
)
self._prepare_request()
self._record_start_time_ms()
# Expect inference to pass successfully for a large timeout
# value
call = await triton_client.infer(

generator = triton_client.infer(
model_name=self.model_name_,
inputs=self.inputs_,
outputs=self.outputs_,
get_call_obj=True,
)
asyncio.create_task(handle_response(call))
asyncio.create_task(cancel_request(call))
grpc_call = await anext(generator)

tasks = []
tasks.append(asyncio.create_task(handle_response(generator)))
tasks.append(asyncio.create_task(cancel_request(grpc_call)))

for task in tasks:
await task

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

self._record_end_time_ms()
self._test_runtime_duration(5000)
Expand Down Expand Up @@ -211,7 +218,7 @@ async def async_request_iterator():
response_iterator = triton_client.stream_infer(
inputs_iterator=async_request_iterator(), get_call_obj=True
)
streaming_call = await response_iterator.__anext__()
streaming_call = await anext(response_iterator)

async def cancel_streaming(streaming_call):
await asyncio.sleep(2)
Expand All @@ -222,8 +229,12 @@ async def handle_response(response_iterator):
async for response in response_iterator:
self.assertTrue(False, "Received an unexpected response!")

asyncio.create_task(handle_response(response_iterator))
asyncio.create_task(cancel_streaming(streaming_call))
tasks = []
tasks.append(asyncio.create_task(handle_response(response_iterator)))
tasks.append(asyncio.create_task(cancel_streaming(streaming_call)))

for task in tasks:
await task

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

self._record_end_time_ms()
self._test_runtime_duration(5000)
Expand Down

0 comments on commit 8771ed0

Please sign in to comment.