Skip to content

Commit

Permalink
feat: add sdk support for xai example-based explanations
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 545803031
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jul 5, 2023
1 parent 718f04b commit f9ca1d5
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 0 deletions.
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/explain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
SampledShapleyAttribution = explanation_compat.SampledShapleyAttribution
SmoothGradConfig = explanation_compat.SmoothGradConfig
XraiAttribution = explanation_compat.XraiAttribution
Presets = explanation_compat.Presets
Examples = explanation_compat.Examples


__all__ = (
Expand All @@ -58,4 +60,6 @@
"SmoothGradConfig",
"Visualization",
"XraiAttribution",
"Presets",
"Examples",
)
127 changes: 127 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,59 @@
_TEST_EXPLANATION_PARAMETERS = (
test_constants.ModelConstants._TEST_EXPLANATION_PARAMETERS
)
_TEST_EXPLANATION_METADATA_EXAMPLES = explain.ExplanationMetadata(
outputs={"embedding": {"output_tensor_name": "embedding"}},
inputs={
"my_input": {
"input_tensor_name": "bytes_inputs",
"encoding": "IDENTITY",
"modality": "image",
},
"id": {"input_tensor_name": "id", "encoding": "IDENTITY"},
},
)
_TEST_EXPLANATION_PARAMETERS_EXAMPLES_PRESETS = explain.ExplanationParameters(
{
"examples": {
"example_gcs_source": {
"gcs_source": {
"uris": ["gs://example-bucket/folder/instance1.jsonl"],
},
},
"neighbor_count": 10,
"presets": {"query": "FAST", "modality": "TEXT"},
}
}
)
_TEST_EXPLANATION_PARAMETERS_EXAMPLES_FULL_CONFIG = explain.ExplanationParameters(
{
"examples": {
"example_gcs_source": {
"gcs_source": {
"uris": ["gs://example-bucket/folder/instance1.jsonl"],
},
},
"neighbor_count": 10,
"nearest_neighbor_search_config": [
{
"contentsDeltaUri": "",
"config": {
"dimensions": 50,
"approximateNeighborsCount": 10,
"distanceMeasureType": "SQUARED_L2_DISTANCE",
"featureNormType": "NONE",
"algorithmConfig": {
"treeAhConfig": {
"leafNodeEmbeddingCount": 1000,
"leafNodesToSearchPercent": 100,
}
},
},
}
],
}
}
)

# CMEK encryption
_TEST_ENCRYPTION_KEY_NAME = "key_1234"
Expand Down Expand Up @@ -1119,6 +1172,80 @@ def test_upload_with_parameters_without_metadata(
timeout=None,
)

@pytest.mark.parametrize("sync", [True, False])
def test_upload_with_parameters_for_examples_presets(
self, upload_model_mock, get_model_mock, sync
):
my_model = models.Model.upload(
display_name=_TEST_MODEL_NAME,
serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
explanation_parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_PRESETS,
explanation_metadata=_TEST_EXPLANATION_METADATA_EXAMPLES,
sync=sync,
)

if not sync:
my_model.wait()

container_spec = gca_model.ModelContainerSpec(
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
)

managed_model = gca_model.Model(
display_name=_TEST_MODEL_NAME,
container_spec=container_spec,
explanation_spec=gca_model.explanation.ExplanationSpec(
metadata=_TEST_EXPLANATION_METADATA_EXAMPLES,
parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_PRESETS,
),
version_aliases=["default"],
)

upload_model_mock.assert_called_once_with(
request=gca_model_service.UploadModelRequest(
parent=initializer.global_config.common_location_path(),
model=managed_model,
),
timeout=None,
)

@pytest.mark.parametrize("sync", [True, False])
def test_upload_with_parameters_for_examples_full_config(
self, upload_model_mock, get_model_mock, sync
):
my_model = models.Model.upload(
display_name=_TEST_MODEL_NAME,
serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
explanation_parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_FULL_CONFIG,
explanation_metadata=_TEST_EXPLANATION_METADATA_EXAMPLES,
sync=sync,
)

if not sync:
my_model.wait()

container_spec = gca_model.ModelContainerSpec(
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
)

managed_model = gca_model.Model(
display_name=_TEST_MODEL_NAME,
container_spec=container_spec,
explanation_spec=gca_model.explanation.ExplanationSpec(
metadata=_TEST_EXPLANATION_METADATA_EXAMPLES,
parameters=_TEST_EXPLANATION_PARAMETERS_EXAMPLES_FULL_CONFIG,
),
version_aliases=["default"],
)

upload_model_mock.assert_called_once_with(
request=gca_model_service.UploadModelRequest(
parent=initializer.global_config.common_location_path(),
model=managed_model,
),
timeout=None,
)

@pytest.mark.parametrize("sync", [True, False])
def test_upload_uploads_and_gets_model_with_all_args(
self, upload_model_mock, get_model_mock, sync
Expand Down

0 comments on commit f9ca1d5

Please sign in to comment.