Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support filters in matching engine vector matching #1608

Merged
merged 10 commits into from
Aug 26, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence, Tuple

from google.auth import credentials as auth_credentials
Expand Down Expand Up @@ -51,6 +51,25 @@ class MatchNeighbor:
distance: float


@dataclass
class Namespace:
"""Namespace specifies the rules for determining the datapoints that are eligible for each matching query, overall query is an AND across namespaces.
Args:
name (str):
Required. The name of this Namespace.
allow_tokens (List(str)):
Optional. The allowed tokens in the namespace.
deny_tokens (List(str)):
Optional. The denied tokens in the namespace. When a token is denied, then matches will be excluded whenever the other datapoint has that token.
For example, if a query specifies [Namespace("color", ["red","blue"], ["purple"])], then that query will match datapoints that are red or blue,
but if those points are also purple, then they will be excluded even if they are red/blue.
"""

name: str
allow_tokens: list = field(default_factory=list)
deny_tokens: list = field(default_factory=list)


class MatchingEngineIndexEndpoint(base.VertexAiResourceNounWithFutureManager):
"""Matching Engine index endpoint resource for Vertex AI."""

Expand Down Expand Up @@ -796,7 +815,11 @@ def description(self) -> str:
return self._gca_resource.description

def match(
self, deployed_index_id: str, queries: List[List[float]], num_neighbors: int = 1
self,
deployed_index_id: str,
queries: List[List[float]],
num_neighbors: int = 1,
filter: Optional[List[Namespace]] = [],
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index.
Expand All @@ -808,6 +831,11 @@ def match(
num_neighbors (int):
Required. The number of nearest neighbors to be retrieved from database for
each query.
filter (List[Namespace]):
Optional. A list of Namespaces for filtering the matching results.
For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints
that satisfy "red color" but not include datapoints with "squared shape".
Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
Expand Down Expand Up @@ -836,16 +864,22 @@ def match(
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex()
)
batch_request_for_index.deployed_index_id = deployed_index_id
batch_request_for_index.requests.extend(
[
match_service_pb2.MatchRequest(
num_neighbors=num_neighbors,
deployed_index_id=deployed_index_id,
float_val=query,
)
for query in queries
]
)
requests = []
for query in queries:
request = match_service_pb2.MatchRequest(
num_neighbors=num_neighbors,
deployed_index_id=deployed_index_id,
float_val=query,
)
for namespace in filter:
restrict = match_service_pb2.Namespace()
restrict.name = namespace.name
restrict.allow_tokens.extend(namespace.allow_tokens)
restrict.deny_tokens.extend(namespace.deny_tokens)
request.restricts.append(restrict)
requests.append(request)

batch_request_for_index.requests.extend(requests)
batch_request.requests.append(batch_request_for_index)

# Perform the request
Expand Down
15 changes: 15 additions & 0 deletions tests/system/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import uuid

from google.cloud import aiplatform
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
Namespace,
)

from tests.system.aiplatform import e2e_base

Expand Down Expand Up @@ -161,6 +164,8 @@
-0.021106,
]

_TEST_FILTER = [Namespace("name", ["allow_token"], ["deny_token"])]


class TestMatchingEngine(e2e_base.TestEndToEnd):

Expand Down Expand Up @@ -283,6 +288,16 @@ def test_create_get_list_matching_engine_index(self, shared_state):

# assert results[0][0].id == 870

# TODO: Test `my_index_endpoint.match` with filter.
# This requires uploading a new content of the Matching Engine Index to Cloud Storage.
# results = my_index_endpoint.match(
# deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
# queries=[_TEST_MATCH_QUERY],
# num_neighbors=1,
# filter=_TEST_FILTER,
# )
# assert results[0][0].id == 9999

# Undeploy index
my_index_endpoint = my_index_endpoint.undeploy_index(
deployed_index_id=deployed_index.id
Expand Down
75 changes: 75 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
Namespace,
)
from google.cloud.aiplatform.compat.types import (
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
index_endpoint as gca_index_endpoint,
Expand All @@ -37,6 +41,8 @@

from google.protobuf import field_mask_pb2

import grpc

import pytest

# project
Expand Down Expand Up @@ -210,6 +216,9 @@
]
]
_TEST_NUM_NEIGHBOURS = 1
_TEST_FILTER = [
Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"])
]


def uuid_mock():
Expand Down Expand Up @@ -380,6 +389,33 @@ def create_index_endpoint_mock():
yield create_index_endpoint_mock


@pytest.fixture
def index_endpoint_match_queries_mock():
with patch.object(
grpc._channel._UnaryUnaryMultiCallable,
"__call__",
) as index_endpoint_match_queries_mock:
index_endpoint_match_queries_mock.return_value = (
match_service_pb2.BatchMatchResponse(
responses=[
match_service_pb2.BatchMatchResponse.BatchMatchResponsePerIndex(
deployed_index_id="1",
responses=[
match_service_pb2.MatchResponse(
neighbor=[
match_service_pb2.MatchResponse.Neighbor(
id="1", distance=0.1
)
]
)
],
)
]
)
)
yield index_endpoint_match_queries_mock


@pytest.mark.usefixtures("google_auth_mock")
class TestMatchingEngineIndexEndpoint:
def setup_method(self):
Expand Down Expand Up @@ -617,3 +653,42 @@ def test_delete_index_endpoint_with_force(
delete_index_endpoint_mock.assert_called_once_with(
name=_TEST_INDEX_ENDPOINT_NAME
)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_index_endpoint_match_queries(self, index_endpoint_match_queries_mock):
aiplatform.init(project=_TEST_PROJECT)

my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_index_endpoint.match(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=_TEST_QUERIES,
num_neighbors=_TEST_NUM_NEIGHBOURS,
filter=_TEST_FILTER,
)

batch_request = match_service_pb2.BatchMatchRequest(
requests=[
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
requests=[
match_service_pb2.MatchRequest(
num_neighbors=_TEST_NUM_NEIGHBOURS,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
float_val=_TEST_QUERIES[0],
restricts=[
match_service_pb2.Namespace(
name="class",
allow_tokens=["token_1"],
deny_tokens=["token_2"],
)
],
)
],
)
]
)

index_endpoint_match_queries_mock.assert_called_with(batch_request)