diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py index 685b96edcc..00b8b70137 100644 --- a/tests/system/aiplatform/test_language_models.py +++ b/tests/system/aiplatform/test_language_models.py @@ -100,6 +100,44 @@ def test_chat_on_chat_model(self): assert chat.message_history[2].content == message2 assert chat.message_history[3].author == chat.MODEL_AUTHOR + def test_chat_model_send_message_streaming(self): + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + chat_model = ChatModel.from_pretrained("google/chat-bison@001") + chat = chat_model.start_chat( + context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.", + examples=[ + InputOutputTextPair( + input_text="Who do you work for?", + output_text="I work for Ned.", + ), + InputOutputTextPair( + input_text="What do I like?", + output_text="Ned likes watching movies.", + ), + ], + temperature=0.0, + ) + + message1 = "Are my favorite movies based on a book series?" + for response in chat.send_message_streaming(message1): + assert response.text + assert len(chat.message_history) == 2 + assert chat.message_history[0].author == chat.USER_AUTHOR + assert chat.message_history[0].content == message1 + assert chat.message_history[1].author == chat.MODEL_AUTHOR + + message2 = "When were these books published?" + for response2 in chat.send_message_streaming( + message2, + temperature=0.1, + ): + assert response2.text + assert len(chat.message_history) == 4 + assert chat.message_history[2].author == chat.USER_AUTHOR + assert chat.message_history[2].content == message2 + assert chat.message_history[3].author == chat.MODEL_AUTHOR + def test_text_embedding(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 056cebfc37..e0c2051a66 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -228,6 +228,33 @@ ], } +_TEST_CHAT_PREDICTION_STREAMING = [ + { + "candidates": [ + { + "author": "1", + "content": "1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15.", + } + ], + "safetyAttributes": [{"blocked": False, "categories": None, "scores": None}], + }, + { + "candidates": [ + { + "author": "1", + "content": " 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27.", + } + ], + "safetyAttributes": [ + { + "blocked": True, + "categories": ["Finance"], + "scores": [0.1], + } + ], + }, +] + _TEST_CODE_GENERATION_PREDICTION = { "safetyAttributes": { "categories": [], @@ -1735,6 +1762,86 @@ def test_chat_ga(self): assert prediction_parameters["topK"] == message_top_k assert prediction_parameters["topP"] == message_top_p + def test_chat_model_send_message_streaming(self): + """Tests the chat generation model.""" + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CHAT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = language_models.ChatModel.from_pretrained("chat-bison@001") + + chat = model.start_chat( + context=""" + My name is Ned. + You are my personal assistant. + My favorite movies are Lord of the Rings and Hobbit. + """, + examples=[ + language_models.InputOutputTextPair( + input_text="Who do you work for?", + output_text="I work for Ned.", + ), + language_models.InputOutputTextPair( + input_text="What do I like?", + output_text="Ned likes watching movies.", + ), + ], + message_history=[ + language_models.ChatMessage( + author=preview_language_models.ChatSession.USER_AUTHOR, + content="Question 1?", + ), + language_models.ChatMessage( + author=preview_language_models.ChatSession.MODEL_AUTHOR, + content="Answer 1.", + ), + ], + temperature=0.0, + ) + + # Using list instead of a generator so that it can be reused. + response_generator = [ + gca_prediction_service.StreamingPredictResponse( + outputs=[_streaming_prediction.value_to_tensor(response_dict)] + ) + for response_dict in _TEST_CHAT_PREDICTION_STREAMING + ] + + message_temperature = 0.2 + message_max_output_tokens = 200 + message_top_k = 2 + message_top_p = 0.2 + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="server_streaming_predict", + return_value=response_generator, + ): + message_text1 = "Are my favorite movies based on a book series?" + + for idx, response in enumerate( + chat.send_message_streaming( + message=message_text1, + max_output_tokens=message_max_output_tokens, + temperature=message_temperature, + top_k=message_top_k, + top_p=message_top_p, + ) + ): + assert len(response.text) > 10 + # New messages are not added until the response is fully read + if idx + 1 < len(response_generator): + assert len(chat.message_history) == 2 + + # New messages are only added after the response is fully read + assert len(chat.message_history) == 4 + assert chat.message_history[2].author == chat.USER_AUTHOR + assert chat.message_history[2].content == message_text1 + assert chat.message_history[3].author == chat.MODEL_AUTHOR + def test_code_chat(self): """Tests the code chat model.""" aiplatform.init( diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 31329011e1..1784c66b70 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -88,6 +88,13 @@ def _model_resource_name(self) -> str: return self._endpoint.list_models()[0].model +@dataclasses.dataclass +class _PredictionRequest: + """A single-instance prediction request.""" + instance: Dict[str, Any] + parameters: Optional[Dict[str, Any]] = None + + class _TunableModelMixin(_LanguageModel): """Model that can be tuned.""" @@ -915,7 +922,7 @@ def message_history(self) -> List[ChatMessage]: """List of previous messages.""" return self._message_history - def send_message( + def _prepare_request( self, message: str, *, @@ -923,8 +930,8 @@ def send_message( temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - ) -> "TextGenerationResponse": - """Sends message to the language model and gets a response. + ) -> _PredictionRequest: + """Prepares a request for the language model. Args: message: Message to send to the model @@ -938,7 +945,7 @@ def send_message( Uses the value specified when calling `ChatModel.start_chat` by default. Returns: - A `TextGenerationResponse` object that contains the text produced by the model. + A `_PredictionRequest` object. """ prediction_parameters = {} @@ -986,27 +993,87 @@ def send_message( for example in self._examples ] - prediction_response = self._model._endpoint.predict( - instances=[prediction_instance], + return _PredictionRequest( + instance=prediction_instance, parameters=prediction_parameters, ) - prediction = prediction_response.predictions[0] + @classmethod + def _parse_chat_prediction_response( + cls, + prediction_response: aiplatform.models.Prediction, + prediction_idx: int = 0, + candidate_idx: int = 0, + ) -> TextGenerationResponse: + """Parses prediction response for chat models. + + Args: + prediction_response: Prediction response received from the model + prediction_idx: Index of the prediction to parse. + candidate_idx: Index of the candidate to parse. + + Returns: + A `TextGenerationResponse` object. + """ + prediction = prediction_response.predictions[prediction_idx] # ! Note: For chat models, the safetyAttributes is a list. - safety_attributes = prediction["safetyAttributes"][0] - response_obj = TextGenerationResponse( - text=prediction["candidates"][0]["content"] + safety_attributes = prediction["safetyAttributes"][candidate_idx] + return TextGenerationResponse( + text=prediction["candidates"][candidate_idx]["content"] if prediction.get("candidates") else None, _prediction_response=prediction_response, is_blocked=safety_attributes.get("blocked", False), safety_attributes=dict( zip( - safety_attributes.get("categories", []), - safety_attributes.get("scores", []), + # Unlike with normal prediction, in streaming prediction + # categories and scores can be None + safety_attributes.get("categories") or [], + safety_attributes.get("scores") or [], ) ), ) + + def send_message( + self, + message: str, + *, + max_output_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + ) -> "TextGenerationResponse": + """Sends message to the language model and gets a response. + + Args: + message: Message to send to the model + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. + Uses the value specified when calling `ChatModel.start_chat` by default. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. + Uses the value specified when calling `ChatModel.start_chat` by default. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. + Uses the value specified when calling `ChatModel.start_chat` by default. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. + Uses the value specified when calling `ChatModel.start_chat` by default. + + Returns: + A `TextGenerationResponse` object that contains the text produced by the model. + """ + prediction_request = self._prepare_request( + message=message, + max_output_tokens=max_output_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + prediction_response = self._model._endpoint.predict( + instances=[prediction_request.instance], + parameters=prediction_request.parameters, + ) + response_obj = self._parse_chat_prediction_response( + prediction_response=prediction_response + ) response_text = response_obj.text self._message_history.append( @@ -1018,6 +1085,71 @@ def send_message( return response_obj + def send_message_streaming( + self, + message: str, + *, + max_output_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + ) -> Iterator[TextGenerationResponse]: + """Sends message to the language model and gets a streamed response. + + The response is only added to the history once it's fully read. + + Args: + message: Message to send to the model + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. + Uses the value specified when calling `ChatModel.start_chat` by default. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. + Uses the value specified when calling `ChatModel.start_chat` by default. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. + Uses the value specified when calling `ChatModel.start_chat` by default. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. + Uses the value specified when calling `ChatModel.start_chat` by default. + + Yields: + A stream of `TextGenerationResponse` objects that contain partial + responses produced by the model. + """ + prediction_request = self._prepare_request( + message=message, + max_output_tokens=max_output_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + prediction_service_client = self._model._endpoint._prediction_client + + full_response_text = "" + + for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict( + prediction_service_client=prediction_service_client, + endpoint_name=self._model._endpoint_name, + instance=prediction_request.instance, + parameters=prediction_request.parameters, + ): + prediction_response = aiplatform.models.Prediction( + predictions=[prediction_dict], + deployed_model_id="", + ) + text_generation_response = self._parse_chat_prediction_response( + prediction_response=prediction_response + ) + full_response_text += text_generation_response.text + yield text_generation_response + + # We only add the question and answer to the history if/when the answer + # was read fully. Otherwise, the answer would have been truncated. + self._message_history.append( + ChatMessage(content=message, author=self.USER_AUTHOR) + ) + self._message_history.append( + ChatMessage(content=full_response_text, author=self.MODEL_AUTHOR) + ) + class ChatSession(_ChatSessionBase): """ChatSession represents a chat session with a language model.