Skip to content

Commit

Permalink
feat: add preview count_tokens method to CodeGenerationModel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 575318395
  • Loading branch information
sararob authored and copybara-github committed Oct 20, 2023
1 parent 01989b1 commit 96e7f7d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
37 changes: 37 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2771,6 +2771,43 @@ def test_code_generation_multiple_candidates(self):
response.candidates[0].text == _TEST_CODE_GENERATION_PREDICTION["content"]
)

def test_code_generation_preview_count_tokens(self):
"""Tests the count_tokens method in CodeGenerationModel."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CODE_COMPLETION_BISON_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.CodeGenerationModel.from_pretrained(
"code-gecko@001"
)

gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse(
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
"total_billable_characters"
],
)

with mock.patch.object(
target=prediction_service_client_v1beta1.PredictionServiceClient,
attribute="count_tokens",
return_value=gca_count_tokens_response,
):
response = model.count_tokens("def reverse_string(s):")

assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
assert (
response.total_billable_characters
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
)

def test_code_completion(self):
"""Tests code completion with the code generation model."""
aiplatform.init(
Expand Down
42 changes: 41 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2648,7 +2648,47 @@ async def predict_streaming_async(
yield _parse_text_generation_model_response(prediction_obj)


class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin):
class _CountTokensCodeGenerationMixin(_LanguageModel):
"""Mixin for code generation models that support the CountTokens API"""

def count_tokens(
self,
prefix: str,
*,
suffix: Optional[str] = None,
) -> CountTokensResponse:
"""Counts the tokens and billable characters for a given code generation prompt.
Note: this does not make a prediction request to the model, it only counts the tokens
in the request.
Args:
prefix (str): Code before the current point.
suffix (str): Code after the current point.
Returns:
A `CountTokensResponse` object that contains the number of tokens
in the text and the number of billable characters.
"""
prediction_request = {"prefix": prefix, "suffix": suffix}

count_tokens_response = self._endpoint._prediction_client.select_version(
"v1beta1"
).count_tokens(
endpoint=self._endpoint_name,
instances=[prediction_request],
)

return CountTokensResponse(
total_tokens=count_tokens_response.total_tokens,
total_billable_characters=count_tokens_response.total_billable_characters,
_count_tokens_response=count_tokens_response,
)


class _PreviewCodeGenerationModel(
CodeGenerationModel, _TunableModelMixin, _CountTokensCodeGenerationMixin
):
__name__ = "CodeGenerationModel"
__module__ = "vertexai.preview.language_models"

Expand Down

0 comments on commit 96e7f7d

Please sign in to comment.