diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index d14d55ed4..58989fc08 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,8 +1,12 @@ # Upcoming Release ## Major features and improvements +* Added `NetCDFDataset` for loading and saving `*.nc` files. ## Bug fixes and other changes ## Community contributions +Many thanks to the following Kedroids for contributing PRs to this release: +* [Riley Brady](https://github.com/riley-brady) + # Release 2.1.0 ## Major features and improvements diff --git a/kedro-datasets/docs/source/api/kedro_datasets.rst b/kedro-datasets/docs/source/api/kedro_datasets.rst index 13e956878..c6fe0d04e 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets.rst @@ -23,6 +23,7 @@ kedro_datasets kedro_datasets.json.JSONDataset kedro_datasets.matlab.MatlabDataset kedro_datasets.matplotlib.MatplotlibWriter + kedro_datasets.netcdf.NetCDFDataset kedro_datasets.networkx.GMLDataset kedro_datasets.networkx.GraphMLDataset kedro_datasets.networkx.JSONDataset diff --git a/kedro-datasets/kedro_datasets/netcdf/__init__.py b/kedro-datasets/kedro_datasets/netcdf/__init__.py new file mode 100644 index 000000000..875b319c8 --- /dev/null +++ b/kedro-datasets/kedro_datasets/netcdf/__init__.py @@ -0,0 +1,14 @@ +"""``NetCDFDataset`` is an ``AbstractDataset`` to save and load NetCDF files.""" +from __future__ import annotations + +from typing import Any + +import lazy_loader as lazy + +# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 +NetCDFDataset: type[NetCDFDataset] +NetCDFDataset: Any + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, submod_attrs={"netcdf_dataset": ["NetCDFDataset"]} +) diff --git a/kedro-datasets/kedro_datasets/netcdf/netcdf_dataset.py b/kedro-datasets/kedro_datasets/netcdf/netcdf_dataset.py new file mode 100644 index 000000000..afed2f4d8 --- /dev/null +++ b/kedro-datasets/kedro_datasets/netcdf/netcdf_dataset.py @@ -0,0 +1,218 @@ +"""NetCDFDataset loads and saves data to a local netcdf (.nc) file.""" +import logging +from copy import deepcopy +from glob import glob +from pathlib import Path, PurePosixPath +from typing import Any + +import fsspec +import xarray as xr +from kedro.io.core import ( + AbstractDataset, + DatasetError, + get_protocol_and_path, +) + +logger = logging.getLogger(__name__) + + +class NetCDFDataset(AbstractDataset): + """``NetCDFDataset`` loads/saves data from/to a NetCDF file using an underlying + filesystem (e.g.: local, S3, GCS). It uses xarray to handle the NetCDF file. + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + single-file: + type: netcdf.NetCDFDataset + filepath: s3://bucket_name/path/to/folder/data.nc + save_args: + mode: a + load_args: + decode_times: False + + multi-file: + type: netcdf.NetCDFDataset + filepath: s3://bucket_name/path/to/folder/data*.nc + load_args: + concat_dim: time + combine: nested + parallel: True + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from kedro_datasets.netcdf import NetCDFDataset + >>> import xarray as xr + >>> ds = xr.DataArray( + ... [0, 1, 2], dims=["x"], coords={"x": [0, 1, 2]}, name="data" + ... ).to_dataset() + >>> dataset = NetCDFDataset( + ... filepath="path/to/folder", + ... save_args={"mode": "w"}, + ... ) + >>> dataset.save(ds) + >>> reloaded = dataset.load() + """ + + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {} + + def __init__( # noqa + self, + *, + filepath: str, + temppath: str = None, + load_args: dict[str, Any] = None, + save_args: dict[str, Any] = None, + fs_args: dict[str, Any] = None, + credentials: dict[str, Any] = None, + metadata: dict[str, Any] = None, + ): + """Creates a new instance of ``NetCDFDataset`` pointing to a concrete NetCDF + file on a specific filesystem + + Args: + filepath: Filepath in POSIX format to a NetCDF file prefixed with a + protocol like `s3://`. If prefix is not provided, `file` protocol + (local filesystem) will be used. The prefix should be any protocol + supported by ``fsspec``. It can also be a path to a glob. If a + glob is provided then it can be used for reading multiple NetCDF + files. + temppath: Local temporary directory, used when reading from remote storage, + since NetCDF files cannot be directly read from remote storage. + load_args: Additional options for loading NetCDF file(s). + Here you can find all available arguments when reading single file: + https://xarray.pydata.org/en/stable/generated/xarray.open_dataset.html + Here you can find all available arguments when reading multiple files: + https://xarray.pydata.org/en/stable/generated/xarray.open_mfdataset.html + All defaults are preserved. + save_args: Additional saving options for saving NetCDF file(s). + Here you can find all available arguments: + https://xarray.pydata.org/en/stable/generated/xarray.Dataset.to_netcdf.html + All defaults are preserved. + fs_args: Extra arguments to pass into underlying filesystem class + constructor (e.g. `{"cache_regions": "us-east-1"}` for + ``s3fs.S3FileSystem``). + credentials: Credentials required to get access to the underlying filesystem. + E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + """ + self._fs_args = deepcopy(fs_args) or {} + self._credentials = deepcopy(credentials) or {} + self._temppath = Path(temppath) if temppath is not None else None + protocol, path = get_protocol_and_path(filepath) + if protocol == "file": + self._fs_args.setdefault("auto_mkdir", True) + elif protocol != "file" and self._temppath is None: + raise ValueError( + "Need to set temppath in catalog if NetCDF file exists on remote " + + "filesystem" + ) + self._protocol = protocol + self._filepath = filepath + + self._storage_options = {**self._credentials, **self._fs_args} + self._fs = fsspec.filesystem(self._protocol, **self._storage_options) + + self.metadata = metadata + + # Handle default load and save arguments + self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) + if load_args is not None: + self._load_args.update(load_args) + self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) + if save_args is not None: + self._save_args.update(save_args) + + # Determine if multiple NetCDF files are being loaded in. + self._is_multifile = ( + True if "*" in str(PurePosixPath(self._filepath).stem) else False + ) + + def _load(self) -> xr.Dataset: + load_path = self._filepath + + # If NetCDF(s) are on any type of remote storage, need to sync to local to open. + # Kerchunk could be implemented here in the future for direct remote reading. + if self._protocol != "file": + logger.info("Syncing remote NetCDF file to local storage.") + + if self._is_multifile: + load_path = sorted(self._fs.glob(load_path)) + + self._fs.get(load_path, f"{self._temppath}/") + load_path = f"{self._temppath}/{self._filepath.stem}.nc" + + if self._is_multifile: + data = xr.open_mfdataset(str(load_path), **self._load_args) + else: + data = xr.open_dataset(load_path, **self._load_args) + + return data + + def _save(self, data: xr.Dataset): + if self._is_multifile: + raise DatasetError( + "Globbed multifile datasets with '*' in filepath cannot be saved. " + + "Create an alternate NetCDFDataset with a single .nc output file." + ) + else: + save_path = self._filepath + bytes_buffer = data.to_netcdf(**self._save_args) + + with self._fs.open(save_path, mode="wb") as fs_file: + fs_file.write(bytes_buffer) + + self._invalidate_cache() + + def _describe(self) -> dict[str, Any]: + return dict( + filepath=self._filepath, + protocol=self._protocol, + load_args=self._load_args, + save_args=self._save_args, + ) + + def _exists(self) -> bool: + load_path = self._filepath + + if self._is_multifile: + files = self._fs.glob(load_path) + exists = True if files else False + else: + exists = self._fs.exists(load_path) + + return exists + + def _invalidate_cache(self): + """Invalidate underlying filesystem caches.""" + self._fs.invalidate_cache(self._filepath) + + def __del__(self): + """Cleanup temporary directory""" + if self._temppath is not None: + logger.info("Deleting local temporary files.") + temp_filepath = self._temppath / PurePosixPath(self._filepath).stem + if self._is_multifile: + temp_files = glob(str(temp_filepath)) + for file in temp_files: + try: + Path(file).unlink() + except FileNotFoundError: # pragma: no cover + pass # pragma: no cover + else: + temp_filepath = ( + str(temp_filepath) + "/" + PurePosixPath(self._filepath).name + ) + try: + Path(temp_filepath).unlink() + except FileNotFoundError: + pass diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index a09f2c9cb..d53f058b1 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -32,7 +32,7 @@ version = {attr = "kedro_datasets.__version__"} fail_under = 100 show_missing = true # temporarily ignore kedro_datasets/__init__.py in coverage report -omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/snowflake/*", "kedro_datasets/tensorflow/*", "kedro_datasets/__init__.py", "kedro_datasets/conftest.py", "kedro_datasets/databricks/*"] +omit = ["tests/*", "kedro_datasets/holoviews/*", "kedro_datasets/netcdf/*", "kedro_datasets/snowflake/*", "kedro_datasets/tensorflow/*", "kedro_datasets/__init__.py", "kedro_datasets/conftest.py", "kedro_datasets/databricks/*"] exclude_lines = ["pragma: no cover", "raise NotImplementedError", "if TYPE_CHECKING:"] [tool.pytest.ini_options] diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index afef7aaf6..6ffad7007 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -31,6 +31,14 @@ def _collect_requirements(requires): } matplotlib_require = {"matplotlib.MatplotlibWriter": ["matplotlib>=3.0.3, <4.0"]} matlab_require = {"matlab.MatlabDataset": ["scipy"]} +netcdf_require = { + "netcdf.NetCDFDataset": [ + "h5netcdf>=1.2.0", + "netcdf4>=1.6.4", + "xarray<=0.20.2; python_version == '3.7'", + "xarray>=2023.1.0; python_version >= '3.8'", + ] +} networkx_require = {"networkx.NetworkXDataset": ["networkx~=2.4"]} pandas_require = { "pandas.CSVDataset": [PANDAS], @@ -118,6 +126,7 @@ def _collect_requirements(requires): "huggingface": _collect_requirements(huggingface_require), "matlab": _collect_requirements(matlab_require), "matplotlib": _collect_requirements(matplotlib_require), + "netcdf": _collect_requirements(netcdf_require), "networkx": _collect_requirements(networkx_require), "pandas": _collect_requirements(pandas_require), "pickle": _collect_requirements(pickle_require), @@ -235,6 +244,7 @@ def _collect_requirements(requires): "tensorflow~=2.0; platform_system != 'Darwin' or platform_machine != 'arm64'", "triad>=0.6.7, <1.0", "trufflehog~=2.1", + "xarray>=2023.1.0", "xlsxwriter~=1.0", # huggingface "datasets", diff --git a/kedro-datasets/tests/netcdf/__init__.py b/kedro-datasets/tests/netcdf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/tests/netcdf/test_netcdf_dataset.py b/kedro-datasets/tests/netcdf/test_netcdf_dataset.py new file mode 100644 index 000000000..51eea1e15 --- /dev/null +++ b/kedro-datasets/tests/netcdf/test_netcdf_dataset.py @@ -0,0 +1,275 @@ +import os + +import boto3 +import pytest +import xarray as xr +from kedro.io.core import DatasetError +from moto import mock_aws +from s3fs import S3FileSystem +from xarray.testing import assert_equal + +from kedro_datasets.netcdf import NetCDFDataset + +FILE_NAME = "test.nc" +MULTIFILE_NAME = "test*.nc" +BUCKET_NAME = "test_bucket" +MULTIFILE_BUCKET_NAME = "test_bucket_multi" +AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} + +# Pathlib cannot be used since it strips out the second slash from "s3://" +S3_PATH = f"s3://{BUCKET_NAME}/{FILE_NAME}" +S3_PATH_MULTIFILE = f"s3://{MULTIFILE_BUCKET_NAME}/{MULTIFILE_NAME}" + + +@pytest.fixture +def mocked_s3_bucket(): + """Create a bucket for testing to store a singular NetCDF file.""" + with mock_aws(): + conn = boto3.client( + "s3", + aws_access_key_id=AWS_CREDENTIALS["key"], + aws_secret_access_key=AWS_CREDENTIALS["secret"], + ) + conn.create_bucket(Bucket=BUCKET_NAME) + yield conn + + +@pytest.fixture +def mocked_s3_bucket_multi(): + """Create a bucket for testing to store multiple NetCDF files.""" + with mock_aws(): + conn = boto3.client( + "s3", + aws_access_key_id=AWS_CREDENTIALS["key"], + aws_secret_access_key=AWS_CREDENTIALS["secret"], + ) + conn.create_bucket(Bucket=MULTIFILE_BUCKET_NAME) + yield conn + + +def dummy_data() -> xr.Dataset: + """Sample xarray dataset for load/save testing.""" + ds = xr.DataArray( + [0, 1, 2, 3], dims=["x"], coords={"x": [0, 1, 2, 3]}, name="data" + ).to_dataset() + return ds + + +@pytest.fixture +def dummy_xr_dataset() -> xr.Dataset: + """Expected result for load/save on a single NetCDF file.""" + return dummy_data() + + +@pytest.fixture +def dummy_xr_dataset_multi() -> xr.Dataset: + """Expected concatenated result for load/save on multiple NetCDF files.""" + data = dummy_data() + return xr.concat([data, data], dim="dummy") + + +@pytest.fixture +def mocked_s3_object(tmp_path, mocked_s3_bucket, dummy_xr_dataset: xr.Dataset): + """Creates singular test NetCDF and adds it to mocked S3 bucket.""" + temporary_path = tmp_path / FILE_NAME + dummy_xr_dataset.to_netcdf(str(temporary_path)) + + mocked_s3_bucket.put_object( + Bucket=BUCKET_NAME, Key=FILE_NAME, Body=temporary_path.read_bytes() + ) + return mocked_s3_bucket + + +@pytest.fixture +def mocked_s3_object_multi( + tmp_path, mocked_s3_bucket_multi, dummy_xr_dataset: xr.Dataset +): + """Creates multiple test NetCDFs and adds them to mocked S3 bucket.""" + + def put_data(file_name: str): + temporary_path = tmp_path / file_name + dummy_xr_dataset.to_netcdf(str(temporary_path)) + mocked_s3_bucket_multi.put_object( + Bucket=MULTIFILE_BUCKET_NAME, + Key=file_name, + Body=temporary_path.read_bytes(), + ) + return mocked_s3_bucket_multi + + mocked_s3_bucket_multi = put_data("test1.nc") + mocked_s3_bucket_multi = put_data("test2.nc") + return mocked_s3_bucket_multi + + +@pytest.fixture +def s3_dataset(load_args, save_args, tmp_path): + """Sample NetCDF dataset pointing to mocked S3 bucket with single NetCDF file.""" + return NetCDFDataset( + filepath=S3_PATH, + temppath=tmp_path, + credentials=AWS_CREDENTIALS, + load_args=load_args, + save_args=save_args, + ) + + +@pytest.fixture +def s3_dataset_multi(save_args, tmp_path): + """Sample NetCDF dataset pointing to mocked S3 bucket with multiple NetCDF files.""" + return NetCDFDataset( + filepath=S3_PATH_MULTIFILE, + temppath=tmp_path, + credentials=AWS_CREDENTIALS, + load_args={"concat_dim": "dummy", "combine": "nested"}, + save_args=save_args, + ) + + +@pytest.fixture() +def s3fs_cleanup(): + # clear cache so we get a clean slate every time we instantiate a S3FileSystem + yield + S3FileSystem.cachable = False + + +@pytest.mark.usefixtures("s3fs_cleanup") +class TestNetCDFDataset: + os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY" + os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" + + def test_temppath_error_raised(self): + """Test that error is raised if S3 NetCDF file referenced without a temporary + path.""" + pattern = "Need to set temppath in catalog" + with pytest.raises(ValueError, match=pattern): + NetCDFDataset( + filepath=S3_PATH, + temppath=None, + ) + + @pytest.mark.parametrize("bad_credentials", [{"key": None, "secret": None}]) + def test_empty_credentials_load(self, bad_credentials, tmp_path): + """Test that error is raised if there are no AWS credentials.""" + netcdf_dataset = NetCDFDataset( + filepath=S3_PATH, temppath=tmp_path, credentials=bad_credentials + ) + pattern = r"Failed while loading data from data set NetCDFDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + netcdf_dataset.load() + + @pytest.mark.xfail(reason="Pending rewrite with new s3fs version") + def test_pass_credentials(self, mocker, tmp_path): + """Test that AWS credentials are passed successfully into boto3 + client instantiation on creating S3 connection.""" + client_mock = mocker.patch("botocore.session.Session.create_client") + s3_dataset = NetCDFDataset( + filepath=S3_PATH, temppath=tmp_path, credentials=AWS_CREDENTIALS + ) + pattern = r"Failed while loading data from data set NetCDFDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + s3_dataset.load() + + assert client_mock.call_count == 1 + args, kwargs = client_mock.call_args_list[0] + assert args == ("s3",) + assert kwargs["aws_access_key_id"] == AWS_CREDENTIALS["key"] + assert kwargs["aws_secret_access_key"] == AWS_CREDENTIALS["secret"] + + @pytest.mark.skip(reason="S3 tests that load datasets don't work properly") + def test_save_data_single(self, s3_dataset, dummy_xr_dataset, mocked_s3_bucket): + """Test saving a single NetCDF file to S3.""" + s3_dataset.save(dummy_xr_dataset) + loaded_data = s3_dataset.load() + assert_equal(loaded_data, dummy_xr_dataset) + + def test_save_data_multi_error(self, s3_dataset_multi, dummy_xr_dataset_multi): + """Test that error is raised when trying to save to a NetCDF destination with + a glob pattern.""" + pattern = r"Globbed multifile datasets with '*'" + with pytest.raises(DatasetError, match=pattern): + s3_dataset_multi.save(dummy_xr_dataset) + + @pytest.mark.skip(reason="S3 tests that load datasets don't work properly") + def test_load_data_single(self, s3_dataset, dummy_xr_dataset, mocked_s3_object): + """Test loading a single NetCDF file from S3.""" + loaded_data = s3_dataset.load() + assert_equal(loaded_data, dummy_xr_dataset) + + @pytest.mark.skip(reason="S3 tests that load datasets don't work properly") + def test_load_data_multi( + self, s3_dataset_multi, dummy_xr_dataset_multi, mocked_s3_object_multi + ): + """Test loading multiple NetCDF files from S3.""" + loaded_data = s3_dataset_multi.load() + assert_equal(loaded_data, dummy_xr_dataset_multi) + + def test_exists(self, s3_dataset, dummy_xr_dataset, mocked_s3_bucket): + """Test `exists` method invocation for both existing and nonexistent single + NetCDF file.""" + assert not s3_dataset.exists() + s3_dataset.save(dummy_xr_dataset) + assert s3_dataset.exists() + + @pytest.mark.usefixtures("mocked_s3_object_multi") + def test_exists_multi_remote(self, s3_dataset_multi): + """Test `exists` method invocation works for multifile glob pattern on S3.""" + assert s3_dataset_multi.exists() + + def test_exists_multi_locally(self, tmp_path, dummy_xr_dataset): + """Test `exists` method invocation for both existing and nonexistent set of + multiple local NetCDF files.""" + dataset = NetCDFDataset(filepath=str(tmp_path / MULTIFILE_NAME)) + assert not dataset.exists() + NetCDFDataset(filepath=str(tmp_path / "test1.nc")).save(dummy_xr_dataset) + NetCDFDataset(filepath=str(tmp_path / "test2.nc")).save(dummy_xr_dataset) + assert dataset.exists() + + def test_save_load_locally(self, tmp_path, dummy_xr_dataset): + """Test loading and saving the a NetCDF file locally.""" + file_path = str(tmp_path / "some" / "dir" / FILE_NAME) + dataset = NetCDFDataset(filepath=file_path) + + assert not dataset.exists() + dataset.save(dummy_xr_dataset) + assert dataset.exists() + loaded_data = dataset.load() + dummy_xr_dataset.equals(loaded_data) + + def test_load_locally_multi( + self, tmp_path, dummy_xr_dataset, dummy_xr_dataset_multi + ): + """Test loading multiple NetCDF files locally.""" + file_path = str(tmp_path / "some" / "dir" / MULTIFILE_NAME) + dataset = NetCDFDataset( + filepath=file_path, load_args={"concat_dim": "dummy", "combine": "nested"} + ) + + assert not dataset.exists() + NetCDFDataset(filepath=str(tmp_path / "some" / "dir" / "test1.nc")).save( + dummy_xr_dataset + ) + NetCDFDataset(filepath=str(tmp_path / "some" / "dir" / "test2.nc")).save( + dummy_xr_dataset + ) + assert dataset.exists() + loaded_data = dataset.load() + dummy_xr_dataset_multi.equals(loaded_data.compute()) + + @pytest.mark.parametrize( + "load_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_load_extra_params(self, s3_dataset, load_args): + """Test overriding the default load arguments.""" + for key, value in load_args.items(): + assert s3_dataset._load_args[key] == value + + @pytest.mark.parametrize( + "save_args", [{"k1": "v1", "index": "value"}], indirect=True + ) + def test_save_extra_params(self, s3_dataset, save_args): + """Test overriding the default save arguments.""" + for key, value in save_args.items(): + assert s3_dataset._save_args[key] == value + + for key, value in s3_dataset.DEFAULT_SAVE_ARGS.items(): + assert s3_dataset._save_args[key] == value