Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose load and save publicly for each dataset #3920

Merged
merged 21 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
131a994
Expose `load` and `save` publicly for each dataset
deepyaman Jun 4, 2024
57f721a
Split docstrings to maintain one-line descriptions
deepyaman Jun 4, 2024
c788b14
Assign new-style public `load` and `save` in tests
deepyaman Jun 5, 2024
8bcb34f
Don't fail when trying to access unset annotations
deepyaman Jun 5, 2024
39b73a2
Add coverage for legacy dataset functionality, too
deepyaman Jun 5, 2024
04590b6
Fix detection of `load` or `save` being overridden
deepyaman Jun 5, 2024
aa29258
Cover saving versioned over unversioned for legacy
deepyaman Jun 5, 2024
a4cf5fa
Add tests for invalid data and version consistency
deepyaman Jun 5, 2024
cd8e878
Guard against wrapping inherited `load` and `save`
deepyaman Jun 5, 2024
2d94fcb
Add missing type annotations for wrapper functions
deepyaman Jun 5, 2024
73d64da
Update type hints for wrapper methods, adding Self
deepyaman Jun 6, 2024
11e12f8
Try to remove `Self` annotation from function defs
deepyaman Jun 6, 2024
4e31497
Restore `Self` annotation for load, save functions
deepyaman Jun 6, 2024
f1eeb2c
Update the instructions for creating a new dataset
deepyaman Jun 17, 2024
253d660
Auto-convert from `_load`/`_save` to `load`/`save`
deepyaman Jun 27, 2024
b06c64b
Remove unneeded `load`/`save` in versioned dataset
deepyaman Jun 27, 2024
18e426b
Add docstrings for newly-added, non-public methods
deepyaman Jul 1, 2024
aeea33b
Merge branch 'main' into feat/render-concrete-types
deepyaman Jul 23, 2024
4d355c2
Do not remove `_load` and `_save` until Kedro 0.20
deepyaman Jul 23, 2024
a1d82e8
Ignore mypy not knowing load, save are implemented
deepyaman Jul 23, 2024
6fd55c6
Merge branch 'main' into feat/render-concrete-types
deepyaman Jul 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Upcoming Release 0.19.7

## Major features and improvements
* Exposed `load` and `save` publicly for each dataset in the core `kedro` library, and enabled other datasets to do the same. If a dataset doesn't expose `load` or `save` publicly, Kedro will fall back to using `_load` or `_save`, respectively.
* Kedro commands are now lazily loaded to add performance gains when running Kedro commands.

## Bug fixes and other changes
Expand Down
2 changes: 1 addition & 1 deletion docs/source/data/how_to_create_a_custom_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ To add versioning support to the new dataset we need to extend the
{py:class}`~kedro.io.AbstractVersionedDataset` to:

* Accept a `version` keyword argument as part of the constructor
* Adapt the `_save` and `_load` method to use the versioned data path obtained from `_get_save_path` and `_get_load_path` respectively
* Adapt the `_load` and `_save` method to use the versioned data path obtained from `_get_load_path` and `_get_save_path` respectively

The following amends the full implementation of our basic `ImageDataset`. It now loads and saves data to and from a versioned subfolder (`data/01_raw/pokemon-images-and-types/images/images/pikachu.png/<version>/pikachu.png` with `version` being a datetime-formatted string `YYYY-MM-DDThh.mm.ss.sssZ` by default):

Expand Down
2 changes: 1 addition & 1 deletion kedro/io/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
"The argument type of 'dataset' should be either a dict/YAML "
"representation of the dataset, or the actual dataset object."
)
self._cache = MemoryDataset(copy_mode=copy_mode)
self._cache = MemoryDataset(copy_mode=copy_mode) # type: ignore[abstract]
self.metadata = metadata

def _release(self) -> None:
Expand Down
237 changes: 149 additions & 88 deletions kedro/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import warnings
from collections import namedtuple
from datetime import datetime, timezone
from functools import partial
from functools import partial, wraps
from glob import iglob
from operator import attrgetter
from pathlib import Path, PurePath, PurePosixPath
Expand All @@ -22,6 +22,7 @@

from cachetools import Cache, cachedmethod
from cachetools.keys import hashkey
from typing_extensions import Self

from kedro.utils import load_obj

Expand Down Expand Up @@ -74,6 +75,7 @@ class VersionNotFoundError(DatasetError):

class AbstractDataset(abc.ABC, Generic[_DI, _DO]):
"""``AbstractDataset`` is the base class for all data set implementations.

All data set implementations should extend this abstract class
and implement the methods marked as abstract.
If a specific dataset implementation cannot be used in conjunction with
Expand All @@ -93,10 +95,10 @@ class AbstractDataset(abc.ABC, Generic[_DI, _DO]):
>>> self._param1 = param1
>>> self._param2 = param2
>>>
>>> def _load(self) -> pd.DataFrame:
>>> def load(self) -> pd.DataFrame:
>>> return pd.read_csv(self._filepath)
>>>
>>> def _save(self, df: pd.DataFrame) -> None:
>>> def save(self, df: pd.DataFrame) -> None:
>>> df.to_csv(str(self._filepath))
>>>
>>> def _exists(self) -> bool:
Expand Down Expand Up @@ -178,57 +180,6 @@ def from_config(
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)

def load(self) -> _DO:
"""Loads data by delegation to the provided load method.

Returns:
Data returned by the provided load method.

Raises:
DatasetError: When underlying load method raises error.

"""

self._logger.debug("Loading %s", str(self))

try:
return self._load()
except DatasetError:
raise
except Exception as exc:
# This exception handling is by design as the composed data sets
# can throw any type of exception.
message = (
f"Failed while loading data from data set {str(self)}.\n{str(exc)}"
)
raise DatasetError(message) from exc

def save(self, data: _DI) -> None:
"""Saves data by delegation to the provided save method.

Args:
data: the value to be saved by provided save method.

Raises:
DatasetError: when underlying save method raises error.
FileNotFoundError: when save method got file instead of dir, on Windows.
NotADirectoryError: when save method got file instead of dir, on Unix.
"""

if data is None:
raise DatasetError("Saving 'None' to a 'Dataset' is not allowed")

try:
self._logger.debug("Saving %s", str(self))
self._save(data)
except DatasetError:
raise
except (FileNotFoundError, NotADirectoryError):
raise
except Exception as exc:
message = f"Failed while saving data to data set {str(self)}.\n{str(exc)}"
raise DatasetError(message) from exc

def __str__(self) -> str:
# TODO: Replace with __repr__ implementation in 0.20.0 release.
def _to_str(obj: Any, is_root: bool = False) -> str:
Expand Down Expand Up @@ -257,6 +208,85 @@ def _to_str(obj: Any, is_root: bool = False) -> str:

return f"{type(self).__name__}({_to_str(self._describe(), True)})"

@classmethod
def _load_wrapper(cls, load_func: Callable[[Self], _DO]) -> Callable[[Self], _DO]:
"""Decorate `load_func` with logging and error handling code."""

@wraps(load_func)
def load(self: Self) -> _DO:
self._logger.debug("Loading %s", str(self))

try:
return load_func(self)
except DatasetError:
raise
except Exception as exc:
# This exception handling is by design as the composed data sets
# can throw any type of exception.
message = (
f"Failed while loading data from data set {str(self)}.\n{str(exc)}"
)
raise DatasetError(message) from exc

load.__annotations__["return"] = load_func.__annotations__.get("return")
load.__loadwrapped__ = True # type: ignore[attr-defined]
return load

@classmethod
def _save_wrapper(
cls, save_func: Callable[[Self, _DI], None]
) -> Callable[[Self, _DI], None]:
"""Decorate `save_func` with logging and error handling code."""

@wraps(save_func)
def save(self: Self, data: _DI) -> None:
if data is None:
raise DatasetError("Saving 'None' to a 'Dataset' is not allowed")

try:
self._logger.debug("Saving %s", str(self))
save_func(self, data)
except (DatasetError, FileNotFoundError, NotADirectoryError):
raise
except Exception as exc:
message = (
f"Failed while saving data to data set {str(self)}.\n{str(exc)}"
)
raise DatasetError(message) from exc

save.__annotations__["data"] = save_func.__annotations__.get("data", Any)
save.__annotations__["return"] = save_func.__annotations__.get("return")
save.__savewrapped__ = True # type: ignore[attr-defined]
return save

def __init_subclass__(cls, **kwargs: Any) -> None:
"""Decorate the `load` and `save` methods provided by the class.

If `_load` or `_save` are defined, alias them as a prerequisite.

"""
super().__init_subclass__(**kwargs)

if hasattr(cls, "_load") and not cls._load.__qualname__.startswith("Abstract"):
cls.load = cls._load # type: ignore[method-assign]

if hasattr(cls, "_save") and not cls._save.__qualname__.startswith("Abstract"):
cls.save = cls._save # type: ignore[method-assign]

if hasattr(cls, "load") and not cls.load.__qualname__.startswith("Abstract"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test on __qualname__.startswith exist in MlflowArtifactDataset but I think there is no test on MlflowAbstractDataset which could conflict, so hopefully we are fine 🤞

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is an issue, we can try to make it more robust (see #3920 (comment)); I think for now, if it works, it's a reasonable implementation until we see more cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good for now, let's go for it!

cls.load = cls._load_wrapper( # type: ignore[assignment]
cls.load
if not getattr(cls.load, "__loadwrapped__", False)
else cls.load.__wrapped__ # type: ignore[attr-defined]
)

if hasattr(cls, "save") and not cls.save.__qualname__.startswith("Abstract"):
cls.save = cls._save_wrapper( # type: ignore[assignment]
cls.save
if not getattr(cls.save, "__savewrapped__", False)
else cls.save.__wrapped__ # type: ignore[attr-defined]
)

def _pretty_repr(self, object_description: dict[str, Any]) -> str:
str_keys = []
for arg_name, arg_descr in object_description.items():
Expand All @@ -276,17 +306,37 @@ def __repr__(self) -> str:
return self._pretty_repr(self._describe())

@abc.abstractmethod
def _load(self) -> _DO:
def load(self) -> _DO:
"""Loads data by delegation to the provided load method.

Returns:
Data returned by the provided load method.

Raises:
DatasetError: When underlying load method raises error.

"""
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
f"it must implement the '_load' method"
f"it must implement the 'load' method"
)

@abc.abstractmethod
def _save(self, data: _DI) -> None:
def save(self, data: _DI) -> None:
"""Saves data by delegation to the provided save method.

Args:
data: the value to be saved by provided save method.

Raises:
DatasetError: when underlying save method raises error.
FileNotFoundError: when save method got file instead of dir, on Windows.
NotADirectoryError: when save method got file instead of dir, on Unix.

"""
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
f"it must implement the '_save' method"
f"it must implement the 'save' method"
)

@abc.abstractmethod
Expand Down Expand Up @@ -509,7 +559,9 @@ def _local_exists(local_filepath: str) -> bool: # SKIP_IF_NO_SPARK
class AbstractVersionedDataset(AbstractDataset[_DI, _DO], abc.ABC):
"""
``AbstractVersionedDataset`` is the base class for all versioned data set
implementations. All data sets that implement versioning should extend this
implementations.

All data sets that implement versioning should extend this
abstract class and implement the methods marked as abstract.

Example:
Expand All @@ -526,11 +578,11 @@ class AbstractVersionedDataset(AbstractDataset[_DI, _DO], abc.ABC):
>>> self._param1 = param1
>>> self._param2 = param2
>>>
>>> def _load(self) -> pd.DataFrame:
>>> def load(self) -> pd.DataFrame:
>>> load_path = self._get_load_path()
>>> return pd.read_csv(load_path)
>>>
>>> def _save(self, df: pd.DataFrame) -> None:
>>> def save(self, df: pd.DataFrame) -> None:
>>> save_path = self._get_save_path()
>>> df.to_csv(str(save_path))
>>>
Expand Down Expand Up @@ -652,34 +704,43 @@ def _get_save_path(self) -> PurePosixPath:
def _get_versioned_path(self, version: str) -> PurePosixPath:
return self._filepath / version / self._filepath.name

def load(self) -> _DO:
return super().load()

def save(self, data: _DI) -> None:
self._version_cache.clear()
save_version = self.resolve_save_version() # Make sure last save version is set
try:
super().save(data)
except (FileNotFoundError, NotADirectoryError) as err:
# FileNotFoundError raised in Win, NotADirectoryError raised in Unix
_default_version = "YYYY-MM-DDThh.mm.ss.sssZ"
raise DatasetError(
f"Cannot save versioned dataset '{self._filepath.name}' to "
f"'{self._filepath.parent.as_posix()}' because a file with the same "
f"name already exists in the directory. This is likely because "
f"versioning was enabled on a dataset already saved previously. Either "
f"remove '{self._filepath.name}' from the directory or manually "
f"convert it into a versioned dataset by placing it in a versioned "
f"directory (e.g. with default versioning format "
f"'{self._filepath.as_posix()}/{_default_version}/{self._filepath.name}"
f"')."
) from err
@classmethod
def _save_wrapper(
cls, save_func: Callable[[Self, _DI], None]
) -> Callable[[Self, _DI], None]:
"""Decorate `save_func` with logging and error handling code."""

@wraps(save_func)
def save(self: Self, data: _DI) -> None:
self._version_cache.clear()
save_version = (
self.resolve_save_version()
) # Make sure last save version is set
try:
super()._save_wrapper(save_func)(self, data)
except (FileNotFoundError, NotADirectoryError) as err:
# FileNotFoundError raised in Win, NotADirectoryError raised in Unix
_default_version = "YYYY-MM-DDThh.mm.ss.sssZ"
raise DatasetError(
f"Cannot save versioned dataset '{self._filepath.name}' to "
f"'{self._filepath.parent.as_posix()}' because a file with the same "
f"name already exists in the directory. This is likely because "
f"versioning was enabled on a dataset already saved previously. Either "
f"remove '{self._filepath.name}' from the directory or manually "
f"convert it into a versioned dataset by placing it in a versioned "
f"directory (e.g. with default versioning format "
f"'{self._filepath.as_posix()}/{_default_version}/{self._filepath.name}"
f"')."
) from err

load_version = self.resolve_load_version()
if load_version != save_version:
warnings.warn(
_CONSISTENCY_WARNING.format(save_version, load_version, str(self))
)
self._version_cache.clear()

load_version = self.resolve_load_version()
if load_version != save_version:
warnings.warn(
_CONSISTENCY_WARNING.format(save_version, load_version, str(self))
)
return save

def exists(self) -> bool:
"""Checks whether a data set's output already exists by calling
Expand Down
2 changes: 1 addition & 1 deletion kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def add_feed_dict(self, feed_dict: dict[str, Any], replace: bool = False) -> Non
if isinstance(feed_dict[dataset_name], AbstractDataset):
dataset = feed_dict[dataset_name]
else:
dataset = MemoryDataset(data=feed_dict[dataset_name])
dataset = MemoryDataset(data=feed_dict[dataset_name]) # type: ignore[abstract]

self.add(dataset_name, dataset, replace)

Expand Down
16 changes: 8 additions & 8 deletions kedro/io/lambda_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@ def _to_str(func: Any) -> str | None:

return descr

def _save(self, data: Any) -> None:
if not self.__save:
def load(self) -> Any:
if not self.__load:
raise DatasetError(
"Cannot save to data set. No 'save' function "
"Cannot load data set. No 'load' function "
"provided when LambdaDataset was created."
)
self.__save(data)
return self.__load()

def _load(self) -> Any:
if not self.__load:
def save(self, data: Any) -> None:
if not self.__save:
raise DatasetError(
"Cannot load data set. No 'load' function "
"Cannot save to data set. No 'save' function "
"provided when LambdaDataset was created."
)
return self.__load()
self.__save(data)

def _exists(self) -> bool:
if not self.__exists:
Expand Down
2 changes: 1 addition & 1 deletion kedro/io/memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.metadata = metadata
self._EPHEMERAL = True
if data is not _EMPTY:
self._save(data)
self.save.__wrapped__(self, data) # type: ignore[attr-defined]

def _load(self) -> Any:
if self._data is _EMPTY:
Expand Down
Loading