Skip to content

Commit

Permalink
feat: LLM - Support streaming prediction for text generation models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558068359
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 18, 2023
1 parent 8df5185 commit fb527f3
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 1 deletion.
166 changes: 166 additions & 0 deletions google/cloud/aiplatform/_streaming_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# -*- coding: utf-8 -*-

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Streaming prediction functions."""

from typing import Any, Dict, Iterator, List, Optional, Sequence

from google.cloud.aiplatform_v1.services import prediction_service
from google.cloud.aiplatform_v1.types import (
prediction_service as prediction_service_types,
)
from google.cloud.aiplatform_v1.types import (
types as aiplatform_types,
)


def value_to_tensor(value: Any) -> aiplatform_types.Tensor:
"""Converts a Python value to `Tensor`.
Args:
value: A value to convert
Returns:
A `Tensor` object
"""
if value is None:
return aiplatform_types.Tensor()
elif isinstance(value, int):
return aiplatform_types.Tensor(int_val=[value])
elif isinstance(value, float):
return aiplatform_types.Tensor(float_val=[value])
elif isinstance(value, bool):
return aiplatform_types.Tensor(bool_val=[value])
elif isinstance(value, str):
return aiplatform_types.Tensor(string_val=[value])
elif isinstance(value, bytes):
return aiplatform_types.Tensor(bytes_val=[value])
elif isinstance(value, list):
return aiplatform_types.Tensor(list_val=[value_to_tensor(x) for x in value])
elif isinstance(value, dict):
return aiplatform_types.Tensor(
struct_val={k: value_to_tensor(v) for k, v in value.items()}
)
raise TypeError(f"Unsupported value type {type(value)}")


def tensor_to_value(tensor_pb: aiplatform_types.Tensor) -> Any:
"""Converts `Tensor` to a Python value.
Args:
tensor_pb: A `Tensor` object
Returns:
A corresponding Python object
"""
list_of_fields = tensor_pb.ListFields()
if not list_of_fields:
return None
descriptor, value = tensor_pb.ListFields()[0]
if descriptor.name == "list_val":
return [tensor_to_value(x) for x in value]
elif descriptor.name == "struct_val":
return {k: tensor_to_value(v) for k, v in value.items()}
if not isinstance(value, Sequence):
raise TypeError(f"Unexpected non-list tensor value {value}")
if len(value) == 1:
return value[0]
else:
return value


def predict_stream_of_tensor_lists_from_single_tensor_list(
prediction_service_client: prediction_service.PredictionServiceClient,
endpoint_name: str,
tensor_list: List[aiplatform_types.Tensor],
parameters_tensor: Optional[aiplatform_types.Tensor] = None,
) -> Iterator[List[aiplatform_types.Tensor]]:
"""Predicts a stream of lists of `Tensor` objects from a single list of `Tensor` objects.
Args:
tensor_list: Model input as a list of `Tensor` objects.
parameters_tensor: Optional. Prediction parameters in `Tensor` form.
prediction_service_client: A PredictionServiceClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction `Tensor` lists.
"""
request = prediction_service_types.StreamingPredictRequest(
endpoint=endpoint_name,
inputs=tensor_list,
parameters=parameters_tensor,
)
for response in prediction_service_client.server_streaming_predict(request=request):
yield response.outputs


def predict_stream_of_dict_lists_from_single_dict_list(
prediction_service_client: prediction_service.PredictionServiceClient,
endpoint_name: str,
dict_list: List[Dict[str, Any]],
parameters: Optional[Dict[str, Any]] = None,
) -> Iterator[List[Dict[str, Any]]]:
"""Predicts a stream of lists of dicts from a stream of lists of dicts.
Args:
dict_list: Model input as a list of `dict` objects.
parameters: Optional. Prediction parameters `dict` form.
prediction_service_client: A PredictionServiceClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction dict lists.
"""
tensor_list = [value_to_tensor(d) for d in dict_list]
parameters_tensor = value_to_tensor(parameters) if parameters else None
for tensor_list in predict_stream_of_tensor_lists_from_single_tensor_list(
prediction_service_client=prediction_service_client,
endpoint_name=endpoint_name,
tensor_list=tensor_list,
parameters_tensor=parameters_tensor,
):
yield [tensor_to_value(tensor._pb) for tensor in tensor_list]


def predict_stream_of_dicts_from_single_dict(
prediction_service_client: prediction_service.PredictionServiceClient,
endpoint_name: str,
instance: Dict[str, Any],
parameters: Optional[Dict[str, Any]] = None,
) -> Iterator[Dict[str, Any]]:
"""Predicts a stream of dicts from a single instance dict.
Args:
instance: A single input instance `dict`.
parameters: Optional. Prediction parameters `dict`.
prediction_service_client: A PredictionServiceClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction dicts.
"""
for dict_list in predict_stream_of_dict_lists_from_single_dict_list(
prediction_service_client=prediction_service_client,
endpoint_name=endpoint_name,
dict_list=[instance],
parameters=parameters,
):
if len(dict_list) > 1:
raise ValueError(
f"Expected to receive a single output, but got {dict_list}"
)
yield dict_list[0]
14 changes: 14 additions & 0 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ def test_text_generation(self):
top_k=5,
).text

def test_text_generation_streaming(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = TextGenerationModel.from_pretrained("google/text-bison@001")

for response in model.predict_streaming(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0,
top_p=1,
top_k=5,
):
assert response.text

def test_chat_on_chat_model(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

Expand Down
67 changes: 67 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.cloud import storage

from google.cloud import aiplatform
from google.cloud.aiplatform import _streaming_prediction
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.utils import gcs_utils
Expand Down Expand Up @@ -168,6 +169,34 @@
1. Preheat oven to 350 degrees F (175 degrees C).""",
}

_TEST_TEXT_GENERATION_PREDICTION_STREAMING = [
{
"content": "1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.",
},
{
"content": " 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.",
"safetyAttributes": {"blocked": False, "categories": None, "scores": None},
},
{
"content": " 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45.",
"citationMetadata": {
"citations": [
{
"title": "THEATRUM ARITHMETICO-GEOMETRICUM",
"publicationDate": "1727",
"endIndex": 181,
"startIndex": 12,
}
]
},
"safetyAttributes": {
"blocked": True,
"categories": ["Finance"],
"scores": [0.1],
},
},
]

_TEST_CHAT_GENERATION_PREDICTION1 = {
"safetyAttributes": [
{
Expand Down Expand Up @@ -1040,6 +1069,10 @@ class TestLanguageModels:
def setup_method(self):
reload(initializer)
reload(aiplatform)
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

def teardown_method(self):
initializer.global_pool.shutdown(wait=True)
Expand Down Expand Up @@ -1165,6 +1198,40 @@ def test_text_generation_ga(self):
assert "topP" not in prediction_parameters
assert "topK" not in prediction_parameters

def test_text_generation_model_predict_streaming(self):
"""Tests the TextGenerationModel.predict_streaming method."""
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

response_generator = (
gca_prediction_service.StreamingPredictResponse(
outputs=[_streaming_prediction.value_to_tensor(response_dict)]
)
for response_dict in _TEST_TEXT_GENERATION_PREDICTION_STREAMING
)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="server_streaming_predict",
return_value=response_generator,
):
for response in model.predict_streaming(
"Count to 50",
max_output_tokens=1000,
temperature=0,
top_p=1,
top_k=5,
):
assert len(response.text) > 10

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB],
Expand Down
67 changes: 66 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
"""Classes for working with language models."""

import dataclasses
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
import warnings

from google.cloud import aiplatform
from google.cloud.aiplatform import _streaming_prediction
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform import utils as aiplatform_utils
Expand Down Expand Up @@ -389,6 +390,70 @@ def _batch_predict(
)
return results

def predict_streaming(
self,
prompt: str,
*,
max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> Iterator[TextGenerationResponse]:
"""Gets a streaming model response for a single prompt.
The result is a stream (generator) of partial responses.
Args:
prompt: Question to ask the model.
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
Yields:
A stream of `TextGenerationResponse` objects that contain partial
responses produced by the model.
"""
prediction_service_client = self._endpoint._prediction_client
# Note: "prompt", not "content" like in the non-streaming case. b/294462691
instance = {"prompt": prompt}
prediction_parameters = {}

if max_output_tokens:
prediction_parameters["maxDecodeSteps"] = max_output_tokens

if temperature is not None:
prediction_parameters["temperature"] = temperature

if top_p:
prediction_parameters["topP"] = top_p

if top_k:
prediction_parameters["topK"] = top_k

for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
prediction_service_client=prediction_service_client,
endpoint_name=self._endpoint_name,
instance=instance,
parameters=prediction_parameters,
):
safety_attributes_dict = prediction_dict.get("safetyAttributes", {})
prediction_obj = aiplatform.models.Prediction(
predictions=[prediction_dict],
deployed_model_id="",
)
yield TextGenerationResponse(
text=prediction_dict["content"],
_prediction_response=prediction_obj,
is_blocked=safety_attributes_dict.get("blocked", False),
safety_attributes=dict(
zip(
safety_attributes_dict.get("categories") or [],
safety_attributes_dict.get("scores") or [],
)
),
)


class _ModelWithBatchPredict(_LanguageModel):
"""Model that supports batch prediction."""
Expand Down

0 comments on commit fb527f3

Please sign in to comment.