diff --git a/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py b/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py index 8520646b9f..85def084a2 100644 --- a/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py +++ b/google/cloud/aiplatform/preview/pipelinejobschedule/pipeline_job_schedules.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Optional +from typing import List, Optional from google.auth import credentials as auth_credentials from google.cloud.aiplatform import base @@ -39,6 +39,7 @@ from google.cloud.aiplatform_v1beta1.types import ( pipeline_job as gca_pipeline_job_v1beta1, ) +from google.protobuf import field_mask_pb2 as field_mask _LOGGER = base.Logger(__name__) @@ -68,6 +69,7 @@ def __init__( ): """Retrieves a PipelineJobSchedule resource and instantiates its representation. + Args: pipeline_job (PipelineJob): Required. PipelineJob used to init the schedule. @@ -255,3 +257,131 @@ def _create( ) _LOGGER.info("View Schedule:\n%s" % self._dashboard_uri()) + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + enable_simple_view: bool = True, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["PipelineJobSchedule"]: + """List all instances of this PipelineJobSchedule resource. + + Example Usage: + + aiplatform.PipelineJobSchedule.list( + filter='display_name="experiment_a27"', + order_by='create_time desc' + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + enable_simple_view (bool): + Optional. Whether to pass the `read_mask` parameter to the list call. + Defaults to False if not provided. This will improve the performance of calling + list(). However, the returned PipelineJobSchedule list will not include all fields for + each PipelineJobSchedule. Setting this to True will exclude the following fields in your + response: 'create_pipeline_job_request', 'next_run_time', 'last_pause_time', + 'last_resume_time', 'max_concurrent_run_count', 'allow_queueing','last_scheduled_run_response'. + The following fields will be included in each PipelineJobSchedule resource in your + response: 'name', 'display_name', 'start_time', 'end_time', 'max_run_count', + 'started_run_count', 'state', 'create_time', 'update_time', 'cron', 'catch_up'. + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[PipelineJobSchedule] - A list of PipelineJobSchedule resource objects. + """ + + read_mask_fields = None + + if enable_simple_view: + read_mask_fields = field_mask.FieldMask(paths=_READ_MASK_FIELDS) + _LOGGER.warn( + "By enabling simple view, the PipelineJobSchedule resources returned from this method will not contain all fields." + ) + + return cls._list_with_local_order( + filter=filter, + order_by=order_by, + read_mask=read_mask_fields, + project=project, + location=location, + credentials=credentials, + ) + + def list_jobs( + self, + filter: Optional[str] = None, + order_by: Optional[str] = None, + enable_simple_view: bool = False, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[PipelineJob]: + """List all PipelineJob 's created by this PipelineJobSchedule. + + Example usage: + + pipeline_job_schedule.list_jobs(order_by='create_time_desc') + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + enable_simple_view (bool): + Optional. Whether to pass the `read_mask` parameter to the list call. + Defaults to False if not provided. This will improve the performance of calling + list(). However, the returned PipelineJob list will not include all fields for + each PipelineJob. Setting this to True will exclude the following fields in your + response: `runtime_config`, `service_account`, `network`, and some subfields of + `pipeline_spec` and `job_detail`. The following fields will be included in + each PipelineJob resource in your response: `state`, `display_name`, + `pipeline_spec.pipeline_info`, `create_time`, `start_time`, `end_time`, + `update_time`, `labels`, `template_uri`, `template_metadata.version`, + `job_detail.pipeline_run_context`, `job_detail.pipeline_context`. + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[PipelineJob] - A list of PipelineJob resource objects. + """ + list_filter = f"schedule_name={self._gca_resource.name}" + if filter: + list_filter = list_filter + f" AND {filter}" + + return PipelineJob.list( + filter=list_filter, + order_by=order_by, + enable_simple_view=enable_simple_view, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/preview/schedule/schedules.py b/google/cloud/aiplatform/preview/schedule/schedules.py index 400c7be82f..dd16ca3723 100644 --- a/google/cloud/aiplatform/preview/schedule/schedules.py +++ b/google/cloud/aiplatform/preview/schedule/schedules.py @@ -60,7 +60,6 @@ def __init__( location: str, ): """Retrieves a Schedule resource and instantiates its representation. - Args: credentials (auth_credentials.Credentials): Optional. Custom credentials to use to create this Schedule. @@ -111,6 +110,35 @@ def get( return self + def pause(self) -> None: + """Starts asynchronous pause on the Schedule. + + Changes Schedule state from State.ACTIVE to State.PAUSED. + """ + self.api_client.pause_schedule(name=self.resource_name) + + def resume( + self, + catch_up: bool = True, + ) -> None: + """Starts asynchronous resume on the Schedule. + + Changes Schedule state from State.PAUSED to State.ACTIVE. + + Args: + catch_up (bool): + Optional. Whether to backfill missed runs when the Schedule is + resumed from State.PAUSED. + """ + self.api_client.resume_schedule(name=self.resource_name) + + def done(self) -> bool: + """Helper method that return True is Schedule is done. False otherwise.""" + if not self._gca_resource: + return False + + return self.state in _SCHEDULE_COMPLETE_STATES + def wait(self) -> None: """Wait for this Schedule to complete.""" if self._latest_future is None: diff --git a/tests/unit/aiplatform/test_pipeline_job_schedules.py b/tests/unit/aiplatform/test_pipeline_job_schedules.py index 277b77bae2..f1468b4c2a 100644 --- a/tests/unit/aiplatform/test_pipeline_job_schedules.py +++ b/tests/unit/aiplatform/test_pipeline_job_schedules.py @@ -21,7 +21,6 @@ from unittest import mock from unittest.mock import patch from urllib import request -import yaml from google.auth import credentials as auth_credentials from google.cloud import storage @@ -29,10 +28,13 @@ from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer from google.cloud.aiplatform.compat.services import ( + pipeline_service_client, schedule_service_client_v1beta1 as schedule_service_client, ) from google.cloud.aiplatform.compat.types import ( + context_v1beta1 as gca_context, pipeline_job_v1beta1 as gca_pipeline_job, + pipeline_state_v1beta1 as gca_pipeline_state, schedule_v1beta1 as gca_schedule, ) from google.cloud.aiplatform.preview.constants import ( @@ -47,6 +49,7 @@ ) from google.cloud.aiplatform.utils import gcs_utils import pytest +import yaml from google.protobuf import field_mask_pb2 as field_mask from google.protobuf import json_format @@ -66,10 +69,6 @@ _TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT = 1 _TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT = 2 -_TEST_PIPELINE_JOB_SCHEDULE_LIST_READ_MASK = field_mask.FieldMask( - paths=schedule_constants._PIPELINE_JOB_SCHEDULE_READ_MASK_FIELDS -) - _TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json" _TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" _TEST_HTTPS_TEMPLATE_PATH = "https://raw.githubusercontent.com/repo/pipeline.json" @@ -274,6 +273,22 @@ def make_schedule(state): ) +def make_pipeline_job(state): + test_pipeline_job_name = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/sample-test-pipeline-20230605" + return gca_pipeline_job.PipelineJob( + name=test_pipeline_job_name, + state=state, + create_time=_TEST_PIPELINE_CREATE_TIME, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + job_detail=gca_pipeline_job.PipelineJobDetail( + pipeline_run_context=gca_context.Context( + name=test_pipeline_job_name, + ) + ), + ) + + @pytest.fixture def mock_schedule_service_get(): with mock.patch.object( @@ -308,6 +323,54 @@ def mock_schedule_service_get_with_fail(): yield mock_get_schedule +@pytest.fixture +def mock_schedule_service_pause(): + with mock.patch.object( + schedule_service_client.ScheduleServiceClient, "pause_schedule" + ) as mock_pause_schedule: + yield mock_pause_schedule + + +@pytest.fixture +def mock_schedule_service_resume(): + with mock.patch.object( + schedule_service_client.ScheduleServiceClient, "resume_schedule" + ) as mock_resume_schedule: + yield mock_resume_schedule + + +@pytest.fixture +def mock_schedule_service_list(): + with mock.patch.object( + schedule_service_client.ScheduleServiceClient, "list_schedules" + ) as mock_list_schedules: + mock_list_schedules.return_value = [ + make_schedule(gca_schedule.Schedule.State.COMPLETED), + make_schedule(gca_schedule.Schedule.State.COMPLETED), + make_schedule(gca_schedule.Schedule.State.COMPLETED), + ] + yield mock_list_schedules + + +@pytest.fixture +def mock_pipeline_service_list(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "list_pipeline_jobs" + ) as mock_list_pipeline_jobs: + mock_list_pipeline_jobs.return_value = [ + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + ] + yield mock_list_pipeline_jobs + + @pytest.fixture def mock_load_yaml_and_json(job_spec): with patch.object(storage.Blob, "download_as_bytes") as mock_load_yaml_and_json: @@ -909,3 +972,335 @@ def test_get_schedule(self, mock_schedule_service_get): assert isinstance( pipeline_job_schedule, pipeline_job_schedules.PipelineJobSchedule ) + + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + def test_done_method_schedule_service( + self, + mock_schedule_service_create, + mock_schedule_service_get, + mock_schedule_bucket_exists, + job_spec, + mock_load_yaml_and_json, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + ) + + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule( + pipeline_job=job, + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + ) + + pipeline_job_schedule.create( + cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + ) + assert pipeline_job_schedule.done() is False + + pipeline_job_schedule.wait() + + assert pipeline_job_schedule.done() is True + + @pytest.mark.usefixtures( + "mock_schedule_service_create", + "mock_schedule_service_get", + "mock_schedule_bucket_exists", + ) + @pytest.mark.parametrize( + "job_spec", + [ + _TEST_PIPELINE_SPEC_JSON, + _TEST_PIPELINE_SPEC_YAML, + _TEST_PIPELINE_JOB, + _TEST_PIPELINE_SPEC_LEGACY_JSON, + _TEST_PIPELINE_SPEC_LEGACY_YAML, + _TEST_PIPELINE_JOB_LEGACY, + ], + ) + def test_pause_resume_schedule_service( + self, + mock_schedule_service_pause, + mock_schedule_service_resume, + mock_load_yaml_and_json, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + ) + + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule( + pipeline_job=job, + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + ) + + pipeline_job_schedule.create( + cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + ) + + pipeline_job_schedule.pause() + + mock_schedule_service_pause.assert_called_once_with( + name=_TEST_PIPELINE_JOB_SCHEDULE_NAME + ) + + pipeline_job_schedule.resume() + + mock_schedule_service_resume.assert_called_once_with( + name=_TEST_PIPELINE_JOB_SCHEDULE_NAME + ) + + @pytest.mark.usefixtures( + "mock_schedule_service_create", + "mock_schedule_service_get", + "mock_schedule_bucket_exists", + ) + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + def test_list_schedules(self, mock_schedule_service_list, mock_load_yaml_and_json): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS, + enable_caching=True, + ) + + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule( + pipeline_job=job, + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + ) + + pipeline_job_schedule.create( + cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + create_request_timeout=None, + ) + + pipeline_job_schedule.list(enable_simple_view=False) + + mock_schedule_service_list.assert_called_once_with( + request={"parent": _TEST_PARENT} + ) + + @pytest.mark.usefixtures( + "mock_schedule_service_create", + "mock_schedule_service_get", + "mock_schedule_bucket_exists", + ) + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + def test_list_schedules_with_read_mask( + self, mock_schedule_service_list, mock_load_yaml_and_json + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS, + enable_caching=True, + ) + + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule( + pipeline_job=job, + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + ) + + pipeline_job_schedule.create( + cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + create_request_timeout=None, + ) + + pipeline_job_schedule.list(enable_simple_view=True) + + test_pipeline_job_schedule_list_read_mask = field_mask.FieldMask( + paths=schedule_constants._PIPELINE_JOB_SCHEDULE_READ_MASK_FIELDS + ) + + mock_schedule_service_list.assert_called_once_with( + request={ + "parent": _TEST_PARENT, + "read_mask": test_pipeline_job_schedule_list_read_mask, + }, + ) + + @pytest.mark.usefixtures( + "mock_schedule_service_create", + "mock_schedule_service_get", + "mock_schedule_bucket_exists", + ) + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + def test_list_schedule_jobs( + self, + mock_pipeline_service_list, + mock_load_yaml_and_json, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS, + enable_caching=True, + ) + + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule( + pipeline_job=job, + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + ) + + pipeline_job_schedule.create( + cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION, + max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT, + max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + create_request_timeout=None, + ) + + pipeline_job_schedule.list_jobs() + + mock_pipeline_service_list.assert_called_once_with( + request={ + "parent": _TEST_PARENT, + "filter": f"schedule_name={_TEST_PIPELINE_JOB_SCHEDULE_NAME}", + }, + ) + + @pytest.mark.usefixtures( + "mock_schedule_service_create", + "mock_schedule_service_get", + ) + @pytest.mark.parametrize( + "job_spec", + [ + _TEST_PIPELINE_SPEC_JSON, + _TEST_PIPELINE_SPEC_YAML, + _TEST_PIPELINE_JOB, + _TEST_PIPELINE_SPEC_LEGACY_JSON, + _TEST_PIPELINE_SPEC_LEGACY_YAML, + _TEST_PIPELINE_JOB_LEGACY, + ], + ) + def test_pause_pipeline_job_schedule_without_created( + self, + mock_schedule_service_pause, + mock_load_yaml_and_json, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + ) + + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule( + pipeline_job=job, + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + ) + + with pytest.raises(RuntimeError) as e: + pipeline_job_schedule.pause() + + assert e.match(regexp=r"Schedule resource has not been created") + + @pytest.mark.usefixtures( + "mock_schedule_service_create", + "mock_schedule_service_get", + ) + @pytest.mark.parametrize( + "job_spec", + [ + _TEST_PIPELINE_SPEC_JSON, + _TEST_PIPELINE_SPEC_YAML, + _TEST_PIPELINE_JOB, + _TEST_PIPELINE_SPEC_LEGACY_JSON, + _TEST_PIPELINE_SPEC_LEGACY_YAML, + _TEST_PIPELINE_JOB_LEGACY, + ], + ) + def test_resume_pipeline_job_schedule_without_created( + self, + mock_schedule_service_resume, + mock_load_yaml_and_json, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + ) + + pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule( + pipeline_job=job, + display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME, + ) + + with pytest.raises(RuntimeError) as e: + pipeline_job_schedule.resume() + + assert e.match(regexp=r"Schedule resource has not been created")