Skip to content

Commit

Permalink
feat: add support for find_neighbors/read_index_datapoints in matchin…
Browse files Browse the repository at this point in the history
…g engine public endpoint

PiperOrigin-RevId: 527357229
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Apr 26, 2023
1 parent a8ba666 commit e3a87f0
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 3 deletions.
1 change: 1 addition & 0 deletions google/cloud/aiplatform/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
services.specialist_pool_service_client = (
services.specialist_pool_service_client_v1beta1
)
services.match_service_client = services.match_service_client_v1beta1
services.metadata_service_client = services.metadata_service_client_v1beta1
services.tensorboard_service_client = services.tensorboard_service_client_v1beta1
services.index_service_client = services.index_service_client_v1beta1
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/compat/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from google.cloud.aiplatform_v1beta1.services.job_service import (
client as job_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.match_service import (
client as match_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.metadata_service import (
client as metadata_service_client_v1beta1,
)
Expand Down Expand Up @@ -129,6 +132,7 @@
index_service_client_v1beta1,
index_endpoint_service_client_v1beta1,
job_service_client_v1beta1,
match_service_client_v1beta1,
model_service_client_v1beta1,
pipeline_service_client_v1beta1,
prediction_service_client_v1beta1,
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/compat/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
lineage_subgraph as lineage_subgraph_v1beta1,
machine_resources as machine_resources_v1beta1,
manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1beta1,
match_service as match_service_v1beta1,
metadata_schema as metadata_schema_v1beta1,
metadata_service as metadata_service_v1beta1,
metadata_store as metadata_store_v1beta1,
Expand Down Expand Up @@ -260,6 +261,7 @@
matching_engine_deployed_index_ref_v1beta1,
index_v1beta1,
index_endpoint_v1beta1,
match_service_v1beta1,
metadata_service_v1beta1,
metadata_schema_v1beta1,
metadata_store_v1beta1,
Expand Down
12 changes: 10 additions & 2 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def get_client_options(
location_override: Optional[str] = None,
prediction_client: bool = False,
api_base_path_override: Optional[str] = None,
api_path_override: Optional[str] = None,
) -> client_options.ClientOptions:
"""Creates GAPIC client_options using location and type.
Expand All @@ -289,6 +290,7 @@ def get_client_options(
Vertex AI.
prediction_client (str): Optional. flag to use a prediction endpoint.
api_base_path_override (str): Optional. Override default API base path.
api_path_override (str): Optional. Override default api path.
Returns:
clients_options (google.api_core.client_options.ClientOptions):
A ClientOptions object set with regionalized API endpoint, i.e.
Expand All @@ -311,9 +313,12 @@ def get_client_options(
else constants.API_BASE_PATH
)

return client_options.ClientOptions(
api_endpoint=f"{region}-{service_base_path}"
api_endpoint = (
f"{region}-{service_base_path}"
if not api_path_override
else api_path_override
)
return client_options.ClientOptions(api_endpoint=api_endpoint)

def common_location_path(
self, project: Optional[str] = None, location: Optional[str] = None
Expand Down Expand Up @@ -345,6 +350,7 @@ def create_client(
location_override: Optional[str] = None,
prediction_client: bool = False,
api_base_path_override: Optional[str] = None,
api_path_override: Optional[str] = None,
appended_user_agent: Optional[List[str]] = None,
) -> utils.VertexAiServiceClientWithOverride:
"""Instantiates a given VertexAiServiceClient with optional
Expand All @@ -358,6 +364,7 @@ def create_client(
location_override (str): Optional. location override.
prediction_client (str): Optional. flag to use a prediction endpoint.
api_base_path_override (str): Optional. Override default api base path.
api_path_override (str): Optional. Override default api path.
appended_user_agent (List[str]):
Optional. User agent appended in the client info. If more than one, it will be
separated by spaces.
Expand All @@ -383,6 +390,7 @@ def create_client(
location_override=location_override,
prediction_client=prediction_client,
api_base_path_override=api_base_path_override,
api_path_override=api_path_override,
),
"client_info": client_info,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from google.cloud.aiplatform.compat.types import (
machine_resources as gca_machine_resources_compat,
matching_engine_index_endpoint as gca_matching_engine_index_endpoint,
match_service_v1beta1 as gca_match_service_v1beta1,
index_v1beta1 as gca_index_v1beta1,
)
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
from google.cloud.aiplatform.matching_engine._protos import (
Expand Down Expand Up @@ -127,6 +129,9 @@ def __init__(
)
self._gca_resource = self._get_gca_resource(resource_name=index_endpoint_name)

if self.public_endpoint_domain_name:
self._public_match_client = self._instantiate_public_match_client()

@classmethod
def create(
cls,
Expand Down Expand Up @@ -344,6 +349,22 @@ def _create(

return index_obj

def _instantiate_public_match_client(
self,
) -> utils.MatchClientWithOverride:
"""Helper method to instantiates match client with optional
overrides for this endpoint.
Returns:
match_client (match_service_client.MatchServiceClient):
Initialized match client with optional overrides.
"""
return initializer.global_config.create_client(
client_class=utils.MatchClientWithOverride,
credentials=self.credentials,
location_override=self.location,
api_path_override=self.public_endpoint_domain_name,
)

@property
def public_endpoint_domain_name(self) -> Optional[str]:
"""Public endpoint DNS name."""
Expand Down Expand Up @@ -928,6 +949,124 @@ def description(self) -> str:
self._assert_gca_resource_is_available()
return self._gca_resource.description

def find_neighbors(
self,
*,
deployed_index_id: str,
queries: List[List[float]],
num_neighbors: int = 10,
filter: Optional[List[Namespace]] = [],
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint.
```
Example usage:
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name='projects/123/locations/us-central1/index_endpoint/my_index_id'
)
my_index_endpoint.find_neighbors(deployed_index_id="public_test1", queries= [[1, 1]],)
```
Args:
deployed_index_id (str):
Required. The ID of the DeployedIndex to match the queries against.
queries (List[List[float]]):
Required. A list of queries. Each query is a list of floats, representing a single embedding.
num_neighbors (int):
Required. The number of nearest neighbors to be retrieved from database for
each query.
filter (List[Namespace]):
Optional. A list of Namespaces for filtering the matching results.
For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints
that satisfy "red color" but not include datapoints with "squared shape".
Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
"""

if not self._public_match_client:
raise ValueError(
"Please make sure index has been deployed to public endpoint, and follow the example usage to call this method."
)

# Create the FindNeighbors request
find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest()
find_neighbors_request.index_endpoint = self.resource_name
find_neighbors_request.deployed_index_id = deployed_index_id

for query in queries:
find_neighbors_query = (
gca_match_service_v1beta1.FindNeighborsRequest.Query()
)
find_neighbors_query.neighbor_count = num_neighbors
datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query)
for namespace in filter:
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
restrict.namespace = namespace.name
restrict.allow_list.extend(namespace.allow_tokens)
restrict.deny_list.extend(namespace.deny_tokens)
datapoint.restricts.append(restrict)
find_neighbors_query.datapoint = datapoint
find_neighbors_request.queries.append(find_neighbors_query)

response = self._public_match_client.find_neighbors(find_neighbors_request)

# Wrap the results in MatchNeighbor objects and return
return [
[
MatchNeighbor(
id=neighbor.datapoint.datapoint_id, distance=neighbor.distance
)
for neighbor in embedding_neighbors.neighbors
]
for embedding_neighbors in response.nearest_neighbors
]

def read_index_datapoints(
self,
*,
deployed_index_id: str,
ids: List[str] = [],
) -> List[gca_index_v1beta1.IndexDatapoint]:
"""Reads the datapoints/vectors of the given IDs on the specified deployed index which is deployed to public endpoint.
```
Example Usage:
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name='projects/123/locations/us-central1/index_endpoint/my_index_id'
)
my_index_endpoint.read_index_datapoints(deployed_index_id="public_test1", ids= ["606431", "896688"],)
```
Args:
deployed_index_id (str):
Required. The ID of the DeployedIndex to match the queries against.
ids (List[str]):
Required. IDs of the datapoints to be searched for.
Returns:
List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs.
"""
if not self._public_match_client:
raise ValueError(
"Please make sure index has been deployed to public endpoint, and follow the example usage to call this method."
)

# Create the ReadIndexDatapoints request
read_index_datapoints_request = (
gca_match_service_v1beta1.ReadIndexDatapointsRequest()
)
read_index_datapoints_request.index_endpoint = self.resource_name
read_index_datapoints_request.deployed_index_id = deployed_index_id

for id in ids:
read_index_datapoints_request.ids.append(id)

response = self._public_match_client.read_index_datapoints(
read_index_datapoints_request
)

# Wrap the results and return
return response.datapoints

def match(
self,
deployed_index_id: str,
Expand Down
9 changes: 9 additions & 0 deletions google/cloud/aiplatform/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
index_service_client_v1beta1,
index_endpoint_service_client_v1beta1,
job_service_client_v1beta1,
match_service_client_v1beta1,
metadata_service_client_v1beta1,
model_service_client_v1beta1,
pipeline_service_client_v1beta1,
Expand Down Expand Up @@ -85,6 +86,7 @@
prediction_service_client_v1beta1.PredictionServiceClient,
pipeline_service_client_v1beta1.PipelineServiceClient,
job_service_client_v1beta1.JobServiceClient,
match_service_client_v1beta1.MatchServiceClient,
metadata_service_client_v1beta1.MetadataServiceClient,
tensorboard_service_client_v1beta1.TensorboardServiceClient,
vizier_service_client_v1beta1.VizierServiceClient,
Expand Down Expand Up @@ -598,6 +600,12 @@ class PredictionClientWithOverride(ClientWithOverride):
)


class MatchClientWithOverride(ClientWithOverride):
_is_temporary = False
_default_version = compat.V1BETA1
_version_map = ((compat.V1BETA1, match_service_client_v1beta1.MatchServiceClient),)


class MetadataClientWithOverride(ClientWithOverride):
_is_temporary = True
_default_version = compat.DEFAULT_VERSION
Expand Down Expand Up @@ -632,6 +640,7 @@ class VizierClientWithOverride(ClientWithOverride):
FeaturestoreClientWithOverride,
JobClientWithOverride,
ModelClientWithOverride,
MatchClientWithOverride,
PipelineClientWithOverride,
PipelineJobClientWithOverride,
PredictionClientWithOverride,
Expand Down
Loading

0 comments on commit e3a87f0

Please sign in to comment.