From 16dfbdd8ddbfa48b6dc61b2d4a9a2bc8b95f3797 Mon Sep 17 00:00:00 2001 From: Riley Brady Date: Wed, 28 Feb 2024 09:45:05 -0700 Subject: [PATCH] feat(datasets): Add NetCDFDataSet class (#360) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initialize template and early additions Signed-off-by: Riley Brady * add placeholder for remote file system load Signed-off-by: Riley Brady * switch to versioned dataset Signed-off-by: Riley Brady * add initial remote -> local get for S3 Signed-off-by: Riley Brady * further generalize remote retrieval Signed-off-by: Riley Brady * add in credentials Signed-off-by: Riley Brady * make temppath optional for remote datasets Signed-off-by: Riley Brady * add initial idea for multifile glob Signed-off-by: Riley Brady * style: Introduce `ruff` for linting in all plugins. (#354) Signed-off-by: Merel Theisen Signed-off-by: Riley Brady * add suggested style changes Signed-off-by: Riley Brady * add temppath to attributes Signed-off-by: Riley Brady * more temppath fixes Signed-off-by: Riley Brady * more temppath updates Signed-off-by: Riley Brady * add better tempfile deletion and work on saving files Signed-off-by: Riley Brady * make __del__ flexible Signed-off-by: Riley Brady * formatting Signed-off-by: Riley Brady * feat(datasets): create custom `DeprecationWarning` (#356) * feat(datasets): create custom `DeprecationWarning` Signed-off-by: Deepyaman Datta * feat(datasets): use the custom deprecation warning Signed-off-by: Deepyaman Datta * chore(datasets): show Kedro's deprecation warnings Signed-off-by: Deepyaman Datta * fix(datasets): remove unused imports in test files Signed-off-by: Deepyaman Datta --------- Signed-off-by: Deepyaman Datta Signed-off-by: Riley Brady * docs(datasets): add note about DataSet deprecation (#357) Signed-off-by: Riley Brady * test(datasets): skip `tensorflow` tests on Windows (#363) Signed-off-by: Deepyaman Datta Signed-off-by: Riley Brady * ci: Pin `tables` version (#370) * Pin tables version Signed-off-by: Ankita Katiyar * Also fix kedro-airflow Signed-off-by: Ankita Katiyar * Revert trying to fix airflow Signed-off-by: Ankita Katiyar --------- Signed-off-by: Ankita Katiyar Signed-off-by: Riley Brady * build(datasets): Release `1.7.1` (#378) Signed-off-by: Merel Theisen Signed-off-by: Riley Brady * docs: Update CONTRIBUTING.md and add one for `kedro-datasets` (#379) Update CONTRIBUTING.md + add one for kedro-datasets Signed-off-by: Ankita Katiyar Signed-off-by: Riley Brady * ci(datasets): Run tensorflow tests separately from other dataset tests (#377) Signed-off-by: Merel Theisen Signed-off-by: Riley Brady * feat: Kedro-Airflow convert all pipelines option (#335) * feat: kedro airflow convert --all option Signed-off-by: Simon Brugman * docs: release docs Signed-off-by: Simon Brugman --------- Signed-off-by: Simon Brugman Signed-off-by: Riley Brady * docs(datasets): blacken code in rst literal blocks (#362) Signed-off-by: Deepyaman Datta Signed-off-by: Riley Brady * docs: cloudpickle is an interesting extension of the pickle functionality (#361) Signed-off-by: H. Felix Wittmann Signed-off-by: Riley Brady * fix(datasets): Fix secret scan entropy error (#383) Fix secret scan entropy error Signed-off-by: Merel Theisen Signed-off-by: Riley Brady * style: Rename mentions of `DataSet` to `Dataset` in `kedro-airflow` and `kedro-telemetry` (#384) Signed-off-by: Merel Theisen Signed-off-by: Riley Brady * feat(datasets): Migrated `PartitionedDataSet` and `IncrementalDataSet` from main repository to kedro-datasets (#253) Signed-off-by: Peter Bludau Co-authored-by: Merel Theisen Signed-off-by: Riley Brady * fix: backwards compatibility for `kedro-airflow` (#381) Signed-off-by: Simon Brugman Signed-off-by: Riley Brady * fix(datasets): Don't warn for SparkDataset on Databricks when using s3 (#341) Signed-off-by: Alistair McKelvie Signed-off-by: Riley Brady * update docs API and release notes Signed-off-by: Riley Brady * add netcdf requirements to setup Signed-off-by: Riley Brady * lint Signed-off-by: Riley Brady * add initial tests Signed-off-by: Riley Brady * update dataset exists for multifile Signed-off-by: Riley Brady * Add full test suite for NetCDFDataSet Signed-off-by: Riley Brady * Add docstring examples Signed-off-by: Riley Brady * change xarray version req Signed-off-by: Riley Brady * update dask req Signed-off-by: Riley Brady * rename DataSet -> Dataset Signed-off-by: Riley Brady * Update xarray reqs for earlier python versions Signed-off-by: Riley Brady * fix setup Signed-off-by: Riley Brady * update test coverage Signed-off-by: Riley Brady * exclude init from test coverage Signed-off-by: Riley Brady * Sub in pathlib for os.remove Signed-off-by: Riley Brady * add metadata to dataset Signed-off-by: Riley Brady * add doctest for the new datasets Signed-off-by: Nok * add patch for supporting http/https Signed-off-by: Riley Brady * Small fixes post-merge Signed-off-by: Juan Luis Cano Rodríguez * Lint Signed-off-by: Juan Luis Cano Rodríguez * Fix import Signed-off-by: Juan Luis Cano Rodríguez * Un-ignore NetCDF doctest Signed-off-by: Juan Luis Cano Rodríguez * Add fixture Signed-off-by: Ankita Katiyar * Mark problematic test as xfail Signed-off-by: Juan Luis Cano Rodríguez * Skip problematic test instead of making it fail Signed-off-by: Juan Luis Cano Rodríguez * Skip problematic tests and fix failing tests Signed-off-by: Ankita Katiyar * Remove comment Signed-off-by: Ankita Katiyar --------- Signed-off-by: Riley Brady Signed-off-by: Merel Theisen Signed-off-by: Deepyaman Datta Signed-off-by: Ankita Katiyar Signed-off-by: Simon Brugman Signed-off-by: H. Felix Wittmann Signed-off-by: Peter Bludau Signed-off-by: Alistair McKelvie Signed-off-by: Merel Theisen <49397448+merelcht@users.noreply.github.com> Signed-off-by: Nok Lam Chan Signed-off-by: Nok Signed-off-by: Juan Luis Cano Rodríguez Signed-off-by: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com> Co-authored-by: Merel Theisen <49397448+merelcht@users.noreply.github.com> Co-authored-by: Deepyaman Datta Co-authored-by: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com> Co-authored-by: Simon Brugman Co-authored-by: Felix Wittmann Co-authored-by: PtrBld <7523956+PtrBld@users.noreply.github.com> Co-authored-by: Merel Theisen Co-authored-by: Alistair McKelvie Co-authored-by: Nok Lam Chan Co-authored-by: Juan Luis Cano Rodríguez Co-authored-by: Ankita Katiyar --- kedro-datasets/RELEASE.md | 4 + .../docs/source/api/kedro_datasets.rst | 1 + .../kedro_datasets/netcdf/__init__.py | 14 + .../kedro_datasets/netcdf/netcdf_dataset.py | 218 ++++++++++++++ kedro-datasets/pyproject.toml | 2 +- kedro-datasets/setup.py | 10 + kedro-datasets/tests/netcdf/__init__.py | 0 .../tests/netcdf/test_netcdf_dataset.py | 275 ++++++++++++++++++ 8 files changed, 523 insertions(+), 1 deletion(-) create mode 100644 kedro-datasets/kedro_datasets/netcdf/__init__.py create mode 100644 kedro-datasets/kedro_datasets/netcdf/netcdf_dataset.py create mode 100644 kedro-datasets/tests/netcdf/__init__.py create mode 100644 kedro-datasets/tests/netcdf/test_netcdf_dataset.py diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index d14d55ed4..58989fc08 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,8 +1,12 @@ # Upcoming Release ## Major features and improvements +* Added `NetCDFDataset` for loading and saving `*.nc` files. ## Bug fixes and other changes ## Community contributions +Many thanks to the following Kedroids for contributing PRs to this release: +* [Riley Brady](https://github.com/riley-brady) + # Release 2.1.0 ## Major features and improvements diff --git a/kedro-datasets/docs/source/api/kedro_datasets.rst b/kedro-datasets/docs/source/api/kedro_datasets.rst index 13e956878..c6fe0d04e 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets.rst @@ -23,6 +23,7 @@ kedro_datasets kedro_datasets.json.JSONDataset kedro_datasets.matlab.MatlabDataset kedro_datasets.matplotlib.MatplotlibWriter + kedro_datasets.netcdf.NetCDFDataset kedro_datasets.networkx.GMLDataset kedro_datasets.networkx.GraphMLDataset kedro_datasets.networkx.JSONDataset diff --git a/kedro-datasets/kedro_datasets/netcdf/__init__.py b/kedro-datasets/kedro_datasets/netcdf/__init__.py new file mode 100644 index 000000000..875b319c8 --- /dev/null +++ b/kedro-datasets/kedro_datasets/netcdf/__init__.py @@ -0,0 +1,14 @@ +"""``NetCDFDataset`` is an ``AbstractDataset`` to save and load NetCDF files.""" +from __future__ import annotations + +from typing import Any + +import lazy_loader as lazy + +# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 +NetCDFDataset: type[NetCDFDataset] +NetCDFDataset: Any + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, submod_attrs={"netcdf_dataset": ["NetCDFDataset"]} +) diff --git a/kedro-datasets/kedro_datasets/netcdf/netcdf_dataset.py b/kedro-datasets/kedro_datasets/netcdf/netcdf_dataset.py new file mode 100644 index 000000000..afed2f4d8 --- /dev/null +++ b/kedro-datasets/kedro_datasets/netcdf/netcdf_dataset.py @@ -0,0 +1,218 @@ +"""NetCDFDataset loads and saves data to a local netcdf (.nc) file.""" +import logging +from copy import deepcopy +from glob import glob +from pathlib import Path, PurePosixPath +from typing import Any + +import fsspec +import xarray as xr +from kedro.io.core import ( + AbstractDataset, + DatasetError, + get_protocol_and_path, +) + +logger = logging.getLogger(__name__) + + +class NetCDFDataset(AbstractDataset): + """``NetCDFDataset`` loads/saves data from/to a NetCDF file using an underlying + filesystem (e.g.: local, S3, GCS). It uses xarray to handle the NetCDF file. + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + single-file: + type: netcdf.NetCDFDataset + filepath: s3://bucket_name/path/to/folder/data.nc + save_args: + mode: a + load_args: + decode_times: False + + multi-file: + type: netcdf.NetCDFDataset + filepath: s3://bucket_name/path/to/folder/data*.nc + load_args: + concat_dim: time + combine: nested + parallel: True + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from kedro_datasets.netcdf import NetCDFDataset + >>> import xarray as xr + >>> ds = xr.DataArray( + ... [0, 1, 2], dims=["x"], coords={"x": [0, 1, 2]}, name="data" + ... ).to_dataset() + >>> dataset = NetCDFDataset( + ... filepath="path/to/folder", + ... save_args={"mode": "w"}, + ... ) + >>> dataset.save(ds) + >>> reloaded = dataset.load() + """ + + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {} + + def __init__( # noqa + self, + *, + filepath: str, + temppath: str = None, + load_args: dict[str, Any] = None, + save_args: dict[str, Any] = None, + fs_args: dict[str, Any] = None, + credentials: dict[str, Any] = None, + metadata: dict[str, Any] = None, + ): + """Creates a new instance of ``NetCDFDataset`` pointing to a concrete NetCDF + file on a specific filesystem + + Args: + filepath: Filepath in POSIX format to a NetCDF file prefixed with a + protocol like `s3://`. If prefix is not provided, `file` protocol + (local filesystem) will be used. The prefix should be any protocol + supported by ``fsspec``. It can also be a path to a glob. If a + glob is provided then it can be used for reading multiple NetCDF + files. + temppath: Local temporary directory, used when reading from remote storage, + since NetCDF files cannot be directly read from remote storage. + load_args: Additional options for loading NetCDF file(s). + Here you can find all available arguments when reading single file: + https://xarray.pydata.org/en/stable/generated/xarray.open_dataset.html + Here you can find all available arguments when reading multiple files: + https://xarray.pydata.org/en/stable/generated/xarray.open_mfdataset.html + All defaults are preserved. + save_args: Additional saving options for saving NetCDF file(s). + Here you can find all available arguments: + https://xarray.pydata.org/en/stable/generated/xarray.Dataset.to_netcdf.html + All defaults are preserved. + fs_args: Extra arguments to pass into underlying filesystem class + constructor (e.g. `{"cache_regions": "us-east-1"}` for + ``s3fs.S3FileSystem``). + credentials: Credentials required to get access to the underlying filesystem. + E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + """ + self._fs_args = deepcopy(fs_args) or {} + self._credentials = deepcopy(credentials) or {} + self._temppath = Path(temppath) if temppath is not None else None + protocol, path = get_protocol_and_path(filepath) + if protocol == "file": + self._fs_args.setdefault("auto_mkdir", True) + elif protocol != "file" and self._temppath is None: + raise ValueError( + "Need to set temppath in catalog if NetCDF file exists on remote " + + "filesystem" + ) + self._protocol = protocol + self._filepath = filepath + + self._storage_options = {**self._credentials, **self._fs_args} + self._fs = fsspec.filesystem(self._protocol, **self._storage_options) + + self.metadata = metadata + + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + # Determine if multiple NetCDF files are being loaded in. + self._is_multifile = ( + True if "*" in str(PurePosixPath(self._filepath).stem) else False + ) + + def _load(self) -> xr.Dataset: + load_path = self._filepath + + # If NetCDF(s) are on any type of remote storage, need to sync to local to open. + # Kerchunk could be implemented here in the future for direct remote reading. + if self._protocol != "file": + logger.info("Syncing remote NetCDF file to local storage.") + + if self._is_multifile: + load_path = sorted(self._fs.glob(load_path)) + + self._fs.get(load_path, f"{self._temppath}/") + load_path = f"{self._temppath}/{self._filepath.stem}.nc" + + if self._is_multifile: + data = xr.open_mfdataset(str(load_path), **self._load_args) + else: + data = xr.open_dataset(load_path, **self._load_args) + + return data + + def _save(self, data: xr.Dataset): + if self._is_multifile: + raise DatasetError( + "Globbed multifile datasets with '*' in filepath cannot be saved. " + + "Create an alternate NetCDFDataset with a single .nc output file." + ) + else: + save_path = self._filepath + bytes_buffer = data.to_netcdf(**self._save_args) + + with self._fs.open(save_path, mode="wb") as fs_file: + fs_file.write(bytes_buffer) + + self._invalidate_cache() + + def _describe(self) -> dict[str, Any]: + return dict( + filepath=self._filepath, + protocol=self._protocol, + load_args=self._load_args, + save_args=self._save_args, + ) + + def _exists(self) -> bool: + load_path = self._filepath + + if self._is_multifile: + files = self._fs.glob(load_path) + exists = True if files else False + else: + exists = self._fs.exists(load_path) + + return exists + + def _invalidate_cache(self): + """Invalidate underlying filesystem caches.""" + self._fs.invalidate_cache(self._filepath) + + def __del__(self): + """Cleanup temporary directory""" + if self._temppath is not None: + logger.info("Deleting local temporary files.") + temp_filepath = self._temppath / PurePosixPath(self._filepath).stem + if self._is_multifile: + temp_files = glob(str(temp_filepath)) + for file in temp_files: + try: + Path(file).unlink() + except FileNotFoundError: # pragma: no cover + pass # pragma: no cover + else: + temp_filepath = ( + str(temp_filepath) + "/" + PurePosixPath(self._filepath).name + ) + try: + Path(temp_filepath).unlink() + except FileNotFoundError: + pass diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index a09f2c9cb..d53f058b1 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -32,7 +32,7 @@ version = {attr = "kedro_datasets.__version__"} fail_under = 100 show_missing = true # temporarily ignore kedro_datasets/__init__.py in coverage report -omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/snowflake/*", "kedro_datasets/tensorflow/*", "kedro_datasets/__init__.py", "kedro_datasets/conftest.py", "kedro_datasets/databricks/*"] +omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/netcdf/*", "kedro_datasets/snowflake/*", "kedro_datasets/tensorflow/*", "kedro_datasets/__init__.py", "kedro_datasets/conftest.py", "kedro_datasets/databricks/*"] exclude_lines = ["pragma: no cover", "raise NotImplementedError", "if TYPE_CHECKING:"] [tool.pytest.ini_options] diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index afef7aaf6..6ffad7007 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -31,6 +31,14 @@ def _collect_requirements(requires): } matplotlib_require = {"matplotlib.MatplotlibWriter": ["matplotlib>=3.0.3, <4.0"]} matlab_require = {"matlab.MatlabDataset": ["scipy"]} +netcdf_require = { + "netcdf.NetCDFDataset": [ + "h5netcdf>=1.2.0", + "netcdf4>=1.6.4", + "xarray<=0.20.2; python_version == '3.7'", + "xarray>=2023.1.0; python_version >= '3.8'", + ] +} networkx_require = {"networkx.NetworkXDataset": ["networkx~=2.4"]} pandas_require = { "pandas.CSVDataset": [PANDAS], @@ -118,6 +126,7 @@ def _collect_requirements(requires): "huggingface": _collect_requirements(huggingface_require), "matlab": _collect_requirements(matlab_require), "matplotlib": _collect_requirements(matplotlib_require), + "netcdf": _collect_requirements(netcdf_require), "networkx": _collect_requirements(networkx_require), "pandas": _collect_requirements(pandas_require), "pickle": _collect_requirements(pickle_require), @@ -235,6 +244,7 @@ def _collect_requirements(requires): "tensorflow~=2.0; platform_system != 'Darwin' or platform_machine != 'arm64'", "triad>=0.6.7, <1.0", "trufflehog~=2.1", + "xarray>=2023.1.0", "xlsxwriter~=1.0", # huggingface "datasets", diff --git a/kedro-datasets/tests/netcdf/__init__.py b/kedro-datasets/tests/netcdf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/tests/netcdf/test_netcdf_dataset.py b/kedro-datasets/tests/netcdf/test_netcdf_dataset.py new file mode 100644 index 000000000..51eea1e15 --- /dev/null +++ b/kedro-datasets/tests/netcdf/test_netcdf_dataset.py @@ -0,0 +1,275 @@ +import os + +import boto3 +import pytest +import xarray as xr +from kedro.io.core import DatasetError +from moto import mock_aws +from s3fs import S3FileSystem +from xarray.testing import assert_equal + +from kedro_datasets.netcdf import NetCDFDataset + +FILE_NAME = "test.nc" +MULTIFILE_NAME = "test*.nc" +BUCKET_NAME = "test_bucket" +MULTIFILE_BUCKET_NAME = "test_bucket_multi" +AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} + +# Pathlib cannot be used since it strips out the second slash from "s3://" +S3_PATH = f"s3://{BUCKET_NAME}/{FILE_NAME}" +S3_PATH_MULTIFILE = f"s3://{MULTIFILE_BUCKET_NAME}/{MULTIFILE_NAME}" + + +@pytest.fixture +def mocked_s3_bucket(): + """Create a bucket for testing to store a singular NetCDF file.""" + with mock_aws(): + conn = boto3.client( + "s3", + aws_access_key_id=AWS_CREDENTIALS["key"], + aws_secret_access_key=AWS_CREDENTIALS["secret"], + ) + conn.create_bucket(Bucket=BUCKET_NAME) + yield conn + + +@pytest.fixture +def mocked_s3_bucket_multi(): + """Create a bucket for testing to store multiple NetCDF files.""" + with mock_aws(): + conn = boto3.client( + "s3", + aws_access_key_id=AWS_CREDENTIALS["key"], + aws_secret_access_key=AWS_CREDENTIALS["secret"], + ) + conn.create_bucket(Bucket=MULTIFILE_BUCKET_NAME) + yield conn + + +def dummy_data() -> xr.Dataset: + """Sample xarray dataset for load/save testing.""" + ds = xr.DataArray( + [0, 1, 2, 3], dims=["x"], coords={"x": [0, 1, 2, 3]}, name="data" + ).to_dataset() + return ds + + +@pytest.fixture +def dummy_xr_dataset() -> xr.Dataset: + """Expected result for load/save on a single NetCDF file.""" + return dummy_data() + + +@pytest.fixture +def dummy_xr_dataset_multi() -> xr.Dataset: + """Expected concatenated result for load/save on multiple NetCDF files.""" + data = dummy_data() + return xr.concat([data, data], dim="dummy") + + +@pytest.fixture +def mocked_s3_object(tmp_path, mocked_s3_bucket, dummy_xr_dataset: xr.Dataset): + """Creates singular test NetCDF and adds it to mocked S3 bucket.""" + temporary_path = tmp_path / FILE_NAME + dummy_xr_dataset.to_netcdf(str(temporary_path)) + + mocked_s3_bucket.put_object( + Bucket=BUCKET_NAME, Key=FILE_NAME, Body=temporary_path.read_bytes() + ) + return mocked_s3_bucket + + +@pytest.fixture +def mocked_s3_object_multi( + tmp_path, mocked_s3_bucket_multi, dummy_xr_dataset: xr.Dataset +): + """Creates multiple test NetCDFs and adds them to mocked S3 bucket.""" + + def put_data(file_name: str): + temporary_path = tmp_path / file_name + dummy_xr_dataset.to_netcdf(str(temporary_path)) + mocked_s3_bucket_multi.put_object( + Bucket=MULTIFILE_BUCKET_NAME, + Key=file_name, + Body=temporary_path.read_bytes(), + ) + return mocked_s3_bucket_multi + + mocked_s3_bucket_multi = put_data("test1.nc") + mocked_s3_bucket_multi = put_data("test2.nc") + return mocked_s3_bucket_multi + + +@pytest.fixture +def s3_dataset(load_args, save_args, tmp_path): + """Sample NetCDF dataset pointing to mocked S3 bucket with single NetCDF file.""" + return NetCDFDataset( + filepath=S3_PATH, + temppath=tmp_path, + credentials=AWS_CREDENTIALS, + load_args=load_args, + save_args=save_args, + ) + + +@pytest.fixture +def s3_dataset_multi(save_args, tmp_path): + """Sample NetCDF dataset pointing to mocked S3 bucket with multiple NetCDF files.""" + return NetCDFDataset( + filepath=S3_PATH_MULTIFILE, + temppath=tmp_path, + credentials=AWS_CREDENTIALS, + load_args={"concat_dim": "dummy", "combine": "nested"}, + save_args=save_args, + ) + + +@pytest.fixture() +def s3fs_cleanup(): + # clear cache so we get a clean slate every time we instantiate a S3FileSystem + yield + S3FileSystem.cachable = False + + +@pytest.mark.usefixtures("s3fs_cleanup") +class TestNetCDFDataset: + os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY" + os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" + + def test_temppath_error_raised(self): + """Test that error is raised if S3 NetCDF file referenced without a temporary + path.""" + pattern = "Need to set temppath in catalog" + with pytest.raises(ValueError, match=pattern): + NetCDFDataset( + filepath=S3_PATH, + temppath=None, + ) + + @pytest.mark.parametrize("bad_credentials", [{"key": None, "secret": None}]) + def test_empty_credentials_load(self, bad_credentials, tmp_path): + """Test that error is raised if there are no AWS credentials.""" + netcdf_dataset = NetCDFDataset( + filepath=S3_PATH, temppath=tmp_path, credentials=bad_credentials + ) + pattern = r"Failed while loading data from data set NetCDFDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + netcdf_dataset.load() + + @pytest.mark.xfail(reason="Pending rewrite with new s3fs version") + def test_pass_credentials(self, mocker, tmp_path): + """Test that AWS credentials are passed successfully into boto3 + client instantiation on creating S3 connection.""" + client_mock = mocker.patch("botocore.session.Session.create_client") + s3_dataset = NetCDFDataset( + filepath=S3_PATH, temppath=tmp_path, credentials=AWS_CREDENTIALS + ) + pattern = r"Failed while loading data from data set NetCDFDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + s3_dataset.load() + + assert client_mock.call_count == 1 + args, kwargs = client_mock.call_args_list[0] + assert args == ("s3",) + assert kwargs["aws_access_key_id"] == AWS_CREDENTIALS["key"] + assert kwargs["aws_secret_access_key"] == AWS_CREDENTIALS["secret"] + + @pytest.mark.skip(reason="S3 tests that load datasets don't work properly") + def test_save_data_single(self, s3_dataset, dummy_xr_dataset, mocked_s3_bucket): + """Test saving a single NetCDF file to S3.""" + s3_dataset.save(dummy_xr_dataset) + loaded_data = s3_dataset.load() + assert_equal(loaded_data, dummy_xr_dataset) + + def test_save_data_multi_error(self, s3_dataset_multi, dummy_xr_dataset_multi): + """Test that error is raised when trying to save to a NetCDF destination with + a glob pattern.""" + pattern = r"Globbed multifile datasets with '*'" + with pytest.raises(DatasetError, match=pattern): + s3_dataset_multi.save(dummy_xr_dataset) + + @pytest.mark.skip(reason="S3 tests that load datasets don't work properly") + def test_load_data_single(self, s3_dataset, dummy_xr_dataset, mocked_s3_object): + """Test loading a single NetCDF file from S3.""" + loaded_data = s3_dataset.load() + assert_equal(loaded_data, dummy_xr_dataset) + + @pytest.mark.skip(reason="S3 tests that load datasets don't work properly") + def test_load_data_multi( + self, s3_dataset_multi, dummy_xr_dataset_multi, mocked_s3_object_multi + ): + """Test loading multiple NetCDF files from S3.""" + loaded_data = s3_dataset_multi.load() + assert_equal(loaded_data, dummy_xr_dataset_multi) + + def test_exists(self, s3_dataset, dummy_xr_dataset, mocked_s3_bucket): + """Test `exists` method invocation for both existing and nonexistent single + NetCDF file.""" + assert not s3_dataset.exists() + s3_dataset.save(dummy_xr_dataset) + assert s3_dataset.exists() + + @pytest.mark.usefixtures("mocked_s3_object_multi") + def test_exists_multi_remote(self, s3_dataset_multi): + """Test `exists` method invocation works for multifile glob pattern on S3.""" + assert s3_dataset_multi.exists() + + def test_exists_multi_locally(self, tmp_path, dummy_xr_dataset): + """Test `exists` method invocation for both existing and nonexistent set of + multiple local NetCDF files.""" + dataset = NetCDFDataset(filepath=str(tmp_path / MULTIFILE_NAME)) + assert not dataset.exists() + NetCDFDataset(filepath=str(tmp_path / "test1.nc")).save(dummy_xr_dataset) + NetCDFDataset(filepath=str(tmp_path / "test2.nc")).save(dummy_xr_dataset) + assert dataset.exists() + + def test_save_load_locally(self, tmp_path, dummy_xr_dataset): + """Test loading and saving the a NetCDF file locally.""" + file_path = str(tmp_path / "some" / "dir" / FILE_NAME) + dataset = NetCDFDataset(filepath=file_path) + + assert not dataset.exists() + dataset.save(dummy_xr_dataset) + assert dataset.exists() + loaded_data = dataset.load() + dummy_xr_dataset.equals(loaded_data) + + def test_load_locally_multi( + self, tmp_path, dummy_xr_dataset, dummy_xr_dataset_multi + ): + """Test loading multiple NetCDF files locally.""" + file_path = str(tmp_path / "some" / "dir" / MULTIFILE_NAME) + dataset = NetCDFDataset( + filepath=file_path, load_args={"concat_dim": "dummy", "combine": "nested"} + ) + + assert not dataset.exists() + NetCDFDataset(filepath=str(tmp_path / "some" / "dir" / "test1.nc")).save( + dummy_xr_dataset + ) + NetCDFDataset(filepath=str(tmp_path / "some" / "dir" / "test2.nc")).save( + dummy_xr_dataset + ) + assert dataset.exists() + loaded_data = dataset.load() + dummy_xr_dataset_multi.equals(loaded_data.compute()) + + @pytest.mark.parametrize( + "load_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_load_extra_params(self, s3_dataset, load_args): + """Test overriding the default load arguments.""" + for key, value in load_args.items(): + assert s3_dataset._load_args[key] == value + + @pytest.mark.parametrize( + "save_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_save_extra_params(self, s3_dataset, save_args): + """Test overriding the default save arguments.""" + for key, value in save_args.items(): + assert s3_dataset._save_args[key] == value + + for key, value in s3_dataset.DEFAULT_SAVE_ARGS.items(): + assert s3_dataset._save_args[key] == value