Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Make matching engine API public #1192

Merged
merged 23 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
Feature,
Featurestore,
)
from google.cloud.aiplatform.matching_engine import (
MatchingEngineIndex,
MatchingEngineIndexEndpoint,
)
from google.cloud.aiplatform.metadata import metadata
from google.cloud.aiplatform.models import Endpoint
from google.cloud.aiplatform.models import Model
Expand Down Expand Up @@ -105,6 +109,8 @@
"EntityType",
"Feature",
"Featurestore",
"MatchingEngineIndex",
"MatchingEngineIndexEndpoint",
"ImageDataset",
"HyperparameterTuningJob",
"Model",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
# limitations under the License.
#

from google.cloud.aiplatform._matching_engine.matching_engine_index import (
from google.cloud.aiplatform.matching_engine.matching_engine_index import (
MatchingEngineIndex,
)
from google.cloud.aiplatform._matching_engine.matching_engine_index_config import (
from google.cloud.aiplatform.matching_engine.matching_engine_index_config import (
BruteForceConfig as MatchingEngineBruteForceAlgorithmConfig,
MatchingEngineIndexConfig as MatchingEngineIndexConfig,
TreeAhConfig as MatchingEngineTreeAhAlgorithmConfig,
)
from google.cloud.aiplatform._matching_engine.matching_engine_index_endpoint import (
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
MatchingEngineIndexEndpoint,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
from google.cloud.aiplatform._matching_engine import match_service_pb2
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2

import grpc

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
matching_engine_index as gca_matching_engine_index,
)
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform._matching_engine import matching_engine_index_config
from google.cloud.aiplatform.matching_engine import matching_engine_index_config
from google.cloud.aiplatform import utils

_LOGGER = base.Logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import _matching_engine
from google.cloud.aiplatform import matching_engine
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.compat.types import (
machine_resources as gca_machine_resources_compat,
matching_engine_index_endpoint as gca_matching_engine_index_endpoint,
)
from google.cloud.aiplatform._matching_engine import match_service_pb2
from google.cloud.aiplatform._matching_engine import match_service_pb2_grpc
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2_grpc
from google.protobuf import field_mask_pb2

import grpc
Expand Down Expand Up @@ -432,7 +432,7 @@ def _build_deployed_index(

def deploy_index(
self,
index: _matching_engine.MatchingEngineIndex,
index: matching_engine.MatchingEngineIndex,
deployed_index_id: str,
display_name: Optional[str] = None,
machine_type: Optional[str] = None,
Expand Down
10 changes: 7 additions & 3 deletions tests/system/aiplatform/e2e_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import pytest
import uuid

from typing import Any, Dict, Generator

from google.api_core import exceptions
Expand All @@ -29,8 +30,7 @@
from google.cloud.aiplatform import initializer

_PROJECT = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
_PROJECT_NUMBER = os.getenv("PROJECT_NUMBER")
_VPC_NETWORK_NAME = os.getenv("private-net")
_VPC_NETWORK_URI = os.getenv("_VPC_NETWORK_URI")
_LOCATION = "us-central1"


Expand Down Expand Up @@ -136,7 +136,10 @@ def tear_down_resources(self, shared_state: Dict[str, Any]):
# Bring all Endpoints to the front of the list
# Ensures Models are undeployed first before we attempt deletion
shared_state["resources"].sort(
key=lambda r: 1 if isinstance(r, aiplatform.Endpoint) else 2
key=lambda r: 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had talked to Morgan before about some ideas to clean this up. This will get uglier as time passes.

if isinstance(r, aiplatform.Endpoint)
or isinstance(r, aiplatform.MatchingEngineIndexEndpoint)
else 2
)

for resource in shared_state["resources"]:
Expand All @@ -146,6 +149,7 @@ def tear_down_resources(self, shared_state: Dict[str, Any]):
(
aiplatform.Endpoint,
aiplatform.Featurestore,
aiplatform.MatchingEngineIndexEndpoint,
),
):
# For endpoint, undeploy model then delete endpoint
Expand Down
81 changes: 75 additions & 6 deletions tests/system/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#

import uuid
import pytest

from google.cloud import aiplatform

Expand Down Expand Up @@ -52,10 +51,6 @@
_TEST_INDEX_ENDPOINT_DISPLAY_NAME = "endpoint_name"
_TEST_INDEX_ENDPOINT_DESCRIPTION = "my endpoint"

_TEST_INDEX_ENDPOINT_VPC_NETWORK = "projects/{}/global/networks/{}".format(
e2e_base._PROJECT_NUMBER, e2e_base._VPC_NETWORK_NAME
)

# DEPLOYED INDEX
_TEST_DEPLOYED_INDEX_ID = f"deployed_index_id_{uuid.uuid4()}"
_TEST_DEPLOYED_INDEX_DISPLAY_NAME = f"deployed_index_display_name_{uuid.uuid4()}"
Expand Down Expand Up @@ -167,7 +162,6 @@
]


@pytest.mark.skip(reason="TestMatchingEngine not available")
class TestMatchingEngine(e2e_base.TestEndToEnd):

_temp_prefix = "temp_vertex_sdk_e2e_matching_engine_test"
Expand Down Expand Up @@ -226,9 +220,84 @@ def test_create_get_list_matching_engine_index(self, shared_state):

assert updated_index.name == get_index.name

# Create endpoint and check that it is listed
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
network=e2e_base._VPC_NETWORK_URI,
labels=_TEST_LABELS,
)
assert my_index_endpoint.resource_name in [
index_endpoint.resource_name
for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list()
]

assert my_index_endpoint.labels == _TEST_LABELS
assert my_index_endpoint.display_name == _TEST_INDEX_ENDPOINT_DISPLAY_NAME
assert my_index_endpoint.description == _TEST_INDEX_ENDPOINT_DESCRIPTION

shared_state["resources"].append(my_index_endpoint)

# Deploy endpoint
my_index_endpoint = my_index_endpoint.deploy_index(
index=index,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME,
)

# Update endpoint
updated_index_endpoint = my_index_endpoint.update(
display_name=_TEST_DISPLAY_NAME_UPDATE,
description=_TEST_DESCRIPTION_UPDATE,
labels=_TEST_LABELS_UPDATE,
)

assert updated_index_endpoint.labels == _TEST_LABELS_UPDATE
assert updated_index_endpoint.display_name == _TEST_DISPLAY_NAME_UPDATE
assert updated_index_endpoint.description == _TEST_DESCRIPTION_UPDATE

# Mutate deployed index
my_index_endpoint.mutate_deployed_index(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
min_replica_count=_TEST_MIN_REPLICA_COUNT_UPDATED,
max_replica_count=_TEST_MAX_REPLICA_COUNT_UPDATED,
)

deployed_index = my_index_endpoint.deployed_indexes[0]

assert deployed_index.id == _TEST_DEPLOYED_INDEX_ID
assert deployed_index.index == index.resource_name
assert (
deployed_index.automatic_resources.min_replica_count
== _TEST_MIN_REPLICA_COUNT_UPDATED
)
assert (
deployed_index.automatic_resources.max_replica_count
== _TEST_MAX_REPLICA_COUNT_UPDATED
)

# TODO: Test `my_index_endpoint.match` request. This requires running this test in a VPC.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still don't know how to run the whole system test "inside" the VPC network.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am asking around.

# results = my_index_endpoint.match(
# deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=[_TEST_MATCH_QUERY]
# )

# assert results[0][0].id == 870

# Undeploy index
my_index_endpoint = my_index_endpoint.undeploy_index(
deployed_index_id=deployed_index.id
)

# Delete index and check that it is no longer listed
index.delete()
list_indexes = aiplatform.MatchingEngineIndex.list()
assert get_index.resource_name not in [
index.resource_name for index in list_indexes
]

# Delete index endpoint and check that it is no longer listed
my_index_endpoint.delete()
assert my_index_endpoint.resource_name not in [
index_endpoint.resource_name
for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list()
]
1 change: 0 additions & 1 deletion tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def create_index_mock():
yield create_index_mock


@pytest.mark.skip(reason="MatchingEngineIndex not available")
class TestMatchingEngineIndex:
def setup_method(self):
reload(initializer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def create_index_endpoint_mock():
yield create_index_endpoint_mock


@pytest.mark.skip(reason="MatchingEngineIndexEndpoint not available")
class TestMatchingEngineIndexEndpoint:
def setup_method(self):
reload(initializer)
Expand Down