Skip to content

Commit

Permalink
fixed runtime errors in update and pause functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rosiezou committed Jun 16, 2022
1 parent e6ffa31 commit ee05588
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from google.auth import credentials as auth_credentials
from google.protobuf import duration_pb2 # type: ignore
from google.protobuf import field_mask_pb2 # type: ignore
from google.rpc import status_pb2

from google.cloud import aiplatform
Expand Down Expand Up @@ -1952,7 +1953,9 @@ def __init__(
location=location,
credentials=credentials,
)
self._gca_resource = self._get_gca_resource(resource_name=self.job_name)
self._gca_resource = self._get_gca_resource(
resource_name=model_deployment_monitoring_job_name
)
self._endpoint_resource_name = ""

@classmethod
Expand Down Expand Up @@ -1985,6 +1988,10 @@ def _parse_configs(
all_configs = []
all_models = []
default_endpoint = "aiplatform.googleapis.com"
if aiplatform.initializer.global_config._location is None:
raise ValueError(
"Error parsing model monitoring objective configs: project location is not set"
)
client_options = dict(
api_endpoint=f"{aiplatform.initializer.global_config._location}-{default_endpoint}"
)
Expand Down Expand Up @@ -2309,7 +2316,7 @@ def update(
) -> "ModelDeploymentMonitoringJob":
""""""
current_job = self.api_client.get_model_deployment_monitoring_job(
name=self.model_deployment_monitoring_job_name
name=self._gca_resource.name
)
update_mask: List[str] = []
if display_name:
Expand Down Expand Up @@ -2339,27 +2346,40 @@ def update(
update_mask.append("model_deployment_monitoring_objective_configs")
current_job.model_deployment_monitoring_objective_configs = (
ModelDeploymentMonitoringJob._parse_configs(
objective_configs, self._endpoint_resource_name
objective_configs, current_job.endpoint, deployed_model_ids
)
)
self.api_client.update_model_deployment_monitoring_job(
model_deployment_monitoring_job=current_job, update_mask=update_mask
model_deployment_monitoring_job=current_job,
update_mask=field_mask_pb2.FieldMask(paths=update_mask),
)

def pause(self) -> "ModelDeploymentMonitoringJob":
""""""
self.api_client.pause_model_deployment_monitoring_job(
self.model_deployment_monitoring_job_name
)
if self.state == gca_job_state.JobState.JOB_STATE_RUNNING:
self.api_client.pause_model_deployment_monitoring_job(
name=self._gca_resource.name
)
else:
raise RuntimeError(
"The monitoring job can only be paused under running / pending state, the current state is: %s"
% self.state
)

def resume(self) -> "ModelDeploymentMonitoringJob":
""""""
self.api_client.resume_model_deployment_monitoring_job(
self.model_deployment_monitoring_job_name
)
if self.state == gca_job_state.JobState.JOB_STATE_PAUSED:
self.api_client.resume_model_deployment_monitoring_job(
name=self._gca_resource.name
)
else:
raise RuntimeError(
"The monitoring job can only be resumed under paused state"
)

def delete(self) -> "ModelDeploymentMonitoringJob":
""""""
self.pause()
self.api_client.delete_model_deployment_monitoring_job(
self.model_deployment_monitoring_job_name
name=self._gca_resource.name
)

0 comments on commit ee05588

Please sign in to comment.