From 595b580bfaa238b63f61cb69a7829094c747aaea Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Wed, 15 Nov 2023 19:23:50 -0800 Subject: [PATCH] fix: fix server error due to no `encryption_spec_key_name` in MatchingEngineIndex `create_tree_ah_index` and `create_brute_force_index` PiperOrigin-RevId: 582880161 --- .../matching_engine/matching_engine_index.py | 9 +- .../aiplatform/test_matching_engine_index.py | 82 ++++++++++++++++++- 2 files changed, 87 insertions(+), 4 deletions(-) diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index.py b/google/cloud/aiplatform/matching_engine/matching_engine_index.py index c0713a83b9..9e30b7f1b6 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index.py @@ -196,11 +196,14 @@ def _create( "contentsDeltaUri": contents_delta_uri, }, index_update_method=index_update_method_enum, - encryption_spec=gca_encryption_spec.EncryptionSpec( - kms_key_name=encryption_spec_key_name - ), ) + if encryption_spec_key_name: + encryption_spec = gca_encryption_spec.EncryptionSpec( + kms_key_name=encryption_spec_key_name + ) + gapic_index.encryption_spec = encryption_spec + if labels: utils.validate_labels(labels) gapic_index.labels = labels diff --git a/tests/unit/aiplatform/test_matching_engine_index.py b/tests/unit/aiplatform/test_matching_engine_index.py index 36320a13b3..33072ad6fa 100644 --- a/tests/unit/aiplatform/test_matching_engine_index.py +++ b/tests/unit/aiplatform/test_matching_engine_index.py @@ -369,6 +369,50 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method metadata=_TEST_REQUEST_METADATA, ) + @pytest.mark.usefixtures("get_index_mock") + def test_create_tree_ah_index_backward_compatibility(self, create_index_mock): + aiplatform.init(project=_TEST_PROJECT) + + aiplatform.MatchingEngineIndex.create_tree_ah_index( + display_name=_TEST_INDEX_DISPLAY_NAME, + contents_delta_uri=_TEST_CONTENTS_DELTA_URI, + dimensions=_TEST_INDEX_CONFIG_DIMENSIONS, + approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT, + distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE, + leaf_node_embedding_count=_TEST_LEAF_NODE_EMBEDDING_COUNT, + leaf_nodes_to_search_percent=_TEST_LEAF_NODES_TO_SEARCH_PERCENT, + description=_TEST_INDEX_DESCRIPTION, + labels=_TEST_LABELS, + ) + + config = { + "treeAhConfig": { + "leafNodeEmbeddingCount": _TEST_LEAF_NODE_EMBEDDING_COUNT, + "leafNodesToSearchPercent": _TEST_LEAF_NODES_TO_SEARCH_PERCENT, + } + } + + expected = gca_index.Index( + display_name=_TEST_INDEX_DISPLAY_NAME, + metadata={ + "config": { + "algorithmConfig": config, + "dimensions": _TEST_INDEX_CONFIG_DIMENSIONS, + "approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT, + "distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE, + }, + "contentsDeltaUri": _TEST_CONTENTS_DELTA_URI, + }, + description=_TEST_INDEX_DESCRIPTION, + labels=_TEST_LABELS, + ) + + create_index_mock.assert_called_once_with( + parent=_TEST_PARENT, + index=expected, + metadata=_TEST_REQUEST_METADATA, + ) + @pytest.mark.usefixtures("get_index_mock") @pytest.mark.parametrize("sync", [True, False]) @pytest.mark.parametrize( @@ -419,7 +463,7 @@ def test_create_brute_force_index( index_update_method ], encryption_spec=gca_encryption_spec.EncryptionSpec( - kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME + kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME, ), ) @@ -429,6 +473,42 @@ def test_create_brute_force_index( metadata=_TEST_REQUEST_METADATA, ) + @pytest.mark.usefixtures("get_index_mock") + def test_create_brute_force_index_backward_compatibility(self, create_index_mock): + aiplatform.init(project=_TEST_PROJECT) + + aiplatform.MatchingEngineIndex.create_brute_force_index( + display_name=_TEST_INDEX_DISPLAY_NAME, + contents_delta_uri=_TEST_CONTENTS_DELTA_URI, + dimensions=_TEST_INDEX_CONFIG_DIMENSIONS, + distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE, + description=_TEST_INDEX_DESCRIPTION, + labels=_TEST_LABELS, + ) + + config = {"bruteForceConfig": {}} + + expected = gca_index.Index( + display_name=_TEST_INDEX_DISPLAY_NAME, + metadata={ + "config": { + "algorithmConfig": config, + "dimensions": _TEST_INDEX_CONFIG_DIMENSIONS, + "approximateNeighborsCount": None, + "distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE, + }, + "contentsDeltaUri": _TEST_CONTENTS_DELTA_URI, + }, + description=_TEST_INDEX_DESCRIPTION, + labels=_TEST_LABELS, + ) + + create_index_mock.assert_called_once_with( + parent=_TEST_PARENT, + index=expected, + metadata=_TEST_REQUEST_METADATA, + ) + @pytest.mark.usefixtures("get_index_mock") def test_remove_datapoints(self, remove_datapoints_mock): aiplatform.init(project=_TEST_PROJECT)