diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py index 05145d7205..d522f0f09a 100644 --- a/tests/system/aiplatform/test_language_models.py +++ b/tests/system/aiplatform/test_language_models.py @@ -159,6 +159,29 @@ def test_chat_on_chat_model(self): assert chat.message_history[2].content == message2 assert chat.message_history[3].author == chat.MODEL_AUTHOR + def test_chat_model_preview_count_tokens(self): + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + chat_model = ChatModel.from_pretrained("google/chat-bison@001") + + chat = chat_model.start_chat() + + chat.send_message("What should I do today?") + + response_with_history = chat.count_tokens("Any ideas?") + + response_without_history = chat_model.start_chat().count_tokens( + "What should I do today?" + ) + + assert ( + response_with_history.total_tokens > response_without_history.total_tokens + ) + assert ( + response_with_history.total_billable_characters + > response_without_history.total_billable_characters + ) + @pytest.mark.asyncio async def test_chat_model_async(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 917ed590ff..4d29e0ac73 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -2377,6 +2377,44 @@ def test_chat_model_send_message_streaming(self): assert chat.message_history[2].content == message_text1 assert chat.message_history[3].author == chat.MODEL_AUTHOR + def test_chat_model_preview_count_tokens(self): + """Tests the text generation model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CHAT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = preview_language_models.ChatModel.from_pretrained("chat-bison@001") + + chat = model.start_chat() + assert isinstance(chat, preview_language_models.ChatSession) + + gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse( + total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"], + total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[ + "total_billable_characters" + ], + ) + + with mock.patch.object( + target=prediction_service_client_v1beta1.PredictionServiceClient, + attribute="count_tokens", + return_value=gca_count_tokens_response, + ): + response = chat.count_tokens("What is the best recipe for banana bread?") + + assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"] + assert ( + response.total_billable_characters + == _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"] + ) + def test_code_chat(self): """Tests the code chat model.""" aiplatform.init( @@ -2577,6 +2615,46 @@ def test_code_chat_model_send_message_streaming(self): assert chat.message_history[0].content == message_text1 assert chat.message_history[1].author == chat.MODEL_AUTHOR + def test_code_chat_model_preview_count_tokens(self): + """Tests the text generation model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CODECHAT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = preview_language_models.CodeChatModel.from_pretrained( + "codechat-bison@001" + ) + + chat = model.start_chat() + assert isinstance(chat, preview_language_models.CodeChatSession) + + gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse( + total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"], + total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[ + "total_billable_characters" + ], + ) + + with mock.patch.object( + target=prediction_service_client_v1beta1.PredictionServiceClient, + attribute="count_tokens", + return_value=gca_count_tokens_response, + ): + response = chat.count_tokens("What is the best recipe for banana bread?") + + assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"] + assert ( + response.total_billable_characters + == _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"] + ) + def test_code_generation(self): """Tests code generation with the code generation model.""" aiplatform.init( diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 50c59ed339..4b4d0ec0ff 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -222,7 +222,9 @@ def tune_model( if eval_spec.evaluation_data: if isinstance(eval_spec.evaluation_data, str): if eval_spec.evaluation_data.startswith("gs://"): - tuning_parameters["evaluation_data_uri"] = eval_spec.evaluation_data + tuning_parameters[ + "evaluation_data_uri" + ] = eval_spec.evaluation_data else: raise ValueError("evaluation_data should be a GCS URI") else: @@ -627,7 +629,7 @@ def count_tokens( ) -> CountTokensResponse: """Counts the tokens and billable characters for a given prompt. - Note: this does not make a request to the model, it only counts the tokens + Note: this does not make a prediction request to the model, it only counts the tokens in the request. Args: @@ -802,7 +804,9 @@ def predict( parameters=prediction_request.parameters, ) - return _parse_text_generation_model_multi_candidate_response(prediction_response) + return _parse_text_generation_model_multi_candidate_response( + prediction_response + ) async def predict_async( self, @@ -844,7 +848,9 @@ async def predict_async( parameters=prediction_request.parameters, ) - return _parse_text_generation_model_multi_candidate_response(prediction_response) + return _parse_text_generation_model_multi_candidate_response( + prediction_response + ) def predict_streaming( self, @@ -1587,6 +1593,47 @@ class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin): _LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE + def start_chat( + self, + *, + context: Optional[str] = None, + examples: Optional[List[InputOutputTextPair]] = None, + max_output_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + message_history: Optional[List[ChatMessage]] = None, + stop_sequences: Optional[List[str]] = None, + ) -> "_PreviewChatSession": + """Starts a chat session with the model. + + Args: + context: Context shapes how the model responds throughout the conversation. + For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style + examples: List of structured messages to the model to learn how to respond to the conversation. + A list of `InputOutputTextPair` objects. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. + message_history: A list of previously sent and received messages. + stop_sequences: Customized stop sequences to stop the decoding process. + + Returns: + A `ChatSession` object. + """ + return _PreviewChatSession( + model=self, + context=context, + examples=examples, + max_output_tokens=max_output_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + message_history=message_history, + stop_sequences=stop_sequences, + ) + class CodeChatModel(_ChatModelBase): """CodeChatModel represents a model that is capable of completing code. @@ -1646,6 +1693,47 @@ class _PreviewCodeChatModel(CodeChatModel, _TunableChatModelMixin): _LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE + def start_chat( + self, + *, + context: Optional[str] = None, + examples: Optional[List[InputOutputTextPair]] = None, + max_output_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + message_history: Optional[List[ChatMessage]] = None, + stop_sequences: Optional[List[str]] = None, + ) -> "_PreviewCodeChatSession": + """Starts a chat session with the model. + + Args: + context: Context shapes how the model responds throughout the conversation. + For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style + examples: List of structured messages to the model to learn how to respond to the conversation. + A list of `InputOutputTextPair` objects. + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. + message_history: A list of previously sent and received messages. + stop_sequences: Customized stop sequences to stop the decoding process. + + Returns: + A `ChatSession` object. + """ + return _PreviewCodeChatSession( + model=self, + context=context, + examples=examples, + max_output_tokens=max_output_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + message_history=message_history, + stop_sequences=stop_sequences, + ) + class _ChatSessionBase: """_ChatSessionBase is a base class for all chat sessions.""" @@ -2071,6 +2159,67 @@ async def send_message_streaming_async( ) +class _ChatSessionBaseWithCountTokensMixin(_ChatSessionBase): + """A mixin class for adding count_tokens to ChatSession.""" + + def count_tokens( + self, + message: str, + ) -> CountTokensResponse: + """Counts the tokens and billable characters for the provided chat message and any message history, + context, or examples set on the chat session. + + If you've called `send_message()` in the current chat session before calling `count_tokens()`, the + response will include the total tokens and characters for the previously sent message and the one in the + `count_tokens()` request. To count the tokens for a single message, call `count_tokens()` right after + calling `start_chat()` before calling `send_message()`. + + Note: this does not make a prediction request to the model, it only counts the tokens + in the request. + + Examples:: + + model = ChatModel.from_pretrained("chat-bison@001") + chat_session = model.start_chat() + count_tokens_response = chat_session.count_tokens("How's it going?") + + count_tokens_response.total_tokens + count_tokens_response.total_billable_characters + + Args: + message (str): + Required. A chat message to count tokens or. For example: "How's it going?" + Returns: + A `CountTokensResponse` object that contains the number of tokens + in the text and the number of billable characters. + """ + + count_tokens_request = self._prepare_request(message=message) + + count_tokens_response = self._model._endpoint._prediction_client.select_version( + "v1beta1" + ).count_tokens( + endpoint=self._model._endpoint_name, + instances=[count_tokens_request.instance], + ) + + return CountTokensResponse( + total_tokens=count_tokens_response.total_tokens, + total_billable_characters=count_tokens_response.total_billable_characters, + _count_tokens_response=count_tokens_response, + ) + + +class _PreviewChatSession(_ChatSessionBaseWithCountTokensMixin): + + __module__ = "vertexai.preview.language_models" + + +class _PreviewCodeChatSession(_ChatSessionBaseWithCountTokensMixin): + + __module__ = "vertexai.preview.language_models" + + class ChatSession(_ChatSessionBase): """ChatSession represents a chat session with a language model. @@ -2361,7 +2510,9 @@ def predict( instances=[prediction_request.instance], parameters=prediction_request.parameters, ) - return _parse_text_generation_model_multi_candidate_response(prediction_response) + return _parse_text_generation_model_multi_candidate_response( + prediction_response + ) async def predict_async( self, @@ -2400,7 +2551,9 @@ async def predict_async( instances=[prediction_request.instance], parameters=prediction_request.parameters, ) - return _parse_text_generation_model_multi_candidate_response(prediction_response) + return _parse_text_generation_model_multi_candidate_response( + prediction_response + ) def predict_streaming( self, diff --git a/vertexai/preview/language_models.py b/vertexai/preview/language_models.py index e3048d21a7..7fd673b924 100644 --- a/vertexai/preview/language_models.py +++ b/vertexai/preview/language_models.py @@ -16,7 +16,9 @@ from vertexai.language_models._language_models import ( _PreviewChatModel, + _PreviewChatSession, _PreviewCodeChatModel, + _PreviewCodeChatSession, _PreviewCodeGenerationModel, _PreviewTextEmbeddingModel, _PreviewTextGenerationModel, @@ -43,7 +45,9 @@ ChatModel = _PreviewChatModel +ChatSession = _PreviewChatSession CodeChatModel = _PreviewCodeChatModel +CodeChatSession = _PreviewCodeChatSession CodeGenerationModel = _PreviewCodeGenerationModel TextGenerationModel = _PreviewTextGenerationModel TextEmbeddingModel = _PreviewTextEmbeddingModel