Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add missing content_type kwarg (#47)
Browse files Browse the repository at this point in the history
* Add missing content_type kwarg

* Add changelog

* Add kwargs to tasks

* Update CHANGELOG.md

* Update prefect_gcp/cloud_storage.py

* Add tests
  • Loading branch information
ahuang11 authored Sep 7, 2022
1 parent 78a3fbe commit 9788b5b
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Added
- `content_type` keyword argument in `cloud_storage_upload_blob_from_file` task - [#47](https://github.com/PrefectHQ/prefect-gcp/pull/47)
- `**kwargs` for all tasks in the module `cloud_storage.py` - [#47](https://github.com/PrefectHQ/prefect-gcp/pull/47)

### Changed
- Allowed `~` character to be used in the path for service account file - [#38](https://github.com/PrefectHQ/prefect-gcp/pull/38)
Expand Down
60 changes: 51 additions & 9 deletions prefect_gcp/cloud_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

from anyio import to_thread
from prefect import get_run_logger, task
Expand All @@ -21,6 +21,7 @@ async def cloud_storage_create_bucket(
gcp_credentials: "GcpCredentials",
project: Optional[str] = None,
location: Optional[str] = None,
**create_kwargs: Dict[str, Any]
) -> str:
"""
Creates a bucket.
Expand All @@ -31,6 +32,7 @@ async def cloud_storage_create_bucket(
project: Name of the project to use; overrides the
gcp_credentials project if provided.
location: Location of the bucket.
create_kwargs: Additional keyword arguments to pass to `client.create_bucket`.
Returns:
The bucket name.
Expand All @@ -55,7 +57,9 @@ def example_cloud_storage_create_bucket_flow():
logger.info("Creating %s bucket", bucket)

client = gcp_credentials.get_cloud_storage_client(project=project)
partial_create_bucket = partial(client.create_bucket, bucket, location=location)
partial_create_bucket = partial(
client.create_bucket, bucket, location=location, **create_kwargs
)
await to_thread.run_sync(partial_create_bucket)
return bucket

Expand Down Expand Up @@ -83,6 +87,7 @@ async def cloud_storage_download_blob_as_bytes(
encryption_key: Optional[str] = None,
timeout: Union[float, Tuple[float, float]] = 60,
project: Optional[str] = None,
**download_kwargs: Dict[str, Any]
) -> bytes:
"""
Downloads a blob as bytes.
Expand All @@ -102,6 +107,8 @@ async def cloud_storage_download_blob_as_bytes(
(connect_timeout, read_timeout).
project: Name of the project to use; overrides the
gcp_credentials project if provided.
download_kwargs: Additional keyword arguments to pass to
`Blob.download_as_bytes`.
Returns:
A bytes or string representation of the blob object.
Expand Down Expand Up @@ -131,7 +138,9 @@ def example_cloud_storage_download_blob_flow():
blob, chunk_size=chunk_size, encryption_key=encryption_key
)

partial_download = partial(blob_obj.download_as_bytes, timeout=timeout)
partial_download = partial(
blob_obj.download_as_bytes, timeout=timeout, **download_kwargs
)
contents = await to_thread.run_sync(partial_download)
return contents

Expand All @@ -146,6 +155,7 @@ async def cloud_storage_download_blob_to_file(
encryption_key: Optional[str] = None,
timeout: Union[float, Tuple[float, float]] = 60,
project: Optional[str] = None,
**download_kwargs: Dict[str, Any]
) -> Union[str, Path]:
"""
Downloads a blob to a file path.
Expand All @@ -165,6 +175,8 @@ async def cloud_storage_download_blob_to_file(
(connect_timeout, read_timeout).
project: Name of the project to use; overrides the
gcp_credentials project if provided.
download_kwargs: Additional keyword arguments to pass to
`Blob.download_to_filename`.
Returns:
The path to the blob object.
Expand Down Expand Up @@ -203,7 +215,9 @@ def example_cloud_storage_download_blob_flow():
else:
path = os.path.join(path, blob) # keep as str if a str is passed

partial_download = partial(blob_obj.download_to_filename, path, timeout=timeout)
partial_download = partial(
blob_obj.download_to_filename, path, timeout=timeout, **download_kwargs
)
await to_thread.run_sync(partial_download)
return path

Expand All @@ -219,6 +233,7 @@ async def cloud_storage_upload_blob_from_string(
encryption_key: Optional[str] = None,
timeout: Union[float, Tuple[float, float]] = 60,
project: Optional[str] = None,
**upload_kwargs: Dict[str, Any]
) -> str:
"""
Uploads a blob from a string or bytes representation of data.
Expand All @@ -229,7 +244,7 @@ async def cloud_storage_upload_blob_from_string(
blob: Name of the Cloud Storage blob.
gcp_credentials: Credentials to use for authentication with GCP.
content_type: Type of content being uploaded.
chunk_size (int, optional): The size of a chunk of data whenever
chunk_size: The size of a chunk of data whenever
iterating (in bytes). This must be a multiple of 256 KB
per the API specification.
encryption_key: An encryption key.
Expand All @@ -238,6 +253,8 @@ async def cloud_storage_upload_blob_from_string(
(connect_timeout, read_timeout).
project: Name of the project to use; overrides the
gcp_credentials project if provided.
upload_kwargs: Additional keyword arguments to pass to
`Blob.upload_from_string`.
Returns:
The blob name.
Expand Down Expand Up @@ -269,7 +286,11 @@ def example_cloud_storage_upload_blob_from_string_flow():
)

partial_upload = partial(
blob_obj.upload_from_string, data, content_type=content_type, timeout=timeout
blob_obj.upload_from_string,
data,
content_type=content_type,
timeout=timeout,
**upload_kwargs,
)
await to_thread.run_sync(partial_upload)
return blob
Expand All @@ -281,10 +302,12 @@ async def cloud_storage_upload_blob_from_file(
bucket: str,
blob: str,
gcp_credentials: "GcpCredentials",
content_type: Optional[str] = None,
chunk_size: Optional[int] = None,
encryption_key: Optional[str] = None,
timeout: Union[float, Tuple[float, float]] = 60,
project: Optional[str] = None,
**upload_kwargs: Dict[str, Any]
) -> str:
"""
Uploads a blob from file path or file-like object. Usage for passing in
Expand All @@ -296,7 +319,8 @@ async def cloud_storage_upload_blob_from_file(
bucket: Name of the bucket.
blob: Name of the Cloud Storage blob.
gcp_credentials: Credentials to use for authentication with GCP.
chunk_size (int, optional): The size of a chunk of data whenever
content_type: Type of content being uploaded.
chunk_size: The size of a chunk of data whenever
iterating (in bytes). This must be a multiple of 256 KB
per the API specification.
encryption_key: An encryption key.
Expand All @@ -305,6 +329,8 @@ async def cloud_storage_upload_blob_from_file(
(connect_timeout, read_timeout).
project: Name of the project to use; overrides the
gcp_credentials project if provided.
upload_kwargs: Additional keyword arguments to pass to
`Blob.upload_from_file` or `Blob.upload_from_filename`.
Returns:
The blob name.
Expand Down Expand Up @@ -336,9 +362,21 @@ def example_cloud_storage_upload_blob_from_file_flow():
)

if isinstance(file, BytesIO):
partial_upload = partial(blob_obj.upload_from_file, file, timeout=timeout)
partial_upload = partial(
blob_obj.upload_from_file,
file,
content_type=content_type,
timeout=timeout,
**upload_kwargs,
)
else:
partial_upload = partial(blob_obj.upload_from_filename, file, timeout=timeout)
partial_upload = partial(
blob_obj.upload_from_filename,
file,
content_type=content_type,
timeout=timeout,
**upload_kwargs,
)
await to_thread.run_sync(partial_upload)
return blob

Expand All @@ -352,6 +390,7 @@ async def cloud_storage_copy_blob(
dest_blob: Optional[str] = None,
timeout: Union[float, Tuple[float, float]] = 60,
project: Optional[str] = None,
**copy_kwargs: Dict[str, Any]
) -> str:
"""
Copies data from one Google Cloud Storage bucket to another,
Expand All @@ -368,6 +407,8 @@ async def cloud_storage_copy_blob(
(connect_timeout, read_timeout).
project: Name of the project to use; overrides the
gcp_credentials project if provided.
copy_kwargs: Additional keyword arguments to pass to
`Bucket.copy_blob`.
Returns:
Destination blob name.
Expand Down Expand Up @@ -417,6 +458,7 @@ def example_cloud_storage_copy_blob_flow():
destination_bucket=dest_bucket_obj,
new_name=dest_blob,
timeout=timeout,
**copy_kwargs,
)
await to_thread.run_sync(partial_copy_blob)

Expand Down
11 changes: 9 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

import pytest
from google.cloud.exceptions import NotFound
from prefect.testing.utilities import prefect_test_harness


@pytest.fixture(scope="session", autouse=True)
def prefect_db():
with prefect_test_harness():
yield


@pytest.fixture
Expand All @@ -17,8 +24,8 @@ def __init__(self, credentials=None, project=None):
self.credentials = credentials
self.project = project

def create_bucket(self, bucket, location=None):
return {"bucket": bucket, "location": location}
def create_bucket(self, bucket, location=None, **create_kwargs):
return {"bucket": bucket, "location": location, **create_kwargs}

def get_bucket(self, bucket):
blob_obj = MagicMock()
Expand Down
20 changes: 13 additions & 7 deletions tests/test_cloud_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def test_cloud_storage_create_bucket(gcp_credentials):

@flow
def test_flow():
return cloud_storage_create_bucket(bucket, gcp_credentials, location=location)
return cloud_storage_create_bucket(
bucket, gcp_credentials, location=location, timeout=10
)

assert test_flow() == "expected"

Expand All @@ -30,7 +32,7 @@ def test_cloud_storage_download_blob_to_file(path, gcp_credentials):
@flow
def test_flow():
return cloud_storage_download_blob_to_file(
"bucket", "blob", path, gcp_credentials
"bucket", "blob", path, gcp_credentials, timeout=10
)

assert test_flow() == path
Expand All @@ -39,25 +41,28 @@ def test_flow():
def test_cloud_storage_download_blob_as_bytes(gcp_credentials):
@flow
def test_flow():
return cloud_storage_download_blob_as_bytes("bucket", "blob", gcp_credentials)
return cloud_storage_download_blob_as_bytes(
"bucket", "blob", gcp_credentials, timeout=10
)

assert test_flow() == b"bytes"


@pytest.mark.parametrize(
"file",
[
"./file_path",
BytesIO(b"bytes_data"),
"./file_path.html",
BytesIO(b"<div>bytes_data</div>"),
],
)
def test_cloud_storage_upload_blob_from_file(file, gcp_credentials):
blob = "blob"
content_type = "text/html"

@flow
def test_flow():
return cloud_storage_upload_blob_from_file(
file, "bucket", blob, gcp_credentials
file, "bucket", blob, gcp_credentials, content_type=content_type, timeout=10
)

assert test_flow() == blob
Expand All @@ -77,7 +82,7 @@ def test_cloud_storage_upload_blob_from_string(data, blob, gcp_credentials):
@flow
def test_flow():
return cloud_storage_upload_blob_from_string(
data, "bucket", "blob", gcp_credentials
data, "bucket", "blob", gcp_credentials, timeout=10
)

assert test_flow() == blob
Expand All @@ -93,6 +98,7 @@ def test_flow():
"source_blob",
gcp_credentials,
dest_blob=dest_blob,
timeout=10,
)

if dest_blob is None:
Expand Down

0 comments on commit 9788b5b

Please sign in to comment.