-
Notifications
You must be signed in to change notification settings - Fork 348
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/main' into feature/batch-predi…
…ction/service-account
- Loading branch information
Showing
5 changed files
with
110 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ | |
|
||
# constants used for testing | ||
USER_EMAIL = "[email protected]" | ||
PERMANENT_CHURN_ENDPOINT_ID = "1843089351408353280" | ||
PERMANENT_CHURN_MODEL_ID = "5295507484113371136" | ||
CHURN_MODEL_PATH = "gs://mco-mm/churn" | ||
DEFAULT_INPUT = { | ||
"cnt_ad_reward": 0, | ||
|
@@ -117,15 +117,26 @@ | |
objective_config2 = model_monitoring.ObjectiveConfig(skew_config, drift_config2) | ||
|
||
|
||
@pytest.mark.usefixtures("tear_down_resources") | ||
class TestModelDeploymentMonitoring(e2e_base.TestEndToEnd): | ||
_temp_prefix = "temp_e2e_model_monitoring_test_" | ||
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) | ||
endpoint = aiplatform.Endpoint(PERMANENT_CHURN_ENDPOINT_ID) | ||
|
||
def test_mdm_two_models_one_valid_config(self): | ||
def test_create_endpoint(self, shared_state): | ||
# initial setup | ||
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) | ||
self.endpoint = aiplatform.Endpoint.create(self._make_display_name("endpoint")) | ||
shared_state["resources"] = [self.endpoint] | ||
self.model = aiplatform.Model(PERMANENT_CHURN_MODEL_ID) | ||
self.endpoint.deploy(self.model) | ||
self.endpoint.deploy(self.model, traffic_percentage=50) | ||
|
||
def test_mdm_two_models_one_valid_config(self, shared_state): | ||
""" | ||
Enable model monitoring on two existing models deployed to the same endpoint. | ||
""" | ||
assert len(shared_state["resources"]) == 1 | ||
self.endpoint = shared_state["resources"][0] | ||
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) | ||
# test model monitoring configurations | ||
job = aiplatform.ModelDeploymentMonitoringJob.create( | ||
display_name=self._make_display_name(key=JOB_NAME), | ||
|
@@ -153,6 +164,7 @@ def test_mdm_two_models_one_valid_config(self): | |
== [USER_EMAIL] | ||
) | ||
assert gapic_job.model_monitoring_alert_config.enable_logging | ||
assert len(gapic_job.model_deployment_monitoring_objective_configs) == 2 | ||
|
||
gca_obj_config = gapic_job.model_deployment_monitoring_objective_configs[ | ||
0 | ||
|
@@ -181,8 +193,11 @@ def test_mdm_two_models_one_valid_config(self): | |
with pytest.raises(core_exceptions.NotFound): | ||
job.api_client.get_model_deployment_monitoring_job(name=job_resource) | ||
|
||
def test_mdm_pause_and_update_config(self): | ||
def test_mdm_pause_and_update_config(self, shared_state): | ||
"""Test objective config updates for existing MDM job""" | ||
assert len(shared_state["resources"]) == 1 | ||
self.endpoint = shared_state["resources"][0] | ||
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) | ||
job = aiplatform.ModelDeploymentMonitoringJob.create( | ||
display_name=self._make_display_name(key=JOB_NAME), | ||
logging_sampling_strategy=sampling_strategy, | ||
|
@@ -245,7 +260,10 @@ def test_mdm_pause_and_update_config(self): | |
with pytest.raises(core_exceptions.NotFound): | ||
job.state | ||
|
||
def test_mdm_two_models_two_valid_configs(self): | ||
def test_mdm_two_models_two_valid_configs(self, shared_state): | ||
assert len(shared_state["resources"]) == 1 | ||
self.endpoint = shared_state["resources"][0] | ||
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) | ||
[deployed_model1, deployed_model2] = list( | ||
map(lambda x: x.id, self.endpoint.list_models()) | ||
) | ||
|
@@ -302,7 +320,10 @@ def test_mdm_two_models_two_valid_configs(self): | |
|
||
job.delete() | ||
|
||
def test_mdm_invalid_config_incorrect_model_id(self): | ||
def test_mdm_invalid_config_incorrect_model_id(self, shared_state): | ||
assert len(shared_state["resources"]) == 1 | ||
self.endpoint = shared_state["resources"][0] | ||
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) | ||
with pytest.raises(ValueError) as e: | ||
aiplatform.ModelDeploymentMonitoringJob.create( | ||
display_name=self._make_display_name(key=JOB_NAME), | ||
|
@@ -318,7 +339,10 @@ def test_mdm_invalid_config_incorrect_model_id(self): | |
) | ||
assert "Invalid model ID" in str(e.value) | ||
|
||
def test_mdm_invalid_config_xai(self): | ||
def test_mdm_invalid_config_xai(self, shared_state): | ||
assert len(shared_state["resources"]) == 1 | ||
self.endpoint = shared_state["resources"][0] | ||
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) | ||
with pytest.raises(RuntimeError) as e: | ||
objective_config.explanation_config = model_monitoring.ExplanationConfig() | ||
aiplatform.ModelDeploymentMonitoringJob.create( | ||
|
@@ -337,7 +361,10 @@ def test_mdm_invalid_config_xai(self): | |
in str(e.value) | ||
) | ||
|
||
def test_mdm_two_models_invalid_configs_xai(self): | ||
def test_mdm_two_models_invalid_configs_xai(self, shared_state): | ||
assert len(shared_state["resources"]) == 1 | ||
self.endpoint = shared_state["resources"][0] | ||
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) | ||
[deployed_model1, deployed_model2] = list( | ||
map(lambda x: x.id, self.endpoint.list_models()) | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters