Skip to content

Commit

Permalink
feat: GenAI - Added system_instruction and tools support to `Gene…
Browse files Browse the repository at this point in the history
…rativeModel.count_tokens`

PiperOrigin-RevId: 669000052
  • Loading branch information
happy-qiao authored and copybara-github committed Aug 29, 2024
1 parent 20f2cad commit 50fca69
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
33 changes: 33 additions & 0 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,36 @@ def test_compute_tokens_from_text(self, api_endpoint_env_name):
assert token_info.role
# Lightly validate that the tokens are not Base64 encoded
assert b"=" not in token_info.tokens

def test_count_tokens_from_text(self):
plain_model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
model = generative_models.GenerativeModel(
GEMINI_MODEL_NAME, system_instruction=["You are a chatbot."]
)
get_current_weather_func = generative_models.FunctionDeclaration.from_func(
get_current_weather
)
weather_tool = generative_models.Tool(
function_declarations=[get_current_weather_func],
)
content = ["Why is sky blue?", "Explain it like I'm 5."]

response_without_si = plain_model.count_tokens(content)
response_with_si = model.count_tokens(content)
response_with_si_and_tool = model.count_tokens(
content,
tools=[weather_tool],
)

# system instruction + user prompt
assert response_with_si.total_tokens > response_without_si.total_tokens
assert (
response_with_si.total_billable_characters
> response_without_si.total_billable_characters
)
# system instruction + user prompt + tool
assert response_with_si_and_tool.total_tokens > response_with_si.total_tokens
assert (
response_with_si_and_tool.total_billable_characters
> response_with_si.total_billable_characters
)
25 changes: 21 additions & 4 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ async def async_generator():
return async_generator()

def count_tokens(
self, contents: ContentsType
self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens.
Expand All @@ -836,22 +836,32 @@ def count_tokens(
* str, Image, Part,
* List[Union[str, Image, Part]],
* List[Content]
tools: A list of tools (functions) that the model can try calling.
Returns:
A CountTokensResponse object that has the following attributes:
total_tokens: The total number of tokens counted across all instances from the request.
total_billable_characters: The total number of billable characters counted across all instances from the request.
"""
request = self._prepare_request(
contents=contents,
tools=tools,
)
return self._prediction_client.count_tokens(
request=gapic_prediction_service_types.CountTokensRequest(
endpoint=self._prediction_resource_name,
model=self._prediction_resource_name,
contents=self._prepare_request(contents=contents).contents,
contents=request.contents,
system_instruction=request.system_instruction,
tools=request.tools,
)
)

async def count_tokens_async(
self, contents: ContentsType
self,
contents: ContentsType,
*,
tools: Optional[List["Tool"]] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens asynchronously.
Expand All @@ -863,17 +873,24 @@ async def count_tokens_async(
* str, Image, Part,
* List[Union[str, Image, Part]],
* List[Content]
tools: A list of tools (functions) that the model can try calling.
Returns:
And awaitable for a CountTokensResponse object that has the following attributes:
total_tokens: The total number of tokens counted across all instances from the request.
total_billable_characters: The total number of billable characters counted across all instances from the request.
"""
request = self._prepare_request(
contents=contents,
tools=tools,
)
return await self._prediction_async_client.count_tokens(
request=gapic_prediction_service_types.CountTokensRequest(
endpoint=self._prediction_resource_name,
model=self._prediction_resource_name,
contents=self._prepare_request(contents=contents).contents,
contents=request.contents,
system_instruction=request.system_instruction,
tools=request.tools,
)
)

Expand Down

0 comments on commit 50fca69

Please sign in to comment.