diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 6023583918a..da2cabb9b0d 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -28,6 +28,7 @@ from google.auth import credentials as auth_credentials from google.protobuf import duration_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore from google.rpc import status_pb2 from google.cloud import aiplatform @@ -1952,7 +1953,9 @@ def __init__( location=location, credentials=credentials, ) - self._gca_resource = self._get_gca_resource(resource_name=self.job_name) + self._gca_resource = self._get_gca_resource( + resource_name=model_deployment_monitoring_job_name + ) self._endpoint_resource_name = "" @classmethod @@ -1985,6 +1988,10 @@ def _parse_configs( all_configs = [] all_models = [] default_endpoint = "aiplatform.googleapis.com" + if aiplatform.initializer.global_config._location is None: + raise ValueError( + "Error parsing model monitoring objective configs: project location is not set" + ) client_options = dict( api_endpoint=f"{aiplatform.initializer.global_config._location}-{default_endpoint}" ) @@ -2309,7 +2316,7 @@ def update( ) -> "ModelDeploymentMonitoringJob": """""" current_job = self.api_client.get_model_deployment_monitoring_job( - name=self.model_deployment_monitoring_job_name + name=self._gca_resource.name ) update_mask: List[str] = [] if display_name: @@ -2339,27 +2346,40 @@ def update( update_mask.append("model_deployment_monitoring_objective_configs") current_job.model_deployment_monitoring_objective_configs = ( ModelDeploymentMonitoringJob._parse_configs( - objective_configs, self._endpoint_resource_name + objective_configs, current_job.endpoint, deployed_model_ids ) ) self.api_client.update_model_deployment_monitoring_job( - model_deployment_monitoring_job=current_job, update_mask=update_mask + model_deployment_monitoring_job=current_job, + update_mask=field_mask_pb2.FieldMask(paths=update_mask), ) def pause(self) -> "ModelDeploymentMonitoringJob": """""" - self.api_client.pause_model_deployment_monitoring_job( - self.model_deployment_monitoring_job_name - ) + if self.state == gca_job_state.JobState.JOB_STATE_RUNNING: + self.api_client.pause_model_deployment_monitoring_job( + name=self._gca_resource.name + ) + else: + raise RuntimeError( + "The monitoring job can only be paused under running / pending state, the current state is: %s" + % self.state + ) def resume(self) -> "ModelDeploymentMonitoringJob": """""" - self.api_client.resume_model_deployment_monitoring_job( - self.model_deployment_monitoring_job_name - ) + if self.state == gca_job_state.JobState.JOB_STATE_PAUSED: + self.api_client.resume_model_deployment_monitoring_job( + name=self._gca_resource.name + ) + else: + raise RuntimeError( + "The monitoring job can only be resumed under paused state" + ) def delete(self) -> "ModelDeploymentMonitoringJob": """""" + self.pause() self.api_client.delete_model_deployment_monitoring_job( - self.model_deployment_monitoring_job_name + name=self._gca_resource.name )