Skip to content

Commit

Permalink
Merge pull request #2 from jaycee-li/988-add-a-way-to-easily-clone-a-…
Browse files Browse the repository at this point in the history
…pipelinejob

feat: add clone method to PipelineJob
  • Loading branch information
jaycee-li authored May 17, 2022
2 parents bcbb21d + 1a223db commit fc63bbb
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 1 deletion.
140 changes: 139 additions & 1 deletion google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

# Copyright 2021 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -473,3 +473,141 @@ def list(
def wait_for_resource_creation(self) -> None:
"""Waits until resource has been created."""
self._wait_for_resource_creation()

def clone(
self,
display_name: Optional[str] = None,
job_id: Optional[str] = None,
pipeline_root: Optional[str] = None,
parameter_values: Optional[Dict[str, Any]] = None,
enable_caching: Optional[bool] = None,
encryption_spec_key_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
credentials: Optional[auth_credentials.Credentials] = None,
project: Optional[str] = None,
location: Optional[str] = None,
) -> "PipelineJob":
"""Returns a new PipelineJob object with the same settings as the original one.
Args:
display_name (str):
Optional. The user-defined name of this cloned Pipeline.
If not specified, original pipeline name will be used.
job_id (str):
Optional. The unique ID of the job run.
If not specified, "cloned" + pipeline name + timestamp will be used.
pipeline_root (str):
Optional. The root of the pipeline outputs. Default to be the same
staging bucket as original pipeline.
parameter_values (Dict[str, Any]):
Optional. The mapping from runtime parameter names to its values that
control the pipeline run. Defaults to be the same values as original
PipelineJob.
enable_caching (bool):
Optional. Whether to turn on caching for the run.
If this is not set, defaults to be the same as original pipeline.
If this is set, the setting applies to all tasks in the pipeline.
encryption_spec_key_name (str):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the job. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute resource is created.
If this is set, then all
resources created by the PipelineJob will
be encrypted with the provided encryption key.
If not specified, encryption_spec of original PipelineJob will be used.
labels (Dict[str,str]):
Optional. The user defined metadata to organize PipelineJob.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to create this PipelineJob.
Overrides credentials set in aiplatform.init.
project (str),
Optional. The project that you want to run this PipelineJob in.
If not set, the project set in original PipelineJob will be used.
location (str),
Optional. Location to create PipelineJob.
If not set, location set in original PipelineJob will be used.
"""
## Initialize an empty PipelineJob
if not project:
project = self.project
if not location:
location = self.location
if not credentials:
credentials = self.credentials

cloned = self.__class__._empty_constructor(
project=project,
location=location,
credentials=credentials,
)
cloned._parent = initializer.global_config.common_location_path(
project=project, location=location
)

## Get gca_resource from original PipelineJob
pipeline_job = json_format.MessageToDict(self._gca_resource._pb)

## Set pipeline_spec
pipeline_spec = pipeline_job['pipelineSpec']
if 'deploymentConfig' in pipeline_spec:
del pipeline_spec['deploymentConfig']

## Set caching
if enable_caching is not None:
_set_enable_caching_value(pipeline_spec, enable_caching)

## Set job_id
pipeline_name = pipeline_spec["pipelineInfo"]["name"]
cloned.job_id = job_id or "cloned-{pipeline_name}-{timestamp}".format(
pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower())
.lstrip("-")
.rstrip("-"),
timestamp=_get_current_time().strftime("%Y%m%d%H%M%S"),
)
if not _VALID_NAME_PATTERN.match(cloned.job_id):
raise ValueError(
"Generated job ID: {} is illegal as a Vertex pipelines job ID. "
"Expecting an ID following the regex pattern "
'"[a-z][-a-z0-9]{{0,127}}"'.format(cloned.job_id)
)

## Set display_name, labels and encryption_spec
if display_name:
utils.validate_display_name(display_name)
elif not display_name and "displayName" in pipeline_job:
display_name = pipeline_job["displayName"]

if labels:
utils.validate_labels(labels)
elif not labels and "labels" in pipeline_job:
labels = pipeline_job["labels"]

if encryption_spec_key_name or "encryptionSpec" not in pipeline_job:
encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
)
else:
encryption_spec = pipeline_job["encryptionSpec"]

## Set runtime_config
builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
pipeline_job
)
builder.update_pipeline_root(pipeline_root)
builder.update_runtime_parameters(parameter_values)
runtime_config_dict = builder.build()
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(runtime_config_dict, runtime_config)

## Create gca_resource for cloned PipelineJob
cloned._gca_resource = gca_pipeline_job_v1.PipelineJob(
display_name=display_name,
pipeline_spec=pipeline_spec,
labels=labels,
runtime_config=runtime_config,
encryption_spec=encryption_spec,
)

return cloned
164 changes: 164 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,3 +1045,167 @@ def test_pipeline_failure_raises(self, mock_load_yaml_and_json, sync):

if not sync:
job.wait()

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_clone_pipeline_job(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
job_spec,
mock_load_yaml_and_json,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

cloned = job.clone(job_id="cloned-"+_TEST_PIPELINE_JOB_ID)

cloned.submit(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
create_request_timeout=None,
)

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

job_spec = yaml.safe_load(job_spec)
pipeline_spec = job_spec.get("pipelineSpec") or job_spec

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job_v1.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.1.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id="cloned-"+_TEST_PIPELINE_JOB_ID,
timeout=None,
)

assert not mock_pipeline_service_get.called

cloned.wait()

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert cloned._gca_resource == make_pipeline_job(
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_clone_pipeline_job_with_all_args(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
job_spec,
mock_load_yaml_and_json,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

cloned = job.clone(
display_name=f"cloned-{_TEST_PIPELINE_JOB_DISPLAY_NAME}",
job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}",
pipeline_root=f"cloned-{_TEST_GCS_BUCKET_NAME}",
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
credentials=_TEST_CREDENTIALS,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

cloned.submit(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
create_request_timeout=None,
)

expected_runtime_config_dict = {
"gcsOutputDirectory": f"cloned-{_TEST_GCS_BUCKET_NAME}",
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job_v1.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

job_spec = yaml.safe_load(job_spec)
pipeline_spec = job_spec.get("pipelineSpec") or job_spec

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job_v1.PipelineJob(
display_name=f"cloned-{_TEST_PIPELINE_JOB_DISPLAY_NAME}",
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.1.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}",
timeout=None,
)

assert not mock_pipeline_service_get.called

cloned.wait()

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert cloned._gca_resource == make_pipeline_job(
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
)

0 comments on commit fc63bbb

Please sign in to comment.