From 7da4164697ac01ac94a45b34086facfd0d360f1b Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Mon, 19 Dec 2022 19:14:20 -0800 Subject: [PATCH] feat: Add default skew threshold to be an optional input at _SkewDetectionConfig and also mark the target_field and data_source of skew config to optional. PiperOrigin-RevId: 496543878 --- .../aiplatform/model_monitoring/objective.py | 63 +++++++++------ .../unit/aiplatform/test_model_monitoring.py | 78 ++++++++++++++++--- 2 files changed, 106 insertions(+), 35 deletions(-) diff --git a/google/cloud/aiplatform/model_monitoring/objective.py b/google/cloud/aiplatform/model_monitoring/objective.py index 89916417d1..d81ff72633 100644 --- a/google/cloud/aiplatform/model_monitoring/objective.py +++ b/google/cloud/aiplatform/model_monitoring/objective.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Optional, Dict +from typing import Optional, Dict, Union from google.cloud.aiplatform_v1.types import ( io as gca_io_v1, @@ -39,27 +39,30 @@ class _SkewDetectionConfig: def __init__( self, - data_source: str, - skew_thresholds: Dict[str, float], - target_field: str, - attribute_skew_thresholds: Dict[str, float], + data_source: Optional[str] = None, + skew_thresholds: Union[Dict[str, float], float, None] = None, + target_field: Optional[str] = None, + attribute_skew_thresholds: Optional[Dict[str, float]] = None, data_format: Optional[str] = None, ): """Base class for training-serving skew detection. Args: data_source (str): - Required. Path to training dataset. + Optional. Path to training dataset. - skew_thresholds (Dict[str, float]): + skew_thresholds: Union[Dict[str, float], float, None]: Optional. Key is the feature name and value is the threshold. If a feature needs to be monitored for skew, a value threshold must be configured for that feature. The threshold here is against feature distribution distance between the - training and prediction feature. + training and prediction feature. If a float is passed, + then all features will be monitored using the same + threshold. If None is passed, all feature will be monitored + using alert threshold 0.3 (Backend default). target_field (str): - Required. The target field name the model is to + Optional. The target field name the model is to predict. This field will be excluded when doing Predict and (or) Explain for the training data. @@ -93,12 +96,18 @@ def as_proto(self): """Returns _SkewDetectionConfig as a proto message.""" skew_thresholds_mapping = {} attribution_score_skew_thresholds_mapping = {} + default_skew_threshold = None if self.skew_thresholds is not None: - for key in self.skew_thresholds.keys(): - skew_threshold = gca_model_monitoring.ThresholdConfig( - value=self.skew_thresholds[key] + if isinstance(self.skew_thresholds, float): + default_skew_threshold = gca_model_monitoring.ThresholdConfig( + value=self.skew_thresholds ) - skew_thresholds_mapping[key] = skew_threshold + else: + for key in self.skew_thresholds.keys(): + skew_threshold = gca_model_monitoring.ThresholdConfig( + value=self.skew_thresholds[key] + ) + skew_thresholds_mapping[key] = skew_threshold if self.attribute_skew_thresholds is not None: for key in self.attribute_skew_thresholds.keys(): attribution_score_skew_threshold = gca_model_monitoring.ThresholdConfig( @@ -110,6 +119,7 @@ def as_proto(self): return gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig( skew_thresholds=skew_thresholds_mapping, attribution_score_skew_thresholds=attribution_score_skew_thresholds_mapping, + default_skew_threshold=default_skew_threshold, ) @@ -266,9 +276,9 @@ class SkewDetectionConfig(_SkewDetectionConfig): def __init__( self, - data_source: str, - target_field: str, - skew_thresholds: Optional[Dict[str, float]] = None, + data_source: Optional[str] = None, + target_field: Optional[str] = None, + skew_thresholds: Union[Dict[str, float], float, None] = None, attribute_skew_thresholds: Optional[Dict[str, float]] = None, data_format: Optional[str] = None, ): @@ -276,20 +286,23 @@ def __init__( Args: data_source (str): - Required. Path to training dataset. + Optional. Path to training dataset. target_field (str): - Required. The target field name the model is to + Optional. The target field name the model is to predict. This field will be excluded when doing Predict and (or) Explain for the training data. - skew_thresholds (Dict[str, float]): + skew_thresholds: Union[Dict[str, float], float, None]: Optional. Key is the feature name and value is the threshold. If a feature needs to be monitored for skew, a value threshold must be configured for that feature. The threshold here is against feature distribution distance between the - training and prediction feature. + training and prediction feature. If a float is passed, + then all features will be monitored using the same + threshold. If None is passed, all feature will be monitored + using alert threshold 0.3 (Backend default). attribute_skew_thresholds (Dict[str, float]): Optional. Key is the feature name and value is the @@ -315,11 +328,11 @@ def __init__( ValueError for unsupported data formats. """ super().__init__( - data_source, - skew_thresholds, - target_field, - attribute_skew_thresholds, - data_format, + data_source=data_source, + skew_thresholds=skew_thresholds, + target_field=target_field, + attribute_skew_thresholds=attribute_skew_thresholds, + data_format=data_format, ) diff --git a/tests/unit/aiplatform/test_model_monitoring.py b/tests/unit/aiplatform/test_model_monitoring.py index 87f5e0848b..4d81a04c6d 100644 --- a/tests/unit/aiplatform/test_model_monitoring.py +++ b/tests/unit/aiplatform/test_model_monitoring.py @@ -24,26 +24,79 @@ model_monitoring as gca_model_monitoring, ) -_TEST_THRESHOLD = 0.1 _TEST_TARGET_FIELD = "target" _TEST_BQ_DATASOURCE = "bq://test/data" _TEST_GCS_DATASOURCE = "gs://test/data" _TEST_OTHER_DATASOURCE = "" -_TEST_KEY = "key" +_TEST_DRIFT_TRESHOLD = {"key": 0.2} _TEST_EMAIL1 = "test1" _TEST_EMAIL2 = "test2" _TEST_VALID_DATA_FORMATS = ["tf-record", "csv", "jsonl"] _TEST_SAMPLING_RATE = 0.8 _TEST_MONITORING_INTERVAL = 1 +_TEST_SKEW_THRESHOLDS = [None, 0.2, {"key": 0.1}] +_TEST_ATTRIBUTE_SKEW_THRESHOLDS = [None, {"key": 0.1}] class TestModelMonitoringConfigs: + """Tests for model monitoring configs.""" + @pytest.mark.parametrize( "data_source", [_TEST_BQ_DATASOURCE, _TEST_GCS_DATASOURCE, _TEST_OTHER_DATASOURCE], ) @pytest.mark.parametrize("data_format", _TEST_VALID_DATA_FORMATS) - def test_valid_configs(self, data_source, data_format): + @pytest.mark.parametrize("skew_thresholds", _TEST_SKEW_THRESHOLDS) + def test_skew_config_proto_value(self, data_source, data_format, skew_thresholds): + """Tests if skew config can be constrctued properly to gapic proto.""" + attribute_skew_thresholds = {"key": 0.1} + skew_config = model_monitoring.SkewDetectionConfig( + data_source=data_source, + skew_thresholds=skew_thresholds, + target_field=_TEST_TARGET_FIELD, + attribute_skew_thresholds=attribute_skew_thresholds, + data_format=data_format, + ) + # data_format and data source are not used at + # TrainingPredictionSkewDetectionConfig. + if isinstance(skew_thresholds, dict): + expected_gapic_proto = gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig( + skew_thresholds={ + key: gca_model_monitoring.ThresholdConfig(value=val) + for key, val in skew_thresholds.items() + }, + attribution_score_skew_thresholds={ + key: gca_model_monitoring.ThresholdConfig(value=val) + for key, val in attribute_skew_thresholds.items() + }, + ) + else: + expected_gapic_proto = gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig( + default_skew_threshold=gca_model_monitoring.ThresholdConfig( + value=skew_thresholds + ) + if skew_thresholds is not None + else None, + attribution_score_skew_thresholds={ + key: gca_model_monitoring.ThresholdConfig(value=val) + for key, val in attribute_skew_thresholds.items() + }, + ) + assert skew_config.as_proto() == expected_gapic_proto + + @pytest.mark.parametrize( + "data_source", + [_TEST_BQ_DATASOURCE, _TEST_GCS_DATASOURCE, _TEST_OTHER_DATASOURCE], + ) + @pytest.mark.parametrize("data_format", _TEST_VALID_DATA_FORMATS) + @pytest.mark.parametrize("skew_thresholds", _TEST_SKEW_THRESHOLDS) + @pytest.mark.parametrize( + "attribute_skew_thresholds", _TEST_ATTRIBUTE_SKEW_THRESHOLDS + ) + def test_valid_configs( + self, data_source, data_format, skew_thresholds, attribute_skew_thresholds + ): + """Test config creation validity.""" random_sample_config = model_monitoring.RandomSampleConfig( sample_rate=_TEST_SAMPLING_RATE ) @@ -57,17 +110,16 @@ def test_valid_configs(self, data_source, data_format): ) prediction_drift_config = model_monitoring.DriftDetectionConfig( - drift_thresholds={_TEST_KEY: _TEST_THRESHOLD} + drift_thresholds=_TEST_DRIFT_TRESHOLD ) skew_config = model_monitoring.SkewDetectionConfig( data_source=data_source, - skew_thresholds={_TEST_KEY: _TEST_THRESHOLD}, + skew_thresholds=skew_thresholds, target_field=_TEST_TARGET_FIELD, - attribute_skew_thresholds={_TEST_KEY: _TEST_THRESHOLD}, + attribute_skew_thresholds=attribute_skew_thresholds, data_format=data_format, ) - expected_training_dataset = ( gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( bigquery_source=gca_io.BigQuerySource(input_uri=_TEST_BQ_DATASOURCE), @@ -110,15 +162,21 @@ def test_valid_configs(self, data_source, data_format): @pytest.mark.parametrize("data_source", [_TEST_GCS_DATASOURCE]) @pytest.mark.parametrize("data_format", ["other"]) - def test_invalid_data_format(self, data_source, data_format): + @pytest.mark.parametrize("skew_thresholds", _TEST_SKEW_THRESHOLDS) + @pytest.mark.parametrize( + "attribute_skew_thresholds", _TEST_ATTRIBUTE_SKEW_THRESHOLDS + ) + def test_invalid_data_format( + self, data_source, data_format, skew_thresholds, attribute_skew_thresholds + ): if data_format == "other": with pytest.raises(ValueError) as e: model_monitoring.ObjectiveConfig( skew_detection_config=model_monitoring.SkewDetectionConfig( data_source=data_source, - skew_thresholds={_TEST_KEY: _TEST_THRESHOLD}, + skew_thresholds=skew_thresholds, target_field=_TEST_TARGET_FIELD, - attribute_skew_thresholds={_TEST_KEY: _TEST_THRESHOLD}, + attribute_skew_thresholds=attribute_skew_thresholds, data_format=data_format, ) ).as_proto()