Skip to content

Commit

Permalink
feat: enable grounding to TextGenerationModel predict and predict_asy…
Browse files Browse the repository at this point in the history
…nc methods

PiperOrigin-RevId: 578978864
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 2, 2023
1 parent 2147634 commit b0b4e6b
Show file tree
Hide file tree
Showing 4 changed files with 367 additions and 9 deletions.
6 changes: 4 additions & 2 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ def test_text_generation(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = TextGenerationModel.from_pretrained("google/text-bison@001")

grounding_source = language_models.WebSearchGroundingSource()
assert model.predict(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0.0,
top_p=1.0,
top_k=5,
stop_sequences=["# %%"],
grounding_source=grounding_source,
).text

def test_text_generation_preview_count_tokens(self):
Expand All @@ -77,14 +78,15 @@ async def test_text_generation_model_predict_async(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = TextGenerationModel.from_pretrained("google/text-bison@001")

grounding_source = language_models.WebSearchGroundingSource()
response = await model.predict_async(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0.0,
top_p=1.0,
top_k=5,
stop_sequences=["# %%"],
grounding_source=grounding_source,
)
assert response.text

Expand Down
198 changes: 195 additions & 3 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

# pylint: disable=protected-access,bad-continuation

import dataclasses
import json
import pytest
from importlib import reload
Expand Down Expand Up @@ -74,6 +74,7 @@
from vertexai.language_models import (
_evaluatable_language_models,
)
from vertexai.language_models import GroundingSource
from google.cloud.aiplatform_v1 import Execution as GapicExecution
from google.cloud.aiplatform.compat.types import (
encryption_spec as gca_encryption_spec,
Expand Down Expand Up @@ -166,6 +167,53 @@
},
}

_TEST_GROUNDING_WEB_SEARCH = GroundingSource.WebSearch()

_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE = GroundingSource.VertexAISearch(
data_store_id="test_datastore", location="global"
)

_TEST_TEXT_GENERATION_PREDICTION_GROUNDING = {
"safetyAttributes": {
"categories": ["Violent"],
"blocked": False,
"scores": [0.10000000149011612],
},
"groundingMetadata": {
"citations": [
{"url": "url1", "startIndex": 1, "endIndex": 2},
{"url": "url2", "startIndex": 3, "endIndex": 4},
]
},
"content": """
Ingredients:
* 3 cups all-purpose flour
Instructions:
1. Preheat oven to 350 degrees F (175 degrees C).""",
}

_EXPECTED_PARSED_GROUNDING_METADATA = {
"citations": [
{
"url": "url1",
"start_index": 1,
"end_index": 2,
"title": None,
"license": None,
"publication_date": None,
},
{
"url": "url2",
"start_index": 3,
"end_index": 4,
"title": None,
"license": None,
"publication_date": None,
},
]
}

_TEST_TEXT_GENERATION_PREDICTION = {
"safetyAttributes": {
"categories": ["Violent"],
Expand Down Expand Up @@ -342,7 +390,6 @@ def reverse_string_2(s):""",
"total_billable_characters": 25,
}


_TEST_TEXT_BISON_TRAINING_DF = pd.DataFrame(
{
"input_text": [
Expand Down Expand Up @@ -1392,6 +1439,78 @@ def test_text_generation_multiple_candidates(self):
response.candidates[0].text == _TEST_TEXT_GENERATION_PREDICTION["content"]
)

def test_text_generation_multiple_candidates_grounding(self):
"""Tests the text generation model with multiple candidates with web grounding."""
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

gca_predict_response = gca_prediction_service.PredictResponse()
# Discrepancy between the number of `instances` and the number of `predictions`
# is a violation of the prediction service invariant, but the service does this.
gca_predict_response.predictions.append(
_TEST_TEXT_GENERATION_PREDICTION_GROUNDING
)
gca_predict_response.predictions.append(
_TEST_TEXT_GENERATION_PREDICTION_GROUNDING
)

test_grounding_sources = [
_TEST_GROUNDING_WEB_SEARCH,
_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE,
]
datastore_path = (
"projects/test-project/locations/global/"
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
}
]
},
]

for test_grounding_source, expected_grounding_source in zip(
test_grounding_sources, expected_grounding_sources
):
with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
) as mock_predict:
response = model.predict(
"What is the best recipe for banana bread? Recipe:",
candidate_count=2,
grounding_source=test_grounding_source,
)
prediction_parameters = mock_predict.call_args[1]["parameters"]
assert prediction_parameters["candidateCount"] == 2
assert prediction_parameters["groundingConfig"] == expected_grounding_source
assert (
response.text == _TEST_TEXT_GENERATION_PREDICTION_GROUNDING["content"]
)
assert len(response.candidates) == 2
assert (
response.candidates[0].text
== _TEST_TEXT_GENERATION_PREDICTION_GROUNDING["content"]
)
assert (
dataclasses.asdict(response.candidates[0].grounding_metadata)
== _EXPECTED_PARSED_GROUNDING_METADATA
)

@pytest.mark.asyncio
async def test_text_generation_async(self):
"""Tests the text generation model."""
Expand Down Expand Up @@ -1435,6 +1554,79 @@ async def test_text_generation_async(self):
assert prediction_parameters["stopSequences"] == ["\n"]
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]

@pytest.mark.asyncio
async def test_text_generation_multiple_candidates_grounding_async(self):
"""Tests the text generation model with multiple candidates async with web grounding."""
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

gca_predict_response = gca_prediction_service.PredictResponse()
# Discrepancy between the number of `instances` and the number of `predictions`
# is a violation of the prediction service invariant, but the service does this.
gca_predict_response.predictions.append(
_TEST_TEXT_GENERATION_PREDICTION_GROUNDING
)

test_grounding_sources = [
_TEST_GROUNDING_WEB_SEARCH,
_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE,
]
datastore_path = (
"projects/test-project/locations/global/"
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
}
]
},
]

for test_grounding_source, expected_grounding_source in zip(
test_grounding_sources, expected_grounding_sources
):
with mock.patch.object(
target=prediction_service_async_client.PredictionServiceAsyncClient,
attribute="predict",
return_value=gca_predict_response,
) as mock_predict:
response = await model.predict_async(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0.0,
top_p=1.0,
top_k=5,
stop_sequences=["\n"],
grounding_source=test_grounding_source,
)
prediction_parameters = mock_predict.call_args[1]["parameters"]
assert prediction_parameters["maxDecodeSteps"] == 128
assert prediction_parameters["temperature"] == 0.0
assert prediction_parameters["topP"] == 1.0
assert prediction_parameters["topK"] == 5
assert prediction_parameters["stopSequences"] == ["\n"]
assert prediction_parameters["groundingConfig"] == expected_grounding_source
assert (
response.text == _TEST_TEXT_GENERATION_PREDICTION_GROUNDING["content"]
)
assert (
dataclasses.asdict(response.grounding_metadata)
== _EXPECTED_PARSED_GROUNDING_METADATA
)

def test_text_generation_model_predict_streaming(self):
"""Tests the TextGenerationModel.predict_streaming method."""
with mock.patch.object(
Expand Down Expand Up @@ -1867,7 +2059,7 @@ def test_tune_code_generation_model(
_CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.CodeGenerationModel.from_pretrained(
model = preview_language_models.CodeGenerationModel.from_pretrained(
"code-bison@001"
)
# The tune_model call needs to be inside the PublisherModel mock
Expand Down
2 changes: 2 additions & 0 deletions vertexai/language_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TextEmbeddingModel,
TextGenerationModel,
TextGenerationResponse,
GroundingSource,
)

__all__ = [
Expand All @@ -42,4 +43,5 @@
"TextEmbeddingModel",
"TextGenerationModel",
"TextGenerationResponse",
"GroundingSource",
]
Loading

0 comments on commit b0b4e6b

Please sign in to comment.