Skip to content

Commit

Permalink
Fix response iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabrizian committed Jun 13, 2024
1 parent 6d00416 commit e959c98
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 18 deletions.
22 changes: 11 additions & 11 deletions python/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,6 @@ def test_ready(self):
server = tritonserver.Server(self._server_options).start()
self.assertTrue(server.ready())

@pytest.mark.xfail(
tritonserver.__version__ <= "2.46.0",
reason="Known issue on stop: Exit timeout expired. Exiting immediately",
raises=tritonserver.InternalError,
)
def test_stop(self):
server = tritonserver.Server(self._server_options).start(wait_until_ready=True)

Expand Down Expand Up @@ -455,15 +450,20 @@ def test_basic_inference(self):
"bool_input": numpy.random.rand(1, 100).astype(dtype=numpy.bool_),
}

for response in server.model("test").infer(
response_iterator = server.model("test").infer(
inputs=inputs,
output_memory_type="cpu",
raise_on_error=True,
):
for input_name, input_value in inputs.items():
output_value = response.outputs[input_name.replace("input", "output")]
output_value = numpy.from_dlpack(output_value)
numpy.testing.assert_array_equal(input_value, output_value)
)

responses = list(response_iterator)
self.assertTrue(len(responses), 1)
response = responses[0]

for input_name, input_value in inputs.items():
output_value = response.outputs[input_name.replace("input", "output")]
output_value = numpy.from_dlpack(output_value)
numpy.testing.assert_array_equal(input_value, output_value)

# test normal bool
inputs = {"bool_input": [[True, False, False, True]]}
Expand Down
7 changes: 0 additions & 7 deletions python/tritonserver/_api/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from __future__ import annotations

import asyncio
import inspect
import queue
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Optional
Expand Down Expand Up @@ -182,9 +181,6 @@ def _response_callback(self, response, flags, unused):
asyncio.run_coroutine_threadsafe(
self._user_queue.put(response), self._loop
)
if flags == TRITONSERVER_ResponseCompleteFlag.FINAL:
del self._request
self._request = None
except Exception as e:
message = f"Catastrophic failure in response callback: {e}"
LogMessage(LogLevel.ERROR, message)
Expand Down Expand Up @@ -308,9 +304,6 @@ def _response_callback(self, response, flags, unused):
self._queue.put(response)
if self._user_queue is not None:
self._user_queue.put(response)
if flags == TRITONSERVER_ResponseCompleteFlag.FINAL:
del self._request
self._request = None
except Exception as e:
message = f"Catastrophic failure in response callback: {e}"
LogMessage(LogLevel.ERROR, message)
Expand Down

0 comments on commit e959c98

Please sign in to comment.