diff --git a/google/cloud/aiplatform/_streaming_prediction.py b/google/cloud/aiplatform/_streaming_prediction.py index 9c39f0d1b52..16a1d1331be 100644 --- a/google/cloud/aiplatform/_streaming_prediction.py +++ b/google/cloud/aiplatform/_streaming_prediction.py @@ -130,7 +130,7 @@ async def predict_stream_of_tensor_lists_from_single_tensor_list_async( inputs=tensor_list, parameters=parameters_tensor, ) - async for response in prediction_service_async_client.server_streaming_predict( + async for response in await prediction_service_async_client.server_streaming_predict( request=request ): yield response.outputs diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 3766eb8a7d0..ab96e57d4d4 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -1484,12 +1484,15 @@ async def test_text_generation_model_predict_streaming_async(self): "text-bison@001" ) - async def mock_server_streaming_predict_async(*args, **kwargs): + async def mock_server_streaming_predict_async_iter(*args, **kwargs): for response_dict in _TEST_TEXT_GENERATION_PREDICTION_STREAMING: yield gca_prediction_service.StreamingPredictResponse( outputs=[_streaming_prediction.value_to_tensor(response_dict)] ) + async def mock_server_streaming_predict_async(*args, **kwargs): + return mock_server_streaming_predict_async_iter(*args, **kwargs) + with mock.patch.object( target=prediction_service_async_client.PredictionServiceAsyncClient, attribute="server_streaming_predict",