Skip to content

Commit

Permalink
feat: add remove_datapoints() to MatchingEngineIndex.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 582498139
  • Loading branch information
lingyinw authored and copybara-github committed Nov 15, 2023
1 parent 568d833 commit b86a404
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/compat/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
index_endpoint as index_endpoint_v1beta1,
hyperparameter_tuning_job as hyperparameter_tuning_job_v1beta1,
io as io_v1beta1,
index_service as index_service_v1beta1,
job_service as job_service_v1beta1,
job_state as job_state_v1beta1,
lineage_subgraph as lineage_subgraph_v1beta1,
Expand Down Expand Up @@ -275,6 +276,7 @@
matching_engine_deployed_index_ref_v1beta1,
index_v1beta1,
index_endpoint_v1beta1,
index_service_v1beta1,
match_service_v1beta1,
metadata_service_v1beta1,
metadata_schema_v1beta1,
Expand Down
41 changes: 41 additions & 0 deletions google/cloud/aiplatform/matching_engine/matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.protobuf import field_mask_pb2
from google.cloud.aiplatform import base
from google.cloud.aiplatform.compat.types import (
index_service_v1beta1 as gca_index_service_v1beta1,
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
matching_engine_index as gca_matching_engine_index,
encryption_spec as gca_encryption_spec,
Expand Down Expand Up @@ -661,6 +662,46 @@ def create_brute_force_index(
encryption_spec_key_name=encryption_spec_key_name,
)

def remove_datapoints(
self,
datapoint_ids: Sequence[str],
) -> "MatchingEngineIndex":
"""Remove datapoints for this index.
Args:
datapoints_ids (Sequence[str]):
Required. The list of datapoints ids to be deleted.
Returns:
MatchingEngineIndex - Index resource object
"""
self.wait()

_LOGGER.log_action_start_against_resource(
"Removing datapoints",
"index",
self,
)

remove_lro = self.api_client.remove_datapoints(
gca_index_service_v1beta1.RemoveDatapointsRequest(
index=self.resource_name,
datapoint_ids=datapoint_ids,
)
)

_LOGGER.log_action_started_against_resource_with_lro(
"Remove datapoints", "index", self.__class__, remove_lro
)

self._gca_resource = remove_lro.result(timeout=None)

_LOGGER.log_action_completed_against_resource(
"index", "Removed datapoints", self
)

return self


_INDEX_UPDATE_METHOD_TO_ENUM_VALUE = {
"STREAM_UPDATE": gca_matching_engine_index.Index.IndexUpdateMethod.STREAM_UPDATE,
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from google.cloud.aiplatform.compat.types import (
index as gca_index,
encryption_spec as gca_encryption_spec,
index_service_v1beta1 as gca_index_service_v1beta1,
)
import constants as test_constants

Expand Down Expand Up @@ -110,6 +111,9 @@
# Encryption spec
_TEST_ENCRYPTION_SPEC_KEY_NAME = "TEST_ENCRYPTION_SPEC"

# Streaming update
_TEST_DATAPOINT_IDS = ("1", "2")


def uuid_mock():
return uuid.UUID(int=1)
Expand Down Expand Up @@ -192,6 +196,16 @@ def create_index_mock():
yield create_index_mock


@pytest.fixture
def remove_datapoints_mock():
with patch.object(
index_service_client.IndexServiceClient, "remove_datapoints"
) as remove_datapoints_mock:
remove_datapoints_lro_mock = mock.Mock(operation.Operation)
remove_datapoints_mock.return_value = remove_datapoints_lro_mock
yield remove_datapoints_mock


@pytest.mark.usefixtures("google_auth_mock")
class TestMatchingEngineIndex:
def setup_method(self):
Expand Down Expand Up @@ -414,3 +428,19 @@ def test_create_brute_force_index(
index=expected,
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_mock")
def test_remove_datapoints(self, remove_datapoints_mock):
aiplatform.init(project=_TEST_PROJECT)

my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID)
my_index.remove_datapoints(
datapoint_ids=_TEST_DATAPOINT_IDS,
)

remove_datapoints_request = gca_index_service_v1beta1.RemoveDatapointsRequest(
index=_TEST_INDEX_NAME,
datapoint_ids=_TEST_DATAPOINT_IDS,
)

remove_datapoints_mock.assert_called_once_with(remove_datapoints_request)

0 comments on commit b86a404

Please sign in to comment.