Skip to content

Commit

Permalink
Add saving/loading for transforms, models, pipelines (#1068)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman authored Jan 16, 2023
1 parent 56990dd commit c04bf64
Show file tree
Hide file tree
Showing 83 changed files with 2,863 additions and 909 deletions.
15 changes: 10 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `RMSE` metric & `rmse` functional metric ([#1051](https://github.com/tinkoff-ai/etna/pull/1051))
- `MaxDeviation` metric & `max_deviation` functional metric ([#1061](https://github.com/tinkoff-ai/etna/pull/1061))
- Add saving/loading for transforms, models, pipelines, ensembles; tutorial for saving/loading ([#1068](https://github.com/tinkoff-ai/etna/pull/1068))
-
-
-
-
-
-
### Changed
-
-
-
-
-
-
-
-
-
### Fixed
-
-
-
-
-

-
-
## [1.14.0] - 2022-12-16
### Added
- Add python 3.10 support ([#1005](https://github.com/tinkoff-ai/etna/pull/1005))
Expand Down
3 changes: 3 additions & 0 deletions etna/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from etna.core.mixins import BaseMixin
from etna.core.mixins import SaveMixin
from etna.core.mixins import StringEnumWithRepr
from etna.core.saving import AbstractSaveable
from etna.core.utils import load
111 changes: 111 additions & 0 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import inspect
import json
import pathlib
import pickle
import sys
import warnings
import zipfile
from enum import Enum
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Tuple
from typing import cast

from sklearn.base import BaseEstimator

Expand Down Expand Up @@ -89,3 +96,107 @@ class StringEnumWithRepr(str, Enum):
def __repr__(self):
"""Get string representation for enum string so that enum can be created from it."""
return self.value.__repr__()


def get_etna_version() -> Tuple[int, int, int]:
"""Get current version of etna library."""
python_version = sys.version_info
if python_version[0] == 3 and python_version[1] >= 8:
from importlib.metadata import version

str_version = version("etna")
result = tuple([int(x) for x in str_version.split(".")])
result = cast(Tuple[int, int, int], result)
return result
else:
import pkg_resources

str_version = pkg_resources.get_distribution("etna").version
result = tuple([int(x) for x in str_version.split(".")])
result = cast(Tuple[int, int, int], result)
return result


class SaveMixin:
"""Basic implementation of ``AbstractSaveable`` abstract class.
It saves object to the zip archive with 2 files:
* metadata.json: contains library version and class name.
* object.pkl: pickled object.
"""

def _save_metadata(self, archive: zipfile.ZipFile):
full_class_name = f"{inspect.getmodule(self).__name__}.{self.__class__.__name__}" # type: ignore
metadata = {
"etna_version": get_etna_version(),
"class": full_class_name,
}
metadata_str = json.dumps(metadata, indent=2, sort_keys=True)
metadata_bytes = metadata_str.encode("utf-8")
with archive.open("metadata.json", "w") as output_file:
output_file.write(metadata_bytes)

def _save_state(self, archive: zipfile.ZipFile):
with archive.open("object.pkl", "w") as output_file:
pickle.dump(self, output_file)

def save(self, path: pathlib.Path):
"""Save the object.
Parameters
----------
path:
Path to save object to.
"""
with zipfile.ZipFile(path, "w") as archive:
self._save_metadata(archive)
self._save_state(archive)

@classmethod
def _load_metadata(cls, archive: zipfile.ZipFile) -> Dict[str, Any]:
with archive.open("metadata.json", "r") as input_file:
metadata_bytes = input_file.read()
metadata_str = metadata_bytes.decode("utf-8")
metadata = json.loads(metadata_str)
return metadata

@classmethod
def _validate_metadata(cls, metadata: Dict[str, Any]):
current_etna_version = get_etna_version()
saved_etna_version = tuple(metadata["etna_version"])

# if major version is different give a warning
if current_etna_version[0] != saved_etna_version[0] or current_etna_version[:2] < saved_etna_version[:2]:
current_etna_version_str = ".".join([str(x) for x in current_etna_version])
saved_etna_version_str = ".".join([str(x) for x in saved_etna_version])
warnings.warn(
f"The object was saved under etna version {saved_etna_version_str} "
f"but running version is {current_etna_version_str}, this can cause problems with compatibility!"
)

@classmethod
def _load_state(cls, archive: zipfile.ZipFile) -> Any:
with archive.open("object.pkl", "r") as input_file:
return pickle.load(input_file)

@classmethod
def load(cls, path: pathlib.Path) -> Any:
"""Load an object.
Parameters
----------
path:
Path to load object from.
Returns
-------
:
Loaded object.
"""
with zipfile.ZipFile(path, "r") as archive:
metadata = cls._load_metadata(archive)
cls._validate_metadata(metadata)
obj = cls._load_state(archive)
return obj
31 changes: 31 additions & 0 deletions etna/core/saving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pathlib
from abc import ABC
from abc import abstractmethod
from typing import Any


class AbstractSaveable(ABC):
"""Abstract class with methods for saving, loading objects."""

@abstractmethod
def save(self, path: pathlib.Path):
"""Save the object.
Parameters
----------
path:
Path to save object to.
"""
pass

@classmethod
@abstractmethod
def load(cls, path: pathlib.Path) -> Any:
"""Load an object.
Parameters
----------
path:
Path to load object from.
"""
pass
36 changes: 36 additions & 0 deletions etna/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,44 @@
import inspect
import json
import pathlib
import zipfile
from copy import deepcopy
from functools import wraps
from typing import Any
from typing import Callable

from hydra_slayer import get_factory


def load(path: pathlib.Path, **kwargs: Any) -> Any:
"""Load saved object by path.
Parameters
----------
path:
Path to load object from.
kwargs:
Parameters for loading specific for the loaded object.
Returns
-------
:
Loaded object.
"""
with zipfile.ZipFile(path, "r") as archive:
# read object class
with archive.open("metadata.json", "r") as input_file:
metadata_bytes = input_file.read()
metadata_str = metadata_bytes.decode("utf-8")
metadata = json.loads(metadata_str)
object_class_name = metadata["class"]

# create object for that class
object_class = get_factory(object_class_name)
loaded_object = object_class.load(path=path, **kwargs)

return loaded_object


def init_collector(init: Callable) -> Callable:
"""
Expand Down
2 changes: 1 addition & 1 deletion etna/ensembles/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from etna.ensembles.base import EnsembleMixin
from etna.ensembles.direct_ensemble import DirectEnsemble
from etna.ensembles.mixins import EnsembleMixin
from etna.ensembles.stacking_ensemble import StackingEnsemble
from etna.ensembles.voting_ensemble import VotingEnsemble
55 changes: 0 additions & 55 deletions etna/ensembles/base.py

This file was deleted.

5 changes: 3 additions & 2 deletions etna/ensembles/direct_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from joblib import delayed

from etna.datasets import TSDataset
from etna.ensembles import EnsembleMixin
from etna.ensembles.mixins import EnsembleMixin
from etna.ensembles.mixins import SaveEnsembleMixin
from etna.pipeline.base import BasePipeline


class DirectEnsemble(BasePipeline, EnsembleMixin):
class DirectEnsemble(EnsembleMixin, SaveEnsembleMixin, BasePipeline):
"""DirectEnsemble is a pipeline that forecasts future values merging the forecasts of base pipelines.
Ensemble expects several pipelines during init. These pipelines are expected to have different forecasting horizons.
Expand Down
Loading

1 comment on commit c04bf64

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.