From b86a4046c2cd0c189efc609bd6319f8da76cd6e7 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Tue, 14 Nov 2023 18:07:27 -0800 Subject: [PATCH] feat: add `remove_datapoints()` to `MatchingEngineIndex`. PiperOrigin-RevId: 582498139 --- .../cloud/aiplatform/compat/types/__init__.py | 2 + .../matching_engine/matching_engine_index.py | 41 +++++++++++++++++++ .../aiplatform/test_matching_engine_index.py | 30 ++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py index 83387b064f..f543fb0114 100644 --- a/google/cloud/aiplatform/compat/types/__init__.py +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -52,6 +52,7 @@ index_endpoint as index_endpoint_v1beta1, hyperparameter_tuning_job as hyperparameter_tuning_job_v1beta1, io as io_v1beta1, + index_service as index_service_v1beta1, job_service as job_service_v1beta1, job_state as job_state_v1beta1, lineage_subgraph as lineage_subgraph_v1beta1, @@ -275,6 +276,7 @@ matching_engine_deployed_index_ref_v1beta1, index_v1beta1, index_endpoint_v1beta1, + index_service_v1beta1, match_service_v1beta1, metadata_service_v1beta1, metadata_schema_v1beta1, diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index.py b/google/cloud/aiplatform/matching_engine/matching_engine_index.py index 07eff0d325..c0713a83b9 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index.py @@ -21,6 +21,7 @@ from google.protobuf import field_mask_pb2 from google.cloud.aiplatform import base from google.cloud.aiplatform.compat.types import ( + index_service_v1beta1 as gca_index_service_v1beta1, matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref, matching_engine_index as gca_matching_engine_index, encryption_spec as gca_encryption_spec, @@ -661,6 +662,46 @@ def create_brute_force_index( encryption_spec_key_name=encryption_spec_key_name, ) + def remove_datapoints( + self, + datapoint_ids: Sequence[str], + ) -> "MatchingEngineIndex": + """Remove datapoints for this index. + + Args: + datapoints_ids (Sequence[str]): + Required. The list of datapoints ids to be deleted. + + Returns: + MatchingEngineIndex - Index resource object + """ + self.wait() + + _LOGGER.log_action_start_against_resource( + "Removing datapoints", + "index", + self, + ) + + remove_lro = self.api_client.remove_datapoints( + gca_index_service_v1beta1.RemoveDatapointsRequest( + index=self.resource_name, + datapoint_ids=datapoint_ids, + ) + ) + + _LOGGER.log_action_started_against_resource_with_lro( + "Remove datapoints", "index", self.__class__, remove_lro + ) + + self._gca_resource = remove_lro.result(timeout=None) + + _LOGGER.log_action_completed_against_resource( + "index", "Removed datapoints", self + ) + + return self + _INDEX_UPDATE_METHOD_TO_ENUM_VALUE = { "STREAM_UPDATE": gca_matching_engine_index.Index.IndexUpdateMethod.STREAM_UPDATE, diff --git a/tests/unit/aiplatform/test_matching_engine_index.py b/tests/unit/aiplatform/test_matching_engine_index.py index 40d12ef0a9..36320a13b3 100644 --- a/tests/unit/aiplatform/test_matching_engine_index.py +++ b/tests/unit/aiplatform/test_matching_engine_index.py @@ -35,6 +35,7 @@ from google.cloud.aiplatform.compat.types import ( index as gca_index, encryption_spec as gca_encryption_spec, + index_service_v1beta1 as gca_index_service_v1beta1, ) import constants as test_constants @@ -110,6 +111,9 @@ # Encryption spec _TEST_ENCRYPTION_SPEC_KEY_NAME = "TEST_ENCRYPTION_SPEC" +# Streaming update +_TEST_DATAPOINT_IDS = ("1", "2") + def uuid_mock(): return uuid.UUID(int=1) @@ -192,6 +196,16 @@ def create_index_mock(): yield create_index_mock +@pytest.fixture +def remove_datapoints_mock(): + with patch.object( + index_service_client.IndexServiceClient, "remove_datapoints" + ) as remove_datapoints_mock: + remove_datapoints_lro_mock = mock.Mock(operation.Operation) + remove_datapoints_mock.return_value = remove_datapoints_lro_mock + yield remove_datapoints_mock + + @pytest.mark.usefixtures("google_auth_mock") class TestMatchingEngineIndex: def setup_method(self): @@ -414,3 +428,19 @@ def test_create_brute_force_index( index=expected, metadata=_TEST_REQUEST_METADATA, ) + + @pytest.mark.usefixtures("get_index_mock") + def test_remove_datapoints(self, remove_datapoints_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID) + my_index.remove_datapoints( + datapoint_ids=_TEST_DATAPOINT_IDS, + ) + + remove_datapoints_request = gca_index_service_v1beta1.RemoveDatapointsRequest( + index=_TEST_INDEX_NAME, + datapoint_ids=_TEST_DATAPOINT_IDS, + ) + + remove_datapoints_mock.assert_called_once_with(remove_datapoints_request)