diff --git a/google/cloud/aiplatform/model_monitoring/__init__.py b/google/cloud/aiplatform/model_monitoring/__init__.py index c4562ff147..95c5dcacfc 100644 --- a/google/cloud/aiplatform/model_monitoring/__init__.py +++ b/google/cloud/aiplatform/model_monitoring/__init__.py @@ -15,17 +15,23 @@ # limitations under the License. # -from google.cloud.aiplatform.model_monitoring.alert import EmailAlertConfig +from google.cloud.aiplatform.model_monitoring.alert import ( + AlertConfig, + EmailAlertConfig, +) from google.cloud.aiplatform.model_monitoring.objective import ( SkewDetectionConfig, DriftDetectionConfig, ExplanationConfig, ObjectiveConfig, ) -from google.cloud.aiplatform.model_monitoring.sampling import RandomSampleConfig +from google.cloud.aiplatform.model_monitoring.sampling import ( + RandomSampleConfig, +) from google.cloud.aiplatform.model_monitoring.schedule import ScheduleConfig __all__ = ( + "AlertConfig", "EmailAlertConfig", "SkewDetectionConfig", "DriftDetectionConfig", diff --git a/google/cloud/aiplatform/model_monitoring/alert.py b/google/cloud/aiplatform/model_monitoring/alert.py index fdd3e2e9d9..599301b8f7 100644 --- a/google/cloud/aiplatform/model_monitoring/alert.py +++ b/google/cloud/aiplatform/model_monitoring/alert.py @@ -15,12 +15,13 @@ # limitations under the License. # -from typing import Optional, List +from typing import List, Optional from google.cloud.aiplatform_v1.types import ( model_monitoring as gca_model_monitoring_v1, ) -# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA +# TODO(b/242108750): remove temporary logic once model monitoring for +# batch prediction is GA. from google.cloud.aiplatform_v1beta1.types import ( model_monitoring as gca_model_monitoring_v1beta1, ) @@ -28,43 +29,66 @@ gca_model_monitoring = gca_model_monitoring_v1 -class EmailAlertConfig: +class AlertConfig: def __init__( - self, user_emails: List[str] = [], enable_logging: Optional[bool] = False + self, + user_emails: List[str] = [], + enable_logging: Optional[bool] = False, + notification_channels: List[str] = [], ): - """Initializer for EmailAlertConfig. + """Initializer for AlertConfig. Args: - user_emails (List[str]): - The email addresses to send the alert to. - enable_logging (bool): - Optional. Defaults to False. Streams detected anomalies to Cloud Logging. The anomalies will be - put into json payload encoded from proto - [google.cloud.aiplatform.logging.ModelMonitoringAnomaliesLogEntry][]. - This can be further sync'd to Pub/Sub or any other services - supported by Cloud Logging. + user_emails (List[str]): The email addresses to send the alert to. + enable_logging (bool): Optional. Defaults to False. Streams detected + anomalies to Cloud Logging. The anomalies will be put into json + payload encoded from proto + [google.cloud.aiplatform.logging.ModelMonitoringAnomaliesLogEntry][]. + This can be further sync'd to Pub/Sub or any other services supported + by Cloud Logging. + notification_channels (List[str]): The Cloud notification channels to + send the alert to. """ - self.enable_logging = enable_logging self.user_emails = user_emails + self.enable_logging = enable_logging + self.notification_channels = notification_channels self._config_for_bp = False - # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA def as_proto(self) -> gca_model_monitoring.ModelMonitoringAlertConfig: - """Converts EmailAlertConfig to a proto message. + """Converts AlertConfig to a proto message. Returns: - The GAPIC representation of the email alert config. + The GAPIC representation of the alert config. """ + # TODO(b/242108750): remove temporary logic once model monitoring for + # batch prediction is GA. if self._config_for_bp: gca_model_monitoring = gca_model_monitoring_v1beta1 else: gca_model_monitoring = gca_model_monitoring_v1 - user_email_alert_config = ( - gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig( - user_emails=self.user_emails - ) - ) + return gca_model_monitoring.ModelMonitoringAlertConfig( - email_alert_config=user_email_alert_config, + email_alert_config=gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig( + user_emails=self.user_emails + ), enable_logging=self.enable_logging, + notification_channels=self.notification_channels, ) + + +class EmailAlertConfig(AlertConfig): + def __init__( + self, user_emails: List[str] = [], enable_logging: Optional[bool] = False + ): + """Initializer for EmailAlertConfig. + + Args: + user_emails (List[str]): The email addresses to send the alert to. + enable_logging (bool): Optional. Defaults to False. Streams detected + anomalies to Cloud Logging. The anomalies will be put into json + payload encoded from proto + [google.cloud.aiplatform.logging.ModelMonitoringAnomaliesLogEntry][]. + This can be further sync'd to Pub/Sub or any other services supported + by Cloud Logging. + """ + super().__init__(user_emails=user_emails, enable_logging=enable_logging) diff --git a/tests/system/aiplatform/test_model_monitoring.py b/tests/system/aiplatform/test_model_monitoring.py index 36406097ff..6cc6dbed26 100644 --- a/tests/system/aiplatform/test_model_monitoring.py +++ b/tests/system/aiplatform/test_model_monitoring.py @@ -31,6 +31,7 @@ # constants used for testing USER_EMAIL = "rosiezou@cloudadvocacyorg.joonix.net" +NOTIFICATION_CHANNEL = "projects/123/notificationChannels/456" PERMANENT_CHURN_MODEL_ID = "5295507484113371136" CHURN_MODEL_PATH = "gs://mco-mm/churn" DEFAULT_INPUT = { @@ -90,10 +91,16 @@ # global test constants sampling_strategy = model_monitoring.RandomSampleConfig(sample_rate=LOG_SAMPLE_RATE) -alert_config = model_monitoring.EmailAlertConfig( +email_alert_config = model_monitoring.EmailAlertConfig( user_emails=[USER_EMAIL], enable_logging=True ) +alert_config = model_monitoring.AlertConfig( + user_emails=[USER_EMAIL], + enable_logging=True, + notification_channels=[NOTIFICATION_CHANNEL], +) + schedule_config = model_monitoring.ScheduleConfig(monitor_interval=MONITOR_INTERVAL) skew_config = model_monitoring.SkewDetectionConfig( @@ -149,7 +156,7 @@ def test_mdm_two_models_one_valid_config(self, shared_state): display_name=self._make_display_name(key=JOB_NAME), logging_sampling_strategy=sampling_strategy, schedule_config=schedule_config, - alert_config=alert_config, + alert_config=email_alert_config, objective_configs=objective_config, create_request_timeout=3600, project=e2e_base._PROJECT, @@ -211,7 +218,7 @@ def test_mdm_pause_and_update_config(self, shared_state): display_name=self._make_display_name(key=JOB_NAME), logging_sampling_strategy=sampling_strategy, schedule_config=schedule_config, - alert_config=alert_config, + alert_config=email_alert_config, objective_configs=model_monitoring.ObjectiveConfig( drift_detection_config=drift_config ), @@ -284,7 +291,7 @@ def test_mdm_two_models_two_valid_configs(self, shared_state): display_name=self._make_display_name(key=JOB_NAME), logging_sampling_strategy=sampling_strategy, schedule_config=schedule_config, - alert_config=alert_config, + alert_config=email_alert_config, objective_configs=all_configs, create_request_timeout=3600, project=e2e_base._PROJECT, @@ -338,7 +345,7 @@ def test_mdm_invalid_config_incorrect_model_id(self, shared_state): display_name=self._make_display_name(key=JOB_NAME), logging_sampling_strategy=sampling_strategy, schedule_config=schedule_config, - alert_config=alert_config, + alert_config=email_alert_config, objective_configs=objective_config, create_request_timeout=3600, project=e2e_base._PROJECT, @@ -358,7 +365,7 @@ def test_mdm_invalid_config_xai(self, shared_state): display_name=self._make_display_name(key=JOB_NAME), logging_sampling_strategy=sampling_strategy, schedule_config=schedule_config, - alert_config=alert_config, + alert_config=email_alert_config, objective_configs=objective_config, create_request_timeout=3600, project=e2e_base._PROJECT, @@ -388,7 +395,7 @@ def test_mdm_two_models_invalid_configs_xai(self, shared_state): display_name=self._make_display_name(key=JOB_NAME), logging_sampling_strategy=sampling_strategy, schedule_config=schedule_config, - alert_config=alert_config, + alert_config=email_alert_config, objective_configs=all_configs, create_request_timeout=3600, project=e2e_base._PROJECT, @@ -399,3 +406,31 @@ def test_mdm_two_models_invalid_configs_xai(self, shared_state): "`explanation_config` should only be enabled if the model has `explanation_spec populated" in str(e.value) ) + + def test_mdm_notification_channel_alert_config(self, shared_state): + self.endpoint = shared_state["resources"][0] + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + # test model monitoring configurations + job = aiplatform.ModelDeploymentMonitoringJob.create( + display_name=self._make_display_name(key=JOB_NAME), + logging_sampling_strategy=sampling_strategy, + schedule_config=schedule_config, + alert_config=alert_config, + objective_configs=objective_config, + create_request_timeout=3600, + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + endpoint=self.endpoint, + ) + + gapic_job = job._gca_resource + assert ( + gapic_job.model_monitoring_alert_config.email_alert_config.user_emails + == [USER_EMAIL] + ) + assert gapic_job.model_monitoring_alert_config.enable_logging + assert gapic_job.model_monitoring_alert_config.notification_channels == [ + NOTIFICATION_CHANNEL + ] + + job.delete() diff --git a/tests/unit/aiplatform/test_model_monitoring.py b/tests/unit/aiplatform/test_model_monitoring.py index 4d81a04c6d..d74c012a19 100644 --- a/tests/unit/aiplatform/test_model_monitoring.py +++ b/tests/unit/aiplatform/test_model_monitoring.py @@ -31,6 +31,7 @@ _TEST_DRIFT_TRESHOLD = {"key": 0.2} _TEST_EMAIL1 = "test1" _TEST_EMAIL2 = "test2" +_TEST_NOTIFICATION_CHANNEL = "projects/123/notificationChannels/456" _TEST_VALID_DATA_FORMATS = ["tf-record", "csv", "jsonl"] _TEST_SAMPLING_RATE = 0.8 _TEST_MONITORING_INTERVAL = 1 @@ -105,10 +106,16 @@ def test_valid_configs( monitor_interval=_TEST_MONITORING_INTERVAL ) - alert_config = model_monitoring.EmailAlertConfig( + email_alert_config = model_monitoring.EmailAlertConfig( user_emails=[_TEST_EMAIL1, _TEST_EMAIL2] ) + alert_config = model_monitoring.AlertConfig( + user_emails=[_TEST_EMAIL1, _TEST_EMAIL2], + enable_logging=True, + notification_channels=[_TEST_NOTIFICATION_CHANNEL], + ) + prediction_drift_config = model_monitoring.DriftDetectionConfig( drift_thresholds=_TEST_DRIFT_TRESHOLD ) @@ -149,8 +156,17 @@ def test_valid_configs( == prediction_drift_config.as_proto() ) assert objective_config.as_proto().explanation_config == xai_config.as_proto() + assert ( + _TEST_EMAIL1 in email_alert_config.as_proto().email_alert_config.user_emails + ) + assert ( + _TEST_EMAIL2 in email_alert_config.as_proto().email_alert_config.user_emails + ) assert _TEST_EMAIL1 in alert_config.as_proto().email_alert_config.user_emails assert _TEST_EMAIL2 in alert_config.as_proto().email_alert_config.user_emails + assert ( + _TEST_NOTIFICATION_CHANNEL in alert_config.as_proto().notification_channels + ) assert ( random_sample_config.as_proto().random_sample_config.sample_rate == _TEST_SAMPLING_RATE