Skip to content

Commit

Permalink
fix: fix server error due to no encryption_spec_key_name in Matchin…
Browse files Browse the repository at this point in the history
…gEngineIndex `create_tree_ah_index` and `create_brute_force_index`

PiperOrigin-RevId: 582880161
  • Loading branch information
lingyinw authored and copybara-github committed Nov 16, 2023
1 parent dd4b852 commit 595b580
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,14 @@ def _create(
"contentsDeltaUri": contents_delta_uri,
},
index_update_method=index_update_method_enum,
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=encryption_spec_key_name
),
)

if encryption_spec_key_name:
encryption_spec = gca_encryption_spec.EncryptionSpec(
kms_key_name=encryption_spec_key_name
)
gapic_index.encryption_spec = encryption_spec

if labels:
utils.validate_labels(labels)
gapic_index.labels = labels
Expand Down
82 changes: 81 additions & 1 deletion tests/unit/aiplatform/test_matching_engine_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,50 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_mock")
def test_create_tree_ah_index_backward_compatibility(self, create_index_mock):
aiplatform.init(project=_TEST_PROJECT)

aiplatform.MatchingEngineIndex.create_tree_ah_index(
display_name=_TEST_INDEX_DISPLAY_NAME,
contents_delta_uri=_TEST_CONTENTS_DELTA_URI,
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
leaf_node_embedding_count=_TEST_LEAF_NODE_EMBEDDING_COUNT,
leaf_nodes_to_search_percent=_TEST_LEAF_NODES_TO_SEARCH_PERCENT,
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
)

config = {
"treeAhConfig": {
"leafNodeEmbeddingCount": _TEST_LEAF_NODE_EMBEDDING_COUNT,
"leafNodesToSearchPercent": _TEST_LEAF_NODES_TO_SEARCH_PERCENT,
}
}

expected = gca_index.Index(
display_name=_TEST_INDEX_DISPLAY_NAME,
metadata={
"config": {
"algorithmConfig": config,
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
},
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
},
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
)

create_index_mock.assert_called_once_with(
parent=_TEST_PARENT,
index=expected,
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_mock")
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -419,7 +463,7 @@ def test_create_brute_force_index(
index_update_method
],
encryption_spec=gca_encryption_spec.EncryptionSpec(
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
),
)

Expand All @@ -429,6 +473,42 @@ def test_create_brute_force_index(
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_mock")
def test_create_brute_force_index_backward_compatibility(self, create_index_mock):
aiplatform.init(project=_TEST_PROJECT)

aiplatform.MatchingEngineIndex.create_brute_force_index(
display_name=_TEST_INDEX_DISPLAY_NAME,
contents_delta_uri=_TEST_CONTENTS_DELTA_URI,
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
)

config = {"bruteForceConfig": {}}

expected = gca_index.Index(
display_name=_TEST_INDEX_DISPLAY_NAME,
metadata={
"config": {
"algorithmConfig": config,
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
"approximateNeighborsCount": None,
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
},
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
},
description=_TEST_INDEX_DESCRIPTION,
labels=_TEST_LABELS,
)

create_index_mock.assert_called_once_with(
parent=_TEST_PARENT,
index=expected,
metadata=_TEST_REQUEST_METADATA,
)

@pytest.mark.usefixtures("get_index_mock")
def test_remove_datapoints(self, remove_datapoints_mock):
aiplatform.init(project=_TEST_PROJECT)
Expand Down

0 comments on commit 595b580

Please sign in to comment.