Skip to content

Commit

Permalink
Merge branch 'main' into mh-seq2seq
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMichaelHu authored May 9, 2022
2 parents 2cd2d97 + 469db6b commit a29dce3
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 20 deletions.
6 changes: 6 additions & 0 deletions google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
Feature,
Featurestore,
)
from google.cloud.aiplatform.matching_engine import (
MatchingEngineIndex,
MatchingEngineIndexEndpoint,
)
from google.cloud.aiplatform.metadata import metadata
from google.cloud.aiplatform.models import Endpoint
from google.cloud.aiplatform.models import Model
Expand Down Expand Up @@ -105,6 +109,8 @@
"EntityType",
"Feature",
"Featurestore",
"MatchingEngineIndex",
"MatchingEngineIndexEndpoint",
"ImageDataset",
"HyperparameterTuningJob",
"Model",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
# limitations under the License.
#

from google.cloud.aiplatform._matching_engine.matching_engine_index import (
from google.cloud.aiplatform.matching_engine.matching_engine_index import (
MatchingEngineIndex,
)
from google.cloud.aiplatform._matching_engine.matching_engine_index_config import (
from google.cloud.aiplatform.matching_engine.matching_engine_index_config import (
BruteForceConfig as MatchingEngineBruteForceAlgorithmConfig,
MatchingEngineIndexConfig as MatchingEngineIndexConfig,
TreeAhConfig as MatchingEngineTreeAhAlgorithmConfig,
)
from google.cloud.aiplatform._matching_engine.matching_engine_index_endpoint import (
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
MatchingEngineIndexEndpoint,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
from google.cloud.aiplatform._matching_engine import match_service_pb2
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2

import grpc

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
matching_engine_index as gca_matching_engine_index,
)
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform._matching_engine import matching_engine_index_config
from google.cloud.aiplatform.matching_engine import matching_engine_index_config
from google.cloud.aiplatform import utils

_LOGGER = base.Logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import _matching_engine
from google.cloud.aiplatform import matching_engine
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.compat.types import (
machine_resources as gca_machine_resources_compat,
matching_engine_index_endpoint as gca_matching_engine_index_endpoint,
)
from google.cloud.aiplatform._matching_engine import match_service_pb2
from google.cloud.aiplatform._matching_engine import match_service_pb2_grpc
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2_grpc
from google.protobuf import field_mask_pb2

import grpc
Expand Down Expand Up @@ -432,7 +432,7 @@ def _build_deployed_index(

def deploy_index(
self,
index: _matching_engine.MatchingEngineIndex,
index: matching_engine.MatchingEngineIndex,
deployed_index_id: str,
display_name: Optional[str] = None,
machine_type: Optional[str] = None,
Expand Down
10 changes: 7 additions & 3 deletions tests/system/aiplatform/e2e_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import pytest
import uuid

from typing import Any, Dict, Generator

from google.api_core import exceptions
Expand All @@ -29,8 +30,7 @@
from google.cloud.aiplatform import initializer

_PROJECT = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
_PROJECT_NUMBER = os.getenv("PROJECT_NUMBER")
_VPC_NETWORK_NAME = os.getenv("private-net")
_VPC_NETWORK_URI = os.getenv("_VPC_NETWORK_URI")
_LOCATION = "us-central1"


Expand Down Expand Up @@ -136,7 +136,10 @@ def tear_down_resources(self, shared_state: Dict[str, Any]):
# Bring all Endpoints to the front of the list
# Ensures Models are undeployed first before we attempt deletion
shared_state["resources"].sort(
key=lambda r: 1 if isinstance(r, aiplatform.Endpoint) else 2
key=lambda r: 1
if isinstance(r, aiplatform.Endpoint)
or isinstance(r, aiplatform.MatchingEngineIndexEndpoint)
else 2
)

for resource in shared_state["resources"]:
Expand All @@ -146,6 +149,7 @@ def tear_down_resources(self, shared_state: Dict[str, Any]):
(
aiplatform.Endpoint,
aiplatform.Featurestore,
aiplatform.MatchingEngineIndexEndpoint,
),
):
# For endpoint, undeploy model then delete endpoint
Expand Down
81 changes: 75 additions & 6 deletions tests/system/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#

import uuid
import pytest

from google.cloud import aiplatform

Expand Down Expand Up @@ -52,10 +51,6 @@
_TEST_INDEX_ENDPOINT_DISPLAY_NAME = "endpoint_name"
_TEST_INDEX_ENDPOINT_DESCRIPTION = "my endpoint"

_TEST_INDEX_ENDPOINT_VPC_NETWORK = "projects/{}/global/networks/{}".format(
e2e_base._PROJECT_NUMBER, e2e_base._VPC_NETWORK_NAME
)

# DEPLOYED INDEX
_TEST_DEPLOYED_INDEX_ID = f"deployed_index_id_{uuid.uuid4()}"
_TEST_DEPLOYED_INDEX_DISPLAY_NAME = f"deployed_index_display_name_{uuid.uuid4()}"
Expand Down Expand Up @@ -167,7 +162,6 @@
]


@pytest.mark.skip(reason="TestMatchingEngine not available")
class TestMatchingEngine(e2e_base.TestEndToEnd):

_temp_prefix = "temp_vertex_sdk_e2e_matching_engine_test"
Expand Down Expand Up @@ -226,9 +220,84 @@ def test_create_get_list_matching_engine_index(self, shared_state):

assert updated_index.name == get_index.name

# Create endpoint and check that it is listed
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
network=e2e_base._VPC_NETWORK_URI,
labels=_TEST_LABELS,
)
assert my_index_endpoint.resource_name in [
index_endpoint.resource_name
for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list()
]

assert my_index_endpoint.labels == _TEST_LABELS
assert my_index_endpoint.display_name == _TEST_INDEX_ENDPOINT_DISPLAY_NAME
assert my_index_endpoint.description == _TEST_INDEX_ENDPOINT_DESCRIPTION

shared_state["resources"].append(my_index_endpoint)

# Deploy endpoint
my_index_endpoint = my_index_endpoint.deploy_index(
index=index,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME,
)

# Update endpoint
updated_index_endpoint = my_index_endpoint.update(
display_name=_TEST_DISPLAY_NAME_UPDATE,
description=_TEST_DESCRIPTION_UPDATE,
labels=_TEST_LABELS_UPDATE,
)

assert updated_index_endpoint.labels == _TEST_LABELS_UPDATE
assert updated_index_endpoint.display_name == _TEST_DISPLAY_NAME_UPDATE
assert updated_index_endpoint.description == _TEST_DESCRIPTION_UPDATE

# Mutate deployed index
my_index_endpoint.mutate_deployed_index(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
min_replica_count=_TEST_MIN_REPLICA_COUNT_UPDATED,
max_replica_count=_TEST_MAX_REPLICA_COUNT_UPDATED,
)

deployed_index = my_index_endpoint.deployed_indexes[0]

assert deployed_index.id == _TEST_DEPLOYED_INDEX_ID
assert deployed_index.index == index.resource_name
assert (
deployed_index.automatic_resources.min_replica_count
== _TEST_MIN_REPLICA_COUNT_UPDATED
)
assert (
deployed_index.automatic_resources.max_replica_count
== _TEST_MAX_REPLICA_COUNT_UPDATED
)

# TODO: Test `my_index_endpoint.match` request. This requires running this test in a VPC.
# results = my_index_endpoint.match(
# deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=[_TEST_MATCH_QUERY]
# )

# assert results[0][0].id == 870

# Undeploy index
my_index_endpoint = my_index_endpoint.undeploy_index(
deployed_index_id=deployed_index.id
)

# Delete index and check that it is no longer listed
index.delete()
list_indexes = aiplatform.MatchingEngineIndex.list()
assert get_index.resource_name not in [
index.resource_name for index in list_indexes
]

# Delete index endpoint and check that it is no longer listed
my_index_endpoint.delete()
assert my_index_endpoint.resource_name not in [
index_endpoint.resource_name
for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list()
]
1 change: 0 additions & 1 deletion tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def create_index_mock():
yield create_index_mock


@pytest.mark.skip(reason="MatchingEngineIndex not available")
class TestMatchingEngineIndex:
def setup_method(self):
reload(initializer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def create_index_endpoint_mock():
yield create_index_endpoint_mock


@pytest.mark.skip(reason="MatchingEngineIndexEndpoint not available")
class TestMatchingEngineIndexEndpoint:
def setup_method(self):
reload(initializer)
Expand Down

0 comments on commit a29dce3

Please sign in to comment.