diff --git a/CHANGELOG.md b/CHANGELOG.md index 46eb7c27..db8c1a3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added +- `content_type` keyword argument in `cloud_storage_upload_blob_from_file` task - [#47](https://github.com/PrefectHQ/prefect-gcp/pull/47) +- `**kwargs` for all tasks in the module `cloud_storage.py` - [#47](https://github.com/PrefectHQ/prefect-gcp/pull/47) ### Changed - Allowed `~` character to be used in the path for service account file - [#38](https://github.com/PrefectHQ/prefect-gcp/pull/38) diff --git a/prefect_gcp/cloud_storage.py b/prefect_gcp/cloud_storage.py index e6fbe826..cce21e1f 100644 --- a/prefect_gcp/cloud_storage.py +++ b/prefect_gcp/cloud_storage.py @@ -4,7 +4,7 @@ from functools import partial from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from anyio import to_thread from prefect import get_run_logger, task @@ -21,6 +21,7 @@ async def cloud_storage_create_bucket( gcp_credentials: "GcpCredentials", project: Optional[str] = None, location: Optional[str] = None, + **create_kwargs: Dict[str, Any] ) -> str: """ Creates a bucket. @@ -31,6 +32,7 @@ async def cloud_storage_create_bucket( project: Name of the project to use; overrides the gcp_credentials project if provided. location: Location of the bucket. + create_kwargs: Additional keyword arguments to pass to `client.create_bucket`. Returns: The bucket name. @@ -55,7 +57,9 @@ def example_cloud_storage_create_bucket_flow(): logger.info("Creating %s bucket", bucket) client = gcp_credentials.get_cloud_storage_client(project=project) - partial_create_bucket = partial(client.create_bucket, bucket, location=location) + partial_create_bucket = partial( + client.create_bucket, bucket, location=location, **create_kwargs + ) await to_thread.run_sync(partial_create_bucket) return bucket @@ -83,6 +87,7 @@ async def cloud_storage_download_blob_as_bytes( encryption_key: Optional[str] = None, timeout: Union[float, Tuple[float, float]] = 60, project: Optional[str] = None, + **download_kwargs: Dict[str, Any] ) -> bytes: """ Downloads a blob as bytes. @@ -102,6 +107,8 @@ async def cloud_storage_download_blob_as_bytes( (connect_timeout, read_timeout). project: Name of the project to use; overrides the gcp_credentials project if provided. + download_kwargs: Additional keyword arguments to pass to + `Blob.download_as_bytes`. Returns: A bytes or string representation of the blob object. @@ -131,7 +138,9 @@ def example_cloud_storage_download_blob_flow(): blob, chunk_size=chunk_size, encryption_key=encryption_key ) - partial_download = partial(blob_obj.download_as_bytes, timeout=timeout) + partial_download = partial( + blob_obj.download_as_bytes, timeout=timeout, **download_kwargs + ) contents = await to_thread.run_sync(partial_download) return contents @@ -146,6 +155,7 @@ async def cloud_storage_download_blob_to_file( encryption_key: Optional[str] = None, timeout: Union[float, Tuple[float, float]] = 60, project: Optional[str] = None, + **download_kwargs: Dict[str, Any] ) -> Union[str, Path]: """ Downloads a blob to a file path. @@ -165,6 +175,8 @@ async def cloud_storage_download_blob_to_file( (connect_timeout, read_timeout). project: Name of the project to use; overrides the gcp_credentials project if provided. + download_kwargs: Additional keyword arguments to pass to + `Blob.download_to_filename`. Returns: The path to the blob object. @@ -203,7 +215,9 @@ def example_cloud_storage_download_blob_flow(): else: path = os.path.join(path, blob) # keep as str if a str is passed - partial_download = partial(blob_obj.download_to_filename, path, timeout=timeout) + partial_download = partial( + blob_obj.download_to_filename, path, timeout=timeout, **download_kwargs + ) await to_thread.run_sync(partial_download) return path @@ -219,6 +233,7 @@ async def cloud_storage_upload_blob_from_string( encryption_key: Optional[str] = None, timeout: Union[float, Tuple[float, float]] = 60, project: Optional[str] = None, + **upload_kwargs: Dict[str, Any] ) -> str: """ Uploads a blob from a string or bytes representation of data. @@ -229,7 +244,7 @@ async def cloud_storage_upload_blob_from_string( blob: Name of the Cloud Storage blob. gcp_credentials: Credentials to use for authentication with GCP. content_type: Type of content being uploaded. - chunk_size (int, optional): The size of a chunk of data whenever + chunk_size: The size of a chunk of data whenever iterating (in bytes). This must be a multiple of 256 KB per the API specification. encryption_key: An encryption key. @@ -238,6 +253,8 @@ async def cloud_storage_upload_blob_from_string( (connect_timeout, read_timeout). project: Name of the project to use; overrides the gcp_credentials project if provided. + upload_kwargs: Additional keyword arguments to pass to + `Blob.upload_from_string`. Returns: The blob name. @@ -269,7 +286,11 @@ def example_cloud_storage_upload_blob_from_string_flow(): ) partial_upload = partial( - blob_obj.upload_from_string, data, content_type=content_type, timeout=timeout + blob_obj.upload_from_string, + data, + content_type=content_type, + timeout=timeout, + **upload_kwargs, ) await to_thread.run_sync(partial_upload) return blob @@ -281,10 +302,12 @@ async def cloud_storage_upload_blob_from_file( bucket: str, blob: str, gcp_credentials: "GcpCredentials", + content_type: Optional[str] = None, chunk_size: Optional[int] = None, encryption_key: Optional[str] = None, timeout: Union[float, Tuple[float, float]] = 60, project: Optional[str] = None, + **upload_kwargs: Dict[str, Any] ) -> str: """ Uploads a blob from file path or file-like object. Usage for passing in @@ -296,7 +319,8 @@ async def cloud_storage_upload_blob_from_file( bucket: Name of the bucket. blob: Name of the Cloud Storage blob. gcp_credentials: Credentials to use for authentication with GCP. - chunk_size (int, optional): The size of a chunk of data whenever + content_type: Type of content being uploaded. + chunk_size: The size of a chunk of data whenever iterating (in bytes). This must be a multiple of 256 KB per the API specification. encryption_key: An encryption key. @@ -305,6 +329,8 @@ async def cloud_storage_upload_blob_from_file( (connect_timeout, read_timeout). project: Name of the project to use; overrides the gcp_credentials project if provided. + upload_kwargs: Additional keyword arguments to pass to + `Blob.upload_from_file` or `Blob.upload_from_filename`. Returns: The blob name. @@ -336,9 +362,21 @@ def example_cloud_storage_upload_blob_from_file_flow(): ) if isinstance(file, BytesIO): - partial_upload = partial(blob_obj.upload_from_file, file, timeout=timeout) + partial_upload = partial( + blob_obj.upload_from_file, + file, + content_type=content_type, + timeout=timeout, + **upload_kwargs, + ) else: - partial_upload = partial(blob_obj.upload_from_filename, file, timeout=timeout) + partial_upload = partial( + blob_obj.upload_from_filename, + file, + content_type=content_type, + timeout=timeout, + **upload_kwargs, + ) await to_thread.run_sync(partial_upload) return blob @@ -352,6 +390,7 @@ async def cloud_storage_copy_blob( dest_blob: Optional[str] = None, timeout: Union[float, Tuple[float, float]] = 60, project: Optional[str] = None, + **copy_kwargs: Dict[str, Any] ) -> str: """ Copies data from one Google Cloud Storage bucket to another, @@ -368,6 +407,8 @@ async def cloud_storage_copy_blob( (connect_timeout, read_timeout). project: Name of the project to use; overrides the gcp_credentials project if provided. + copy_kwargs: Additional keyword arguments to pass to + `Bucket.copy_blob`. Returns: Destination blob name. @@ -417,6 +458,7 @@ def example_cloud_storage_copy_blob_flow(): destination_bucket=dest_bucket_obj, new_name=dest_blob, timeout=timeout, + **copy_kwargs, ) await to_thread.run_sync(partial_copy_blob) diff --git a/tests/conftest.py b/tests/conftest.py index d1fa2134..3f7c76f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,13 @@ import pytest from google.cloud.exceptions import NotFound +from prefect.testing.utilities import prefect_test_harness + + +@pytest.fixture(scope="session", autouse=True) +def prefect_db(): + with prefect_test_harness(): + yield @pytest.fixture @@ -17,8 +24,8 @@ def __init__(self, credentials=None, project=None): self.credentials = credentials self.project = project - def create_bucket(self, bucket, location=None): - return {"bucket": bucket, "location": location} + def create_bucket(self, bucket, location=None, **create_kwargs): + return {"bucket": bucket, "location": location, **create_kwargs} def get_bucket(self, bucket): blob_obj = MagicMock() diff --git a/tests/test_cloud_storage.py b/tests/test_cloud_storage.py index f86be37a..42137870 100644 --- a/tests/test_cloud_storage.py +++ b/tests/test_cloud_storage.py @@ -20,7 +20,9 @@ def test_cloud_storage_create_bucket(gcp_credentials): @flow def test_flow(): - return cloud_storage_create_bucket(bucket, gcp_credentials, location=location) + return cloud_storage_create_bucket( + bucket, gcp_credentials, location=location, timeout=10 + ) assert test_flow() == "expected" @@ -30,7 +32,7 @@ def test_cloud_storage_download_blob_to_file(path, gcp_credentials): @flow def test_flow(): return cloud_storage_download_blob_to_file( - "bucket", "blob", path, gcp_credentials + "bucket", "blob", path, gcp_credentials, timeout=10 ) assert test_flow() == path @@ -39,7 +41,9 @@ def test_flow(): def test_cloud_storage_download_blob_as_bytes(gcp_credentials): @flow def test_flow(): - return cloud_storage_download_blob_as_bytes("bucket", "blob", gcp_credentials) + return cloud_storage_download_blob_as_bytes( + "bucket", "blob", gcp_credentials, timeout=10 + ) assert test_flow() == b"bytes" @@ -47,17 +51,18 @@ def test_flow(): @pytest.mark.parametrize( "file", [ - "./file_path", - BytesIO(b"bytes_data"), + "./file_path.html", + BytesIO(b"
bytes_data
"), ], ) def test_cloud_storage_upload_blob_from_file(file, gcp_credentials): blob = "blob" + content_type = "text/html" @flow def test_flow(): return cloud_storage_upload_blob_from_file( - file, "bucket", blob, gcp_credentials + file, "bucket", blob, gcp_credentials, content_type=content_type, timeout=10 ) assert test_flow() == blob @@ -77,7 +82,7 @@ def test_cloud_storage_upload_blob_from_string(data, blob, gcp_credentials): @flow def test_flow(): return cloud_storage_upload_blob_from_string( - data, "bucket", "blob", gcp_credentials + data, "bucket", "blob", gcp_credentials, timeout=10 ) assert test_flow() == blob @@ -93,6 +98,7 @@ def test_flow(): "source_blob", gcp_credentials, dest_blob=dest_blob, + timeout=10, ) if dest_blob is None: