diff --git a/google/cloud/aiplatform/explain/__init__.py b/google/cloud/aiplatform/explain/__init__.py index 8167e80a4a..ec7f15f003 100644 --- a/google/cloud/aiplatform/explain/__init__.py +++ b/google/cloud/aiplatform/explain/__init__.py @@ -42,6 +42,8 @@ SampledShapleyAttribution = explanation_compat.SampledShapleyAttribution SmoothGradConfig = explanation_compat.SmoothGradConfig XraiAttribution = explanation_compat.XraiAttribution +Presets = explanation_compat.Presets +Examples = explanation_compat.Examples __all__ = ( @@ -58,4 +60,6 @@ "SmoothGradConfig", "Visualization", "XraiAttribution", + "Presets", + "Examples", ) diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index e444070c98..842ab70b5c 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -152,6 +152,59 @@ _TEST_EXPLANATION_PARAMETERS = ( test_constants.ModelConstants._TEST_EXPLANATION_PARAMETERS ) +_TEST_EXPLANATION_METADATA_EXAMPLES = explain.ExplanationMetadata( + outputs={"embedding": {"output_tensor_name": "embedding"}}, + inputs={ + "my_input": { + "input_tensor_name": "bytes_inputs", + "encoding": "IDENTITY", + "modality": "image", + }, + "id": {"input_tensor_name": "id", "encoding": "IDENTITY"}, + }, +) +_TEST_EXPLANATION_PARAMETERS_EXAMPLES_PRESETS = explain.ExplanationParameters( + { + "examples": { + "example_gcs_source": { + "gcs_source": { + "uris": ["gs://example-bucket/folder/instance1.jsonl"], + }, + }, + "neighbor_count": 10, + "presets": {"query": "FAST", "modality": "TEXT"}, + } + } +) +_TEST_EXPLANATION_PARAMETERS_EXAMPLES_FULL_CONFIG = explain.ExplanationParameters( + { + "examples": { + "example_gcs_source": { + "gcs_source": { + "uris": ["gs://example-bucket/folder/instance1.jsonl"], + }, + }, + "neighbor_count": 10, + "nearest_neighbor_search_config": [ + { + "contentsDeltaUri": "", + "config": { + "dimensions": 50, + "approximateNeighborsCount": 10, + "distanceMeasureType": "SQUARED_L2_DISTANCE", + "featureNormType": "NONE", + "algorithmConfig": { + "treeAhConfig": { + "leafNodeEmbeddingCount": 1000, + "leafNodesToSearchPercent": 100, + } + }, + }, + } + ], + } + } +) # CMEK encryption _TEST_ENCRYPTION_KEY_NAME = "key_1234" @@ -1119,6 +1172,80 @@ def test_upload_with_parameters_without_metadata( timeout=None, ) + @pytest.mark.parametrize("sync", [True, False]) + def test_upload_with_parameters_for_examples_presets( + self, upload_model_mock, get_model_mock, sync + ): + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_PRESETS, + explanation_metadata=_TEST_EXPLANATION_METADATA_EXAMPLES, + sync=sync, + ) + + if not sync: + my_model.wait() + + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + ) + + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, + container_spec=container_spec, + explanation_spec=gca_model.explanation.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA_EXAMPLES, + parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_PRESETS, + ), + version_aliases=["default"], + ) + + upload_model_mock.assert_called_once_with( + request=gca_model_service.UploadModelRequest( + parent=initializer.global_config.common_location_path(), + model=managed_model, + ), + timeout=None, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_upload_with_parameters_for_examples_full_config( + self, upload_model_mock, get_model_mock, sync + ): + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_FULL_CONFIG, + explanation_metadata=_TEST_EXPLANATION_METADATA_EXAMPLES, + sync=sync, + ) + + if not sync: + my_model.wait() + + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + ) + + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, + container_spec=container_spec, + explanation_spec=gca_model.explanation.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA_EXAMPLES, + parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_FULL_CONFIG, + ), + version_aliases=["default"], + ) + + upload_model_mock.assert_called_once_with( + request=gca_model_service.UploadModelRequest( + parent=initializer.global_config.common_location_path(), + model=managed_model, + ), + timeout=None, + ) + @pytest.mark.parametrize("sync", [True, False]) def test_upload_uploads_and_gets_model_with_all_args( self, upload_model_mock, get_model_mock, sync