From bb27619d71fe237690f9c14a37461f1ca839822b Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Thu, 26 Jan 2023 15:25:46 -0800 Subject: [PATCH] fix: Use Client.list_blobs instead of Bucket.list_blobs in CPR artifact downloader, to make sure that CPR works with custom service accounts on Vertex Prediction. PiperOrigin-RevId: 504956857 --- google/cloud/aiplatform/utils/prediction_utils.py | 3 +-- tests/unit/aiplatform/test_utils.py | 14 +++++--------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/google/cloud/aiplatform/utils/prediction_utils.py b/google/cloud/aiplatform/utils/prediction_utils.py index 9cdfa14d17..a66ad07849 100644 --- a/google/cloud/aiplatform/utils/prediction_utils.py +++ b/google/cloud/aiplatform/utils/prediction_utils.py @@ -135,8 +135,7 @@ def download_model_artifacts(artifact_uri: str) -> None: bucket_name, prefix = matches.groups() gcs_client = storage.Client() - bucket = gcs_client.get_bucket(bucket_name) - blobs = bucket.list_blobs(prefix=prefix) + blobs = gcs_client.list_blobs(bucket_name, prefix=prefix) for blob in blobs: name_without_prefix = blob.name[len(prefix) :] name_without_prefix = ( diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 164dec7d4f..47e6caa421 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -70,13 +70,11 @@ def __init__(self, name): blob2 = mock.MagicMock() type(blob2).name = mock.PropertyMock(return_value=f"{GCS_PREFIX}/") - def get_blobs(prefix): + def get_blobs(bucket_name, prefix=""): return [blob1, blob2] with patch.object(storage, "Client") as mock_storage_client: - get_bucket_mock = mock.Mock() - get_bucket_mock.return_value.list_blobs.side_effect = get_blobs - mock_storage_client.return_value.get_bucket.return_value = get_bucket_mock() + mock_storage_client.return_value.list_blobs.side_effect = get_blobs yield mock_storage_client @@ -806,16 +804,14 @@ def test_download_model_artifacts(self, mock_storage_client): prediction_utils.download_model_artifacts(f"gs://{GCS_BUCKET}/{GCS_PREFIX}") assert mock_storage_client.called - mock_storage_client().get_bucket.assert_called_once_with(GCS_BUCKET) - mock_storage_client().get_bucket().list_blobs.assert_called_once_with( - prefix=GCS_PREFIX + mock_storage_client().list_blobs.assert_called_once_with( + GCS_BUCKET, prefix=GCS_PREFIX ) - mock_storage_client().get_bucket().list_blobs.side_effect("")[ + mock_storage_client().list_blobs.side_effect("")[ 0 ].download_to_filename.assert_called_once_with(FAKE_FILENAME) assert ( not mock_storage_client() - .get_bucket() .list_blobs.side_effect("")[1] .download_to_filename.called )