Skip to content

Commit

Permalink
Expose load and save publicly for each dataset (#3920)
Browse files Browse the repository at this point in the history
* Expose `load` and `save` publicly for each dataset

Signed-off-by: Deepyaman Datta <[email protected]>

* Split docstrings to maintain one-line descriptions

Signed-off-by: Deepyaman Datta <[email protected]>

* Assign new-style public `load` and `save` in tests

Signed-off-by: Deepyaman Datta <[email protected]>

* Don't fail when trying to access unset annotations

Signed-off-by: Deepyaman Datta <[email protected]>

* Add coverage for legacy dataset functionality, too

Signed-off-by: Deepyaman Datta <[email protected]>

* Fix detection of `load` or `save` being overridden

Signed-off-by: Deepyaman Datta <[email protected]>

* Cover saving versioned over unversioned for legacy

Signed-off-by: Deepyaman Datta <[email protected]>

* Add tests for invalid data and version consistency

Signed-off-by: Deepyaman Datta <[email protected]>

* Guard against wrapping inherited `load` and `save`

Signed-off-by: Deepyaman Datta <[email protected]>

* Add missing type annotations for wrapper functions

Signed-off-by: Deepyaman Datta <[email protected]>

* Update type hints for wrapper methods, adding Self

Signed-off-by: Deepyaman Datta <[email protected]>

* Try to remove `Self` annotation from function defs

Signed-off-by: Deepyaman Datta <[email protected]>

* Restore `Self` annotation for load, save functions

This reverts commit 5041f5a.

Signed-off-by: Deepyaman Datta <[email protected]>

* Update the instructions for creating a new dataset

Signed-off-by: Deepyaman Datta <[email protected]>

---------

Signed-off-by: Deepyaman Datta <[email protected]>
  • Loading branch information
deepyaman authored Jul 29, 2024
1 parent 6609455 commit 52458c2
Show file tree
Hide file tree
Showing 9 changed files with 304 additions and 101 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Upcoming Release 0.19.7

## Major features and improvements
* Exposed `load` and `save` publicly for each dataset in the core `kedro` library, and enabled other datasets to do the same. If a dataset doesn't expose `load` or `save` publicly, Kedro will fall back to using `_load` or `_save`, respectively.
* Kedro commands are now lazily loaded to add performance gains when running Kedro commands.

## Bug fixes and other changes
Expand Down
2 changes: 1 addition & 1 deletion docs/source/data/how_to_create_a_custom_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ To add versioning support to the new dataset we need to extend the
{py:class}`~kedro.io.AbstractVersionedDataset` to:

* Accept a `version` keyword argument as part of the constructor
* Adapt the `_save` and `_load` method to use the versioned data path obtained from `_get_save_path` and `_get_load_path` respectively
* Adapt the `_load` and `_save` method to use the versioned data path obtained from `_get_load_path` and `_get_save_path` respectively

The following amends the full implementation of our basic `ImageDataset`. It now loads and saves data to and from a versioned subfolder (`data/01_raw/pokemon-images-and-types/images/images/pikachu.png/<version>/pikachu.png` with `version` being a datetime-formatted string `YYYY-MM-DDThh.mm.ss.sssZ` by default):

Expand Down
2 changes: 1 addition & 1 deletion kedro/io/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
"The argument type of 'dataset' should be either a dict/YAML "
"representation of the dataset, or the actual dataset object."
)
self._cache = MemoryDataset(copy_mode=copy_mode)
self._cache = MemoryDataset(copy_mode=copy_mode) # type: ignore[abstract]
self.metadata = metadata

def _release(self) -> None:
Expand Down
237 changes: 149 additions & 88 deletions kedro/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import warnings
from collections import namedtuple
from datetime import datetime, timezone
from functools import partial
from functools import partial, wraps
from glob import iglob
from operator import attrgetter
from pathlib import Path, PurePath, PurePosixPath
Expand All @@ -22,6 +22,7 @@

from cachetools import Cache, cachedmethod
from cachetools.keys import hashkey
from typing_extensions import Self

from kedro.utils import load_obj

Expand Down Expand Up @@ -74,6 +75,7 @@ class VersionNotFoundError(DatasetError):

class AbstractDataset(abc.ABC, Generic[_DI, _DO]):
"""``AbstractDataset`` is the base class for all data set implementations.
All data set implementations should extend this abstract class
and implement the methods marked as abstract.
If a specific dataset implementation cannot be used in conjunction with
Expand All @@ -93,10 +95,10 @@ class AbstractDataset(abc.ABC, Generic[_DI, _DO]):
>>> self._param1 = param1
>>> self._param2 = param2
>>>
>>> def _load(self) -> pd.DataFrame:
>>> def load(self) -> pd.DataFrame:
>>> return pd.read_csv(self._filepath)
>>>
>>> def _save(self, df: pd.DataFrame) -> None:
>>> def save(self, df: pd.DataFrame) -> None:
>>> df.to_csv(str(self._filepath))
>>>
>>> def _exists(self) -> bool:
Expand Down Expand Up @@ -178,57 +180,6 @@ def from_config(
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)

def load(self) -> _DO:
"""Loads data by delegation to the provided load method.
Returns:
Data returned by the provided load method.
Raises:
DatasetError: When underlying load method raises error.
"""

self._logger.debug("Loading %s", str(self))

try:
return self._load()
except DatasetError:
raise
except Exception as exc:
# This exception handling is by design as the composed data sets
# can throw any type of exception.
message = (
f"Failed while loading data from data set {str(self)}.\n{str(exc)}"
)
raise DatasetError(message) from exc

def save(self, data: _DI) -> None:
"""Saves data by delegation to the provided save method.
Args:
data: the value to be saved by provided save method.
Raises:
DatasetError: when underlying save method raises error.
FileNotFoundError: when save method got file instead of dir, on Windows.
NotADirectoryError: when save method got file instead of dir, on Unix.
"""

if data is None:
raise DatasetError("Saving 'None' to a 'Dataset' is not allowed")

try:
self._logger.debug("Saving %s", str(self))
self._save(data)
except DatasetError:
raise
except (FileNotFoundError, NotADirectoryError):
raise
except Exception as exc:
message = f"Failed while saving data to data set {str(self)}.\n{str(exc)}"
raise DatasetError(message) from exc

def __str__(self) -> str:
# TODO: Replace with __repr__ implementation in 0.20.0 release.
def _to_str(obj: Any, is_root: bool = False) -> str:
Expand Down Expand Up @@ -257,6 +208,85 @@ def _to_str(obj: Any, is_root: bool = False) -> str:

return f"{type(self).__name__}({_to_str(self._describe(), True)})"

@classmethod
def _load_wrapper(cls, load_func: Callable[[Self], _DO]) -> Callable[[Self], _DO]:
"""Decorate `load_func` with logging and error handling code."""

@wraps(load_func)
def load(self: Self) -> _DO:
self._logger.debug("Loading %s", str(self))

try:
return load_func(self)
except DatasetError:
raise
except Exception as exc:
# This exception handling is by design as the composed data sets
# can throw any type of exception.
message = (
f"Failed while loading data from data set {str(self)}.\n{str(exc)}"
)
raise DatasetError(message) from exc

load.__annotations__["return"] = load_func.__annotations__.get("return")
load.__loadwrapped__ = True # type: ignore[attr-defined]
return load

@classmethod
def _save_wrapper(
cls, save_func: Callable[[Self, _DI], None]
) -> Callable[[Self, _DI], None]:
"""Decorate `save_func` with logging and error handling code."""

@wraps(save_func)
def save(self: Self, data: _DI) -> None:
if data is None:
raise DatasetError("Saving 'None' to a 'Dataset' is not allowed")

try:
self._logger.debug("Saving %s", str(self))
save_func(self, data)
except (DatasetError, FileNotFoundError, NotADirectoryError):
raise
except Exception as exc:
message = (
f"Failed while saving data to data set {str(self)}.\n{str(exc)}"
)
raise DatasetError(message) from exc

save.__annotations__["data"] = save_func.__annotations__.get("data", Any)
save.__annotations__["return"] = save_func.__annotations__.get("return")
save.__savewrapped__ = True # type: ignore[attr-defined]
return save

def __init_subclass__(cls, **kwargs: Any) -> None:
"""Decorate the `load` and `save` methods provided by the class.
If `_load` or `_save` are defined, alias them as a prerequisite.
"""
super().__init_subclass__(**kwargs)

if hasattr(cls, "_load") and not cls._load.__qualname__.startswith("Abstract"):
cls.load = cls._load # type: ignore[method-assign]

if hasattr(cls, "_save") and not cls._save.__qualname__.startswith("Abstract"):
cls.save = cls._save # type: ignore[method-assign]

if hasattr(cls, "load") and not cls.load.__qualname__.startswith("Abstract"):
cls.load = cls._load_wrapper( # type: ignore[assignment]
cls.load
if not getattr(cls.load, "__loadwrapped__", False)
else cls.load.__wrapped__ # type: ignore[attr-defined]
)

if hasattr(cls, "save") and not cls.save.__qualname__.startswith("Abstract"):
cls.save = cls._save_wrapper( # type: ignore[assignment]
cls.save
if not getattr(cls.save, "__savewrapped__", False)
else cls.save.__wrapped__ # type: ignore[attr-defined]
)

def _pretty_repr(self, object_description: dict[str, Any]) -> str:
str_keys = []
for arg_name, arg_descr in object_description.items():
Expand All @@ -276,17 +306,37 @@ def __repr__(self) -> str:
return self._pretty_repr(self._describe())

@abc.abstractmethod
def _load(self) -> _DO:
def load(self) -> _DO:
"""Loads data by delegation to the provided load method.
Returns:
Data returned by the provided load method.
Raises:
DatasetError: When underlying load method raises error.
"""
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
f"it must implement the '_load' method"
f"it must implement the 'load' method"
)

@abc.abstractmethod
def _save(self, data: _DI) -> None:
def save(self, data: _DI) -> None:
"""Saves data by delegation to the provided save method.
Args:
data: the value to be saved by provided save method.
Raises:
DatasetError: when underlying save method raises error.
FileNotFoundError: when save method got file instead of dir, on Windows.
NotADirectoryError: when save method got file instead of dir, on Unix.
"""
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
f"it must implement the '_save' method"
f"it must implement the 'save' method"
)

@abc.abstractmethod
Expand Down Expand Up @@ -509,7 +559,9 @@ def _local_exists(local_filepath: str) -> bool: # SKIP_IF_NO_SPARK
class AbstractVersionedDataset(AbstractDataset[_DI, _DO], abc.ABC):
"""
``AbstractVersionedDataset`` is the base class for all versioned data set
implementations. All data sets that implement versioning should extend this
implementations.
All data sets that implement versioning should extend this
abstract class and implement the methods marked as abstract.
Example:
Expand All @@ -526,11 +578,11 @@ class AbstractVersionedDataset(AbstractDataset[_DI, _DO], abc.ABC):
>>> self._param1 = param1
>>> self._param2 = param2
>>>
>>> def _load(self) -> pd.DataFrame:
>>> def load(self) -> pd.DataFrame:
>>> load_path = self._get_load_path()
>>> return pd.read_csv(load_path)
>>>
>>> def _save(self, df: pd.DataFrame) -> None:
>>> def save(self, df: pd.DataFrame) -> None:
>>> save_path = self._get_save_path()
>>> df.to_csv(str(save_path))
>>>
Expand Down Expand Up @@ -652,34 +704,43 @@ def _get_save_path(self) -> PurePosixPath:
def _get_versioned_path(self, version: str) -> PurePosixPath:
return self._filepath / version / self._filepath.name

def load(self) -> _DO:
return super().load()

def save(self, data: _DI) -> None:
self._version_cache.clear()
save_version = self.resolve_save_version() # Make sure last save version is set
try:
super().save(data)
except (FileNotFoundError, NotADirectoryError) as err:
# FileNotFoundError raised in Win, NotADirectoryError raised in Unix
_default_version = "YYYY-MM-DDThh.mm.ss.sssZ"
raise DatasetError(
f"Cannot save versioned dataset '{self._filepath.name}' to "
f"'{self._filepath.parent.as_posix()}' because a file with the same "
f"name already exists in the directory. This is likely because "
f"versioning was enabled on a dataset already saved previously. Either "
f"remove '{self._filepath.name}' from the directory or manually "
f"convert it into a versioned dataset by placing it in a versioned "
f"directory (e.g. with default versioning format "
f"'{self._filepath.as_posix()}/{_default_version}/{self._filepath.name}"
f"')."
) from err
@classmethod
def _save_wrapper(
cls, save_func: Callable[[Self, _DI], None]
) -> Callable[[Self, _DI], None]:
"""Decorate `save_func` with logging and error handling code."""

@wraps(save_func)
def save(self: Self, data: _DI) -> None:
self._version_cache.clear()
save_version = (
self.resolve_save_version()
) # Make sure last save version is set
try:
super()._save_wrapper(save_func)(self, data)
except (FileNotFoundError, NotADirectoryError) as err:
# FileNotFoundError raised in Win, NotADirectoryError raised in Unix
_default_version = "YYYY-MM-DDThh.mm.ss.sssZ"
raise DatasetError(
f"Cannot save versioned dataset '{self._filepath.name}' to "
f"'{self._filepath.parent.as_posix()}' because a file with the same "
f"name already exists in the directory. This is likely because "
f"versioning was enabled on a dataset already saved previously. Either "
f"remove '{self._filepath.name}' from the directory or manually "
f"convert it into a versioned dataset by placing it in a versioned "
f"directory (e.g. with default versioning format "
f"'{self._filepath.as_posix()}/{_default_version}/{self._filepath.name}"
f"')."
) from err

load_version = self.resolve_load_version()
if load_version != save_version:
warnings.warn(
_CONSISTENCY_WARNING.format(save_version, load_version, str(self))
)
self._version_cache.clear()

load_version = self.resolve_load_version()
if load_version != save_version:
warnings.warn(
_CONSISTENCY_WARNING.format(save_version, load_version, str(self))
)
return save

def exists(self) -> bool:
"""Checks whether a data set's output already exists by calling
Expand Down
2 changes: 1 addition & 1 deletion kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def add_feed_dict(self, feed_dict: dict[str, Any], replace: bool = False) -> Non
if isinstance(feed_dict[dataset_name], AbstractDataset):
dataset = feed_dict[dataset_name]
else:
dataset = MemoryDataset(data=feed_dict[dataset_name])
dataset = MemoryDataset(data=feed_dict[dataset_name]) # type: ignore[abstract]

self.add(dataset_name, dataset, replace)

Expand Down
16 changes: 8 additions & 8 deletions kedro/io/lambda_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@ def _to_str(func: Any) -> str | None:

return descr

def _save(self, data: Any) -> None:
if not self.__save:
def load(self) -> Any:
if not self.__load:
raise DatasetError(
"Cannot save to data set. No 'save' function "
"Cannot load data set. No 'load' function "
"provided when LambdaDataset was created."
)
self.__save(data)
return self.__load()

def _load(self) -> Any:
if not self.__load:
def save(self, data: Any) -> None:
if not self.__save:
raise DatasetError(
"Cannot load data set. No 'load' function "
"Cannot save to data set. No 'save' function "
"provided when LambdaDataset was created."
)
return self.__load()
self.__save(data)

def _exists(self) -> bool:
if not self.__exists:
Expand Down
2 changes: 1 addition & 1 deletion kedro/io/memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.metadata = metadata
self._EPHEMERAL = True
if data is not _EMPTY:
self._save(data)
self.save.__wrapped__(self, data) # type: ignore[attr-defined]

def _load(self) -> Any:
if self._data is _EMPTY:
Expand Down
Loading

0 comments on commit 52458c2

Please sign in to comment.