From 61cff4bda371e3baa61d98528d18093e5fa890b4 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Tue, 9 Jan 2024 01:29:53 -0800 Subject: [PATCH 1/7] feat: Support private service connect for `MatchingEngineIndexEndpoint` `match()` and `read_index_datapoints()`. PiperOrigin-RevId: 596852286 --- .../matching_engine_index_endpoint.py | 89 +++++++++++++++---- .../test_matching_engine_index_endpoint.py | 77 +++++++++++++++- 2 files changed, 146 insertions(+), 20 deletions(-) diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index 85e3ceff7d..5501c07d2f 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -220,6 +220,9 @@ def __init__( if self.public_endpoint_domain_name: self._public_match_client = self._instantiate_public_match_client() + self._match_grpc_stub_cache = {} + self._private_service_connect_ip_address = None + @classmethod def create( cls, @@ -521,33 +524,62 @@ def _instantiate_public_match_client( def _instantiate_private_match_service_stub( self, - deployed_index_id: str, + deployed_index_id: Optional[str] = None, + ip_address: Optional[str] = None, ) -> match_service_pb2_grpc.MatchServiceStub: """Helper method to instantiate private match service stub. Args: deployed_index_id (str): - Required. The user specified ID of the - DeployedIndex. + Optional. Required for private service access endpoint. + The user specified ID of the DeployedIndex. + ip_address (str): + Optional. Required for private service connect. The ip address + the forwarding rule makes use of. Returns: stub (match_service_pb2_grpc.MatchServiceStub): Initialized match service stub. + Raises: + RuntimeError: No deployed index with id deployed_index_id found + ValueError: Should not set ip address for networks other than + private service connect. """ - # Find the deployed index by id - deployed_indexes = [ - deployed_index - for deployed_index in self.deployed_indexes - if deployed_index.id == deployed_index_id - ] + if ip_address: + # Should only set for Private Service Connect + if self.public_endpoint_domain_name: + raise ValueError( + "MatchingEngineIndexEndpoint is set to use ", + "public network. Could not establish connection using " + "provided ip address", + ) + elif self.private_service_access_network: + raise ValueError( + "MatchingEngineIndexEndpoint is set to use ", + "private service access network. Could not establish " + "connection using provided ip address", + ) + else: + # Private Service Access, find server ip for deployed index + deployed_indexes = [ + deployed_index + for deployed_index in self.deployed_indexes + if deployed_index.id == deployed_index_id + ] - if not deployed_indexes: - raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found") + if not deployed_indexes: + raise RuntimeError( + f"No deployed index with id '{deployed_index_id}' found" + ) - # Retrieve server ip from deployed index - server_ip = deployed_indexes[0].private_endpoints.match_grpc_address + # Retrieve server ip from deployed index + ip_address = deployed_indexes[0].private_endpoints.match_grpc_address - # Set up channel and stub - channel = grpc.insecure_channel("{}:10000".format(server_ip)) - return match_service_pb2_grpc.MatchServiceStub(channel) + if ip_address not in self._match_grpc_stub_cache: + # Set up channel and stub + channel = grpc.insecure_channel("{}:10000".format(ip_address)) + self._match_grpc_stub_cache[ + ip_address + ] = match_service_pb2_grpc.MatchServiceStub(channel) + return self._match_grpc_stub_cache[ip_address] @property def public_endpoint_domain_name(self) -> Optional[str]: @@ -555,6 +587,22 @@ def public_endpoint_domain_name(self) -> Optional[str]: self._assert_gca_resource_is_available() return self._gca_resource.public_endpoint_domain_name + @property + def private_service_access_network(self) -> Optional[str]: + """ "Private service access network.""" + self._assert_gca_resource_is_available() + return self._gca_resource.network + + @property + def private_service_connect_ip_address(self) -> Optional[str]: + """ "Private service connect ip address.""" + return self._private_service_connect_ip_address + + @private_service_connect_ip_address.setter + def private_service_connect_ip_address(self, ip_address: str) -> Optional[str]: + """ "Setter for private service connect ip address.""" + self._private_service_connect_ip_address = ip_address + def update( self, display_name: str, @@ -1300,7 +1348,8 @@ def read_index_datapoints( if not self._public_match_client: # Call private match service stub with BatchGetEmbeddings request embeddings = self._batch_get_embeddings( - deployed_index_id=deployed_index_id, ids=ids + deployed_index_id=deployed_index_id, + ids=ids, ) response = [] @@ -1362,7 +1411,8 @@ def _batch_get_embeddings( List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs. """ stub = self._instantiate_private_match_service_stub( - deployed_index_id=deployed_index_id + deployed_index_id=deployed_index_id, + ip_address=self._private_service_connect_ip_address, ) # Create the batch get embeddings request @@ -1420,7 +1470,8 @@ def match( List[List[MatchNeighbor]] - A list of nearest neighbors for each query. """ stub = self._instantiate_private_match_service_stub( - deployed_index_id=deployed_index_id + deployed_index_id=deployed_index_id, + ip_address=self._private_service_connect_ip_address, ) # Create the batch match request diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 6b6af65e68..76ec65692c 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -246,6 +246,7 @@ _TEST_RETURN_FULL_DATAPOINT = True _TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name" _TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"] +_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS = "10.128.0.5" _TEST_READ_INDEX_DATAPOINTS_RESPONSE = [ gca_index_v1beta1.IndexDatapoint( datapoint_id="1", @@ -1137,6 +1138,54 @@ def test_private_index_endpoint_find_neighbor_queries( ) index_endpoint_match_queries_mock.assert_called_with(batch_match_request) + @pytest.mark.usefixtures("get_index_endpoint_mock") + def test_index_private_service_connect_endpoint_match_queries( + self, index_endpoint_match_queries_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_index_endpoint.private_service_connect_ip_address = ( + _TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS + ) + my_index_endpoint.match( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + queries=_TEST_QUERIES, + num_neighbors=_TEST_NUM_NEIGHBOURS, + filter=_TEST_FILTER, + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + ) + + batch_request = match_service_pb2.BatchMatchRequest( + requests=[ + match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + requests=[ + match_service_pb2.MatchRequest( + num_neighbors=_TEST_NUM_NEIGHBOURS, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + float_val=_TEST_QUERIES[0], + restricts=[ + match_service_pb2.Namespace( + name="class", + allow_tokens=["token_1"], + deny_tokens=["token_2"], + ) + ], + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + ) + ], + ) + ] + ) + + index_endpoint_match_queries_mock.assert_called_with(batch_request) + @pytest.mark.usefixtures("get_index_public_endpoint_mock") def test_index_public_endpoint_match_queries( self, index_public_endpoint_match_queries_mock @@ -1330,7 +1379,7 @@ def test_index_endpoint_batch_get_embeddings( index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request) @pytest.mark.usefixtures("get_index_endpoint_mock") - def test_index_private_endpoint_read_index_datapoints( + def test_index_endpoint_find_neighbors_for_private_service_access( self, index_endpoint_batch_get_embeddings_mock ): aiplatform.init(project=_TEST_PROJECT) @@ -1350,3 +1399,29 @@ def test_index_private_endpoint_read_index_datapoints( index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request) assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE + + @pytest.mark.usefixtures("get_index_endpoint_mock") + def test_index_endpoint_find_neighbors_for_private_service_connect( + self, index_endpoint_batch_get_embeddings_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_index_endpoint.private_service_connect_ip = ( + _TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS + ) + response = my_index_endpoint.read_index_datapoints( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + ids=["1", "2"], + ) + + batch_request = match_service_pb2.BatchGetEmbeddingsRequest( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"] + ) + + index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request) + + assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE From d0f65fddb2d610924d6f72f92b51ca079b53b1be Mon Sep 17 00:00:00 2001 From: Matthew Tang Date: Tue, 9 Jan 2024 14:02:26 -0800 Subject: [PATCH 2/7] chore: Reset explanation_config in the system test test_model_monitoring PiperOrigin-RevId: 597038105 --- tests/system/aiplatform/test_model_monitoring.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/system/aiplatform/test_model_monitoring.py b/tests/system/aiplatform/test_model_monitoring.py index 6cc6dbed26..3d36e8e2ed 100644 --- a/tests/system/aiplatform/test_model_monitoring.py +++ b/tests/system/aiplatform/test_model_monitoring.py @@ -410,6 +410,8 @@ def test_mdm_two_models_invalid_configs_xai(self, shared_state): def test_mdm_notification_channel_alert_config(self, shared_state): self.endpoint = shared_state["resources"][0] aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + # Reset objective_config.explanation_config + objective_config.explanation_config = None # test model monitoring configurations job = aiplatform.ModelDeploymentMonitoringJob.create( display_name=self._make_display_name(key=JOB_NAME), From ad8d9c1df17578de3b893ebe46d00d457960da00 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Tue, 9 Jan 2024 23:04:16 -0800 Subject: [PATCH 3/7] feat: Add `return_full_datapoint` for `MatchEngineIndexEndpoint` `match()`. PiperOrigin-RevId: 597148566 --- .../matching_engine_index_endpoint.py | 8 ++++++ .../test_matching_engine_index_endpoint.py | 27 ++++++++++--------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index 5501c07d2f..9e003c1d5f 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -1262,6 +1262,7 @@ def find_neighbors( per_crowding_attribute_num_neighbors=per_crowding_attribute_neighbor_count, approx_num_neighbors=approx_num_neighbors, fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override, + return_full_datapoint=return_full_datapoint, ) # Create the FindNeighbors request @@ -1434,6 +1435,7 @@ def match( per_crowding_attribute_num_neighbors: Optional[int] = None, approx_num_neighbors: Optional[int] = None, fraction_leaf_nodes_to_search_override: Optional[float] = None, + return_full_datapoint: bool = False, ) -> List[List[MatchNeighbor]]: """Retrieves nearest neighbors for the given embedding queries on the specified deployed index for private endpoint only. @@ -1465,6 +1467,11 @@ def match( query time allows user to tune search performance. This value increase result in both search accuracy and latency increase. The value should be between 0.0 and 1.0. + return_full_datapoint (bool): + Optional. If set to true, the full datapoints (including all + vector values and of the nearest neighbors are returned. + Note that returning full datapoint will significantly increase the + latency and cost of the query. Returns: List[List[MatchNeighbor]] - A list of nearest neighbors for each query. @@ -1502,6 +1509,7 @@ def match( per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors, approx_num_neighbors=approx_num_neighbors, fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override, + embedding_enabled=return_full_datapoint, ) requests.append(request) diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 76ec65692c..184dec8a38 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -1060,6 +1060,7 @@ def test_private_index_endpoint_match_queries( per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT, ) batch_request = match_service_pb2.BatchMatchRequest( @@ -1081,6 +1082,7 @@ def test_private_index_endpoint_match_queries( per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + embedding_enabled=_TEST_RETURN_FULL_DATAPOINT, ) for i in range(len(_TEST_QUERIES)) ], @@ -1096,11 +1098,11 @@ def test_private_index_endpoint_find_neighbor_queries( ): aiplatform.init(project=_TEST_PROJECT) - my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + my_private_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( index_endpoint_name=_TEST_INDEX_ENDPOINT_ID ) - my_pubic_index_endpoint.find_neighbors( + my_private_index_endpoint.find_neighbors( deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=_TEST_QUERIES, num_neighbors=_TEST_NUM_NEIGHBOURS, @@ -1130,6 +1132,7 @@ def test_private_index_endpoint_find_neighbor_queries( per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + embedding_enabled=_TEST_RETURN_FULL_DATAPOINT, ) for test_query in _TEST_QUERIES ], @@ -1187,16 +1190,16 @@ def test_index_private_service_connect_endpoint_match_queries( index_endpoint_match_queries_mock.assert_called_with(batch_request) @pytest.mark.usefixtures("get_index_public_endpoint_mock") - def test_index_public_endpoint_match_queries( + def test_index_public_endpoint_find_neighbors_queries( self, index_public_endpoint_match_queries_mock ): aiplatform.init(project=_TEST_PROJECT) - my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + my_public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( index_endpoint_name=_TEST_INDEX_ENDPOINT_ID ) - my_pubic_index_endpoint.find_neighbors( + my_public_index_endpoint.find_neighbors( deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=_TEST_QUERIES, num_neighbors=_TEST_NUM_NEIGHBOURS, @@ -1208,7 +1211,7 @@ def test_index_public_endpoint_match_queries( ) find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest( - index_endpoint=my_pubic_index_endpoint.resource_name, + index_endpoint=my_public_index_endpoint.resource_name, deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=[ gca_match_service_v1beta1.FindNeighborsRequest.Query( @@ -1241,11 +1244,11 @@ def test_index_public_endpoint_match_queries_with_numeric_filtering( ): aiplatform.init(project=_TEST_PROJECT) - my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + my_public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( index_endpoint_name=_TEST_INDEX_ENDPOINT_ID ) - my_pubic_index_endpoint.find_neighbors( + my_public_index_endpoint.find_neighbors( deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=_TEST_QUERIES, num_neighbors=_TEST_NUM_NEIGHBOURS, @@ -1258,7 +1261,7 @@ def test_index_public_endpoint_match_queries_with_numeric_filtering( ) find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest( - index_endpoint=my_pubic_index_endpoint.resource_name, + index_endpoint=my_public_index_endpoint.resource_name, deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=[ gca_match_service_v1beta1.FindNeighborsRequest.Query( @@ -1337,18 +1340,18 @@ def test_index_public_endpoint_read_index_datapoints( ): aiplatform.init(project=_TEST_PROJECT) - my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + my_public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( index_endpoint_name=_TEST_INDEX_ENDPOINT_ID ) - my_pubic_index_endpoint.read_index_datapoints( + my_public_index_endpoint.read_index_datapoints( deployed_index_id=_TEST_DEPLOYED_INDEX_ID, ids=_TEST_IDS, ) read_index_datapoints_request = ( gca_match_service_v1beta1.ReadIndexDatapointsRequest( - index_endpoint=my_pubic_index_endpoint.resource_name, + index_endpoint=my_public_index_endpoint.resource_name, deployed_index_id=_TEST_DEPLOYED_INDEX_ID, ids=_TEST_IDS, ) From a75e81c9e8bfe577468205fc0fc97366ff06f19d Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 10 Jan 2024 10:48:53 -0800 Subject: [PATCH 4/7] feat: enable inline context in grounding to TextGenerationModel predict. PiperOrigin-RevId: 597296033 --- vertexai/language_models/_language_models.py | 58 ++++++++++++++++++-- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 2ae7a29e9c..ed96b20775 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -749,6 +749,27 @@ def _to_grounding_source_dict(self) -> Dict[str, Any]: } +@dataclasses.dataclass +class InlineContext(_GroundingSourceBase): + """InlineContext represents a grounding source using provided inline context. + Attributes: + inline_context: The content used as inline context. + """ + + inline_context: str + _type: str = dataclasses.field(default="INLINE", init=False, repr=False) + + def _to_grounding_source_dict(self) -> Dict[str, Any]: + return { + "sources": [ + { + "type": self._type, + } + ], + "inlineContext": self.inline_context, + } + + @dataclasses.dataclass class VertexAISearch(_GroundingSourceBase): """VertexAISearchDatastore represents a grounding source using Vertex AI Search datastore @@ -792,6 +813,7 @@ class GroundingSource: WebSearch = WebSearch VertexAISearch = VertexAISearch + InlineContext = InlineContext @dataclasses.dataclass @@ -976,7 +998,11 @@ def predict( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, logprobs: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -1053,7 +1079,11 @@ async def predict_async( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, logprobs: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -1284,7 +1314,11 @@ def _create_text_generation_prediction_request( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, logprobs: Optional[int] = None, presence_penalty: Optional[float] = None, @@ -2136,7 +2170,11 @@ def _prepare_request( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, ) -> _PredictionRequest: """Prepares a request for the language model. @@ -2289,7 +2327,11 @@ def send_message( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, ) -> "MultiCandidateTextGenerationResponse": """Sends message to the language model and gets a response. @@ -2352,7 +2394,11 @@ async def send_message_async( stop_sequences: Optional[List[str]] = None, candidate_count: Optional[int] = None, grounding_source: Optional[ - Union[GroundingSource.WebSearch, GroundingSource.VertexAISearch] + Union[ + GroundingSource.WebSearch, + GroundingSource.VertexAISearch, + GroundingSource.InlineContext, + ] ] = None, ) -> "MultiCandidateTextGenerationResponse": """Asynchronously sends message to the language model and gets a response. From ec23c963c0ae29644ec00749f7888ca698e16a88 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Wed, 10 Jan 2024 14:26:41 -0800 Subject: [PATCH 5/7] chore: GenAI - Fixed the `ChatSession.start_chat` type annotation PiperOrigin-RevId: 597356414 --- vertexai/generative_models/_generative_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index 9423b6ed20..4f6dddf2ca 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -621,7 +621,7 @@ async def count_tokens_async( def start_chat( self, *, - history: Optional[List[gapic_content_types.Content]] = None, + history: Optional[List["Content"]] = None, ) -> "ChatSession": """Creates a stateful chat session. From bbdd9e26b8a97e8c00f5ad42df8562974b3e5cce Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Thu, 11 Jan 2024 01:04:45 -0800 Subject: [PATCH 6/7] chore: GenAI - Added GenAI system tests PiperOrigin-RevId: 597478674 --- .../system/vertexai/test_generative_models.py | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 tests/system/vertexai/test_generative_models.py diff --git a/tests/system/vertexai/test_generative_models.py b/tests/system/vertexai/test_generative_models.py new file mode 100644 index 0000000000..5dd4f30523 --- /dev/null +++ b/tests/system/vertexai/test_generative_models.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access, g-multiple-import +"""System tests for generative models.""" + +import pytest + +# Google imports +from google import auth +from google.cloud import aiplatform +from tests.system.aiplatform import e2e_base +from vertexai.preview import generative_models + + +class TestGenerativeModels(e2e_base.TestEndToEnd): + """System tests for generative models.""" + + _temp_prefix = "temp_generative_models_test_" + + def setup_method(self): + super().setup_method() + credentials, _ = auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + aiplatform.init( + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + credentials=credentials, + ) + + def test_generate_content_from_text(self): + model = generative_models.GenerativeModel("gemini-pro") + response = model.generate_content("Why is sky blue?") + assert response.text + + @pytest.mark.asyncio + async def test_generate_content_async(self): + model = generative_models.GenerativeModel("gemini-pro") + response = await model.generate_content_async("Why is sky blue?") + assert response.text + + def test_generate_content_streaming(self): + model = generative_models.GenerativeModel("gemini-pro") + stream = model.generate_content("Why is sky blue?", stream=True) + for chunk in stream: + assert chunk.text + + @pytest.mark.asyncio + async def test_generate_content_streaming_async(self): + model = generative_models.GenerativeModel("gemini-pro") + async_stream = await model.generate_content_async( + "Why is sky blue?", + stream=True, + ) + async for chunk in async_stream: + assert chunk.text + + def test_generate_content_with_parameters(self): + model = generative_models.GenerativeModel("gemini-pro") + response = model.generate_content( + contents="Why is sky blue?", + generation_config=generative_models.GenerationConfig( + temperature=0.1, + top_p=0.95, + top_k=20, + candidate_count=1, + max_output_tokens=100, + stop_sequences=["STOP!"], + ), + safety_settings={ + generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH, + generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_NONE, + }, + ) + assert response.text + + def test_generate_content_from_list_of_content_dict(self): + model = generative_models.GenerativeModel("gemini-pro") + response = model.generate_content( + contents=[{"role": "user", "parts": [{"text": "Why is sky blue?"}]}] + ) + assert response.text + + def test_generate_content_from_remote_image(self): + vision_model = generative_models.GenerativeModel("gemini-pro-vision") + image_part = generative_models.Part.from_uri( + uri="gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg", + mime_type="image/jpeg", + ) + response = vision_model.generate_content(image_part) + assert response.text + assert "cat" in response.text + + def test_generate_content_from_text_and_remote_image(self): + vision_model = generative_models.GenerativeModel("gemini-pro-vision") + image_part = generative_models.Part.from_uri( + uri="gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg", + mime_type="image/jpeg", + ) + response = vision_model.generate_content( + contents=["What is shown in this image?", image_part], + ) + assert response.text + assert "cat" in response.text + + def test_generate_content_from_text_and_remote_video(self): + vision_model = generative_models.GenerativeModel("gemini-pro-vision") + video_part = generative_models.Part.from_uri( + uri="gs://cloud-samples-data/video/animals.mp4", + mime_type="video/mp4", + ) + response = vision_model.generate_content( + contents=["What is in the video?", video_part], + ) + assert response.text + assert "Zootopia" in response.text + + # Chat + + def test_send_message_from_text(self): + model = generative_models.GenerativeModel("gemini-pro") + chat = model.start_chat() + response1 = chat.send_message("I really like fantasy books.") + assert response1.text + assert len(chat.history) == 2 + + response2 = chat.send_message("What things do I like?.") + assert response2.text + assert len(chat.history) == 4 From 1fbf0493dc5fa2bb05f33a4319d79a81625e07cc Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Thu, 11 Jan 2024 11:05:41 -0800 Subject: [PATCH 7/7] feat: Fix dataset export system test PiperOrigin-RevId: 597603710 --- tests/system/aiplatform/test_dataset.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/system/aiplatform/test_dataset.py b/tests/system/aiplatform/test_dataset.py index 7de2b974b7..351fa22358 100644 --- a/tests/system/aiplatform/test_dataset.py +++ b/tests/system/aiplatform/test_dataset.py @@ -51,7 +51,7 @@ _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" _TEST_API_ENDPOINT = f"{_TEST_LOCATION}-aiplatform.googleapis.com" -_TEST_IMAGE_DATASET_ID = "1084241610289446912" # permanent_50_flowers_dataset +_TEST_IMAGE_DATASET_ID = "1997950066622464000" # permanent_50_flowers_dataset _TEST_TEXT_DATASET_ID = ( "6203215905493614592" # permanent_text_entity_extraction_dataset ) @@ -390,24 +390,24 @@ def test_export_data_for_custom_training(self, staging_bucket): # Custom training data export should be generic, hence using the base # _Dataset class here in test. In practice, users shuold be able to # use this function in any inhericted classes of _Dataset. - dataset = aiplatform.datasets._Dataset(dataset_name=_TEST_TEXT_DATASET_ID) + dataset = aiplatform.datasets._Dataset(dataset_name=_TEST_IMAGE_DATASET_ID) split = { - "training_fraction": 0.6, - "validation_fraction": 0.2, - "test_fraction": 0.2, + "training_filter": "labels.aiplatform.googleapis.com/ml_use=training", + "validation_filter": "labels.aiplatform.googleapis.com/ml_use=validation", + "test_filter": "labels.aiplatform.googleapis.com/ml_use=test", } export_data_response = dataset.export_data_for_custom_training( output_dir=f"gs://{staging_bucket.name}", - annotation_schema_uri="gs://google-cloud-aiplatform/schema/dataset/annotation/text_classification_1.0.0.yaml", + annotation_schema_uri="gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml", split=split, ) # Ensure three output paths (training, validation and test) are provided assert len(export_data_response["exported_files"]) == 3 - # Ensure data stats are calculated and present - assert export_data_response["data_stats"]["training_data_items_count"] > 0 + # Ensure data stats are calculated and correct + assert export_data_response["data_stats"]["training_data_items_count"] == 40 def test_update_dataset(self): """Create a new dataset and use update() method to change its display_name, labels, and description.