From bd7f1a7349f6a690ad2ab2eea2f442e8f4b9c9b3 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Fri, 27 Oct 2023 14:25:18 -0700 Subject: [PATCH] fix: LLM - Fixed the async streaming Fixes https://github.com/googleapis/python-aiplatform/issues/2853 PiperOrigin-RevId: 577305447 --- google/cloud/aiplatform/_streaming_prediction.py | 2 +- tests/unit/aiplatform/test_language_models.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/google/cloud/aiplatform/_streaming_prediction.py b/google/cloud/aiplatform/_streaming_prediction.py index 9c39f0d1b5..16a1d1331b 100644 --- a/google/cloud/aiplatform/_streaming_prediction.py +++ b/google/cloud/aiplatform/_streaming_prediction.py @@ -130,7 +130,7 @@ async def predict_stream_of_tensor_lists_from_single_tensor_list_async( inputs=tensor_list, parameters=parameters_tensor, ) - async for response in prediction_service_async_client.server_streaming_predict( + async for response in await prediction_service_async_client.server_streaming_predict( request=request ): yield response.outputs diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 3766eb8a7d..ab96e57d4d 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -1484,12 +1484,15 @@ async def test_text_generation_model_predict_streaming_async(self): "text-bison@001" ) - async def mock_server_streaming_predict_async(*args, **kwargs): + async def mock_server_streaming_predict_async_iter(*args, **kwargs): for response_dict in _TEST_TEXT_GENERATION_PREDICTION_STREAMING: yield gca_prediction_service.StreamingPredictResponse( outputs=[_streaming_prediction.value_to_tensor(response_dict)] ) + async def mock_server_streaming_predict_async(*args, **kwargs): + return mock_server_streaming_predict_async_iter(*args, **kwargs) + with mock.patch.object( target=prediction_service_async_client.PredictionServiceAsyncClient, attribute="server_streaming_predict",