Skip to content

Commit

Permalink
feat: LLM - Added count_tokens support to ChatModel (preview)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 575006811
  • Loading branch information
sararob authored and copybara-github committed Oct 19, 2023
1 parent eb6071f commit 01989b1
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 6 deletions.
23 changes: 23 additions & 0 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,29 @@ def test_chat_on_chat_model(self):
assert chat.message_history[2].content == message2
assert chat.message_history[3].author == chat.MODEL_AUTHOR

def test_chat_model_preview_count_tokens(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

chat_model = ChatModel.from_pretrained("google/chat-bison@001")

chat = chat_model.start_chat()

chat.send_message("What should I do today?")

response_with_history = chat.count_tokens("Any ideas?")

response_without_history = chat_model.start_chat().count_tokens(
"What should I do today?"
)

assert (
response_with_history.total_tokens > response_without_history.total_tokens
)
assert (
response_with_history.total_billable_characters
> response_without_history.total_billable_characters
)

@pytest.mark.asyncio
async def test_chat_model_async(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2377,6 +2377,44 @@ def test_chat_model_send_message_streaming(self):
assert chat.message_history[2].content == message_text1
assert chat.message_history[3].author == chat.MODEL_AUTHOR

def test_chat_model_preview_count_tokens(self):
"""Tests the text generation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CHAT_BISON_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")

chat = model.start_chat()
assert isinstance(chat, preview_language_models.ChatSession)

gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse(
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
"total_billable_characters"
],
)

with mock.patch.object(
target=prediction_service_client_v1beta1.PredictionServiceClient,
attribute="count_tokens",
return_value=gca_count_tokens_response,
):
response = chat.count_tokens("What is the best recipe for banana bread?")

assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
assert (
response.total_billable_characters
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
)

def test_code_chat(self):
"""Tests the code chat model."""
aiplatform.init(
Expand Down Expand Up @@ -2577,6 +2615,46 @@ def test_code_chat_model_send_message_streaming(self):
assert chat.message_history[0].content == message_text1
assert chat.message_history[1].author == chat.MODEL_AUTHOR

def test_code_chat_model_preview_count_tokens(self):
"""Tests the text generation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.CodeChatModel.from_pretrained(
"codechat-bison@001"
)

chat = model.start_chat()
assert isinstance(chat, preview_language_models.CodeChatSession)

gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse(
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
"total_billable_characters"
],
)

with mock.patch.object(
target=prediction_service_client_v1beta1.PredictionServiceClient,
attribute="count_tokens",
return_value=gca_count_tokens_response,
):
response = chat.count_tokens("What is the best recipe for banana bread?")

assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
assert (
response.total_billable_characters
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
)

def test_code_generation(self):
"""Tests code generation with the code generation model."""
aiplatform.init(
Expand Down
165 changes: 159 additions & 6 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def tune_model(
if eval_spec.evaluation_data:
if isinstance(eval_spec.evaluation_data, str):
if eval_spec.evaluation_data.startswith("gs://"):
tuning_parameters["evaluation_data_uri"] = eval_spec.evaluation_data
tuning_parameters[
"evaluation_data_uri"
] = eval_spec.evaluation_data
else:
raise ValueError("evaluation_data should be a GCS URI")
else:
Expand Down Expand Up @@ -627,7 +629,7 @@ def count_tokens(
) -> CountTokensResponse:
"""Counts the tokens and billable characters for a given prompt.
Note: this does not make a request to the model, it only counts the tokens
Note: this does not make a prediction request to the model, it only counts the tokens
in the request.
Args:
Expand Down Expand Up @@ -802,7 +804,9 @@ def predict(
parameters=prediction_request.parameters,
)

return _parse_text_generation_model_multi_candidate_response(prediction_response)
return _parse_text_generation_model_multi_candidate_response(
prediction_response
)

async def predict_async(
self,
Expand Down Expand Up @@ -844,7 +848,9 @@ async def predict_async(
parameters=prediction_request.parameters,
)

return _parse_text_generation_model_multi_candidate_response(prediction_response)
return _parse_text_generation_model_multi_candidate_response(
prediction_response
)

def predict_streaming(
self,
Expand Down Expand Up @@ -1587,6 +1593,47 @@ class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin):

_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE

def start_chat(
self,
*,
context: Optional[str] = None,
examples: Optional[List[InputOutputTextPair]] = None,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
message_history: Optional[List[ChatMessage]] = None,
stop_sequences: Optional[List[str]] = None,
) -> "_PreviewChatSession":
"""Starts a chat session with the model.
Args:
context: Context shapes how the model responds throughout the conversation.
For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
examples: List of structured messages to the model to learn how to respond to the conversation.
A list of `InputOutputTextPair` objects.
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
message_history: A list of previously sent and received messages.
stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `ChatSession` object.
"""
return _PreviewChatSession(
model=self,
context=context,
examples=examples,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
message_history=message_history,
stop_sequences=stop_sequences,
)


class CodeChatModel(_ChatModelBase):
"""CodeChatModel represents a model that is capable of completing code.
Expand Down Expand Up @@ -1646,6 +1693,47 @@ class _PreviewCodeChatModel(CodeChatModel, _TunableChatModelMixin):

_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE

def start_chat(
self,
*,
context: Optional[str] = None,
examples: Optional[List[InputOutputTextPair]] = None,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
message_history: Optional[List[ChatMessage]] = None,
stop_sequences: Optional[List[str]] = None,
) -> "_PreviewCodeChatSession":
"""Starts a chat session with the model.
Args:
context: Context shapes how the model responds throughout the conversation.
For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
examples: List of structured messages to the model to learn how to respond to the conversation.
A list of `InputOutputTextPair` objects.
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
message_history: A list of previously sent and received messages.
stop_sequences: Customized stop sequences to stop the decoding process.
Returns:
A `ChatSession` object.
"""
return _PreviewCodeChatSession(
model=self,
context=context,
examples=examples,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
message_history=message_history,
stop_sequences=stop_sequences,
)


class _ChatSessionBase:
"""_ChatSessionBase is a base class for all chat sessions."""
Expand Down Expand Up @@ -2071,6 +2159,67 @@ async def send_message_streaming_async(
)


class _ChatSessionBaseWithCountTokensMixin(_ChatSessionBase):
"""A mixin class for adding count_tokens to ChatSession."""

def count_tokens(
self,
message: str,
) -> CountTokensResponse:
"""Counts the tokens and billable characters for the provided chat message and any message history,
context, or examples set on the chat session.
If you've called `send_message()` in the current chat session before calling `count_tokens()`, the
response will include the total tokens and characters for the previously sent message and the one in the
`count_tokens()` request. To count the tokens for a single message, call `count_tokens()` right after
calling `start_chat()` before calling `send_message()`.
Note: this does not make a prediction request to the model, it only counts the tokens
in the request.
Examples::
model = ChatModel.from_pretrained("chat-bison@001")
chat_session = model.start_chat()
count_tokens_response = chat_session.count_tokens("How's it going?")
count_tokens_response.total_tokens
count_tokens_response.total_billable_characters
Args:
message (str):
Required. A chat message to count tokens or. For example: "How's it going?"
Returns:
A `CountTokensResponse` object that contains the number of tokens
in the text and the number of billable characters.
"""

count_tokens_request = self._prepare_request(message=message)

count_tokens_response = self._model._endpoint._prediction_client.select_version(
"v1beta1"
).count_tokens(
endpoint=self._model._endpoint_name,
instances=[count_tokens_request.instance],
)

return CountTokensResponse(
total_tokens=count_tokens_response.total_tokens,
total_billable_characters=count_tokens_response.total_billable_characters,
_count_tokens_response=count_tokens_response,
)


class _PreviewChatSession(_ChatSessionBaseWithCountTokensMixin):

__module__ = "vertexai.preview.language_models"


class _PreviewCodeChatSession(_ChatSessionBaseWithCountTokensMixin):

__module__ = "vertexai.preview.language_models"


class ChatSession(_ChatSessionBase):
"""ChatSession represents a chat session with a language model.
Expand Down Expand Up @@ -2361,7 +2510,9 @@ def predict(
instances=[prediction_request.instance],
parameters=prediction_request.parameters,
)
return _parse_text_generation_model_multi_candidate_response(prediction_response)
return _parse_text_generation_model_multi_candidate_response(
prediction_response
)

async def predict_async(
self,
Expand Down Expand Up @@ -2400,7 +2551,9 @@ async def predict_async(
instances=[prediction_request.instance],
parameters=prediction_request.parameters,
)
return _parse_text_generation_model_multi_candidate_response(prediction_response)
return _parse_text_generation_model_multi_candidate_response(
prediction_response
)

def predict_streaming(
self,
Expand Down
4 changes: 4 additions & 0 deletions vertexai/preview/language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

from vertexai.language_models._language_models import (
_PreviewChatModel,
_PreviewChatSession,
_PreviewCodeChatModel,
_PreviewCodeChatSession,
_PreviewCodeGenerationModel,
_PreviewTextEmbeddingModel,
_PreviewTextGenerationModel,
Expand All @@ -43,7 +45,9 @@


ChatModel = _PreviewChatModel
ChatSession = _PreviewChatSession
CodeChatModel = _PreviewCodeChatModel
CodeChatSession = _PreviewCodeChatSession
CodeGenerationModel = _PreviewCodeGenerationModel
TextGenerationModel = _PreviewTextGenerationModel
TextEmbeddingModel = _PreviewTextEmbeddingModel
Expand Down

0 comments on commit 01989b1

Please sign in to comment.