Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix endpoint parsing in ModelDeploymentMonitoringJob.update #1671

Merged
merged 31 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5376092
fix: fix endpoint parsing in ModelDeploymentMonitoringJob.update() fu…
rosiezou Sep 16, 2022
5fe5ae5
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 16, 2022
96ea492
addressed PR feedback
rosiezou Sep 17, 2022
592b3c8
Merge branch 'main' into mm-bugfix
rosiezou Sep 17, 2022
c51a903
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 17, 2022
376da4d
Merge branch 'mm-bugfix' of https://github.com/googleapis/python-aipl…
gcf-owl-bot[bot] Sep 17, 2022
6a12b67
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 17, 2022
79ff356
Merge branch 'mm-bugfix' of https://github.com/googleapis/python-aipl…
gcf-owl-bot[bot] Sep 17, 2022
d05fa08
addressed PR comments
rosiezou Sep 20, 2022
d57d9f4
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 20, 2022
f5439b0
Merge branch 'main' into mm-bugfix
rosiezou Sep 21, 2022
b5e8b8c
addressed more PR feedback
rosiezou Sep 22, 2022
56a55da
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 22, 2022
8493589
addressed more PR comments
rosiezou Sep 23, 2022
3977bad
Merge branch 'main' into mm-bugfix
rosiezou Sep 23, 2022
81f70dd
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 23, 2022
df0bd21
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 23, 2022
e5ee24e
Merge branch 'mm-bugfix' of https://github.com/googleapis/python-aipl…
gcf-owl-bot[bot] Sep 23, 2022
61af46c
removed unused code
rosiezou Sep 23, 2022
beee888
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 23, 2022
26c517f
addressed more PR feedback
rosiezou Sep 27, 2022
c991ffa
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 27, 2022
0019689
Merge branch 'main' into mm-bugfix
rosiezou Sep 27, 2022
b3c920c
fixing linter issues
rosiezou Sep 27, 2022
780556c
addressed more PR comments
rosiezou Sep 28, 2022
b4afa3d
fixing pylint errors
rosiezou Sep 28, 2022
7f708f7
🦉 Updates from OwlBot post-processor
gcf-owl-bot[bot] Sep 28, 2022
61f978c
silencing unused import warning
rosiezou Sep 28, 2022
f0b846b
fixed unused import error
rosiezou Sep 28, 2022
714d062
Merge branch 'main' into mm-bugfix
rosiezou Sep 28, 2022
d62e88d
Merge branch 'main' into mm-bugfix
rosiezou Sep 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,32 +2484,31 @@ def update(
update_mask.append("model_deployment_monitoring_objective_configs")
current_job.model_deployment_monitoring_objective_configs = (
ModelDeploymentMonitoringJob._parse_configs(
objective_configs,
current_job.endpoint,
deployed_model_ids,
objective_configs=objective_configs,
endpoint=aiplatform.Endpoint(
current_job.endpoint, credentials=self.credentials
),
deployed_model_ids=deployed_model_ids,
)
)
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
if self.state == gca_job_state.JobState.JOB_STATE_RUNNING:
self.api_client.update_model_deployment_monitoring_job(
model_deployment_monitoring_job=current_job,
update_mask=field_mask_pb2.FieldMask(paths=update_mask),
)
self.api_client.update_model_deployment_monitoring_job(
model_deployment_monitoring_job=current_job,
update_mask=field_mask_pb2.FieldMask(paths=update_mask),
)
return self

def pause(self) -> "ModelDeploymentMonitoringJob":
"""Pause a running MDM job."""
if self.state == gca_job_state.JobState.JOB_STATE_RUNNING:
self.api_client.pause_model_deployment_monitoring_job(
name=self._gca_resource.name
)
self.api_client.pause_model_deployment_monitoring_job(
name=self._gca_resource.name
)
return self

def resume(self) -> "ModelDeploymentMonitoringJob":
"""Resumes a paused MDM job."""
if self.state == gca_job_state.JobState.JOB_STATE_PAUSED:
self.api_client.resume_model_deployment_monitoring_job(
name=self._gca_resource.name
)
self.api_client.resume_model_deployment_monitoring_job(
name=self._gca_resource.name
)
return self

def delete(self) -> None:
Expand Down
104 changes: 83 additions & 21 deletions tests/system/aiplatform/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,28 @@

# constants used for testing
USER_EMAIL = ""
PERMANENT_CHURN_ENDPOINT_ID = "8289570005524152320"
PERMANENT_CHURN_ENDPOINT_ID = "1843089351408353280"
CHURN_MODEL_PATH = "gs://mco-mm/churn"
DEFAULT_INPUT = {
"cnt_ad_reward": 0,
"cnt_challenge_a_friend": 0,
"cnt_completed_5_levels": 1,
"cnt_level_complete_quickplay": 3,
"cnt_level_end_quickplay": 5,
"cnt_level_reset_quickplay": 2,
"cnt_level_start_quickplay": 6,
"cnt_post_score": 34,
"cnt_spend_virtual_currency": 0,
"cnt_use_extra_steps": 0,
"cnt_user_engagement": 120,
"country": "Denmark",
"dayofweek": 3,
"julianday": 254,
"language": "da-dk",
"month": 9,
"operating_system": "IOS",
"user_pseudo_id": "104B0770BAE16E8B53DF330C95881893",
}

JOB_NAME = "churn"

Expand Down Expand Up @@ -117,10 +137,7 @@ def test_mdm_two_models_one_valid_config(self):
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
predict_instance_schema_uri="",
analysis_instance_schema_uri="",
)
assert job is not None

gapic_job = job._gca_resource
assert (
Expand Down Expand Up @@ -156,22 +173,77 @@ def test_mdm_two_models_one_valid_config(self):
gca_obj_config.prediction_drift_detection_config == drift_config.as_proto()
)

# delete this job and re-configure it to only enable drift detection for faster testing
job.delete()
job_resource = job._gca_resource.name

# test job update and delete()
timeout = time.time() + 3600
new_obj_config = model_monitoring.ObjectiveConfig(skew_config)
# test job delete
with pytest.raises(core_exceptions.NotFound):
rosiezou marked this conversation as resolved.
Show resolved Hide resolved
job.api_client.get_model_deployment_monitoring_job(name=job_resource)

def test_mdm_pause_and_update_config(self):
"""Test objective config updates for existing MDM job"""
job = aiplatform.ModelDeploymentMonitoringJob.create(
display_name=self._make_display_name(key=JOB_NAME),
logging_sampling_strategy=sampling_strategy,
schedule_config=schedule_config,
alert_config=alert_config,
objective_configs=model_monitoring.ObjectiveConfig(
drift_detection_config=drift_config
),
create_request_timeout=3600,
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
)
# test unsuccessful job update when it's pending
DRIFT_THRESHOLDS["cnt_user_engagement"] += 0.01
new_obj_config = model_monitoring.ObjectiveConfig(
drift_detection_config=model_monitoring.DriftDetectionConfig(
drift_thresholds=DRIFT_THRESHOLDS,
attribute_drift_thresholds=ATTRIB_DRIFT_THRESHOLDS,
)
)
if job.state == gca_job_state.JobState.JOB_STATE_PENDING:
with pytest.raises(core_exceptions.FailedPrecondition):
job.update(objective_configs=new_obj_config)

# generate traffic to force MDM job to come online
for i in range(2000):
DEFAULT_INPUT["cnt_user_engagement"] += i
self.endpoint.predict([DEFAULT_INPUT], use_raw_predict=True)

while time.time() < timeout:
# test job update
while True:
time.sleep(1)
if job.state == gca_job_state.JobState.JOB_STATE_RUNNING:
job.update(objective_configs=new_obj_config)
assert str(job._gca_resource.prediction_drift_detection_config) == ""
break
time.sleep(5)

# verify job update
while True:
time.sleep(1)
if job.state == gca_job_state.JobState.JOB_STATE_RUNNING:
gca_obj_config = (
job._gca_resource.model_deployment_monitoring_objective_configs[
0
].objective_config
)
assert (
gca_obj_config.prediction_drift_detection_config
== new_obj_config.drift_detection_config.as_proto()
)
break

# test pause
job.pause()
while job.state != gca_job_state.JobState.JOB_STATE_PAUSED:
time.sleep(1)
job.delete()

# confirm deletion
with pytest.raises(core_exceptions.NotFound):
job.api_client.get_model_deployment_monitoring_job(name=job_resource)
job.state

def test_mdm_two_models_two_valid_configs(self):
[deployed_model1, deployed_model2] = list(
Expand All @@ -181,7 +253,6 @@ def test_mdm_two_models_two_valid_configs(self):
deployed_model1: objective_config,
deployed_model2: objective_config2,
}
job = None
job = aiplatform.ModelDeploymentMonitoringJob.create(
display_name=self._make_display_name(key=JOB_NAME),
logging_sampling_strategy=sampling_strategy,
Expand All @@ -192,10 +263,7 @@ def test_mdm_two_models_two_valid_configs(self):
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
predict_instance_schema_uri="",
analysis_instance_schema_uri="",
)
assert job is not None

gapic_job = job._gca_resource
assert (
Expand Down Expand Up @@ -246,8 +314,6 @@ def test_mdm_invalid_config_incorrect_model_id(self):
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
predict_instance_schema_uri="",
analysis_instance_schema_uri="",
deployed_model_ids=[""],
)
assert "Invalid model ID" in str(e.value)
Expand All @@ -265,8 +331,6 @@ def test_mdm_invalid_config_xai(self):
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
predict_instance_schema_uri="",
analysis_instance_schema_uri="",
)
assert (
"`explanation_config` should only be enabled if the model has `explanation_spec populated"
Expand Down Expand Up @@ -294,8 +358,6 @@ def test_mdm_two_models_invalid_configs_xai(self):
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
predict_instance_schema_uri="",
analysis_instance_schema_uri="",
)
assert (
"`explanation_config` should only be enabled if the model has `explanation_spec populated"
Expand Down
96 changes: 96 additions & 0 deletions tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@
job_state as gca_job_state_compat,
machine_resources as gca_machine_resources_compat,
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
model_deployment_monitoring_job as gca_model_deployment_monitoring_job_compat,
model_monitoring as gca_model_monitoring_compat,
)

from google.cloud.aiplatform.compat.services import (
job_service_client,
)
from google.protobuf import field_mask_pb2 # type: ignore

from test_endpoints import get_endpoint_with_models_mock

_TEST_API_CLIENT = job_service_client.JobServiceClient

Expand Down Expand Up @@ -84,6 +89,11 @@
f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}"
)

_TEST_MDM_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/modelDeploymentMonitoringJobs/{_TEST_ID}"
_TEST_ENDPOINT = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}"
)

_TEST_JOB_STATE_SUCCESS = gca_job_state_compat.JobState(4)
_TEST_JOB_STATE_RUNNING = gca_job_state_compat.JobState(3)
_TEST_JOB_STATE_PENDING = gca_job_state_compat.JobState(2)
Expand Down Expand Up @@ -164,6 +174,8 @@
_TEST_JOB_DELETE_METHOD_NAME = "delete_custom_job"
_TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}"

_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01}

# TODO(b/171333554): Move reusable test fixtures to conftest.py file


Expand Down Expand Up @@ -969,3 +981,87 @@ def test_batch_predict_job_with_versioned_model(
].model
== _TEST_VERSIONED_MODEL_NAME
)


@pytest.fixture
def get_mdm_job_mock():
with mock.patch.object(
_TEST_API_CLIENT, "get_model_deployment_monitoring_job"
) as get_mdm_job_mock:
get_mdm_job_mock.return_value = (
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
state=_TEST_JOB_STATE_RUNNING,
endpoint=_TEST_ENDPOINT,
)
)
yield get_mdm_job_mock


@pytest.fixture
def update_mdm_job_mock(get_mdm_job_mock, get_endpoint_with_models_mock):
with mock.patch.object(
_TEST_API_CLIENT, "update_model_deployment_monitoring_job"
) as update_mdm_job_mock:
expected_objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(
prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(
drift_thresholds={
"TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(value=0.01)
}
)
)
all_configs = []
for model in get_endpoint_with_models_mock.return_value.deployed_models:
all_configs.append(
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
deployed_model_id=model.id,
objective_config=expected_objective_config,
)
)

update_mdm_job_mock.return_vaue.result_type = (
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
state=_TEST_JOB_STATE_RUNNING,
endpoint=_TEST_ENDPOINT,
model_deployment_monitoring_objective_configs=all_configs,
)
)
yield update_mdm_job_mock


@pytest.mark.usefixtures("google_auth_mock")
class TestModelDeploymentMonitoringJob:
def setup_method(self):
reload(initializer)
reload(aiplatform)

def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

@pytest.mark.usefixtures("get_mdm_job_mock", "update_mdm_job_mock")
rosiezou marked this conversation as resolved.
Show resolved Hide resolved
def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock):
job = jobs.ModelDeploymentMonitoringJob(
model_deployment_monitoring_job_name=_TEST_MDM_JOB_NAME
)
drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig(
drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG
)
new_config = aiplatform.model_monitoring.ObjectiveConfig(
drift_detection_config=drift_detection_config
)
job.update(objective_configs=new_config)
assert (
job._gca_resource.model_deployment_monitoring_objective_configs[
0
].objective_config.prediction_drift_detection_config
== drift_detection_config.as_proto()
)
update_mdm_job_mock.assert_called_once_with(
model_deployment_monitoring_job=get_mdm_job_mock.return_value,
update_mask=field_mask_pb2.FieldMask(
paths=["model_deployment_monitoring_objective_configs"]
),
)
rosiezou marked this conversation as resolved.
Show resolved Hide resolved