Skip to content

Commit

Permalink
feat: LLM - Support streaming prediction for chat models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558246099
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 18, 2023
1 parent fb527f3 commit ce60cf7
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 12 deletions.
38 changes: 38 additions & 0 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,44 @@ def test_chat_on_chat_model(self):
assert chat.message_history[2].content == message2
assert chat.message_history[3].author == chat.MODEL_AUTHOR

def test_chat_model_send_message_streaming(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

chat_model = ChatModel.from_pretrained("google/chat-bison@001")
chat = chat_model.start_chat(
context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.",
examples=[
InputOutputTextPair(
input_text="Who do you work for?",
output_text="I work for Ned.",
),
InputOutputTextPair(
input_text="What do I like?",
output_text="Ned likes watching movies.",
),
],
temperature=0.0,
)

message1 = "Are my favorite movies based on a book series?"
for response in chat.send_message_streaming(message1):
assert response.text
assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
assert chat.message_history[0].content == message1
assert chat.message_history[1].author == chat.MODEL_AUTHOR

message2 = "When were these books published?"
for response2 in chat.send_message_streaming(
message2,
temperature=0.1,
):
assert response2.text
assert len(chat.message_history) == 4
assert chat.message_history[2].author == chat.USER_AUTHOR
assert chat.message_history[2].content == message2
assert chat.message_history[3].author == chat.MODEL_AUTHOR

def test_text_embedding(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

Expand Down
107 changes: 107 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,33 @@
],
}

_TEST_CHAT_PREDICTION_STREAMING = [
{
"candidates": [
{
"author": "1",
"content": "1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15.",
}
],
"safetyAttributes": [{"blocked": False, "categories": None, "scores": None}],
},
{
"candidates": [
{
"author": "1",
"content": " 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27.",
}
],
"safetyAttributes": [
{
"blocked": True,
"categories": ["Finance"],
"scores": [0.1],
}
],
},
]

_TEST_CODE_GENERATION_PREDICTION = {
"safetyAttributes": {
"categories": [],
Expand Down Expand Up @@ -1735,6 +1762,86 @@ def test_chat_ga(self):
assert prediction_parameters["topK"] == message_top_k
assert prediction_parameters["topP"] == message_top_p

def test_chat_model_send_message_streaming(self):
"""Tests the chat generation model."""
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CHAT_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.ChatModel.from_pretrained("chat-bison@001")

chat = model.start_chat(
context="""
My name is Ned.
You are my personal assistant.
My favorite movies are Lord of the Rings and Hobbit.
""",
examples=[
language_models.InputOutputTextPair(
input_text="Who do you work for?",
output_text="I work for Ned.",
),
language_models.InputOutputTextPair(
input_text="What do I like?",
output_text="Ned likes watching movies.",
),
],
message_history=[
language_models.ChatMessage(
author=preview_language_models.ChatSession.USER_AUTHOR,
content="Question 1?",
),
language_models.ChatMessage(
author=preview_language_models.ChatSession.MODEL_AUTHOR,
content="Answer 1.",
),
],
temperature=0.0,
)

# Using list instead of a generator so that it can be reused.
response_generator = [
gca_prediction_service.StreamingPredictResponse(
outputs=[_streaming_prediction.value_to_tensor(response_dict)]
)
for response_dict in _TEST_CHAT_PREDICTION_STREAMING
]

message_temperature = 0.2
message_max_output_tokens = 200
message_top_k = 2
message_top_p = 0.2

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="server_streaming_predict",
return_value=response_generator,
):
message_text1 = "Are my favorite movies based on a book series?"

for idx, response in enumerate(
chat.send_message_streaming(
message=message_text1,
max_output_tokens=message_max_output_tokens,
temperature=message_temperature,
top_k=message_top_k,
top_p=message_top_p,
)
):
assert len(response.text) > 10
# New messages are not added until the response is fully read
if idx + 1 < len(response_generator):
assert len(chat.message_history) == 2

# New messages are only added after the response is fully read
assert len(chat.message_history) == 4
assert chat.message_history[2].author == chat.USER_AUTHOR
assert chat.message_history[2].content == message_text1
assert chat.message_history[3].author == chat.MODEL_AUTHOR

def test_code_chat(self):
"""Tests the code chat model."""
aiplatform.init(
Expand Down
156 changes: 144 additions & 12 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ def _model_resource_name(self) -> str:
return self._endpoint.list_models()[0].model


@dataclasses.dataclass
class _PredictionRequest:
"""A single-instance prediction request."""
instance: Dict[str, Any]
parameters: Optional[Dict[str, Any]] = None


class _TunableModelMixin(_LanguageModel):
"""Model that can be tuned."""

Expand Down Expand Up @@ -915,16 +922,16 @@ def message_history(self) -> List[ChatMessage]:
"""List of previous messages."""
return self._message_history

def send_message(
def _prepare_request(
self,
message: str,
*,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> "TextGenerationResponse":
"""Sends message to the language model and gets a response.
) -> _PredictionRequest:
"""Prepares a request for the language model.
Args:
message: Message to send to the model
Expand All @@ -938,7 +945,7 @@ def send_message(
Uses the value specified when calling `ChatModel.start_chat` by default.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
A `_PredictionRequest` object.
"""
prediction_parameters = {}

Expand Down Expand Up @@ -986,27 +993,87 @@ def send_message(
for example in self._examples
]

prediction_response = self._model._endpoint.predict(
instances=[prediction_instance],
return _PredictionRequest(
instance=prediction_instance,
parameters=prediction_parameters,
)

prediction = prediction_response.predictions[0]
@classmethod
def _parse_chat_prediction_response(
cls,
prediction_response: aiplatform.models.Prediction,
prediction_idx: int = 0,
candidate_idx: int = 0,
) -> TextGenerationResponse:
"""Parses prediction response for chat models.
Args:
prediction_response: Prediction response received from the model
prediction_idx: Index of the prediction to parse.
candidate_idx: Index of the candidate to parse.
Returns:
A `TextGenerationResponse` object.
"""
prediction = prediction_response.predictions[prediction_idx]
# ! Note: For chat models, the safetyAttributes is a list.
safety_attributes = prediction["safetyAttributes"][0]
response_obj = TextGenerationResponse(
text=prediction["candidates"][0]["content"]
safety_attributes = prediction["safetyAttributes"][candidate_idx]
return TextGenerationResponse(
text=prediction["candidates"][candidate_idx]["content"]
if prediction.get("candidates")
else None,
_prediction_response=prediction_response,
is_blocked=safety_attributes.get("blocked", False),
safety_attributes=dict(
zip(
safety_attributes.get("categories", []),
safety_attributes.get("scores", []),
# Unlike with normal prediction, in streaming prediction
# categories and scores can be None
safety_attributes.get("categories") or [],
safety_attributes.get("scores") or [],
)
),
)

def send_message(
self,
message: str,
*,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> "TextGenerationResponse":
"""Sends message to the language model and gets a response.
Args:
message: Message to send to the model
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
Uses the value specified when calling `ChatModel.start_chat` by default.
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
Uses the value specified when calling `ChatModel.start_chat` by default.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
Uses the value specified when calling `ChatModel.start_chat` by default.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
Uses the value specified when calling `ChatModel.start_chat` by default.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
"""
prediction_request = self._prepare_request(
message=message,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)

prediction_response = self._model._endpoint.predict(
instances=[prediction_request.instance],
parameters=prediction_request.parameters,
)
response_obj = self._parse_chat_prediction_response(
prediction_response=prediction_response
)
response_text = response_obj.text

self._message_history.append(
Expand All @@ -1018,6 +1085,71 @@ def send_message(

return response_obj

def send_message_streaming(
self,
message: str,
*,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> Iterator[TextGenerationResponse]:
"""Sends message to the language model and gets a streamed response.
The response is only added to the history once it's fully read.
Args:
message: Message to send to the model
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
Uses the value specified when calling `ChatModel.start_chat` by default.
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
Uses the value specified when calling `ChatModel.start_chat` by default.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
Uses the value specified when calling `ChatModel.start_chat` by default.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
Uses the value specified when calling `ChatModel.start_chat` by default.
Yields:
A stream of `TextGenerationResponse` objects that contain partial
responses produced by the model.
"""
prediction_request = self._prepare_request(
message=message,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)

prediction_service_client = self._model._endpoint._prediction_client

full_response_text = ""

for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict(
prediction_service_client=prediction_service_client,
endpoint_name=self._model._endpoint_name,
instance=prediction_request.instance,
parameters=prediction_request.parameters,
):
prediction_response = aiplatform.models.Prediction(
predictions=[prediction_dict],
deployed_model_id="",
)
text_generation_response = self._parse_chat_prediction_response(
prediction_response=prediction_response
)
full_response_text += text_generation_response.text
yield text_generation_response

# We only add the question and answer to the history if/when the answer
# was read fully. Otherwise, the answer would have been truncated.
self._message_history.append(
ChatMessage(content=message, author=self.USER_AUTHOR)
)
self._message_history.append(
ChatMessage(content=full_response_text, author=self.MODEL_AUTHOR)
)


class ChatSession(_ChatSessionBase):
"""ChatSession represents a chat session with a language model.
Expand Down

0 comments on commit ce60cf7

Please sign in to comment.