diff --git a/google/cloud/aiplatform/utils/yaml_utils.py b/google/cloud/aiplatform/utils/yaml_utils.py index 4558a80200..4c660957a1 100644 --- a/google/cloud/aiplatform/utils/yaml_utils.py +++ b/google/cloud/aiplatform/utils/yaml_utils.py @@ -52,8 +52,10 @@ def load_yaml( if path.startswith("gs://"): return _load_yaml_from_gs_uri(path, project, credentials) elif path.startswith("http://") or path.startswith("https://"): - if _VALID_AR_URL.match(path) or _VALID_HTTPS_URL.match(path): + if _VALID_AR_URL.match(path): return _load_yaml_from_https_uri(path, credentials) + elif _VALID_HTTPS_URL.match(path): + return _load_yaml_from_https_uri(path) else: raise ValueError( "Invalid HTTPS URI. If not using Artifact Registry, please " diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index e81866bfef..f3c79df814 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -21,7 +21,7 @@ import json import os import textwrap -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Tuple from unittest import mock from unittest.mock import patch from urllib import request as urllib_request @@ -29,6 +29,7 @@ import pytest import yaml from google.api_core import client_options, gapic_v1 +from google.auth import credentials from google.cloud import aiplatform from google.cloud import storage from google.cloud.aiplatform import compat, utils @@ -775,7 +776,7 @@ def json_file(tmp_path): @pytest.fixture(scope="function") -def mock_request_urlopen(request: str) -> str: +def mock_request_urlopen(request: str) -> Tuple[str, mock.MagicMock]: data = {"key": "val", "list": ["1", 2, 3.0]} with mock.patch.object(urllib_request, "urlopen") as mock_urlopen: mock_read_response = mock.MagicMock() @@ -783,7 +784,7 @@ def mock_request_urlopen(request: str) -> str: mock_decode_response.return_value = json.dumps(data) mock_read_response.return_value.decode = mock_decode_response mock_urlopen.return_value.read = mock_read_response - yield request.param + yield request.param, mock_urlopen class TestYamlUtils: @@ -802,10 +803,17 @@ def test_load_yaml_from_local_file__with_json(self, json_file): ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"], indirect=True, ) - def test_load_yaml_from_ar_uri(self, mock_request_urlopen): - actual = yaml_utils.load_yaml(mock_request_urlopen) + def test_load_yaml_from_ar_uri_passes_creds(self, mock_request_urlopen): + url, mock_urlopen = mock_request_urlopen + mock_credentials = mock.create_autospec(credentials.Credentials, instance=True) + mock_credentials.valid = True + mock_credentials.token = "some_token" + actual = yaml_utils.load_yaml(url, credentials=mock_credentials) expected = {"key": "val", "list": ["1", 2, 3.0]} assert actual == expected + assert mock_urlopen.call_args[0][0].headers == { + "Authorization": "Bearer some_token" + } @pytest.mark.parametrize( "mock_request_urlopen", @@ -816,10 +824,15 @@ def test_load_yaml_from_ar_uri(self, mock_request_urlopen): ], indirect=True, ) - def test_load_yaml_from_https_uri(self, mock_request_urlopen): - actual = yaml_utils.load_yaml(mock_request_urlopen) + def test_load_yaml_from_https_uri_ignores_creds(self, mock_request_urlopen): + url, mock_urlopen = mock_request_urlopen + mock_credentials = mock.create_autospec(credentials.Credentials, instance=True) + mock_credentials.valid = True + mock_credentials.token = "some_token" + actual = yaml_utils.load_yaml(url, credentials=mock_credentials) expected = {"key": "val", "list": ["1", 2, 3.0]} assert actual == expected + assert mock_urlopen.call_args[0][0].headers == {} @pytest.mark.parametrize( "uri",