-
Notifications
You must be signed in to change notification settings - Fork 348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support model monitoring for batch prediction in Vertex SDK #1570
Changes from 4 commits
4a12af8
f647786
cf1e03b
f35f160
c6be116
9b4f25e
9047f10
aa680b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -385,6 +385,13 @@ def create( | |
sync: bool = True, | ||
create_request_timeout: Optional[float] = None, | ||
batch_size: Optional[int] = None, | ||
model_monitoring_objective_config: Optional[ | ||
"aiplatform.model_monitoring.ObjectiveConfig" | ||
] = None, | ||
model_monitoring_alert_config: Optional[ | ||
"aiplatform.model_monitoring.AlertConfig" | ||
] = None, | ||
analysis_instance_schema_uri: Optional[str] = None, | ||
) -> "BatchPredictionJob": | ||
"""Create a batch prediction job. | ||
|
||
|
@@ -551,6 +558,23 @@ def create( | |
but too high value will result in a whole batch not fitting in a machine's memory, | ||
and the whole operation will fail. | ||
The default value is 64. | ||
model_monitoring_objective_config (aiplatform.model_monitoring.ObjectiveConfig): | ||
Optional. The objective config for model monitoring. Passing this parameter enables | ||
monitoring on the model associated with this batch prediction job. | ||
model_monitoring_alert_config (aiplatform.model_monitoring.EmailAlertConfig): | ||
Optional. Configures how model monitoring alerts are sent to the user. Right now | ||
only email alert is supported. | ||
analysis_instance_schema_uri (str): | ||
Optional. Only applicable if model_monitoring_objective_config is also passed. | ||
This parameter specifies the YAML schema file uri describing the format of a single | ||
instance that you want Tensorflow Data Validation (TFDV) to | ||
analyze. If this field is empty, all the feature data types are | ||
inferred from predict_instance_schema_uri, meaning that TFDV | ||
will use the data in the exact format as prediction request/response. | ||
If there are any data type differences between predict instance | ||
and TFDV instance, this field can be used to override the schema. | ||
For models trained with Vertex AI, this field must be set as all the | ||
fields in predict instance formatted as string. | ||
Returns: | ||
(jobs.BatchPredictionJob): | ||
Instantiated representation of the created batch prediction job. | ||
|
@@ -601,7 +625,18 @@ def create( | |
f"{predictions_format} is not an accepted prediction format " | ||
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}" | ||
) | ||
|
||
# TODO: remove temporary import statements once model monitoring for batch prediction is GA | ||
if model_monitoring_objective_config: | ||
from google.cloud.aiplatform.compat.types import ( | ||
io_v1beta1 as gca_io_compat, | ||
batch_prediction_job_v1beta1 as gca_bp_job_compat, | ||
model_monitoring_v1beta1 as gca_model_monitoring_compat, | ||
) | ||
else: | ||
from google.cloud.aiplatform.compat.types import ( | ||
io as gca_io_compat, | ||
batch_prediction_job as gca_bp_job_compat, | ||
) | ||
gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob() | ||
|
||
# Required Fields | ||
|
@@ -688,6 +723,28 @@ def create( | |
) | ||
) | ||
|
||
# Model Monitoring | ||
if model_monitoring_objective_config: | ||
if model_monitoring_objective_config.drift_detection_config: | ||
_LOGGER.info( | ||
"Drift detection config is currently not supported for monitoring models associated with batch prediction jobs." | ||
) | ||
if model_monitoring_objective_config.explanation_config: | ||
_LOGGER.info( | ||
"XAI config is currently not supported for monitoring models associated with batch prediction jobs." | ||
) | ||
gapic_batch_prediction_job.model_monitoring_config = ( | ||
gca_model_monitoring_compat.ModelMonitoringConfig( | ||
objective_configs=[ | ||
model_monitoring_objective_config.as_proto(config_for_bp=True) | ||
], | ||
alert_config=model_monitoring_alert_config.as_proto( | ||
config_for_bp=True | ||
), | ||
analysis_instance_schema_uri=analysis_instance_schema_uri, | ||
) | ||
) | ||
|
||
empty_batch_prediction_job = cls._empty_constructor( | ||
project=project, | ||
location=location, | ||
|
@@ -702,6 +759,11 @@ def create( | |
sync=sync, | ||
create_request_timeout=create_request_timeout, | ||
) | ||
# TODO: remove temporary re-import statements once model monitoring for batch prediction is GA | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since there are several of these throughout the PR, perhaps add a tracking bug number so the TODO comments look like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. buganizer component created & assigned |
||
from google.cloud.aiplatform.compat.types import ( | ||
io as gca_io_compat, | ||
batch_prediction_job as gca_bp_job_compat, | ||
) | ||
|
||
@classmethod | ||
@base.optional_sync(return_input_arg="empty_batch_prediction_job") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,9 +17,16 @@ | |
|
||
from typing import Optional, List | ||
from google.cloud.aiplatform_v1.types import ( | ||
model_monitoring as gca_model_monitoring, | ||
model_monitoring as gca_model_monitoring_v1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) | ||
|
||
# TODO: remove imports from v1beta1 once model monitoring for batch prediction is GA | ||
from google.cloud.aiplatform_v1beta1.types import ( | ||
model_monitoring as gca_model_monitoring_v1beta1, | ||
) | ||
|
||
gca_model_monitoring = gca_model_monitoring_v1 | ||
|
||
|
||
class EmailAlertConfig: | ||
def __init__( | ||
|
@@ -40,8 +47,19 @@ def __init__( | |
self.enable_logging = enable_logging | ||
self.user_emails = user_emails | ||
|
||
def as_proto(self): | ||
"""Returns EmailAlertConfig as a proto message.""" | ||
# TODO: remove config_for_bp parameter when model monitoring for batch prediction is GA | ||
def as_proto(self, config_for_bp: Optional[bool] = False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This typing annotation should be just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated |
||
"""Returns EmailAlertConfig as a proto message. | ||
|
||
Args: | ||
config_for_bp (bool): | ||
Optional. Set this parameter to True if the config object | ||
is used for model monitoring on a batch prediction job. | ||
""" | ||
if config_for_bp: | ||
gca_model_monitoring = gca_model_monitoring_v1beta1 | ||
else: | ||
gca_model_monitoring = gca_model_monitoring_v1 | ||
user_email_alert_config = ( | ||
gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig( | ||
user_emails=self.user_emails | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,11 +18,19 @@ | |
from typing import Optional, Dict | ||
|
||
from google.cloud.aiplatform_v1.types import ( | ||
io as gca_io, | ||
ThresholdConfig as gca_threshold_config, | ||
model_monitoring as gca_model_monitoring, | ||
io as gca_io_v1, | ||
model_monitoring as gca_model_monitoring_v1, | ||
) | ||
|
||
# TODO: remove imports from v1beta1 once model monitoring for batch prediction is GA | ||
from google.cloud.aiplatform_v1beta1.types import ( | ||
io as gca_io_v1beta1, | ||
model_monitoring as gca_model_monitoring_v1beta1, | ||
) | ||
|
||
gca_model_monitoring = gca_model_monitoring_v1 | ||
gca_io = gca_io_v1 | ||
|
||
TF_RECORD = "tf-record" | ||
CSV = "csv" | ||
JSONL = "jsonl" | ||
|
@@ -80,19 +88,20 @@ def __init__( | |
self.attribute_skew_thresholds = attribute_skew_thresholds | ||
self.data_format = data_format | ||
self.target_field = target_field | ||
self.training_dataset = None | ||
|
||
def as_proto(self): | ||
"""Returns _SkewDetectionConfig as a proto message.""" | ||
skew_thresholds_mapping = {} | ||
attribution_score_skew_thresholds_mapping = {} | ||
if self.skew_thresholds is not None: | ||
for key in self.skew_thresholds.keys(): | ||
skew_threshold = gca_threshold_config(value=self.skew_thresholds[key]) | ||
skew_threshold = gca_model_monitoring.ThresholdConfig( | ||
value=self.skew_thresholds[key] | ||
) | ||
skew_thresholds_mapping[key] = skew_threshold | ||
if self.attribute_skew_thresholds is not None: | ||
for key in self.attribute_skew_thresholds.keys(): | ||
attribution_score_skew_threshold = gca_threshold_config( | ||
attribution_score_skew_threshold = gca_model_monitoring.ThresholdConfig( | ||
value=self.attribute_skew_thresholds[key] | ||
) | ||
attribution_score_skew_thresholds_mapping[ | ||
|
@@ -134,12 +143,16 @@ def as_proto(self): | |
attribution_score_drift_thresholds_mapping = {} | ||
if self.drift_thresholds is not None: | ||
for key in self.drift_thresholds.keys(): | ||
drift_threshold = gca_threshold_config(value=self.drift_thresholds[key]) | ||
drift_threshold = gca_model_monitoring.ThresholdConfig( | ||
value=self.drift_thresholds[key] | ||
) | ||
drift_thresholds_mapping[key] = drift_threshold | ||
if self.attribute_drift_thresholds is not None: | ||
for key in self.attribute_drift_thresholds.keys(): | ||
attribution_score_drift_threshold = gca_threshold_config( | ||
value=self.attribute_drift_thresholds[key] | ||
attribution_score_drift_threshold = ( | ||
gca_model_monitoring.ThresholdConfig( | ||
value=self.attribute_drift_thresholds[key] | ||
) | ||
) | ||
attribution_score_drift_thresholds_mapping[ | ||
key | ||
|
@@ -186,11 +199,49 @@ def __init__( | |
self.drift_detection_config = drift_detection_config | ||
self.explanation_config = explanation_config | ||
|
||
def as_proto(self): | ||
"""Returns _ObjectiveConfig as a proto message.""" | ||
# TODO: remove config_for_bp parameter when model monitoring for batch prediction is feature complete and in GA | ||
def as_proto(self, config_for_bp: Optional[bool] = False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above about typing annotation here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated |
||
"""Returns _SkewDetectionConfig as a proto message. | ||
|
||
Args: | ||
config_for_bp (bool): | ||
Optional. Set this parameter to True if the config object | ||
is used for model monitoring on a batch prediction job. | ||
""" | ||
if config_for_bp: | ||
gca_io = gca_io_v1beta1 | ||
gca_model_monitoring = gca_model_monitoring_v1beta1 | ||
else: | ||
gca_io = gca_io_v1 | ||
gca_model_monitoring = gca_model_monitoring_v1 | ||
training_dataset = None | ||
if self.skew_detection_config is not None: | ||
training_dataset = self.skew_detection_config.training_dataset | ||
training_dataset = ( | ||
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( | ||
target_field=self.skew_detection_config.target_field | ||
) | ||
) | ||
if self.skew_detection_config.data_source.startswith("bq:/"): | ||
training_dataset.bigquery_source = gca_io.BigQuerySource( | ||
input_uri=self.skew_detection_config.data_source | ||
) | ||
elif self.skew_detection_config.data_source.startswith("gs:/"): | ||
training_dataset.gcs_source = gca_io.GcsSource( | ||
uris=[self.skew_detection_config.data_source] | ||
) | ||
if ( | ||
self.skew_detection_config.data_format is not None | ||
and self.skew_detection_config.data_format | ||
not in [TF_RECORD, CSV, JSONL] | ||
): | ||
raise ValueError( | ||
"Unsupported value in skew detection config. `data_format` must be one of %s, %s, or %s" | ||
% (TF_RECORD, CSV, JSONL) | ||
) | ||
training_dataset.data_format = self.skew_detection_config.data_format | ||
else: | ||
training_dataset.dataset = self.skew_detection_config.data_source | ||
|
||
return gca_model_monitoring.ModelMonitoringObjectiveConfig( | ||
training_dataset=training_dataset, | ||
training_prediction_skew_detection_config=self.skew_detection_config.as_proto() | ||
|
@@ -271,27 +322,6 @@ def __init__( | |
data_format, | ||
) | ||
|
||
training_dataset = ( | ||
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( | ||
target_field=target_field | ||
) | ||
) | ||
if data_source.startswith("bq:/"): | ||
training_dataset.bigquery_source = gca_io.BigQuerySource( | ||
input_uri=data_source | ||
) | ||
elif data_source.startswith("gs:/"): | ||
training_dataset.gcs_source = gca_io.GcsSource(uris=[data_source]) | ||
if data_format is not None and data_format not in [TF_RECORD, CSV, JSONL]: | ||
raise ValueError( | ||
"Unsupported value. `data_format` must be one of %s, %s, or %s" | ||
% (TF_RECORD, CSV, JSONL) | ||
) | ||
training_dataset.data_format = data_format | ||
else: | ||
training_dataset.dataset = data_source | ||
self.training_dataset = training_dataset | ||
|
||
|
||
class DriftDetectionConfig(_DriftDetectionConfig): | ||
"""A class that configures prediction drift detection for models deployed to an endpoint. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,11 @@ | |
from google.api_core import exceptions as core_exceptions | ||
from tests.system.aiplatform import e2e_base | ||
|
||
from google.cloud.aiplatform_v1.types import ( | ||
io as gca_io, | ||
model_monitoring as gca_model_monitoring, | ||
) | ||
|
||
# constants used for testing | ||
USER_EMAIL = "" | ||
MODEL_NAME = "churn" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this used as the model's If this is indeed intended as a display name of a resource used only for testing, please prefix it with "temp". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated the variable names to make it more specific. But the extra prefixing is unnecessary because the actual display names are being created by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. In that case can we rename this further and call it something like MODEL_DISPLAYNAME_KEY to prevent the next engineer reading this to misinterpret as the actual model displayname? More generally, what’s the meaning of “churn” in the context of testing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've updated the variable name. Regarding the actual value string, "churn" is just a shorthand used to reference the BQ dataset. The same dataset and pre-trained model was used in the example notebook on the Cloud SDK documentation page and a separate example for BQML. |
||
|
@@ -224,7 +229,14 @@ def test_mdm_one_model_one_valid_config(self, shared_state): | |
gca_obj_config = gapic_job.model_deployment_monitoring_objective_configs[ | ||
0 | ||
].objective_config | ||
assert gca_obj_config.training_dataset == skew_config.training_dataset | ||
|
||
expected_training_dataset = ( | ||
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( | ||
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI), | ||
target_field=TARGET, | ||
) | ||
) | ||
assert gca_obj_config.training_dataset == expected_training_dataset | ||
assert ( | ||
gca_obj_config.training_prediction_skew_detection_config | ||
== skew_config.as_proto() | ||
|
@@ -290,12 +302,18 @@ def test_mdm_two_models_two_valid_configs(self, shared_state): | |
) | ||
assert gapic_job.model_monitoring_alert_config.enable_logging | ||
|
||
expected_training_dataset = ( | ||
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( | ||
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI), | ||
target_field=TARGET, | ||
) | ||
) | ||
|
||
for config in gapic_job.model_deployment_monitoring_objective_configs: | ||
gca_obj_config = config.objective_config | ||
deployed_model_id = config.deployed_model_id | ||
assert ( | ||
gca_obj_config.training_dataset | ||
== all_configs[deployed_model_id].skew_detection_config.training_dataset | ||
gca_obj_config.as_proto().training_dataset == expected_training_dataset | ||
) | ||
assert ( | ||
gca_obj_config.training_prediction_skew_detection_config | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these docstrings copied from the source at https://github.com/googleapis/googleapis/tree/master/google/cloud/aiplatform? Will we be able to remember to update this when/if alerts other than email alert become supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're not directly copied from GAPIC. But Jing's team also confirmed that there's no plans for additional alert configs.