Skip to content

Commit

Permalink
fix: fix endpoint parsing in ModelDeploymentMonitoringJob.update (#1671)
Browse files Browse the repository at this point in the history
* fix: fix endpoint parsing in ModelDeploymentMonitoringJob.update() function

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* addressed PR feedback

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* addressed PR comments

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* addressed more PR feedback

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* addressed more PR comments

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* removed unused code

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* addressed more PR feedback

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* fixing linter issues

* addressed more PR comments

* fixing pylint errors

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* silencing unused import warning

* fixed unused import error

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
rosiezou and gcf-owl-bot[bot] authored Sep 29, 2022
1 parent 876fb2a commit 186872d
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 37 deletions.
31 changes: 15 additions & 16 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,32 +2484,31 @@ def update(
update_mask.append("model_deployment_monitoring_objective_configs")
current_job.model_deployment_monitoring_objective_configs = (
ModelDeploymentMonitoringJob._parse_configs(
objective_configs,
current_job.endpoint,
deployed_model_ids,
objective_configs=objective_configs,
endpoint=aiplatform.Endpoint(
current_job.endpoint, credentials=self.credentials
),
deployed_model_ids=deployed_model_ids,
)
)
if self.state == gca_job_state.JobState.JOB_STATE_RUNNING:
self.api_client.update_model_deployment_monitoring_job(
model_deployment_monitoring_job=current_job,
update_mask=field_mask_pb2.FieldMask(paths=update_mask),
)
self.api_client.update_model_deployment_monitoring_job(
model_deployment_monitoring_job=current_job,
update_mask=field_mask_pb2.FieldMask(paths=update_mask),
)
return self

def pause(self) -> "ModelDeploymentMonitoringJob":
"""Pause a running MDM job."""
if self.state == gca_job_state.JobState.JOB_STATE_RUNNING:
self.api_client.pause_model_deployment_monitoring_job(
name=self._gca_resource.name
)
self.api_client.pause_model_deployment_monitoring_job(
name=self._gca_resource.name
)
return self

def resume(self) -> "ModelDeploymentMonitoringJob":
"""Resumes a paused MDM job."""
if self.state == gca_job_state.JobState.JOB_STATE_PAUSED:
self.api_client.resume_model_deployment_monitoring_job(
name=self._gca_resource.name
)
self.api_client.resume_model_deployment_monitoring_job(
name=self._gca_resource.name
)
return self

def delete(self) -> None:
Expand Down
104 changes: 83 additions & 21 deletions tests/system/aiplatform/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,28 @@

# constants used for testing
USER_EMAIL = ""
PERMANENT_CHURN_ENDPOINT_ID = "8289570005524152320"
PERMANENT_CHURN_ENDPOINT_ID = "1843089351408353280"
CHURN_MODEL_PATH = "gs://mco-mm/churn"
DEFAULT_INPUT = {
"cnt_ad_reward": 0,
"cnt_challenge_a_friend": 0,
"cnt_completed_5_levels": 1,
"cnt_level_complete_quickplay": 3,
"cnt_level_end_quickplay": 5,
"cnt_level_reset_quickplay": 2,
"cnt_level_start_quickplay": 6,
"cnt_post_score": 34,
"cnt_spend_virtual_currency": 0,
"cnt_use_extra_steps": 0,
"cnt_user_engagement": 120,
"country": "Denmark",
"dayofweek": 3,
"julianday": 254,
"language": "da-dk",
"month": 9,
"operating_system": "IOS",
"user_pseudo_id": "104B0770BAE16E8B53DF330C95881893",
}

JOB_NAME = "churn"

Expand Down Expand Up @@ -117,10 +137,7 @@ def test_mdm_two_models_one_valid_config(self):
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
predict_instance_schema_uri="",
analysis_instance_schema_uri="",
)
assert job is not None

gapic_job = job._gca_resource
assert (
Expand Down Expand Up @@ -156,22 +173,77 @@ def test_mdm_two_models_one_valid_config(self):
gca_obj_config.prediction_drift_detection_config == drift_config.as_proto()
)

# delete this job and re-configure it to only enable drift detection for faster testing
job.delete()
job_resource = job._gca_resource.name

# test job update and delete()
timeout = time.time() + 3600
new_obj_config = model_monitoring.ObjectiveConfig(skew_config)
# test job delete
with pytest.raises(core_exceptions.NotFound):
job.api_client.get_model_deployment_monitoring_job(name=job_resource)

def test_mdm_pause_and_update_config(self):
"""Test objective config updates for existing MDM job"""
job = aiplatform.ModelDeploymentMonitoringJob.create(
display_name=self._make_display_name(key=JOB_NAME),
logging_sampling_strategy=sampling_strategy,
schedule_config=schedule_config,
alert_config=alert_config,
objective_configs=model_monitoring.ObjectiveConfig(
drift_detection_config=drift_config
),
create_request_timeout=3600,
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
)
# test unsuccessful job update when it's pending
DRIFT_THRESHOLDS["cnt_user_engagement"] += 0.01
new_obj_config = model_monitoring.ObjectiveConfig(
drift_detection_config=model_monitoring.DriftDetectionConfig(
drift_thresholds=DRIFT_THRESHOLDS,
attribute_drift_thresholds=ATTRIB_DRIFT_THRESHOLDS,
)
)
if job.state == gca_job_state.JobState.JOB_STATE_PENDING:
with pytest.raises(core_exceptions.FailedPrecondition):
job.update(objective_configs=new_obj_config)

# generate traffic to force MDM job to come online
for i in range(2000):
DEFAULT_INPUT["cnt_user_engagement"] += i
self.endpoint.predict([DEFAULT_INPUT], use_raw_predict=True)

while time.time() < timeout:
# test job update
while True:
time.sleep(1)
if job.state == gca_job_state.JobState.JOB_STATE_RUNNING:
job.update(objective_configs=new_obj_config)
assert str(job._gca_resource.prediction_drift_detection_config) == ""
break
time.sleep(5)

# verify job update
while True:
time.sleep(1)
if job.state == gca_job_state.JobState.JOB_STATE_RUNNING:
gca_obj_config = (
job._gca_resource.model_deployment_monitoring_objective_configs[
0
].objective_config
)
assert (
gca_obj_config.prediction_drift_detection_config
== new_obj_config.drift_detection_config.as_proto()
)
break

# test pause
job.pause()
while job.state != gca_job_state.JobState.JOB_STATE_PAUSED:
time.sleep(1)
job.delete()

# confirm deletion
with pytest.raises(core_exceptions.NotFound):
job.api_client.get_model_deployment_monitoring_job(name=job_resource)
job.state

def test_mdm_two_models_two_valid_configs(self):
[deployed_model1, deployed_model2] = list(
Expand All @@ -181,7 +253,6 @@ def test_mdm_two_models_two_valid_configs(self):
deployed_model1: objective_config,
deployed_model2: objective_config2,
}
job = None
job = aiplatform.ModelDeploymentMonitoringJob.create(
display_name=self._make_display_name(key=JOB_NAME),
logging_sampling_strategy=sampling_strategy,
Expand All @@ -192,10 +263,7 @@ def test_mdm_two_models_two_valid_configs(self):
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
predict_instance_schema_uri="",
analysis_instance_schema_uri="",
)
assert job is not None

gapic_job = job._gca_resource
assert (
Expand Down Expand Up @@ -246,8 +314,6 @@ def test_mdm_invalid_config_incorrect_model_id(self):
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
predict_instance_schema_uri="",
analysis_instance_schema_uri="",
deployed_model_ids=[""],
)
assert "Invalid model ID" in str(e.value)
Expand All @@ -265,8 +331,6 @@ def test_mdm_invalid_config_xai(self):
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
predict_instance_schema_uri="",
analysis_instance_schema_uri="",
)
assert (
"`explanation_config` should only be enabled if the model has `explanation_spec populated"
Expand Down Expand Up @@ -294,8 +358,6 @@ def test_mdm_two_models_invalid_configs_xai(self):
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
predict_instance_schema_uri="",
analysis_instance_schema_uri="",
)
assert (
"`explanation_config` should only be enabled if the model has `explanation_spec populated"
Expand Down
99 changes: 99 additions & 0 deletions tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
job_state as gca_job_state_compat,
machine_resources as gca_machine_resources_compat,
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
model_deployment_monitoring_job as gca_model_deployment_monitoring_job_compat,
model_monitoring as gca_model_monitoring_compat,
)

from google.cloud.aiplatform.compat.services import (
job_service_client,
)
from google.protobuf import field_mask_pb2 # type: ignore

from test_endpoints import get_endpoint_with_models_mock # noqa: F401

_TEST_API_CLIENT = job_service_client.JobServiceClient

Expand Down Expand Up @@ -84,6 +89,11 @@
f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}"
)

_TEST_MDM_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/modelDeploymentMonitoringJobs/{_TEST_ID}"
_TEST_ENDPOINT = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}"
)

_TEST_JOB_STATE_SUCCESS = gca_job_state_compat.JobState(4)
_TEST_JOB_STATE_RUNNING = gca_job_state_compat.JobState(3)
_TEST_JOB_STATE_PENDING = gca_job_state_compat.JobState(2)
Expand Down Expand Up @@ -164,6 +174,8 @@
_TEST_JOB_DELETE_METHOD_NAME = "delete_custom_job"
_TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}"

_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01}

# TODO(b/171333554): Move reusable test fixtures to conftest.py file


Expand Down Expand Up @@ -969,3 +981,90 @@ def test_batch_predict_job_with_versioned_model(
].model
== _TEST_VERSIONED_MODEL_NAME
)


@pytest.fixture
def get_mdm_job_mock():
with mock.patch.object(
_TEST_API_CLIENT, "get_model_deployment_monitoring_job"
) as get_mdm_job_mock:
get_mdm_job_mock.return_value = (
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
state=_TEST_JOB_STATE_RUNNING,
endpoint=_TEST_ENDPOINT,
)
)
yield get_mdm_job_mock


@pytest.fixture
@pytest.mark.usefixtures("get_mdm_job_mock")
def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811
with mock.patch.object(
_TEST_API_CLIENT, "update_model_deployment_monitoring_job"
) as update_mdm_job_mock:
expected_objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(
prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(
drift_thresholds={
"TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(value=0.01)
}
)
)
all_configs = []
for model in get_endpoint_with_models_mock.return_value.deployed_models:
all_configs.append(
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
deployed_model_id=model.id,
objective_config=expected_objective_config,
)
)

update_mdm_job_mock.return_vaue.result_type = (
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
state=_TEST_JOB_STATE_RUNNING,
endpoint=_TEST_ENDPOINT,
model_deployment_monitoring_objective_configs=all_configs,
)
)
yield update_mdm_job_mock


@pytest.mark.usefixtures("google_auth_mock")
class TestModelDeploymentMonitoringJob:
def setup_method(self):
reload(initializer)
reload(aiplatform)

def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock):
job = jobs.ModelDeploymentMonitoringJob(
model_deployment_monitoring_job_name=_TEST_MDM_JOB_NAME
)
drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig(
drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG
)
new_config = aiplatform.model_monitoring.ObjectiveConfig(
drift_detection_config=drift_detection_config
)
job.update(objective_configs=new_config)
assert (
job._gca_resource.model_deployment_monitoring_objective_configs[
0
].objective_config.prediction_drift_detection_config
== drift_detection_config.as_proto()
)
get_mdm_job_mock.assert_called_with(
name=_TEST_MDM_JOB_NAME,
)
update_mdm_job_mock.assert_called_once_with(
model_deployment_monitoring_job=get_mdm_job_mock.return_value,
update_mask=field_mask_pb2.FieldMask(
paths=["model_deployment_monitoring_objective_configs"]
),
)

0 comments on commit 186872d

Please sign in to comment.