From fb527f3aa59ee90fa6306196b328f513ee4b4d9c Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Fri, 18 Aug 2023 01:04:11 -0700 Subject: [PATCH] feat: LLM - Support streaming prediction for text generation models PiperOrigin-RevId: 558068359 --- .../cloud/aiplatform/_streaming_prediction.py | 166 ++++++++++++++++++ .../system/aiplatform/test_language_models.py | 14 ++ tests/unit/aiplatform/test_language_models.py | 67 +++++++ vertexai/language_models/_language_models.py | 67 ++++++- 4 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 google/cloud/aiplatform/_streaming_prediction.py diff --git a/google/cloud/aiplatform/_streaming_prediction.py b/google/cloud/aiplatform/_streaming_prediction.py new file mode 100644 index 0000000000..cd5a33a491 --- /dev/null +++ b/google/cloud/aiplatform/_streaming_prediction.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Streaming prediction functions.""" + +from typing import Any, Dict, Iterator, List, Optional, Sequence + +from google.cloud.aiplatform_v1.services import prediction_service +from google.cloud.aiplatform_v1.types import ( + prediction_service as prediction_service_types, +) +from google.cloud.aiplatform_v1.types import ( + types as aiplatform_types, +) + + +def value_to_tensor(value: Any) -> aiplatform_types.Tensor: + """Converts a Python value to `Tensor`. + + Args: + value: A value to convert + + Returns: + A `Tensor` object + """ + if value is None: + return aiplatform_types.Tensor() + elif isinstance(value, int): + return aiplatform_types.Tensor(int_val=[value]) + elif isinstance(value, float): + return aiplatform_types.Tensor(float_val=[value]) + elif isinstance(value, bool): + return aiplatform_types.Tensor(bool_val=[value]) + elif isinstance(value, str): + return aiplatform_types.Tensor(string_val=[value]) + elif isinstance(value, bytes): + return aiplatform_types.Tensor(bytes_val=[value]) + elif isinstance(value, list): + return aiplatform_types.Tensor(list_val=[value_to_tensor(x) for x in value]) + elif isinstance(value, dict): + return aiplatform_types.Tensor( + struct_val={k: value_to_tensor(v) for k, v in value.items()} + ) + raise TypeError(f"Unsupported value type {type(value)}") + + +def tensor_to_value(tensor_pb: aiplatform_types.Tensor) -> Any: + """Converts `Tensor` to a Python value. + + Args: + tensor_pb: A `Tensor` object + + Returns: + A corresponding Python object + """ + list_of_fields = tensor_pb.ListFields() + if not list_of_fields: + return None + descriptor, value = tensor_pb.ListFields()[0] + if descriptor.name == "list_val": + return [tensor_to_value(x) for x in value] + elif descriptor.name == "struct_val": + return {k: tensor_to_value(v) for k, v in value.items()} + if not isinstance(value, Sequence): + raise TypeError(f"Unexpected non-list tensor value {value}") + if len(value) == 1: + return value[0] + else: + return value + + +def predict_stream_of_tensor_lists_from_single_tensor_list( + prediction_service_client: prediction_service.PredictionServiceClient, + endpoint_name: str, + tensor_list: List[aiplatform_types.Tensor], + parameters_tensor: Optional[aiplatform_types.Tensor] = None, +) -> Iterator[List[aiplatform_types.Tensor]]: + """Predicts a stream of lists of `Tensor` objects from a single list of `Tensor` objects. + + Args: + tensor_list: Model input as a list of `Tensor` objects. + parameters_tensor: Optional. Prediction parameters in `Tensor` form. + prediction_service_client: A PredictionServiceClient object. + endpoint_name: Resource name of Endpoint or PublisherModel. + + Yields: + A generator of model prediction `Tensor` lists. + """ + request = prediction_service_types.StreamingPredictRequest( + endpoint=endpoint_name, + inputs=tensor_list, + parameters=parameters_tensor, + ) + for response in prediction_service_client.server_streaming_predict(request=request): + yield response.outputs + + +def predict_stream_of_dict_lists_from_single_dict_list( + prediction_service_client: prediction_service.PredictionServiceClient, + endpoint_name: str, + dict_list: List[Dict[str, Any]], + parameters: Optional[Dict[str, Any]] = None, +) -> Iterator[List[Dict[str, Any]]]: + """Predicts a stream of lists of dicts from a stream of lists of dicts. + + Args: + dict_list: Model input as a list of `dict` objects. + parameters: Optional. Prediction parameters `dict` form. + prediction_service_client: A PredictionServiceClient object. + endpoint_name: Resource name of Endpoint or PublisherModel. + + Yields: + A generator of model prediction dict lists. + """ + tensor_list = [value_to_tensor(d) for d in dict_list] + parameters_tensor = value_to_tensor(parameters) if parameters else None + for tensor_list in predict_stream_of_tensor_lists_from_single_tensor_list( + prediction_service_client=prediction_service_client, + endpoint_name=endpoint_name, + tensor_list=tensor_list, + parameters_tensor=parameters_tensor, + ): + yield [tensor_to_value(tensor._pb) for tensor in tensor_list] + + +def predict_stream_of_dicts_from_single_dict( + prediction_service_client: prediction_service.PredictionServiceClient, + endpoint_name: str, + instance: Dict[str, Any], + parameters: Optional[Dict[str, Any]] = None, +) -> Iterator[Dict[str, Any]]: + """Predicts a stream of dicts from a single instance dict. + + Args: + instance: A single input instance `dict`. + parameters: Optional. Prediction parameters `dict`. + prediction_service_client: A PredictionServiceClient object. + endpoint_name: Resource name of Endpoint or PublisherModel. + + Yields: + A generator of model prediction dicts. + """ + for dict_list in predict_stream_of_dict_lists_from_single_dict_list( + prediction_service_client=prediction_service_client, + endpoint_name=endpoint_name, + dict_list=[instance], + parameters=parameters, + ): + if len(dict_list) > 1: + raise ValueError( + f"Expected to receive a single output, but got {dict_list}" + ) + yield dict_list[0] diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py index 77d93ab9bd..685b96edcc 100644 --- a/tests/system/aiplatform/test_language_models.py +++ b/tests/system/aiplatform/test_language_models.py @@ -48,6 +48,20 @@ def test_text_generation(self): top_k=5, ).text + def test_text_generation_streaming(self): + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + model = TextGenerationModel.from_pretrained("google/text-bison@001") + + for response in model.predict_streaming( + "What is the best recipe for banana bread? Recipe:", + max_output_tokens=128, + temperature=0, + top_p=1, + top_k=5, + ): + assert response.text + def test_chat_on_chat_model(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 76d6c144d6..056cebfc37 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -28,6 +28,7 @@ from google.cloud import storage from google.cloud import aiplatform +from google.cloud.aiplatform import _streaming_prediction from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer from google.cloud.aiplatform.utils import gcs_utils @@ -168,6 +169,34 @@ 1. Preheat oven to 350 degrees F (175 degrees C).""", } +_TEST_TEXT_GENERATION_PREDICTION_STREAMING = [ + { + "content": "1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.", + }, + { + "content": " 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.", + "safetyAttributes": {"blocked": False, "categories": None, "scores": None}, + }, + { + "content": " 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45.", + "citationMetadata": { + "citations": [ + { + "title": "THEATRUM ARITHMETICO-GEOMETRICUM", + "publicationDate": "1727", + "endIndex": 181, + "startIndex": 12, + } + ] + }, + "safetyAttributes": { + "blocked": True, + "categories": ["Finance"], + "scores": [0.1], + }, + }, +] + _TEST_CHAT_GENERATION_PREDICTION1 = { "safetyAttributes": [ { @@ -1040,6 +1069,10 @@ class TestLanguageModels: def setup_method(self): reload(initializer) reload(aiplatform) + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) def teardown_method(self): initializer.global_pool.shutdown(wait=True) @@ -1165,6 +1198,40 @@ def test_text_generation_ga(self): assert "topP" not in prediction_parameters assert "topK" not in prediction_parameters + def test_text_generation_model_predict_streaming(self): + """Tests the TextGenerationModel.predict_streaming method.""" + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = language_models.TextGenerationModel.from_pretrained( + "text-bison@001" + ) + + response_generator = ( + gca_prediction_service.StreamingPredictResponse( + outputs=[_streaming_prediction.value_to_tensor(response_dict)] + ) + for response_dict in _TEST_TEXT_GENERATION_PREDICTION_STREAMING + ) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="server_streaming_predict", + return_value=response_generator, + ): + for response in model.predict_streaming( + "Count to 50", + max_output_tokens=1000, + temperature=0, + top_p=1, + top_k=5, + ): + assert len(response.text) > 10 + @pytest.mark.parametrize( "job_spec", [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB], diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 28d73f4eac..31329011e1 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -15,10 +15,11 @@ """Classes for working with language models.""" import dataclasses -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Union import warnings from google.cloud import aiplatform +from google.cloud.aiplatform import _streaming_prediction from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer as aiplatform_initializer from google.cloud.aiplatform import utils as aiplatform_utils @@ -389,6 +390,70 @@ def _batch_predict( ) return results + def predict_streaming( + self, + prompt: str, + *, + max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + ) -> Iterator[TextGenerationResponse]: + """Gets a streaming model response for a single prompt. + + The result is a stream (generator) of partial responses. + + Args: + prompt: Question to ask the model. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. + + Yields: + A stream of `TextGenerationResponse` objects that contain partial + responses produced by the model. + """ + prediction_service_client = self._endpoint._prediction_client + # Note: "prompt", not "content" like in the non-streaming case. b/294462691 + instance = {"prompt": prompt} + prediction_parameters = {} + + if max_output_tokens: + prediction_parameters["maxDecodeSteps"] = max_output_tokens + + if temperature is not None: + prediction_parameters["temperature"] = temperature + + if top_p: + prediction_parameters["topP"] = top_p + + if top_k: + prediction_parameters["topK"] = top_k + + for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict( + prediction_service_client=prediction_service_client, + endpoint_name=self._endpoint_name, + instance=instance, + parameters=prediction_parameters, + ): + safety_attributes_dict = prediction_dict.get("safetyAttributes", {}) + prediction_obj = aiplatform.models.Prediction( + predictions=[prediction_dict], + deployed_model_id="", + ) + yield TextGenerationResponse( + text=prediction_dict["content"], + _prediction_response=prediction_obj, + is_blocked=safety_attributes_dict.get("blocked", False), + safety_attributes=dict( + zip( + safety_attributes_dict.get("categories") or [], + safety_attributes_dict.get("scores") or [], + ) + ), + ) + class _ModelWithBatchPredict(_LanguageModel): """Model that supports batch prediction."""