Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support model monitoring for batch prediction in Vertex SDK #1570

Merged
merged 8 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ def create(
sync: bool = True,
create_request_timeout: Optional[float] = None,
batch_size: Optional[int] = None,
model_monitoring_objective_config: Optional[
"aiplatform.model_monitoring.ObjectiveConfig"
] = None,
model_monitoring_alert_config: Optional[
"aiplatform.model_monitoring.AlertConfig"
] = None,
analysis_instance_schema_uri: Optional[str] = None,
) -> "BatchPredictionJob":
"""Create a batch prediction job.

Expand Down Expand Up @@ -551,6 +558,23 @@ def create(
but too high value will result in a whole batch not fitting in a machine's memory,
and the whole operation will fail.
The default value is 64.
model_monitoring_objective_config (aiplatform.model_monitoring.ObjectiveConfig):
Optional. The objective config for model monitoring. Passing this parameter enables
monitoring on the model associated with this batch prediction job.
model_monitoring_alert_config (aiplatform.model_monitoring.EmailAlertConfig):
Optional. Configures how model monitoring alerts are sent to the user. Right now
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these docstrings copied from the source at https://github.com/googleapis/googleapis/tree/master/google/cloud/aiplatform? Will we be able to remember to update this when/if alerts other than email alert become supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're not directly copied from GAPIC. But Jing's team also confirmed that there's no plans for additional alert configs.

only email alert is supported.
analysis_instance_schema_uri (str):
Optional. Only applicable if model_monitoring_objective_config is also passed.
This parameter specifies the YAML schema file uri describing the format of a single
instance that you want Tensorflow Data Validation (TFDV) to
analyze. If this field is empty, all the feature data types are
inferred from predict_instance_schema_uri, meaning that TFDV
will use the data in the exact format as prediction request/response.
If there are any data type differences between predict instance
and TFDV instance, this field can be used to override the schema.
For models trained with Vertex AI, this field must be set as all the
fields in predict instance formatted as string.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
Expand Down Expand Up @@ -601,7 +625,18 @@ def create(
f"{predictions_format} is not an accepted prediction format "
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
)

# TODO: remove temporary import statements once model monitoring for batch prediction is GA
if model_monitoring_objective_config:
from google.cloud.aiplatform.compat.types import (
io_v1beta1 as gca_io_compat,
batch_prediction_job_v1beta1 as gca_bp_job_compat,
model_monitoring_v1beta1 as gca_model_monitoring_compat,
)
else:
from google.cloud.aiplatform.compat.types import (
io as gca_io_compat,
batch_prediction_job as gca_bp_job_compat,
)
gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob()

# Required Fields
Expand Down Expand Up @@ -688,6 +723,28 @@ def create(
)
)

# Model Monitoring
if model_monitoring_objective_config:
if model_monitoring_objective_config.drift_detection_config:
_LOGGER.info(
"Drift detection config is currently not supported for monitoring models associated with batch prediction jobs."
)
if model_monitoring_objective_config.explanation_config:
_LOGGER.info(
"XAI config is currently not supported for monitoring models associated with batch prediction jobs."
)
gapic_batch_prediction_job.model_monitoring_config = (
gca_model_monitoring_compat.ModelMonitoringConfig(
objective_configs=[
model_monitoring_objective_config.as_proto(config_for_bp=True)
],
alert_config=model_monitoring_alert_config.as_proto(
config_for_bp=True
),
analysis_instance_schema_uri=analysis_instance_schema_uri,
)
)

empty_batch_prediction_job = cls._empty_constructor(
project=project,
location=location,
Expand All @@ -702,6 +759,11 @@ def create(
sync=sync,
create_request_timeout=create_request_timeout,
)
# TODO: remove temporary re-import statements once model monitoring for batch prediction is GA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are several of these throughout the PR, perhaps add a tracking bug number so the TODO comments look like # TODO(b/.....): remove... with a common bug number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

buganizer component created & assigned

from google.cloud.aiplatform.compat.types import (
io as gca_io_compat,
batch_prediction_job as gca_bp_job_compat,
)

@classmethod
@base.optional_sync(return_input_arg="empty_batch_prediction_job")
Expand Down
24 changes: 21 additions & 3 deletions google/cloud/aiplatform/model_monitoring/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@

from typing import Optional, List
from google.cloud.aiplatform_v1.types import (
model_monitoring as gca_model_monitoring,
model_monitoring as gca_model_monitoring_v1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use compat for these imports?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compat by default imports from GA version of GAPIC, unless if DEFAULT_VERSION is set to v1beta1. I tried to see if the import aliases will index into the correct version by simply setting DEFAULT_VERSION = 'v1beta1' and then switching it back to v1 on an ad-hoc basis, but it doesn't dynamically index in the way I was hoping for. I think it's because the symbol table isn't automatically re-written unless if we explicitly re-import. So that's why I imported both v1 and v1beta1 versions explicitly.

)

# TODO: remove imports from v1beta1 once model monitoring for batch prediction is GA
from google.cloud.aiplatform_v1beta1.types import (
model_monitoring as gca_model_monitoring_v1beta1,
)

gca_model_monitoring = gca_model_monitoring_v1


class EmailAlertConfig:
def __init__(
Expand All @@ -40,8 +47,19 @@ def __init__(
self.enable_logging = enable_logging
self.user_emails = user_emails

def as_proto(self):
"""Returns EmailAlertConfig as a proto message."""
# TODO: remove config_for_bp parameter when model monitoring for batch prediction is GA
def as_proto(self, config_for_bp: Optional[bool] = False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This typing annotation should be just config_for_bp: bool = False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

"""Returns EmailAlertConfig as a proto message.

Args:
config_for_bp (bool):
Optional. Set this parameter to True if the config object
is used for model monitoring on a batch prediction job.
"""
if config_for_bp:
gca_model_monitoring = gca_model_monitoring_v1beta1
else:
gca_model_monitoring = gca_model_monitoring_v1
user_email_alert_config = (
gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig(
user_emails=self.user_emails
Expand Down
96 changes: 63 additions & 33 deletions google/cloud/aiplatform/model_monitoring/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@
from typing import Optional, Dict

from google.cloud.aiplatform_v1.types import (
io as gca_io,
ThresholdConfig as gca_threshold_config,
model_monitoring as gca_model_monitoring,
io as gca_io_v1,
model_monitoring as gca_model_monitoring_v1,
)

# TODO: remove imports from v1beta1 once model monitoring for batch prediction is GA
from google.cloud.aiplatform_v1beta1.types import (
io as gca_io_v1beta1,
model_monitoring as gca_model_monitoring_v1beta1,
)

gca_model_monitoring = gca_model_monitoring_v1
gca_io = gca_io_v1

TF_RECORD = "tf-record"
CSV = "csv"
JSONL = "jsonl"
Expand Down Expand Up @@ -80,19 +88,20 @@ def __init__(
self.attribute_skew_thresholds = attribute_skew_thresholds
self.data_format = data_format
self.target_field = target_field
self.training_dataset = None

def as_proto(self):
"""Returns _SkewDetectionConfig as a proto message."""
skew_thresholds_mapping = {}
attribution_score_skew_thresholds_mapping = {}
if self.skew_thresholds is not None:
for key in self.skew_thresholds.keys():
skew_threshold = gca_threshold_config(value=self.skew_thresholds[key])
skew_threshold = gca_model_monitoring.ThresholdConfig(
value=self.skew_thresholds[key]
)
skew_thresholds_mapping[key] = skew_threshold
if self.attribute_skew_thresholds is not None:
for key in self.attribute_skew_thresholds.keys():
attribution_score_skew_threshold = gca_threshold_config(
attribution_score_skew_threshold = gca_model_monitoring.ThresholdConfig(
value=self.attribute_skew_thresholds[key]
)
attribution_score_skew_thresholds_mapping[
Expand Down Expand Up @@ -134,12 +143,16 @@ def as_proto(self):
attribution_score_drift_thresholds_mapping = {}
if self.drift_thresholds is not None:
for key in self.drift_thresholds.keys():
drift_threshold = gca_threshold_config(value=self.drift_thresholds[key])
drift_threshold = gca_model_monitoring.ThresholdConfig(
value=self.drift_thresholds[key]
)
drift_thresholds_mapping[key] = drift_threshold
if self.attribute_drift_thresholds is not None:
for key in self.attribute_drift_thresholds.keys():
attribution_score_drift_threshold = gca_threshold_config(
value=self.attribute_drift_thresholds[key]
attribution_score_drift_threshold = (
gca_model_monitoring.ThresholdConfig(
value=self.attribute_drift_thresholds[key]
)
)
attribution_score_drift_thresholds_mapping[
key
Expand Down Expand Up @@ -186,11 +199,49 @@ def __init__(
self.drift_detection_config = drift_detection_config
self.explanation_config = explanation_config

def as_proto(self):
"""Returns _ObjectiveConfig as a proto message."""
# TODO: remove config_for_bp parameter when model monitoring for batch prediction is feature complete and in GA
def as_proto(self, config_for_bp: Optional[bool] = False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above about typing annotation here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

"""Returns _SkewDetectionConfig as a proto message.

Args:
config_for_bp (bool):
Optional. Set this parameter to True if the config object
is used for model monitoring on a batch prediction job.
"""
if config_for_bp:
gca_io = gca_io_v1beta1
gca_model_monitoring = gca_model_monitoring_v1beta1
else:
gca_io = gca_io_v1
gca_model_monitoring = gca_model_monitoring_v1
training_dataset = None
if self.skew_detection_config is not None:
training_dataset = self.skew_detection_config.training_dataset
training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
target_field=self.skew_detection_config.target_field
)
)
if self.skew_detection_config.data_source.startswith("bq:/"):
training_dataset.bigquery_source = gca_io.BigQuerySource(
input_uri=self.skew_detection_config.data_source
)
elif self.skew_detection_config.data_source.startswith("gs:/"):
training_dataset.gcs_source = gca_io.GcsSource(
uris=[self.skew_detection_config.data_source]
)
if (
self.skew_detection_config.data_format is not None
and self.skew_detection_config.data_format
not in [TF_RECORD, CSV, JSONL]
):
raise ValueError(
"Unsupported value in skew detection config. `data_format` must be one of %s, %s, or %s"
% (TF_RECORD, CSV, JSONL)
)
training_dataset.data_format = self.skew_detection_config.data_format
else:
training_dataset.dataset = self.skew_detection_config.data_source

return gca_model_monitoring.ModelMonitoringObjectiveConfig(
training_dataset=training_dataset,
training_prediction_skew_detection_config=self.skew_detection_config.as_proto()
Expand Down Expand Up @@ -271,27 +322,6 @@ def __init__(
data_format,
)

training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
target_field=target_field
)
)
if data_source.startswith("bq:/"):
training_dataset.bigquery_source = gca_io.BigQuerySource(
input_uri=data_source
)
elif data_source.startswith("gs:/"):
training_dataset.gcs_source = gca_io.GcsSource(uris=[data_source])
if data_format is not None and data_format not in [TF_RECORD, CSV, JSONL]:
raise ValueError(
"Unsupported value. `data_format` must be one of %s, %s, or %s"
% (TF_RECORD, CSV, JSONL)
)
training_dataset.data_format = data_format
else:
training_dataset.dataset = data_source
self.training_dataset = training_dataset


class DriftDetectionConfig(_DriftDetectionConfig):
"""A class that configures prediction drift detection for models deployed to an endpoint.
Expand Down
24 changes: 21 additions & 3 deletions tests/system/aiplatform/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
from google.api_core import exceptions as core_exceptions
from tests.system.aiplatform import e2e_base

from google.cloud.aiplatform_v1.types import (
io as gca_io,
model_monitoring as gca_model_monitoring,
)

# constants used for testing
USER_EMAIL = ""
MODEL_NAME = "churn"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used as the model's display_name? if so maybe better call it MODEL_DISPLAY_NAME, since "model name" could be understood as its resource full name.

If this is indeed intended as a display name of a resource used only for testing, please prefix it with "temp".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the variable names to make it more specific. But the extra prefixing is unnecessary because the actual display names are being created by _make_display_name implemented in e2e_base.py, which does append a prefix set by the class (in this case the prefix is 'temp_e2e_model_monitoring_test_')

Copy link
Contributor

@dizcology dizcology Aug 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. In that case can we rename this further and call it something like MODEL_DISPLAYNAME_KEY to prevent the next engineer reading this to misinterpret as the actual model displayname? More generally, what’s the meaning of “churn” in the context of testing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the variable name. Regarding the actual value string, "churn" is just a shorthand used to reference the BQ dataset. The same dataset and pre-trained model was used in the example notebook on the Cloud SDK documentation page and a separate example for BQML.

Expand Down Expand Up @@ -224,7 +229,14 @@ def test_mdm_one_model_one_valid_config(self, shared_state):
gca_obj_config = gapic_job.model_deployment_monitoring_objective_configs[
0
].objective_config
assert gca_obj_config.training_dataset == skew_config.training_dataset

expected_training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
target_field=TARGET,
)
)
assert gca_obj_config.training_dataset == expected_training_dataset
assert (
gca_obj_config.training_prediction_skew_detection_config
== skew_config.as_proto()
Expand Down Expand Up @@ -290,12 +302,18 @@ def test_mdm_two_models_two_valid_configs(self, shared_state):
)
assert gapic_job.model_monitoring_alert_config.enable_logging

expected_training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
target_field=TARGET,
)
)

for config in gapic_job.model_deployment_monitoring_objective_configs:
gca_obj_config = config.objective_config
deployed_model_id = config.deployed_model_id
assert (
gca_obj_config.training_dataset
== all_configs[deployed_model_id].skew_detection_config.training_dataset
gca_obj_config.as_proto().training_dataset == expected_training_dataset
)
assert (
gca_obj_config.training_prediction_skew_detection_config
Expand Down
24 changes: 23 additions & 1 deletion tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def test_batch_prediction_iter_dirs_invalid_output_info(self):
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_batch_prediction_job_mock")
def test_batch_predict_gcs_source_and_dest(
def test_batch_predict_gcs_source_and_dest_with_monitoring(
self, create_batch_prediction_job_mock, sync
):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
Expand All @@ -521,13 +521,22 @@ def test_batch_predict_gcs_source_and_dest(
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
sync=sync,
create_request_timeout=None,
model_monitoring_objective_config=aiplatform.model_monitoring.ObjectiveConfig(),
model_monitoring_alert_config=aiplatform.model_monitoring.EmailAlertConfig(),
)

batch_prediction_job.wait_for_resource_creation()

batch_prediction_job.wait()

# Construct expected request
# TODO: remove temporary import statements once model monitoring for batch prediction is GA
from google.cloud.aiplatform.compat.types import (
io_v1beta1 as gca_io_compat,
batch_prediction_job_v1beta1 as gca_batch_prediction_job_compat,
model_monitoring_v1beta1 as gca_model_monitoring_compat,
)

expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
model=_TEST_MODEL_NAME,
Expand All @@ -543,13 +552,26 @@ def test_batch_predict_gcs_source_and_dest(
),
predictions_format="jsonl",
),
model_monitoring_config=gca_model_monitoring_compat.ModelMonitoringConfig(
alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig(
email_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig()
),
objective_configs=[
gca_model_monitoring_compat.ModelMonitoringObjectiveConfig()
],
),
)

create_batch_prediction_job_mock.assert_called_once_with(
parent=_TEST_PARENT,
batch_prediction_job=expected_gapic_batch_prediction_job,
timeout=None,
)
# TODO: remove temporary import statements once model monitoring for batch prediction is GA
from google.cloud.aiplatform.compat.types import (
io as gca_io_compat,
batch_prediction_job as gca_batch_prediction_job_compat,
)

@mock.patch.object(jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(jobs, "_LOG_WAIT_TIME", 1)
Expand Down
Loading