Skip to content

Commit

Permalink
build(datasets): Bump s3fs (#463)
Browse files Browse the repository at this point in the history
* Use mocking for AWS responses

Signed-off-by: Merel Theisen <[email protected]>

* Add change to release notes

Signed-off-by: Merel Theisen <[email protected]>

* Update release notes

Signed-off-by: Merel Theisen <[email protected]>

* Use pytest xfail instead of commenting out test

Signed-off-by: Merel Theisen <[email protected]>

---------

Signed-off-by: Merel Theisen <[email protected]>
  • Loading branch information
merelcht authored Dec 7, 2023
1 parent 2e5686f commit ff02a51
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 19 deletions.
1 change: 1 addition & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Major features and improvements
* Removed support for Python 3.7 and 3.8
* Spark and Databricks based datasets now support [databricks-connect>=13.0](https://docs.databricks.com/en/dev-tools/databricks-connect-ref.html)
* Bump `s3fs` to latest calendar versioning.

## Bug fixes and other changes
* Fixed bug with loading models saved with `TensorFlowModelDataset`.
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
PANDAS = "pandas>=1.3, <3.0"
SPARK = "pyspark>=2.2, <4.0"
HDFS = "hdfs>=2.5.8, <3.0"
S3FS = "s3fs>=0.3.0, <0.5"
S3FS = "s3fs>=2021.4, <2024.1" # Upper bound set arbitrarily, to be reassessed in early 2024
POLARS = "polars>=0.18.0"
DELTA = "delta-spark~=1.2.1"

Expand Down Expand Up @@ -210,7 +210,7 @@ def _collect_requirements(requires):
"requests-mock~=1.6",
"requests~=2.20",
"ruff~=0.0.290",
"s3fs>=0.3.0, <0.5", # Needs to be at least 0.3.0 to make use of `cachable` attribute on S3FileSystem.
"s3fs>=2021.04, <2024.1",
"snowflake-snowpark-python~=1.0; python_version == '3.9'",
"scikit-learn>=1.0.2,<2",
"scipy>=1.7.3",
Expand Down
70 changes: 70 additions & 0 deletions kedro-datasets/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,79 @@
https://docs.pytest.org/en/latest/fixture.html
"""

from typing import Callable
from unittest.mock import MagicMock

import aiobotocore.awsrequest
import aiobotocore.endpoint
import aiohttp
import aiohttp.client_reqrep
import aiohttp.typedefs
import botocore.awsrequest
import botocore.model
from kedro.io.core import generate_timestamp
from pytest import fixture

BUCKET_NAME = "test_bucket"
IP_ADDRESS = "127.0.0.1"
PORT = 5555
ENDPOINT_URI = f"http://{IP_ADDRESS}:{PORT}/"


"""
Patch aiobotocore to work with moto
See https://github.com/aio-libs/aiobotocore/issues/755
"""


class MockAWSResponse(aiobotocore.awsrequest.AioAWSResponse):
def __init__(self, response: botocore.awsrequest.AWSResponse):
self._moto_response = response
self.status_code = response.status_code
self.raw = MockHttpClientResponse(response)

# adapt async methods to use moto's response
async def _content_prop(self) -> bytes:
return self._moto_response.content

async def _text_prop(self) -> str:
return self._moto_response.text


class MockHttpClientResponse(aiohttp.client_reqrep.ClientResponse):
def __init__(self, response: botocore.awsrequest.AWSResponse):
async def read(self, n: int = -1) -> bytes:
# streaming/range requests. used by s3fs
return response.content

self.content = MagicMock(aiohttp.StreamReader)
self.content.read = read
self.response = response

@property
def raw_headers(self) -> aiohttp.typedefs.RawHeaders:
# Return the headers encoded the way that aiobotocore expects them
return {
k.encode("utf-8"): str(v).encode("utf-8")
for k, v in self.response.headers.items()
}.items()


@fixture(scope="session", autouse=True)
def patch_aiobotocore():
def factory(original: Callable) -> Callable:
def patched_convert_to_response_dict(
http_response: botocore.awsrequest.AWSResponse,
operation_model: botocore.model.OperationModel,
):
return original(MockAWSResponse(http_response), operation_model)

return patched_convert_to_response_dict

aiobotocore.endpoint.convert_to_response_dict = factory(
aiobotocore.endpoint.convert_to_response_dict
)


@fixture(params=[None])
def load_version(request):
Expand Down
10 changes: 4 additions & 6 deletions kedro-datasets/tests/dask/test_parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_empty_credentials_load(self, bad_credentials):
with pytest.raises(DatasetError, match=pattern):
parquet_dataset.load().compute()

@pytest.mark.xfail
def test_pass_credentials(self, mocker):
"""Test that AWS credentials are passed successfully into boto3
client instantiation on creating S3 connection."""
Expand All @@ -106,8 +107,7 @@ def test_pass_credentials(self, mocker):
assert kwargs["aws_access_key_id"] == AWS_CREDENTIALS["key"]
assert kwargs["aws_secret_access_key"] == AWS_CREDENTIALS["secret"]

@pytest.mark.usefixtures("mocked_s3_bucket")
def test_save_data(self, s3_dataset):
def test_save_data(self, s3_dataset, mocked_s3_bucket):
"""Test saving the data to S3."""
pd_data = pd.DataFrame(
{"col1": ["a", "b"], "col2": ["c", "d"], "col3": ["e", "f"]}
Expand All @@ -117,14 +117,12 @@ def test_save_data(self, s3_dataset):
loaded_data = s3_dataset.load()
assert_frame_equal(loaded_data.compute(), dd_data.compute())

@pytest.mark.usefixtures("mocked_s3_object")
def test_load_data(self, s3_dataset, dummy_dd_dataframe):
def test_load_data(self, s3_dataset, dummy_dd_dataframe, mocked_s3_object):
"""Test loading the data from S3."""
loaded_data = s3_dataset.load()
assert_frame_equal(loaded_data.compute(), dummy_dd_dataframe.compute())

@pytest.mark.usefixtures("mocked_s3_bucket")
def test_exists(self, s3_dataset, dummy_dd_dataframe):
def test_exists(self, s3_dataset, dummy_dd_dataframe, mocked_s3_bucket):
"""Test `exists` method invocation for both existing and
nonexistent data set."""
assert not s3_dataset.exists()
Expand Down
27 changes: 16 additions & 11 deletions kedro-datasets/tests/spark/test_spark_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
import sys
import tempfile
Expand Down Expand Up @@ -143,8 +144,8 @@ def mocked_s3_bucket():
with mock_s3():
conn = boto3.client(
"s3",
aws_access_key_id="fake_access_key",
aws_secret_access_key="fake_secret_key",
aws_access_key_id=AWS_CREDENTIALS["key"],
aws_secret_access_key=AWS_CREDENTIALS["secret"],
)
conn.create_bucket(Bucket=BUCKET_NAME)
yield conn
Expand All @@ -159,7 +160,7 @@ def mocked_s3_schema(tmp_path, mocked_s3_bucket, sample_spark_df_schema: StructT
mocked_s3_bucket.put_object(
Bucket=BUCKET_NAME, Key=SCHEMA_FILE_NAME, Body=temporary_path.read_bytes()
)
return mocked_s3_bucket
return f"s3://{BUCKET_NAME}/{SCHEMA_FILE_NAME}"


class FileInfo:
Expand Down Expand Up @@ -726,6 +727,10 @@ def test_dbfs_path_in_different_os(self, os_name, mocker):


class TestSparkDatasetVersionedS3:
os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY"
os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY"

@pytest.mark.xfail
def test_no_version(self, versioned_dataset_s3):
pattern = r"Did not find any versions for SparkDataset\(.+\)"
with pytest.raises(DatasetError, match=pattern):
Expand Down Expand Up @@ -766,27 +771,27 @@ def test_load_exact(self, mocker):
f"s3a://{BUCKET_NAME}/{FILENAME}/{ts}/{FILENAME}", "parquet"
)

def test_save(self, versioned_dataset_s3, version, mocker):
def test_save(self, mocked_s3_schema, versioned_dataset_s3, version, mocker):
mocked_spark_df = mocker.Mock()

# need resolve_load_version() call to return a load version that
# matches save version due to consistency check in versioned_dataset_s3.save()
mocker.patch.object(
versioned_dataset_s3, "resolve_load_version", return_value=version.save
ds_s3 = SparkDataset(
filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", version=version
)

versioned_dataset_s3.save(mocked_spark_df)
# need resolve_load_version() call to return a load version that
# matches save version due to consistency check in versioned_dataset_s3.save()
mocker.patch.object(ds_s3, "resolve_load_version", return_value=version.save)
ds_s3.save(mocked_spark_df)
mocked_spark_df.write.save.assert_called_once_with(
f"s3a://{BUCKET_NAME}/{FILENAME}/{version.save}/{FILENAME}",
"parquet",
)

def test_save_version_warning(self, mocker):
def test_save_version_warning(self, mocked_s3_schema, versioned_dataset_s3, mocker):
exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z")
ds_s3 = SparkDataset(
filepath=f"s3a://{BUCKET_NAME}/{FILENAME}",
version=exact_version,
credentials=AWS_CREDENTIALS,
)
mocked_spark_df = mocker.Mock()

Expand Down

0 comments on commit ff02a51

Please sign in to comment.