From 3d8835e1dbc48502246fc5ae141f465e0ac7ae90 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Mon, 11 Dec 2023 11:07:02 -0800 Subject: [PATCH] fix: `read_index_endpoint` private endpoint support. PiperOrigin-RevId: 589880026 --- .../matching_engine_index_endpoint.py | 36 +++++++------ .../test_matching_engine_index_endpoint.py | 50 ++++++++++++++++++- 2 files changed, 71 insertions(+), 15 deletions(-) diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index c726b1c96c..7f8d2673a6 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -1285,24 +1285,32 @@ def read_index_datapoints( """ if not self._public_match_client: # Call private match service stub with BatchGetEmbeddings request - response = self._batch_get_embeddings( + embeddings = self._batch_get_embeddings( deployed_index_id=deployed_index_id, ids=ids ) - return [ - gca_index_v1beta1.IndexDatapoint( + + response = [] + for embedding in embeddings: + index_datapoint = gca_index_v1beta1.IndexDatapoint( datapoint_id=embedding.id, feature_vector=embedding.float_val, - restricts=gca_index_v1beta1.IndexDatapoint.Restriction( - namespace=embedding.restricts.name, - allow_list=embedding.restricts.allow_tokens, - ), - deny_list=embedding.restricts.deny_tokens, - crowding_attributes=gca_index_v1beta1.CrowdingEmbedding( - str(embedding.crowding_tag) - ), + restricts=[ + gca_index_v1beta1.IndexDatapoint.Restriction( + namespace=restrict.name, + allow_list=restrict.allow_tokens, + deny_list=restrict.deny_tokens, + ) + for restrict in embedding.restricts + ], ) - for embedding in response.embeddings - ] + if embedding.crowding_attribute: + index_datapoint.crowding_tag = ( + gca_index_v1beta1.IndexDatapoint.CrowdingTag( + crowding_attribute=str(embedding.crowding_attribute) + ) + ) + response.append(index_datapoint) + return response # Create the ReadIndexDatapoints request read_index_datapoints_request = ( @@ -1326,7 +1334,7 @@ def _batch_get_embeddings( *, deployed_index_id: str, ids: List[str] = [], - ) -> List[List[match_service_pb2.Embedding]]: + ) -> List[match_service_pb2.Embedding]: """ Reads the datapoints/vectors of the given IDs on the specified index which is deployed to private endpoint. diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 5aac0a0b3e..0cb9c63024 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -246,6 +246,26 @@ _TEST_RETURN_FULL_DATAPOINT = True _TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name" _TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"] +_TEST_READ_INDEX_DATAPOINTS_RESPONSE = [ + gca_index_v1beta1.IndexDatapoint( + datapoint_id="1", + feature_vector=[0.1, 0.2, 0.3], + restricts=[ + gca_index_v1beta1.IndexDatapoint.Restriction( + namespace="class", + allow_list=["token_1"], + deny_list=["token_2"], + ) + ], + ), + gca_index_v1beta1.IndexDatapoint( + datapoint_id="2", + feature_vector=[0.5, 0.2, 0.3], + crowding_tag=gca_index_v1beta1.IndexDatapoint.CrowdingTag( + crowding_attribute="1" + ), + ), +] def uuid_mock(): @@ -505,7 +525,13 @@ def index_endpoint_batch_get_embeddings_mock(): match_service_pb2.Embedding( id="1", float_val=[0.1, 0.2, 0.3], - crowding_attribute=1, + restricts=[ + match_service_pb2.Namespace( + name="class", + allow_tokens=["token_1"], + deny_tokens=["token_2"], + ) + ], ), match_service_pb2.Embedding( id="2", @@ -1249,3 +1275,25 @@ def test_index_endpoint_batch_get_embeddings( ) index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request) + + @pytest.mark.usefixtures("get_index_endpoint_mock") + def test_index_endpoint_find_neighbors_for_private( + self, index_endpoint_batch_get_embeddings_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + response = my_index_endpoint.read_index_datapoints( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, ids=["1", "2"] + ) + + batch_request = match_service_pb2.BatchGetEmbeddingsRequest( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"] + ) + + index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request) + + assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE