diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index da155496ae..6e6423e99e 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -15,7 +15,7 @@ # limitations under the License. # -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, List, Optional, Sequence, Tuple from google.auth import credentials as auth_credentials @@ -51,6 +51,25 @@ class MatchNeighbor: distance: float +@dataclass +class Namespace: + """Namespace specifies the rules for determining the datapoints that are eligible for each matching query, overall query is an AND across namespaces. + Args: + name (str): + Required. The name of this Namespace. + allow_tokens (List(str)): + Optional. The allowed tokens in the namespace. + deny_tokens (List(str)): + Optional. The denied tokens in the namespace. When a token is denied, then matches will be excluded whenever the other datapoint has that token. + For example, if a query specifies [Namespace("color", ["red","blue"], ["purple"])], then that query will match datapoints that are red or blue, + but if those points are also purple, then they will be excluded even if they are red/blue. + """ + + name: str + allow_tokens: list = field(default_factory=list) + deny_tokens: list = field(default_factory=list) + + class MatchingEngineIndexEndpoint(base.VertexAiResourceNounWithFutureManager): """Matching Engine index endpoint resource for Vertex AI.""" @@ -796,7 +815,11 @@ def description(self) -> str: return self._gca_resource.description def match( - self, deployed_index_id: str, queries: List[List[float]], num_neighbors: int = 1 + self, + deployed_index_id: str, + queries: List[List[float]], + num_neighbors: int = 1, + filter: Optional[List[Namespace]] = [], ) -> List[List[MatchNeighbor]]: """Retrieves nearest neighbors for the given embedding queries on the specified deployed index. @@ -808,6 +831,11 @@ def match( num_neighbors (int): Required. The number of nearest neighbors to be retrieved from database for each query. + filter (List[Namespace]): + Optional. A list of Namespaces for filtering the matching results. + For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints + that satisfy "red color" but not include datapoints with "squared shape". + Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail. Returns: List[List[MatchNeighbor]] - A list of nearest neighbors for each query. @@ -836,16 +864,22 @@ def match( match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex() ) batch_request_for_index.deployed_index_id = deployed_index_id - batch_request_for_index.requests.extend( - [ - match_service_pb2.MatchRequest( - num_neighbors=num_neighbors, - deployed_index_id=deployed_index_id, - float_val=query, - ) - for query in queries - ] - ) + requests = [] + for query in queries: + request = match_service_pb2.MatchRequest( + num_neighbors=num_neighbors, + deployed_index_id=deployed_index_id, + float_val=query, + ) + for namespace in filter: + restrict = match_service_pb2.Namespace() + restrict.name = namespace.name + restrict.allow_tokens.extend(namespace.allow_tokens) + restrict.deny_tokens.extend(namespace.deny_tokens) + request.restricts.append(restrict) + requests.append(request) + + batch_request_for_index.requests.extend(requests) batch_request.requests.append(batch_request_for_index) # Perform the request diff --git a/tests/system/aiplatform/test_matching_engine_index.py b/tests/system/aiplatform/test_matching_engine_index.py index 7a4b76feef..110baa37ab 100644 --- a/tests/system/aiplatform/test_matching_engine_index.py +++ b/tests/system/aiplatform/test_matching_engine_index.py @@ -18,6 +18,9 @@ import uuid from google.cloud import aiplatform +from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import ( + Namespace, +) from tests.system.aiplatform import e2e_base @@ -161,6 +164,8 @@ -0.021106, ] +_TEST_FILTER = [Namespace("name", ["allow_token"], ["deny_token"])] + class TestMatchingEngine(e2e_base.TestEndToEnd): @@ -283,6 +288,16 @@ def test_create_get_list_matching_engine_index(self, shared_state): # assert results[0][0].id == 870 + # TODO: Test `my_index_endpoint.match` with filter. + # This requires uploading a new content of the Matching Engine Index to Cloud Storage. + # results = my_index_endpoint.match( + # deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + # queries=[_TEST_MATCH_QUERY], + # num_neighbors=1, + # filter=_TEST_FILTER, + # ) + # assert results[0][0].id == 9999 + # Undeploy index my_index_endpoint = my_index_endpoint.undeploy_index( deployed_index_id=deployed_index.id diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 231ab20ae0..21589c4267 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -24,6 +24,10 @@ from google.cloud import aiplatform from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.matching_engine._protos import match_service_pb2 +from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import ( + Namespace, +) from google.cloud.aiplatform.compat.types import ( matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref, index_endpoint as gca_index_endpoint, @@ -37,6 +41,8 @@ from google.protobuf import field_mask_pb2 +import grpc + import pytest # project @@ -210,6 +216,9 @@ ] ] _TEST_NUM_NEIGHBOURS = 1 +_TEST_FILTER = [ + Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"]) +] def uuid_mock(): @@ -380,6 +389,33 @@ def create_index_endpoint_mock(): yield create_index_endpoint_mock +@pytest.fixture +def index_endpoint_match_queries_mock(): + with patch.object( + grpc._channel._UnaryUnaryMultiCallable, + "__call__", + ) as index_endpoint_match_queries_mock: + index_endpoint_match_queries_mock.return_value = ( + match_service_pb2.BatchMatchResponse( + responses=[ + match_service_pb2.BatchMatchResponse.BatchMatchResponsePerIndex( + deployed_index_id="1", + responses=[ + match_service_pb2.MatchResponse( + neighbor=[ + match_service_pb2.MatchResponse.Neighbor( + id="1", distance=0.1 + ) + ] + ) + ], + ) + ] + ) + ) + yield index_endpoint_match_queries_mock + + @pytest.mark.usefixtures("google_auth_mock") class TestMatchingEngineIndexEndpoint: def setup_method(self): @@ -617,3 +653,42 @@ def test_delete_index_endpoint_with_force( delete_index_endpoint_mock.assert_called_once_with( name=_TEST_INDEX_ENDPOINT_NAME ) + + @pytest.mark.usefixtures("get_index_endpoint_mock") + def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_index_endpoint.match( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + queries=_TEST_QUERIES, + num_neighbors=_TEST_NUM_NEIGHBOURS, + filter=_TEST_FILTER, + ) + + batch_request = match_service_pb2.BatchMatchRequest( + requests=[ + match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + requests=[ + match_service_pb2.MatchRequest( + num_neighbors=_TEST_NUM_NEIGHBOURS, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + float_val=_TEST_QUERIES[0], + restricts=[ + match_service_pb2.Namespace( + name="class", + allow_tokens=["token_1"], + deny_tokens=["token_2"], + ) + ], + ) + ], + ) + ] + ) + + index_endpoint_match_queries_mock.assert_called_with(batch_request)