Skip to content

Commit

Permalink
feat: Adding support for concurrent explanations
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586740015
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 30, 2023
1 parent ae3677c commit 8e2ad75
Show file tree
Hide file tree
Showing 2 changed files with 338 additions and 0 deletions.
198 changes: 198 additions & 0 deletions google/cloud/aiplatform/preview/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,36 @@
endpoint_service_client,
)
from google.cloud.aiplatform.compat.types import (
prediction_service_v1beta1 as gca_prediction_service_compat,
deployed_model_ref_v1beta1 as gca_deployed_model_ref_compat,
deployment_resource_pool_v1beta1 as gca_deployment_resource_pool_compat,
explanation_v1beta1 as gca_explanation_compat,
endpoint_v1beta1 as gca_endpoint_compat,
machine_resources_v1beta1 as gca_machine_resources_compat,
model_v1 as gca_model_compat,
)
from google.protobuf import json_format

_DEFAULT_MACHINE_TYPE = "n1-standard-2"

_LOGGER = base.Logger(__name__)


class Prediction(models.Prediction):
"""Prediction class envelopes returned Model predictions and the Model id.
Attributes:
concurrent_explanations:
Map of explanations that were requested concurrently in addition to
the default explanation for the Model's predictions. It has the same
number of elements as instances to be explained. Default is None.
"""

concurrent_explanations: Optional[
Dict[str, Sequence[gca_explanation_compat.Explanation]]
] = None


class DeploymentResourcePool(base.VertexAiResourceNounWithFutureManager):
client_class = utils.DeploymentResourcePoolClientWithOverride
_resource_noun = "deploymentResourcePools"
Expand Down Expand Up @@ -1013,6 +1031,186 @@ def _deploy_call(

operation_future.result(timeout=None)

def explain(
self,
instances: List[Dict],
parameters: Optional[Dict] = None,
deployed_model_id: Optional[str] = None,
timeout: Optional[float] = None,
explanation_spec_override: Optional[Dict] = None,
concurrent_explanation_spec_override: Optional[Dict] = None,
) -> Prediction:
"""Make a prediction with explanations against this Endpoint.
Example usage:
response = my_endpoint.explain(instances=[...])
my_explanations = response.explanations
Args:
instances (List):
Required. The instances that are the input to the
prediction call. A DeployedModel may have an upper limit
on the number of instances it supports per request, and
when it is exceeded the prediction call errors in case
of AutoML Models, or, in case of customer created
Models, the behaviour is as documented by that Model.
The schema of any single instance may be specified via
Endpoint's DeployedModels'
[Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model]
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``instance_schema_uri``.
parameters (Dict):
The parameters that govern the prediction. The schema of
the parameters may be specified via Endpoint's
DeployedModels' [Model's
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``parameters_schema_uri``.
deployed_model_id (str):
Optional. If specified, this ExplainRequest will be served by the
chosen DeployedModel, overriding this Endpoint's traffic split.
timeout (float): Optional. The timeout for this request in seconds.
explanation_spec_override (Dict):
Optional. Represents overrides to the explaination
specification used when the model was deployed.
The Explanation Override will
be merged with model's existing [Explanation Spec
][google.cloud.aiplatform.v1beta1.ExplanationSpec].
concurrent_explanation_spec_override (Dict):
Optional. The ``explain`` endpoint supports multiple
explanations in parallel. To request concurrent explanation in
addition to the configured explaination method, use this field.
Returns:
prediction (aiplatform.Prediction):
Prediction with returned predictions, explanations, and Model ID.
"""
self.wait()
request = gca_prediction_service_compat.ExplainRequest()

if instances is not None:
request.instances.extend(instances)
if parameters is not None:
request.parameters = parameters
if deployed_model_id is not None:
request.deployed_model_id = deployed_model_id
if explanation_spec_override is not None:
request.explanation_spec_override = explanation_spec_override
if concurrent_explanation_spec_override is not None:
request.concurrent_explanation_spec_override = (
concurrent_explanation_spec_override
)

explain_response = self._prediction_client.select_version("v1beta1").explain(
request, timeout=timeout
)

prediction = Prediction(
predictions=[
json_format.MessageToDict(item)
for item in explain_response.predictions.pb
],
deployed_model_id=explain_response.deployed_model_id,
explanations=explain_response.explanations,
)

concurrent_explanation = {}
for k, e in explain_response.concurrent_explanations.items():
concurrent_explanation[k] = e.explanations

prediction.concurrent_explanations = concurrent_explanation

return prediction

async def explain_async(
self,
instances: List[Dict],
*,
parameters: Optional[Dict] = None,
deployed_model_id: Optional[str] = None,
timeout: Optional[float] = None,
explanation_spec_override: Optional[Dict] = None,
concurrent_explanation_spec_override: Optional[Dict] = None,
) -> Prediction:
"""Make a prediction with explanations against this Endpoint.
Example usage:
```
response = await my_endpoint.explain_async(instances=[...])
my_explanations = response.explanations
```
Args:
instances (List):
Required. The instances that are the input to the
prediction call. A DeployedModel may have an upper limit
on the number of instances it supports per request, and
when it is exceeded the prediction call errors in case
of AutoML Models, or, in case of customer created
Models, the behaviour is as documented by that Model.
The schema of any single instance may be specified via
Endpoint's DeployedModels'
[Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model]
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``instance_schema_uri``.
parameters (Dict):
The parameters that govern the prediction. The schema of
the parameters may be specified via Endpoint's
DeployedModels' [Model's
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``parameters_schema_uri``.
deployed_model_id (str):
Optional. If specified, this ExplainRequest will be served by the
chosen DeployedModel, overriding this Endpoint's traffic split.
timeout (float): Optional. The timeout for this request in seconds.
explanation_spec_override (Dict):
Optional. Represents overrides to the explaination
specification used when the model was deployed.
The Explanation Override will
be merged with model's existing [Explanation Spec
][google.cloud.aiplatform.v1beta1.ExplanationSpec].
concurrent_explanation_spec_override (Dict):
Optional. The ``explain`` endpoint supports multiple
explanations in parallel. To request concurrent explanation in
addition to the configured explaination method, use this field.
Returns:
prediction (aiplatform.Prediction):
Prediction with returned predictions, explanations, and Model ID.
"""
self.wait()

request = gca_prediction_service_compat.ExplainRequest(
endpoint=self.resource_name,
instances=instances,
parameters=parameters,
deployed_model_id=deployed_model_id,
explanation_spec_override=explanation_spec_override,
concurrent_explanation_spec_override=concurrent_explanation_spec_override,
)

explain_response = await self._prediction_async_client.select_version(
"v1beta1"
).explain(request, timeout=timeout)

prediction = Prediction(
predictions=[
json_format.MessageToDict(item)
for item in explain_response.predictions.pb
],
deployed_model_id=explain_response.deployed_model_id,
explanations=explain_response.explanations,
)

concurrent_explanation = {}
for k, e in explain_response.concurrent_explanations.items():
concurrent_explanation[k] = e.explanations

prediction.concurrent_explanations = concurrent_explanation

return prediction


class Model(aiplatform.Model):
def deploy(
Expand Down
Loading

0 comments on commit 8e2ad75

Please sign in to comment.