From 506adc85f255f8e91019bdfa2388be816d84539c Mon Sep 17 00:00:00 2001 From: Michael Hu Date: Tue, 20 Sep 2022 21:59:26 -0400 Subject: [PATCH] feat: add support for HTTPS URI pipeline templates --- google/cloud/aiplatform/constants/pipeline.py | 3 + google/cloud/aiplatform/pipeline_jobs.py | 9 +- google/cloud/aiplatform/utils/yaml_utils.py | 69 +++++++++------ tests/unit/aiplatform/test_pipeline_jobs.py | 83 +++++++++++++++++++ tests/unit/aiplatform/test_utils.py | 39 +++++++-- 5 files changed, 167 insertions(+), 36 deletions(-) diff --git a/google/cloud/aiplatform/constants/pipeline.py b/google/cloud/aiplatform/constants/pipeline.py index 12acf8a52e..fabe1c8e61 100644 --- a/google/cloud/aiplatform/constants/pipeline.py +++ b/google/cloud/aiplatform/constants/pipeline.py @@ -38,6 +38,9 @@ # Pattern for an Artifact Registry URL. _VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*") +# Pattern for any JSON or YAML file over HTTPS. +_VALID_HTTPS_URL = re.compile(r"^https:\/\/([\.\/\w-]+)\/.*(json|yaml|yml)$") + # Fields to include in returned PipelineJob when enable_simple_view=True in PipelineJob.list() _READ_MASK_FIELDS = [ "name", diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index abc4f8eb56..25c2bad2db 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -57,6 +57,9 @@ # Pattern for an Artifact Registry URL. _VALID_AR_URL = pipeline_constants._VALID_AR_URL +# Pattern for any JSON or YAML file over HTTPS. +_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL + _READ_MASK_FIELDS = pipeline_constants._READ_MASK_FIELDS @@ -131,8 +134,8 @@ def __init__( template_path (str): Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"), - or an Artifact Registry URI (e.g. - "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"). + an Artifact Registry URI (e.g. + "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI. job_id (str): Optional. The unique ID of the job run. If not specified, pipeline name + timestamp will be used. @@ -277,7 +280,7 @@ def __init__( ), } - if _VALID_AR_URL.match(template_path): + if _VALID_AR_URL.match(template_path) or _VALID_HTTPS_URL.match(template_path): pipeline_job_args["template_uri"] = template_path self._gca_resource = gca_pipeline_job.PipelineJob(**pipeline_job_args) diff --git a/google/cloud/aiplatform/utils/yaml_utils.py b/google/cloud/aiplatform/utils/yaml_utils.py index bac33733dc..c61de0b861 100644 --- a/google/cloud/aiplatform/utils/yaml_utils.py +++ b/google/cloud/aiplatform/utils/yaml_utils.py @@ -13,18 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -import re +from types import ModuleType from typing import Any, Dict, Optional from urllib import request from google.auth import credentials as auth_credentials from google.auth import transport from google.cloud import storage +from google.cloud.aiplatform.constants import pipeline as pipeline_constants # Pattern for an Artifact Registry URL. -_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*") +_VALID_AR_URL = pipeline_constants._VALID_AR_URL + +# Pattern for any JSON or YAML file over HTTPS. +_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL def load_yaml( @@ -36,8 +39,8 @@ def load_yaml( Args: path (str): - Required. The path of the YAML document in Google Cloud Storage or - local. + Required. The path of the YAML document. It can be a local path, a + Google Cloud Storage URI, an Artifact Registry URI, or an HTTPS URI. project (str): Optional. Project to initiate the Storage client with. credentials (auth_credentials.Credentials): @@ -50,10 +53,25 @@ def load_yaml( return _load_yaml_from_gs_uri(path, project, credentials) elif _VALID_AR_URL.match(path): return _load_yaml_from_ar_uri(path, credentials) + elif _VALID_HTTPS_URL.match(path): + return _load_yaml_from_https_uri(path) else: return _load_yaml_from_local_file(path) +def _maybe_import_yaml() -> ModuleType: + """Tries to import the PyYAML module.""" + try: + import yaml + except ImportError: + raise ImportError( + "PyYAML is not installed and is required to parse PipelineJob or " + 'PipelineSpec files. Please install the SDK using "pip install ' + 'google-cloud-aiplatform[pipelines]"' + ) + return yaml + + def _load_yaml_from_gs_uri( uri: str, project: Optional[str] = None, @@ -72,13 +90,7 @@ def _load_yaml_from_gs_uri( Returns: A Dict object representing the YAML document. """ - try: - import yaml - except ImportError: - raise ImportError( - "pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. " - 'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"' - ) + yaml = _maybe_import_yaml() storage_client = storage.Client(project=project, credentials=credentials) blob = storage.Blob.from_string(uri, storage_client) return yaml.safe_load(blob.download_as_bytes()) @@ -94,13 +106,7 @@ def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]: Returns: A Dict object representing the YAML document. """ - try: - import yaml - except ImportError: - raise ImportError( - "pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. " - 'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"' - ) + yaml = _maybe_import_yaml() with open(file_path) as f: return yaml.safe_load(f) @@ -112,7 +118,7 @@ def _load_yaml_from_ar_uri( """Loads data from a YAML document referenced by a Artifact Registry URI. Args: - path (str): + uri (str): Required. Artifact Registry URI for YAML document. credentials (auth_credentials.Credentials): Optional. Credentials to use with Artifact Registry. @@ -120,13 +126,7 @@ def _load_yaml_from_ar_uri( Returns: A Dict object representing the YAML document. """ - try: - import yaml - except ImportError: - raise ImportError( - "pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. " - 'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"' - ) + yaml = _maybe_import_yaml() req = request.Request(uri) if credentials: @@ -137,3 +137,18 @@ def _load_yaml_from_ar_uri( response = request.urlopen(req) return yaml.safe_load(response.read().decode("utf-8")) + + +def _load_yaml_from_https_uri(uri: str) -> Dict[str, Any]: + """Loads data from a YAML document referenced by an HTTPS URI. + + Args: + uri (str): + Required. HTTPS URI for YAML document. + + Returns: + A Dict object representing the YAML document. + """ + yaml = _maybe_import_yaml() + response = request.urlopen(uri) + return yaml.safe_load(response.read().decode("utf-8")) diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index a608e7df07..efa3bafc47 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -60,6 +60,7 @@ _TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json" _TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" +_TEST_HTTPS_TEMPLATE_PATH = "https://raw.githubusercontent.com/repo/pipeline.json" _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" _TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}" @@ -627,6 +628,88 @@ def test_run_call_pipeline_service_create_artifact_registry( gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED ) + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_https( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_pipeline_bucket_exists, + mock_request_urlopen, + job_spec, + mock_load_yaml_and_json, + sync, + ): + import yaml + + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_HTTPS_TEMPLATE_PATH, + job_id=_TEST_PIPELINE_JOB_ID, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + enable_caching=True, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + sync=sync, + create_request_timeout=None, + ) + + if not sync: + job.wait() + + expected_runtime_config_dict = { + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, + } + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(expected_runtime_config_dict, runtime_config) + + job_spec = yaml.safe_load(job_spec) + pipeline_spec = job_spec.get("pipelineSpec") or job_spec + + # Construct expected request + expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + pipeline_spec={ + "components": {}, + "pipelineInfo": pipeline_spec["pipelineInfo"], + "root": pipeline_spec["root"], + "schemaVersion": "2.1.0", + }, + runtime_config=runtime_config, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + template_uri=_TEST_HTTPS_TEMPLATE_PATH, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=_TEST_PARENT, + pipeline_job=expected_gapic_pipeline_job, + pipeline_job_id=_TEST_PIPELINE_JOB_ID, + timeout=None, + ) + + mock_pipeline_service_get.assert_called_with( + name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY + ) + + assert job._gca_resource == make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + @pytest.mark.parametrize( "job_spec", [ diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 081d0ce18a..182d775648 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -24,7 +24,7 @@ from typing import Callable, Dict, Optional from unittest import mock from unittest.mock import patch -from urllib import request +from urllib import request as urllib_request import pytest import yaml @@ -751,15 +751,15 @@ def json_file(tmp_path): @pytest.fixture(scope="function") -def mock_request_urlopen(): +def mock_request_urlopen(request: str) -> str: data = {"key": "val", "list": ["1", 2, 3.0]} - with mock.patch.object(request, "urlopen") as mock_urlopen: + with mock.patch.object(urllib_request, "urlopen") as mock_urlopen: mock_read_response = mock.MagicMock() mock_decode_response = mock.MagicMock() mock_decode_response.return_value = json.dumps(data) mock_read_response.return_value.decode = mock_decode_response mock_urlopen.return_value.read = mock_read_response - yield "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" + yield request.param class TestYamlUtils: @@ -773,11 +773,38 @@ def test_load_yaml_from_local_file__with_json(self, json_file): expected = {"key": "val", "list": ["1", 2, 3.0]} assert actual == expected + @pytest.mark.parametrize( + "mock_request_urlopen", + ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"], + indirect=True, + ) def test_load_yaml_from_ar_uri(self, mock_request_urlopen): actual = yaml_utils.load_yaml(mock_request_urlopen) expected = {"key": "val", "list": ["1", 2, 3.0]} assert actual == expected - def test_load_yaml_from_invalid_uri(self): + @pytest.mark.parametrize( + "mock_request_urlopen", + [ + "https://raw.githubusercontent.com/repo/pipeline.json", + "https://raw.githubusercontent.com/repo/pipeline.yaml", + "https://raw.githubusercontent.com/repo/pipeline.yml", + ], + indirect=True, + ) + def test_load_yaml_from_https_uri(self, mock_request_urlopen): + actual = yaml_utils.load_yaml(mock_request_urlopen) + expected = {"key": "val", "list": ["1", 2, 3.0]} + assert actual == expected + + @pytest.mark.parametrize( + "uri", + [ + "https://us-docker.pkg.dev/v2/proj/repo/img/tags/list", + "https://example.com/pipeline.exe", + "http://example.com/pipeline.yaml", + ], + ) + def test_load_yaml_from_invalid_uri(self, uri: str): with pytest.raises(FileNotFoundError): - yaml_utils.load_yaml("https://us-docker.pkg.dev/v2/proj/repo/img/tags/list") + yaml_utils.load_yaml(uri)