From bbec998ea71aa342fee08d0d5fa115ab36a6f60f Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Tue, 16 Aug 2022 15:44:00 -0700 Subject: [PATCH] feat: support model monitoring for batch prediction in Vertex SDK (#1570) * feat: support model monitoring for batch prediction in Vertex SDK * fixed broken tests * fixing syntax error * addressed comments * updated test variable name --- google/cloud/aiplatform/jobs.py | 64 ++++++++++++- .../aiplatform/model_monitoring/alert.py | 24 ++++- .../aiplatform/model_monitoring/objective.py | 96 ++++++++++++------- .../aiplatform/test_model_monitoring.py | 36 +++++-- tests/unit/aiplatform/test_jobs.py | 24 ++++- .../unit/aiplatform/test_model_monitoring.py | 38 +++++--- 6 files changed, 224 insertions(+), 58 deletions(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 5d1ac4075c..203362d7a1 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -385,6 +385,13 @@ def create( sync: bool = True, create_request_timeout: Optional[float] = None, batch_size: Optional[int] = None, + model_monitoring_objective_config: Optional[ + "aiplatform.model_monitoring.ObjectiveConfig" + ] = None, + model_monitoring_alert_config: Optional[ + "aiplatform.model_monitoring.AlertConfig" + ] = None, + analysis_instance_schema_uri: Optional[str] = None, ) -> "BatchPredictionJob": """Create a batch prediction job. @@ -551,6 +558,23 @@ def create( but too high value will result in a whole batch not fitting in a machine's memory, and the whole operation will fail. The default value is 64. + model_monitoring_objective_config (aiplatform.model_monitoring.ObjectiveConfig): + Optional. The objective config for model monitoring. Passing this parameter enables + monitoring on the model associated with this batch prediction job. + model_monitoring_alert_config (aiplatform.model_monitoring.EmailAlertConfig): + Optional. Configures how model monitoring alerts are sent to the user. Right now + only email alert is supported. + analysis_instance_schema_uri (str): + Optional. Only applicable if model_monitoring_objective_config is also passed. + This parameter specifies the YAML schema file uri describing the format of a single + instance that you want Tensorflow Data Validation (TFDV) to + analyze. If this field is empty, all the feature data types are + inferred from predict_instance_schema_uri, meaning that TFDV + will use the data in the exact format as prediction request/response. + If there are any data type differences between predict instance + and TFDV instance, this field can be used to override the schema. + For models trained with Vertex AI, this field must be set as all the + fields in predict instance formatted as string. Returns: (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. @@ -601,7 +625,18 @@ def create( f"{predictions_format} is not an accepted prediction format " f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}" ) - + # TODO: remove temporary import statements once model monitoring for batch prediction is GA + if model_monitoring_objective_config: + from google.cloud.aiplatform.compat.types import ( + io_v1beta1 as gca_io_compat, + batch_prediction_job_v1beta1 as gca_bp_job_compat, + model_monitoring_v1beta1 as gca_model_monitoring_compat, + ) + else: + from google.cloud.aiplatform.compat.types import ( + io as gca_io_compat, + batch_prediction_job as gca_bp_job_compat, + ) gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob() # Required Fields @@ -688,6 +723,28 @@ def create( ) ) + # Model Monitoring + if model_monitoring_objective_config: + if model_monitoring_objective_config.drift_detection_config: + _LOGGER.info( + "Drift detection config is currently not supported for monitoring models associated with batch prediction jobs." + ) + if model_monitoring_objective_config.explanation_config: + _LOGGER.info( + "XAI config is currently not supported for monitoring models associated with batch prediction jobs." + ) + gapic_batch_prediction_job.model_monitoring_config = ( + gca_model_monitoring_compat.ModelMonitoringConfig( + objective_configs=[ + model_monitoring_objective_config.as_proto(config_for_bp=True) + ], + alert_config=model_monitoring_alert_config.as_proto( + config_for_bp=True + ), + analysis_instance_schema_uri=analysis_instance_schema_uri, + ) + ) + empty_batch_prediction_job = cls._empty_constructor( project=project, location=location, @@ -702,6 +759,11 @@ def create( sync=sync, create_request_timeout=create_request_timeout, ) + # TODO: b/242108750 + from google.cloud.aiplatform.compat.types import ( + io as gca_io_compat, + batch_prediction_job as gca_bp_job_compat, + ) @classmethod @base.optional_sync(return_input_arg="empty_batch_prediction_job") diff --git a/google/cloud/aiplatform/model_monitoring/alert.py b/google/cloud/aiplatform/model_monitoring/alert.py index 9eed27ec21..929be280f6 100644 --- a/google/cloud/aiplatform/model_monitoring/alert.py +++ b/google/cloud/aiplatform/model_monitoring/alert.py @@ -17,9 +17,16 @@ from typing import Optional, List from google.cloud.aiplatform_v1.types import ( - model_monitoring as gca_model_monitoring, + model_monitoring as gca_model_monitoring_v1, ) +# TODO: remove imports from v1beta1 once model monitoring for batch prediction is GA +from google.cloud.aiplatform_v1beta1.types import ( + model_monitoring as gca_model_monitoring_v1beta1, +) + +gca_model_monitoring = gca_model_monitoring_v1 + class EmailAlertConfig: def __init__( @@ -40,8 +47,19 @@ def __init__( self.enable_logging = enable_logging self.user_emails = user_emails - def as_proto(self): - """Returns EmailAlertConfig as a proto message.""" + # TODO: remove config_for_bp parameter when model monitoring for batch prediction is GA + def as_proto(self, config_for_bp: bool = False): + """Returns EmailAlertConfig as a proto message. + + Args: + config_for_bp (bool): + Optional. Set this parameter to True if the config object + is used for model monitoring on a batch prediction job. + """ + if config_for_bp: + gca_model_monitoring = gca_model_monitoring_v1beta1 + else: + gca_model_monitoring = gca_model_monitoring_v1 user_email_alert_config = ( gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig( user_emails=self.user_emails diff --git a/google/cloud/aiplatform/model_monitoring/objective.py b/google/cloud/aiplatform/model_monitoring/objective.py index a7800a3485..1d7e847eb2 100644 --- a/google/cloud/aiplatform/model_monitoring/objective.py +++ b/google/cloud/aiplatform/model_monitoring/objective.py @@ -18,11 +18,19 @@ from typing import Optional, Dict from google.cloud.aiplatform_v1.types import ( - io as gca_io, - ThresholdConfig as gca_threshold_config, - model_monitoring as gca_model_monitoring, + io as gca_io_v1, + model_monitoring as gca_model_monitoring_v1, ) +# TODO: b/242108750 +from google.cloud.aiplatform_v1beta1.types import ( + io as gca_io_v1beta1, + model_monitoring as gca_model_monitoring_v1beta1, +) + +gca_model_monitoring = gca_model_monitoring_v1 +gca_io = gca_io_v1 + TF_RECORD = "tf-record" CSV = "csv" JSONL = "jsonl" @@ -80,7 +88,6 @@ def __init__( self.attribute_skew_thresholds = attribute_skew_thresholds self.data_format = data_format self.target_field = target_field - self.training_dataset = None def as_proto(self): """Returns _SkewDetectionConfig as a proto message.""" @@ -88,11 +95,13 @@ def as_proto(self): attribution_score_skew_thresholds_mapping = {} if self.skew_thresholds is not None: for key in self.skew_thresholds.keys(): - skew_threshold = gca_threshold_config(value=self.skew_thresholds[key]) + skew_threshold = gca_model_monitoring.ThresholdConfig( + value=self.skew_thresholds[key] + ) skew_thresholds_mapping[key] = skew_threshold if self.attribute_skew_thresholds is not None: for key in self.attribute_skew_thresholds.keys(): - attribution_score_skew_threshold = gca_threshold_config( + attribution_score_skew_threshold = gca_model_monitoring.ThresholdConfig( value=self.attribute_skew_thresholds[key] ) attribution_score_skew_thresholds_mapping[ @@ -134,12 +143,16 @@ def as_proto(self): attribution_score_drift_thresholds_mapping = {} if self.drift_thresholds is not None: for key in self.drift_thresholds.keys(): - drift_threshold = gca_threshold_config(value=self.drift_thresholds[key]) + drift_threshold = gca_model_monitoring.ThresholdConfig( + value=self.drift_thresholds[key] + ) drift_thresholds_mapping[key] = drift_threshold if self.attribute_drift_thresholds is not None: for key in self.attribute_drift_thresholds.keys(): - attribution_score_drift_threshold = gca_threshold_config( - value=self.attribute_drift_thresholds[key] + attribution_score_drift_threshold = ( + gca_model_monitoring.ThresholdConfig( + value=self.attribute_drift_thresholds[key] + ) ) attribution_score_drift_thresholds_mapping[ key @@ -186,11 +199,49 @@ def __init__( self.drift_detection_config = drift_detection_config self.explanation_config = explanation_config - def as_proto(self): - """Returns _ObjectiveConfig as a proto message.""" + # TODO: b/242108750 + def as_proto(self, config_for_bp: bool = False): + """Returns _SkewDetectionConfig as a proto message. + + Args: + config_for_bp (bool): + Optional. Set this parameter to True if the config object + is used for model monitoring on a batch prediction job. + """ + if config_for_bp: + gca_io = gca_io_v1beta1 + gca_model_monitoring = gca_model_monitoring_v1beta1 + else: + gca_io = gca_io_v1 + gca_model_monitoring = gca_model_monitoring_v1 training_dataset = None if self.skew_detection_config is not None: - training_dataset = self.skew_detection_config.training_dataset + training_dataset = ( + gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( + target_field=self.skew_detection_config.target_field + ) + ) + if self.skew_detection_config.data_source.startswith("bq:/"): + training_dataset.bigquery_source = gca_io.BigQuerySource( + input_uri=self.skew_detection_config.data_source + ) + elif self.skew_detection_config.data_source.startswith("gs:/"): + training_dataset.gcs_source = gca_io.GcsSource( + uris=[self.skew_detection_config.data_source] + ) + if ( + self.skew_detection_config.data_format is not None + and self.skew_detection_config.data_format + not in [TF_RECORD, CSV, JSONL] + ): + raise ValueError( + "Unsupported value in skew detection config. `data_format` must be one of %s, %s, or %s" + % (TF_RECORD, CSV, JSONL) + ) + training_dataset.data_format = self.skew_detection_config.data_format + else: + training_dataset.dataset = self.skew_detection_config.data_source + return gca_model_monitoring.ModelMonitoringObjectiveConfig( training_dataset=training_dataset, training_prediction_skew_detection_config=self.skew_detection_config.as_proto() @@ -271,27 +322,6 @@ def __init__( data_format, ) - training_dataset = ( - gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( - target_field=target_field - ) - ) - if data_source.startswith("bq:/"): - training_dataset.bigquery_source = gca_io.BigQuerySource( - input_uri=data_source - ) - elif data_source.startswith("gs:/"): - training_dataset.gcs_source = gca_io.GcsSource(uris=[data_source]) - if data_format is not None and data_format not in [TF_RECORD, CSV, JSONL]: - raise ValueError( - "Unsupported value. `data_format` must be one of %s, %s, or %s" - % (TF_RECORD, CSV, JSONL) - ) - training_dataset.data_format = data_format - else: - training_dataset.dataset = data_source - self.training_dataset = training_dataset - class DriftDetectionConfig(_DriftDetectionConfig): """A class that configures prediction drift detection for models deployed to an endpoint. diff --git a/tests/system/aiplatform/test_model_monitoring.py b/tests/system/aiplatform/test_model_monitoring.py index 77c60153fe..95675c0f20 100644 --- a/tests/system/aiplatform/test_model_monitoring.py +++ b/tests/system/aiplatform/test_model_monitoring.py @@ -24,10 +24,15 @@ from google.api_core import exceptions as core_exceptions from tests.system.aiplatform import e2e_base +from google.cloud.aiplatform_v1.types import ( + io as gca_io, + model_monitoring as gca_model_monitoring, +) + # constants used for testing USER_EMAIL = "" -MODEL_NAME = "churn" -MODEL_NAME2 = "churn2" +MODEL_DISPLAYNAME_KEY = "churn" +MODEL_DISPLAYNAME_KEY2 = "churn2" IMAGE = "us-docker.pkg.dev/cloud-aiplatform/prediction/tf2-cpu.2-5:latest" ENDPOINT = "us-central1-aiplatform.googleapis.com" CHURN_MODEL_PATH = "gs://mco-mm/churn" @@ -139,7 +144,7 @@ def temp_endpoint(self, shared_state): ) model = aiplatform.Model.upload( - display_name=self._make_display_name(key=MODEL_NAME), + display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY), artifact_uri=CHURN_MODEL_PATH, serving_container_image_uri=IMAGE, ) @@ -157,19 +162,19 @@ def temp_endpoint_with_two_models(self, shared_state): ) model1 = aiplatform.Model.upload( - display_name=self._make_display_name(key=MODEL_NAME), + display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY), artifact_uri=CHURN_MODEL_PATH, serving_container_image_uri=IMAGE, ) model2 = aiplatform.Model.upload( - display_name=self._make_display_name(key=MODEL_NAME), + display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY2), artifact_uri=CHURN_MODEL_PATH, serving_container_image_uri=IMAGE, ) shared_state["resources"] = [model1, model2] endpoint = aiplatform.Endpoint.create( - display_name=self._make_display_name(key=MODEL_NAME) + display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY) ) endpoint.deploy( model=model1, machine_type="n1-standard-2", traffic_percentage=100 @@ -224,7 +229,14 @@ def test_mdm_one_model_one_valid_config(self, shared_state): gca_obj_config = gapic_job.model_deployment_monitoring_objective_configs[ 0 ].objective_config - assert gca_obj_config.training_dataset == skew_config.training_dataset + + expected_training_dataset = ( + gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( + bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI), + target_field=TARGET, + ) + ) + assert gca_obj_config.training_dataset == expected_training_dataset assert ( gca_obj_config.training_prediction_skew_detection_config == skew_config.as_proto() @@ -297,12 +309,18 @@ def test_mdm_two_models_two_valid_configs(self, shared_state): ) assert gapic_job.model_monitoring_alert_config.enable_logging + expected_training_dataset = ( + gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( + bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI), + target_field=TARGET, + ) + ) + for config in gapic_job.model_deployment_monitoring_objective_configs: gca_obj_config = config.objective_config deployed_model_id = config.deployed_model_id assert ( - gca_obj_config.training_dataset - == all_configs[deployed_model_id].skew_detection_config.training_dataset + gca_obj_config.as_proto().training_dataset == expected_training_dataset ) assert ( gca_obj_config.training_prediction_skew_detection_config diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index ca1d68fe35..7381481b30 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -508,7 +508,7 @@ def test_batch_prediction_iter_dirs_invalid_output_info(self): @mock.patch.object(jobs, "_LOG_WAIT_TIME", 1) @pytest.mark.parametrize("sync", [True, False]) @pytest.mark.usefixtures("get_batch_prediction_job_mock") - def test_batch_predict_gcs_source_and_dest( + def test_batch_predict_gcs_source_and_dest_with_monitoring( self, create_batch_prediction_job_mock, sync ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) @@ -521,6 +521,8 @@ def test_batch_predict_gcs_source_and_dest( gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=sync, create_request_timeout=None, + model_monitoring_objective_config=aiplatform.model_monitoring.ObjectiveConfig(), + model_monitoring_alert_config=aiplatform.model_monitoring.EmailAlertConfig(), ) batch_prediction_job.wait_for_resource_creation() @@ -528,6 +530,13 @@ def test_batch_predict_gcs_source_and_dest( batch_prediction_job.wait() # Construct expected request + # TODO: remove temporary import statements once model monitoring for batch prediction is GA + from google.cloud.aiplatform.compat.types import ( + io_v1beta1 as gca_io_compat, + batch_prediction_job_v1beta1 as gca_batch_prediction_job_compat, + model_monitoring_v1beta1 as gca_model_monitoring_compat, + ) + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, model=_TEST_MODEL_NAME, @@ -543,6 +552,14 @@ def test_batch_predict_gcs_source_and_dest( ), predictions_format="jsonl", ), + model_monitoring_config=gca_model_monitoring_compat.ModelMonitoringConfig( + alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig( + email_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig() + ), + objective_configs=[ + gca_model_monitoring_compat.ModelMonitoringObjectiveConfig() + ], + ), ) create_batch_prediction_job_mock.assert_called_once_with( @@ -550,6 +567,11 @@ def test_batch_predict_gcs_source_and_dest( batch_prediction_job=expected_gapic_batch_prediction_job, timeout=None, ) + # TODO: remove temporary import statements once model monitoring for batch prediction is GA + from google.cloud.aiplatform.compat.types import ( + io as gca_io_compat, + batch_prediction_job as gca_batch_prediction_job_compat, + ) @mock.patch.object(jobs, "_JOB_WAIT_TIME", 1) @mock.patch.object(jobs, "_LOG_WAIT_TIME", 1) diff --git a/tests/unit/aiplatform/test_model_monitoring.py b/tests/unit/aiplatform/test_model_monitoring.py index a29aa060db..87f5e0848b 100644 --- a/tests/unit/aiplatform/test_model_monitoring.py +++ b/tests/unit/aiplatform/test_model_monitoring.py @@ -19,6 +19,11 @@ from google.cloud.aiplatform import model_monitoring +from google.cloud.aiplatform_v1.types import ( + io as gca_io, + model_monitoring as gca_model_monitoring, +) + _TEST_THRESHOLD = 0.1 _TEST_TARGET_FIELD = "target" _TEST_BQ_DATASOURCE = "bq://test/data" @@ -63,6 +68,13 @@ def test_valid_configs(self, data_source, data_format): data_format=data_format, ) + expected_training_dataset = ( + gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( + bigquery_source=gca_io.BigQuerySource(input_uri=_TEST_BQ_DATASOURCE), + target_field=_TEST_TARGET_FIELD, + ) + ) + xai_config = model_monitoring.ExplanationConfig() objective_config = model_monitoring.ObjectiveConfig( @@ -71,9 +83,11 @@ def test_valid_configs(self, data_source, data_format): explanation_config=xai_config, ) - assert ( - objective_config.as_proto().training_dataset == skew_config.training_dataset - ) + if data_source == _TEST_BQ_DATASOURCE: + assert ( + objective_config.as_proto().training_dataset + == expected_training_dataset + ) assert ( objective_config.as_proto().training_prediction_skew_detection_config == skew_config.as_proto() @@ -99,14 +113,16 @@ def test_valid_configs(self, data_source, data_format): def test_invalid_data_format(self, data_source, data_format): if data_format == "other": with pytest.raises(ValueError) as e: - model_monitoring.SkewDetectionConfig( - data_source=data_source, - skew_thresholds={_TEST_KEY: _TEST_THRESHOLD}, - target_field=_TEST_TARGET_FIELD, - attribute_skew_thresholds={_TEST_KEY: _TEST_THRESHOLD}, - data_format=data_format, - ) + model_monitoring.ObjectiveConfig( + skew_detection_config=model_monitoring.SkewDetectionConfig( + data_source=data_source, + skew_thresholds={_TEST_KEY: _TEST_THRESHOLD}, + target_field=_TEST_TARGET_FIELD, + attribute_skew_thresholds={_TEST_KEY: _TEST_THRESHOLD}, + data_format=data_format, + ) + ).as_proto() assert ( - "Unsupported value. `data_format` must be one of tf-record, csv, or jsonl" + "Unsupported value in skew detection config. `data_format` must be one of tf-record, csv, or jsonl" in str(e.value) )