From b0b4e6b8243cbdb829288e3fc204d94005f1e8b4 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Thu, 2 Nov 2023 14:26:07 -0700 Subject: [PATCH] feat: enable grounding to TextGenerationModel predict and predict_async methods PiperOrigin-RevId: 578978864 --- .../system/aiplatform/test_language_models.py | 6 +- tests/unit/aiplatform/test_language_models.py | 198 +++++++++++++++++- vertexai/language_models/__init__.py | 2 + vertexai/language_models/_language_models.py | 170 ++++++++++++++- 4 files changed, 367 insertions(+), 9 deletions(-) diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py index d522f0f09a..10bc88e89b 100644 --- a/tests/system/aiplatform/test_language_models.py +++ b/tests/system/aiplatform/test_language_models.py @@ -50,7 +50,7 @@ def test_text_generation(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) model = TextGenerationModel.from_pretrained("google/text-bison@001") - + grounding_source = language_models.WebSearchGroundingSource() assert model.predict( "What is the best recipe for banana bread? Recipe:", max_output_tokens=128, @@ -58,6 +58,7 @@ def test_text_generation(self): top_p=1.0, top_k=5, stop_sequences=["# %%"], + grounding_source=grounding_source, ).text def test_text_generation_preview_count_tokens(self): @@ -77,7 +78,7 @@ async def test_text_generation_model_predict_async(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) model = TextGenerationModel.from_pretrained("google/text-bison@001") - + grounding_source = language_models.WebSearchGroundingSource() response = await model.predict_async( "What is the best recipe for banana bread? Recipe:", max_output_tokens=128, @@ -85,6 +86,7 @@ async def test_text_generation_model_predict_async(self): top_p=1.0, top_k=5, stop_sequences=["# %%"], + grounding_source=grounding_source, ) assert response.text diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 7fffb86c7c..94646a407a 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -16,7 +16,7 @@ # # pylint: disable=protected-access,bad-continuation - +import dataclasses import json import pytest from importlib import reload @@ -74,6 +74,7 @@ from vertexai.language_models import ( _evaluatable_language_models, ) +from vertexai.language_models import GroundingSource from google.cloud.aiplatform_v1 import Execution as GapicExecution from google.cloud.aiplatform.compat.types import ( encryption_spec as gca_encryption_spec, @@ -166,6 +167,53 @@ }, } +_TEST_GROUNDING_WEB_SEARCH = GroundingSource.WebSearch() + +_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE = GroundingSource.VertexAISearch( + data_store_id="test_datastore", location="global" +) + +_TEST_TEXT_GENERATION_PREDICTION_GROUNDING = { + "safetyAttributes": { + "categories": ["Violent"], + "blocked": False, + "scores": [0.10000000149011612], + }, + "groundingMetadata": { + "citations": [ + {"url": "url1", "startIndex": 1, "endIndex": 2}, + {"url": "url2", "startIndex": 3, "endIndex": 4}, + ] + }, + "content": """ +Ingredients: +* 3 cups all-purpose flour + +Instructions: +1. Preheat oven to 350 degrees F (175 degrees C).""", +} + +_EXPECTED_PARSED_GROUNDING_METADATA = { + "citations": [ + { + "url": "url1", + "start_index": 1, + "end_index": 2, + "title": None, + "license": None, + "publication_date": None, + }, + { + "url": "url2", + "start_index": 3, + "end_index": 4, + "title": None, + "license": None, + "publication_date": None, + }, + ] +} + _TEST_TEXT_GENERATION_PREDICTION = { "safetyAttributes": { "categories": ["Violent"], @@ -342,7 +390,6 @@ def reverse_string_2(s):""", "total_billable_characters": 25, } - _TEST_TEXT_BISON_TRAINING_DF = pd.DataFrame( { "input_text": [ @@ -1392,6 +1439,78 @@ def test_text_generation_multiple_candidates(self): response.candidates[0].text == _TEST_TEXT_GENERATION_PREDICTION["content"] ) + def test_text_generation_multiple_candidates_grounding(self): + """Tests the text generation model with multiple candidates with web grounding.""" + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = language_models.TextGenerationModel.from_pretrained( + "text-bison@001" + ) + + gca_predict_response = gca_prediction_service.PredictResponse() + # Discrepancy between the number of `instances` and the number of `predictions` + # is a violation of the prediction service invariant, but the service does this. + gca_predict_response.predictions.append( + _TEST_TEXT_GENERATION_PREDICTION_GROUNDING + ) + gca_predict_response.predictions.append( + _TEST_TEXT_GENERATION_PREDICTION_GROUNDING + ) + + test_grounding_sources = [ + _TEST_GROUNDING_WEB_SEARCH, + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, + ] + datastore_path = ( + "projects/test-project/locations/global/" + "collections/default_collection/dataStores/test_datastore" + ) + expected_grounding_sources = [ + {"sources": [{"type": "WEB"}]}, + { + "sources": [ + { + "type": "ENTERPRISE", + "enterpriseDatastore": datastore_path, + } + ] + }, + ] + + for test_grounding_source, expected_grounding_source in zip( + test_grounding_sources, expected_grounding_sources + ): + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ) as mock_predict: + response = model.predict( + "What is the best recipe for banana bread? Recipe:", + candidate_count=2, + grounding_source=test_grounding_source, + ) + prediction_parameters = mock_predict.call_args[1]["parameters"] + assert prediction_parameters["candidateCount"] == 2 + assert prediction_parameters["groundingConfig"] == expected_grounding_source + assert ( + response.text == _TEST_TEXT_GENERATION_PREDICTION_GROUNDING["content"] + ) + assert len(response.candidates) == 2 + assert ( + response.candidates[0].text + == _TEST_TEXT_GENERATION_PREDICTION_GROUNDING["content"] + ) + assert ( + dataclasses.asdict(response.candidates[0].grounding_metadata) + == _EXPECTED_PARSED_GROUNDING_METADATA + ) + @pytest.mark.asyncio async def test_text_generation_async(self): """Tests the text generation model.""" @@ -1435,6 +1554,79 @@ async def test_text_generation_async(self): assert prediction_parameters["stopSequences"] == ["\n"] assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"] + @pytest.mark.asyncio + async def test_text_generation_multiple_candidates_grounding_async(self): + """Tests the text generation model with multiple candidates async with web grounding.""" + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = language_models.TextGenerationModel.from_pretrained( + "text-bison@001" + ) + + gca_predict_response = gca_prediction_service.PredictResponse() + # Discrepancy between the number of `instances` and the number of `predictions` + # is a violation of the prediction service invariant, but the service does this. + gca_predict_response.predictions.append( + _TEST_TEXT_GENERATION_PREDICTION_GROUNDING + ) + + test_grounding_sources = [ + _TEST_GROUNDING_WEB_SEARCH, + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, + ] + datastore_path = ( + "projects/test-project/locations/global/" + "collections/default_collection/dataStores/test_datastore" + ) + expected_grounding_sources = [ + {"sources": [{"type": "WEB"}]}, + { + "sources": [ + { + "type": "ENTERPRISE", + "enterpriseDatastore": datastore_path, + } + ] + }, + ] + + for test_grounding_source, expected_grounding_source in zip( + test_grounding_sources, expected_grounding_sources + ): + with mock.patch.object( + target=prediction_service_async_client.PredictionServiceAsyncClient, + attribute="predict", + return_value=gca_predict_response, + ) as mock_predict: + response = await model.predict_async( + "What is the best recipe for banana bread? Recipe:", + max_output_tokens=128, + temperature=0.0, + top_p=1.0, + top_k=5, + stop_sequences=["\n"], + grounding_source=test_grounding_source, + ) + prediction_parameters = mock_predict.call_args[1]["parameters"] + assert prediction_parameters["maxDecodeSteps"] == 128 + assert prediction_parameters["temperature"] == 0.0 + assert prediction_parameters["topP"] == 1.0 + assert prediction_parameters["topK"] == 5 + assert prediction_parameters["stopSequences"] == ["\n"] + assert prediction_parameters["groundingConfig"] == expected_grounding_source + assert ( + response.text == _TEST_TEXT_GENERATION_PREDICTION_GROUNDING["content"] + ) + assert ( + dataclasses.asdict(response.grounding_metadata) + == _EXPECTED_PARSED_GROUNDING_METADATA + ) + def test_text_generation_model_predict_streaming(self): """Tests the TextGenerationModel.predict_streaming method.""" with mock.patch.object( @@ -1867,7 +2059,7 @@ def test_tune_code_generation_model( _CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT ), ): - model = language_models.CodeGenerationModel.from_pretrained( + model = preview_language_models.CodeGenerationModel.from_pretrained( "code-bison@001" ) # The tune_model call needs to be inside the PublisherModel mock diff --git a/vertexai/language_models/__init__.py b/vertexai/language_models/__init__.py index 8d16584ecb..c1991017e8 100644 --- a/vertexai/language_models/__init__.py +++ b/vertexai/language_models/__init__.py @@ -27,6 +27,7 @@ TextEmbeddingModel, TextGenerationModel, TextGenerationResponse, + GroundingSource, ) __all__ = [ @@ -42,4 +43,5 @@ "TextEmbeddingModel", "TextGenerationModel", "TextGenerationResponse", + "GroundingSource", ] diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 7697657d8d..0acf1ce05f 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -14,6 +14,7 @@ # """Classes for working with language models.""" +import abc import dataclasses from typing import ( Any, @@ -688,6 +689,135 @@ class TuningEvaluationSpec: tensorboard: Optional[Union[aiplatform.Tensorboard, str]] = None +class _GroundingSourceBase(abc.ABC): + """Interface of grounding source dataclass for grounding.""" + + @abc.abstractmethod + def _to_grounding_source_dict(self) -> Dict[str, Any]: + """construct grounding source into dictionary""" + pass + + +@dataclasses.dataclass +class WebSearch(_GroundingSourceBase): + """WebSearch represents a grounding source using public web search.""" + + _type: str = "WEB" + + def _to_grounding_source_dict(self) -> Dict[str, Any]: + return {"type": self._type} + + +@dataclasses.dataclass +class VertexAISearch(_GroundingSourceBase): + """VertexAISearchDatastore represents a grounding source using Vertex AI Search datastore + Attributes: + data_store_id: Data store ID of the Vertex AI Search datastore. + location: GCP multi region where you have set up your Vertex AI Search data store. Possible values can be `global`, `us`, `eu`, etc. + Learn more about Vertex AI Search location here: + https://cloud.google.com/generative-ai-app-builder/docs/locations + project: The project where you have set up your Vertex AI Search. + If not specified, will assume that your Vertex AI Search is within your current project. + """ + + _data_store_id: str + _location: str + _type: str = "ENTERPRISE" + + def __init__( + self, data_store_id: str, location: str, project: Optional[str] = None + ): + self._data_store_id = data_store_id + self._location = location + self._project = project + + def _get_datastore_path(self) -> str: + _project = self._project or aiplatform_initializer.global_config.project + return ( + f"projects/{_project}/locations/{self._location}" + f"/collections/default_collection/dataStores/{self._data_store_id}" + ) + + def _to_grounding_source_dict(self) -> Dict[str, Any]: + return {"type": self._type, "enterpriseDatastore": self._get_datastore_path()} + + +@dataclasses.dataclass +class GroundingSource: + + WebSearch = WebSearch + VertexAISearch = VertexAISearch + + +@dataclasses.dataclass +class GroundingCitation: + """Citaion used from grounding. + Attributes: + start_index: Index in the prediction output where the citation starts + (inclusive). Must be >= 0 and < end_index. + end_index: Index in the prediction output where the citation ends + (exclusive). Must be > start_index and < len(output). + url: URL associated with this citation. If present, this URL links to the + webpage of the source of this citation. Possible URLs include news + websites, GitHub repos, etc. + title: Title associated with this citation. If present, it refers to the title + of the source of this citation. Possible titles include + news titles, book titles, etc. + license: License associated with this citation. If present, it refers to the + license of the source of this citation. Possible licenses include code + licenses, e.g., mit license. + publication_date: Publication date associated with this citation. If present, it refers to + the date at which the source of this citation was published. + Possible formats are YYYY, YYYY-MM, YYYY-MM-DD. + """ + + start_index: Optional[int] = None + end_index: Optional[int] = None + url: Optional[str] = None + title: Optional[str] = None + license: Optional[str] = None + publication_date: Optional[str] = None + + +@dataclasses.dataclass +class GroundingMetadata: + """Metadata for grounding. + Attributes: + citations: List of grounding citations. + """ + + citations: Optional[List[GroundingCitation]] = None + + def _parse_citation_from_dict( + self, citation_dict_camel: Dict[str, Any] + ) -> GroundingCitation: + _start_index = citation_dict_camel.get("startIndex") + _end_index = citation_dict_camel.get("endIndex") + if _start_index is not None: + _start_index = int(_start_index) + if _end_index is not None: + _end_index = int(_end_index) + _url = citation_dict_camel.get("url") + _title = citation_dict_camel.get("title") + _license = citation_dict_camel.get("license") + _publication_date = citation_dict_camel.get("publicationDate") + + return GroundingCitation( + start_index=_start_index, + end_index=_end_index, + url=_url, + title=_title, + license=_license, + publication_date=_publication_date, + ) + + def __init__(self, response: Optional[Dict[str, Any]] = {}): + self.citations = [ + self._parse_citation_from_dict(citation) + for citation in response.get("citations", []) + ] + + @dataclasses.dataclass class TextGenerationResponse: """TextGenerationResponse represents a response of a language model. @@ -697,6 +827,7 @@ class TextGenerationResponse: safety_attributes: Scores for safety attributes. Learn more about the safety attributes here: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions + grounding_metadata: Metadata for grounding. """ __module__ = "vertexai.language_models" @@ -705,12 +836,22 @@ class TextGenerationResponse: _prediction_response: Any is_blocked: bool = False safety_attributes: Dict[str, float] = dataclasses.field(default_factory=dict) + grounding_metadata: Optional[GroundingMetadata] = None def __repr__(self): if self.text: return self.text + # Falling back to the full representation + elif self.grounding_metadata is not None: + return ( + "TextGenerationResponse(" + f"text={self.text!r}" + f", is_blocked={self.is_blocked!r}" + f", safety_attributes={self.safety_attributes!r}" + f", grounding_metadata={self.grounding_metadata!r}" + ")" + ) else: - # Falling back to the full representation return ( "TextGenerationResponse(" f"text={self.text!r}" @@ -735,6 +876,7 @@ class MultiCandidateTextGenerationResponse(TextGenerationResponse): safety_attributes: Scores for safety attributes for the first candidate. Learn more about the safety attributes here: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions + grounding_metadata: Grounding metadata for the first candidate. candidates: The candidate responses. Usually contains a single candidate unless `candidate_count` is used. """ @@ -780,6 +922,9 @@ def predict( top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, + grounding_source: Optional[ + Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + ] = None, ) -> "MultiCandidateTextGenerationResponse": """Gets model response for a single prompt. @@ -791,6 +936,7 @@ def predict( top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. stop_sequences: Customized stop sequences to stop the decoding process. candidate_count: Number of response candidates to return. + grounding_source: If specified, grounding feature will be enabled using the grounding source. Default: None. Returns: A `MultiCandidateTextGenerationResponse` object that contains the text produced by the model. @@ -803,6 +949,7 @@ def predict( top_p=top_p, stop_sequences=stop_sequences, candidate_count=candidate_count, + grounding_source=grounding_source, ) prediction_response = self._endpoint.predict( @@ -824,6 +971,9 @@ async def predict_async( top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, + grounding_source: Optional[ + Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + ] = None, ) -> "MultiCandidateTextGenerationResponse": """Asynchronously gets model response for a single prompt. @@ -835,6 +985,7 @@ async def predict_async( top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. stop_sequences: Customized stop sequences to stop the decoding process. candidate_count: Number of response candidates to return. + grounding_source: If specified, grounding feature will be enabled using the grounding source. Default: None. Returns: A `MultiCandidateTextGenerationResponse` object that contains the text produced by the model. @@ -847,6 +998,7 @@ async def predict_async( top_p=top_p, stop_sequences=stop_sequences, candidate_count=candidate_count, + grounding_source=grounding_source, ) prediction_response = await self._endpoint.predict_async( @@ -966,6 +1118,9 @@ def _create_text_generation_prediction_request( top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, + grounding_source: Optional[ + Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + ] = None, ) -> "_PredictionRequest": """Prepares the text generation request for a single prompt. @@ -977,6 +1132,8 @@ def _create_text_generation_prediction_request( top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. stop_sequences: Customized stop sequences to stop the decoding process. candidate_count: Number of candidates to return. + grounding_source: If specified, grounding feature will be enabled using the grounding source. Default: None. + Returns: A `_PredictionRequest` object that contains prediction instance and parameters. @@ -1006,6 +1163,10 @@ def _create_text_generation_prediction_request( if candidate_count is not None: prediction_parameters["candidateCount"] = candidate_count + if grounding_source is not None: + sources = [grounding_source._to_grounding_source_dict()] + prediction_parameters["groundingConfig"] = {"sources": sources} + return _PredictionRequest( instance=instance, parameters=prediction_parameters, @@ -1019,6 +1180,7 @@ def _parse_text_generation_model_response( """Converts the raw text_generation model response to `TextGenerationResponse`.""" prediction = prediction_response.predictions[prediction_idx] safety_attributes_dict = prediction.get("safetyAttributes", {}) + grounding_metadata_dict = prediction.get("groundingMetadata", {}) return TextGenerationResponse( text=prediction["content"], _prediction_response=prediction_response, @@ -1029,6 +1191,7 @@ def _parse_text_generation_model_response( safety_attributes_dict.get("scores") or [], ) ), + grounding_metadata=GroundingMetadata(grounding_metadata_dict), ) @@ -1054,6 +1217,7 @@ def _parse_text_generation_model_multi_candidate_response( _prediction_response=prediction_response, is_blocked=candidates[0].is_blocked, safety_attributes=candidates[0].safety_attributes, + grounding_metadata=candidates[0].grounding_metadata, candidates=candidates, ) @@ -2689,9 +2853,7 @@ class CodeGenerationModel(_CodeGenerationModel, _TunableTextModelMixin): pass -class _PreviewCodeGenerationModel( - CodeGenerationModel, _CountTokensCodeGenerationMixin -): +class _PreviewCodeGenerationModel(CodeGenerationModel, _CountTokensCodeGenerationMixin): __name__ = "CodeGenerationModel" __module__ = "vertexai.preview.language_models"