Skip to content

Commit

Permalink
ci: Fix L0_model_control_stress_vllm (#79)
Browse files Browse the repository at this point in the history
yinggeh authored Dec 23, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 2f5bfbd commit d061556
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions samples/client.py
Original file line number Diff line number Diff line change
@@ -38,13 +38,21 @@

class LLMClient:
def __init__(self, flags: argparse.Namespace):
self._client = grpcclient.InferenceServerClient(
url=flags.url, verbose=flags.verbose
)
self._flags = flags
self._loop = asyncio.get_event_loop()
self._results_dict = {}

def get_triton_client(self):
try:
triton_client = grpcclient.InferenceServerClient(
url=self._flags.url,
verbose=self._flags.verbose,
)
except Exception as e:
print("channel creation failed: " + str(e))
sys.exit()

return triton_client

async def async_request_iterator(
self, prompts, sampling_parameters, exclude_input_in_output
):
@@ -65,8 +73,9 @@ async def async_request_iterator(

async def stream_infer(self, prompts, sampling_parameters, exclude_input_in_output):
try:
triton_client = self.get_triton_client()
# Start streaming
response_iterator = self._client.stream_infer(
response_iterator = triton_client.stream_infer(
inputs_iterator=self.async_request_iterator(
prompts, sampling_parameters, exclude_input_in_output
),
@@ -138,7 +147,7 @@ async def run(self):
print("FAIL: vLLM example")

def run_async(self):
self._loop.run_until_complete(self.run())
asyncio.run(self.run())

def create_request(
self,

0 comments on commit d061556

Please sign in to comment.