diff --git a/CHANGELOG.md b/CHANGELOG.md index 6026ef750..71d412a48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - - Add `refit` parameter into `backtest` ([#1159](https://github.com/tinkoff-ai/etna/pull/1159)) +- Add `stride` parameter into `backtest` ([#1165](https://github.com/tinkoff-ai/etna/pull/1165)) - Add optional parameter `ts` into `forecast` method of pipelines ([#1071](https://github.com/tinkoff-ai/etna/pull/1071)) - Add tests on `transform` method of transforms on subset of segments, on new segments, on future with gap ([#1094](https://github.com/tinkoff-ai/etna/pull/1094)) - Add tests on `inverse_transform` method of transforms on subset of segments, on new segments, on future with gap ([#1127](https://github.com/tinkoff-ai/etna/pull/1127)) @@ -21,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - - Add more scenarios into tests for models ([#1082](https://github.com/tinkoff-ai/etna/pull/1082)) -- +- - Decouple `SeasonalMovingAverageModel` from `PerSegmentModelMixin` ([#1132](https://github.com/tinkoff-ai/etna/pull/1132)) - Decouple `DeadlineMovingAverageModel` from `PerSegmentModelMixin` ([#1140](https://github.com/tinkoff-ai/etna/pull/1140)) ### Fixed diff --git a/etna/pipeline/base.py b/etna/pipeline/base.py index 5879fc4e8..a8492475f 100644 --- a/etna/pipeline/base.py +++ b/etna/pipeline/base.py @@ -17,6 +17,7 @@ from joblib import delayed from scipy.stats import norm from typing_extensions import TypedDict +from typing_extensions import assert_never from etna.core import AbstractSaveable from etna.core import BaseMixin @@ -35,6 +36,12 @@ class CrossValidationMode(Enum): expand = "expand" constant = "constant" + @classmethod + def _missing_(cls, value): + raise NotImplementedError( + f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} modes allowed" + ) + class FoldMask(BaseMixin): """Container to hold the description of the fold mask. @@ -209,10 +216,11 @@ def backtest( ts: TSDataset, metrics: List[Metric], n_folds: Union[int, List[FoldMask]] = 5, - mode: str = "expand", + mode: Optional[str] = None, aggregate_metrics: bool = False, n_jobs: int = 1, refit: Union[bool, int] = True, + stride: Optional[int] = None, joblib_params: Optional[Dict[str, Any]] = None, forecast_params: Optional[Dict[str, Any]] = None, ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: @@ -229,7 +237,8 @@ def backtest( n_folds: Number of folds or the list of fold masks mode: - One of 'expand', 'constant' -- train generation policy + Train generation policy: 'expand' or 'constant'. Works only if ``n_folds`` is integer. + By default, is set to 'expand'. aggregate_metrics: If True aggregate metrics above folds, return raw metrics otherwise n_jobs: @@ -243,6 +252,8 @@ def backtest( * If ``value: int``: pipeline is trained every ``value`` folds starting from the first. + stride: + Number of points between folds. Works only if ``n_folds`` is integer. By default, is set to ``horizon``. joblib_params: Additional parameters for :py:class:`joblib.Parallel` forecast_params: @@ -453,43 +464,65 @@ def _init_backtest(self): @staticmethod def _validate_backtest_n_folds(n_folds: int): - """Check that given n_folds value is valid.""" + """Check that given n_folds value is >= 1.""" if n_folds < 1: raise ValueError(f"Folds number should be a positive number, {n_folds} given") @staticmethod - def _validate_backtest_dataset(ts: TSDataset, n_folds: int, horizon: int): + def _validate_backtest_mode(n_folds: Union[int, List[FoldMask]], mode: Optional[str]) -> CrossValidationMode: + if mode is None: + return CrossValidationMode.expand + + if not isinstance(n_folds, int): + raise ValueError("Mode shouldn't be set if n_folds are fold masks!") + + return CrossValidationMode(mode.lower()) + + @staticmethod + def _validate_backtest_stride(n_folds: Union[int, List[FoldMask]], horizon: int, stride: Optional[int]) -> int: + if stride is None: + return horizon + + if not isinstance(n_folds, int): + raise ValueError("Stride shouldn't be set if n_folds are fold masks!") + + if stride < 1: + raise ValueError(f"Stride should be a positive number, {stride} given!") + + return stride + + @staticmethod + def _validate_backtest_dataset(ts: TSDataset, n_folds: int, horizon: int, stride: int): """Check all segments have enough timestamps to validate forecaster with given number of splits.""" - min_required_length = horizon * n_folds + min_required_length = horizon + (n_folds - 1) * stride segments = set(ts.df.columns.get_level_values("segment")) for segment in segments: segment_target = ts[:, segment, "target"] if len(segment_target) < min_required_length: raise ValueError( f"All the series from feature dataframe should contain at least " - f"{horizon} * {n_folds} = {min_required_length} timestamps; " + f"{horizon} + {n_folds-1} * {stride} = {min_required_length} timestamps; " f"series {segment} does not." ) @staticmethod - def _generate_masks_from_n_folds(ts: TSDataset, n_folds: int, horizon: int, mode: str) -> List[FoldMask]: + def _generate_masks_from_n_folds( + ts: TSDataset, n_folds: int, horizon: int, mode: CrossValidationMode, stride: int + ) -> List[FoldMask]: """Generate fold masks from n_folds.""" - mode_enum = CrossValidationMode(mode.lower()) - if mode_enum == CrossValidationMode.expand: + if mode is CrossValidationMode.expand: constant_history_length = 0 - elif mode_enum == CrossValidationMode.constant: + elif mode is CrossValidationMode.constant: constant_history_length = 1 else: - raise NotImplementedError( - f"Only '{CrossValidationMode.expand}' and '{CrossValidationMode.constant}' modes allowed" - ) + assert_never(mode) masks = [] dataset_timestamps = list(ts.index) min_timestamp_idx, max_timestamp_idx = 0, len(dataset_timestamps) for offset in range(n_folds, 0, -1): - min_train_idx = min_timestamp_idx + (n_folds - offset) * horizon * constant_history_length - max_train_idx = max_timestamp_idx - horizon * offset - 1 + min_train_idx = min_timestamp_idx + (n_folds - offset) * stride * constant_history_length + max_train_idx = max_timestamp_idx - stride * (offset - 1) - horizon - 1 min_test_idx = max_train_idx + 1 max_test_idx = max_train_idx + horizon @@ -625,7 +658,7 @@ def _get_fold_info(self) -> pd.DataFrame: tmp_df[f"{stage_name}_{border}_time"] = [fold_info[f"{stage_name}_timerange"][border]] tmp_df[self._fold_column] = fold_number timerange_dfs.append(tmp_df) - timerange_df = pd.concat(timerange_dfs) + timerange_df = pd.concat(timerange_dfs, ignore_index=True) return timerange_df def _get_backtest_forecasts(self) -> pd.DataFrame: @@ -648,12 +681,16 @@ def _get_backtest_forecasts(self) -> pd.DataFrame: forecasts.sort_index(axis=1, inplace=True) return forecasts - def _prepare_fold_masks(self, ts: TSDataset, masks: Union[int, List[FoldMask]], mode: str) -> List[FoldMask]: + def _prepare_fold_masks( + self, ts: TSDataset, masks: Union[int, List[FoldMask]], mode: CrossValidationMode, stride: int + ) -> List[FoldMask]: """Prepare and validate fold masks.""" if isinstance(masks, int): self._validate_backtest_n_folds(n_folds=masks) - self._validate_backtest_dataset(ts=ts, n_folds=masks, horizon=self.horizon) - masks = self._generate_masks_from_n_folds(ts=ts, n_folds=masks, horizon=self.horizon, mode=mode) + self._validate_backtest_dataset(ts=ts, n_folds=masks, horizon=self.horizon, stride=stride) + masks = self._generate_masks_from_n_folds( + ts=ts, n_folds=masks, horizon=self.horizon, mode=mode, stride=stride + ) for i, mask in enumerate(masks): mask.first_train_timestamp = mask.first_train_timestamp if mask.first_train_timestamp else ts.index[0] masks[i] = mask @@ -768,10 +805,11 @@ def backtest( ts: TSDataset, metrics: List[Metric], n_folds: Union[int, List[FoldMask]] = 5, - mode: str = "expand", + mode: Optional[str] = None, aggregate_metrics: bool = False, n_jobs: int = 1, refit: Union[bool, int] = True, + stride: Optional[int] = None, joblib_params: Optional[Dict[str, Any]] = None, forecast_params: Optional[Dict[str, Any]] = None, ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: @@ -788,7 +826,8 @@ def backtest( n_folds: Number of folds or the list of fold masks mode: - One of 'expand', 'constant' -- train generation policy, ignored if n_folds is a list of masks + Train generation policy: 'expand' or 'constant'. Works only if ``n_folds`` is integer. + By default, is set to 'expand'. aggregate_metrics: If True aggregate metrics above folds, return raw metrics otherwise n_jobs: @@ -802,6 +841,8 @@ def backtest( * If ``value: int``: pipeline is trained every ``value`` folds starting from the first. + stride: + Number of points between folds. Works only if ``n_folds`` is integer. By default, is set to ``horizon``. joblib_params: Additional parameters for :py:class:`joblib.Parallel` forecast_params: @@ -811,7 +852,17 @@ def backtest( ------- metrics_df, forecast_df, fold_info_df: Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame] Metrics dataframe, forecast dataframe and dataframe with information about folds + + Raises + ------ + ValueError: + If ``mode`` is set when ``n_folds`` are ``List[FoldMask]``. + ValueError: + If ``stride`` is set when ``n_folds`` are ``List[FoldMask]``. """ + mode_enum = self._validate_backtest_mode(n_folds=n_folds, mode=mode) + stride = self._validate_backtest_stride(n_folds=n_folds, horizon=self.horizon, stride=stride) + if joblib_params is None: joblib_params = dict(verbose=11, backend="multiprocessing", mmap_mode="c") @@ -820,7 +871,7 @@ def backtest( self._init_backtest() self._validate_backtest_metrics(metrics=metrics) - masks = self._prepare_fold_masks(ts=ts, masks=n_folds, mode=mode) + masks = self._prepare_fold_masks(ts=ts, masks=n_folds, mode=mode_enum, stride=stride) self._folds = self._run_all_folds( masks=masks, ts=ts, diff --git a/etna/transforms/timestamp/special_days.py b/etna/transforms/timestamp/special_days.py index 265257b96..8db6c8469 100644 --- a/etna/transforms/timestamp/special_days.py +++ b/etna/transforms/timestamp/special_days.py @@ -1,4 +1,6 @@ import datetime +from typing import Any +from typing import Dict from typing import Optional from typing import Tuple @@ -55,6 +57,7 @@ def __init__(self, find_special_weekday: bool = True, find_special_month_day: bo self.anomaly_week_days: Optional[Tuple[int]] = None self.anomaly_month_days: Optional[Tuple[int]] = None + self.res_type: Dict[str, Any] if self.find_special_weekday and find_special_month_day: self.res_type = {"df_sample": (0, 0), "columns": ["anomaly_weekdays", "anomaly_monthdays"]} elif self.find_special_weekday: diff --git a/poetry.lock b/poetry.lock index 5d48c599a..736ea5126 100644 --- a/poetry.lock +++ b/poetry.lock @@ -718,7 +718,7 @@ python-versions = ">=3.7" [[package]] name = "fsspec" -version = "2022.11.0" +version = "2023.1.0" description = "File-system specification" category = "main" optional = true @@ -1409,29 +1409,30 @@ python-versions = ">=3.7" [[package]] name = "mypy" -version = "0.910" +version = "0.950" description = "Optional static typing for Python" category = "main" optional = true -python-versions = ">=3.5" +python-versions = ">=3.6" [package.dependencies] -mypy-extensions = ">=0.4.3,<0.5.0" -toml = "*" -typed-ast = {version = ">=1.4.0,<1.5.0", markers = "python_version < \"3.8\""} -typing-extensions = ">=3.7.4" +mypy-extensions = ">=0.4.3" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typed-ast = {version = ">=1.4.0,<2", markers = "python_version < \"3.8\""} +typing-extensions = ">=3.10" [package.extras] dmypy = ["psutil (>=4.0)"] -python2 = ["typed-ast (>=1.4.0,<1.5.0)"] +python2 = ["typed-ast (>=1.4.0,<2)"] +reports = ["lxml"] [[package]] name = "mypy-extensions" -version = "0.4.3" -description = "Experimental type system extensions for programs checked with the mypy typechecker." +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." category = "main" optional = true -python-versions = "*" +python-versions = ">=3.5" [[package]] name = "myst-parser" @@ -3303,7 +3304,7 @@ wandb = ["wandb"] [metadata] lock-version = "1.1" python-versions = ">=3.7.1, <3.11.0" -content-hash = "78cfffbb71287b0db8af81005d304cbb6e63ed0b725970863bc5800410b70829" +content-hash = "463b30bacf7ec13ce5181b21a964f0abac9112aed2f5769192c78bcd90c9eec3" [metadata.files] absl-py = [ @@ -3838,8 +3839,8 @@ frozenlist = [ {file = "frozenlist-1.3.3.tar.gz", hash = "sha256:58bcc55721e8a90b88332d6cd441261ebb22342e238296bb330968952fbb3a6a"}, ] fsspec = [ - {file = "fsspec-2022.11.0-py3-none-any.whl", hash = "sha256:d6e462003e3dcdcb8c7aa84c73a228f8227e72453cd22570e2363e8844edfe7b"}, - {file = "fsspec-2022.11.0.tar.gz", hash = "sha256:259d5fd5c8e756ff2ea72f42e7613c32667dc2049a4ac3d84364a7ca034acb8b"}, + {file = "fsspec-2023.1.0-py3-none-any.whl", hash = "sha256:b833e2e541e9e8cde0ab549414187871243177feb3d344f9d27b25a93f5d8139"}, + {file = "fsspec-2023.1.0.tar.gz", hash = "sha256:fbae7f20ff801eb5f7d0bedf81f25c787c0dfac5e982d98fa3884a9cde2b5411"}, ] gitdb = [ {file = "gitdb-4.0.9-py3-none-any.whl", hash = "sha256:8033ad4e853066ba6ca92050b9df2f89301b8fc8bf7e9324d412a63f8bf1a8fd"}, @@ -4199,33 +4200,33 @@ mistune = [ ] multidict = [] mypy = [ - {file = "mypy-0.910-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:a155d80ea6cee511a3694b108c4494a39f42de11ee4e61e72bc424c490e46457"}, - {file = "mypy-0.910-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:b94e4b785e304a04ea0828759172a15add27088520dc7e49ceade7834275bedb"}, - {file = "mypy-0.910-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:088cd9c7904b4ad80bec811053272986611b84221835e079be5bcad029e79dd9"}, - {file = "mypy-0.910-cp35-cp35m-win_amd64.whl", hash = "sha256:adaeee09bfde366d2c13fe6093a7df5df83c9a2ba98638c7d76b010694db760e"}, - {file = "mypy-0.910-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:ecd2c3fe726758037234c93df7e98deb257fd15c24c9180dacf1ef829da5f921"}, - {file = "mypy-0.910-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:d9dd839eb0dc1bbe866a288ba3c1afc33a202015d2ad83b31e875b5905a079b6"}, - {file = "mypy-0.910-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:3e382b29f8e0ccf19a2df2b29a167591245df90c0b5a2542249873b5c1d78212"}, - {file = "mypy-0.910-cp36-cp36m-win_amd64.whl", hash = "sha256:53fd2eb27a8ee2892614370896956af2ff61254c275aaee4c230ae771cadd885"}, - {file = "mypy-0.910-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b6fb13123aeef4a3abbcfd7e71773ff3ff1526a7d3dc538f3929a49b42be03f0"}, - {file = "mypy-0.910-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e4dab234478e3bd3ce83bac4193b2ecd9cf94e720ddd95ce69840273bf44f6de"}, - {file = "mypy-0.910-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:7df1ead20c81371ccd6091fa3e2878559b5c4d4caadaf1a484cf88d93ca06703"}, - {file = "mypy-0.910-cp37-cp37m-win_amd64.whl", hash = "sha256:0aadfb2d3935988ec3815952e44058a3100499f5be5b28c34ac9d79f002a4a9a"}, - {file = "mypy-0.910-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec4e0cd079db280b6bdabdc807047ff3e199f334050db5cbb91ba3e959a67504"}, - {file = "mypy-0.910-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:119bed3832d961f3a880787bf621634ba042cb8dc850a7429f643508eeac97b9"}, - {file = "mypy-0.910-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:866c41f28cee548475f146aa4d39a51cf3b6a84246969f3759cb3e9c742fc072"}, - {file = "mypy-0.910-cp38-cp38-win_amd64.whl", hash = "sha256:ceb6e0a6e27fb364fb3853389607cf7eb3a126ad335790fa1e14ed02fba50811"}, - {file = "mypy-0.910-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1a85e280d4d217150ce8cb1a6dddffd14e753a4e0c3cf90baabb32cefa41b59e"}, - {file = "mypy-0.910-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:42c266ced41b65ed40a282c575705325fa7991af370036d3f134518336636f5b"}, - {file = "mypy-0.910-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3c4b8ca36877fc75339253721f69603a9c7fdb5d4d5a95a1a1b899d8b86a4de2"}, - {file = "mypy-0.910-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:c0df2d30ed496a08de5daed2a9ea807d07c21ae0ab23acf541ab88c24b26ab97"}, - {file = "mypy-0.910-cp39-cp39-win_amd64.whl", hash = "sha256:c6c2602dffb74867498f86e6129fd52a2770c48b7cd3ece77ada4fa38f94eba8"}, - {file = "mypy-0.910-py3-none-any.whl", hash = "sha256:ef565033fa5a958e62796867b1df10c40263ea9ded87164d67572834e57a174d"}, - {file = "mypy-0.910.tar.gz", hash = "sha256:704098302473cb31a218f1775a873b376b30b4c18229421e9e9dc8916fd16150"}, + {file = "mypy-0.950-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cf9c261958a769a3bd38c3e133801ebcd284ffb734ea12d01457cb09eacf7d7b"}, + {file = "mypy-0.950-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5b5bd0ffb11b4aba2bb6d31b8643902c48f990cc92fda4e21afac658044f0c0"}, + {file = "mypy-0.950-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e7647df0f8fc947388e6251d728189cfadb3b1e558407f93254e35abc026e22"}, + {file = "mypy-0.950-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eaff8156016487c1af5ffa5304c3e3fd183edcb412f3e9c72db349faf3f6e0eb"}, + {file = "mypy-0.950-cp310-cp310-win_amd64.whl", hash = "sha256:563514c7dc504698fb66bb1cf897657a173a496406f1866afae73ab5b3cdb334"}, + {file = "mypy-0.950-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dd4d670eee9610bf61c25c940e9ade2d0ed05eb44227275cce88701fee014b1f"}, + {file = "mypy-0.950-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ca75ecf2783395ca3016a5e455cb322ba26b6d33b4b413fcdedfc632e67941dc"}, + {file = "mypy-0.950-cp36-cp36m-win_amd64.whl", hash = "sha256:6003de687c13196e8a1243a5e4bcce617d79b88f83ee6625437e335d89dfebe2"}, + {file = "mypy-0.950-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4c653e4846f287051599ed8f4b3c044b80e540e88feec76b11044ddc5612ffed"}, + {file = "mypy-0.950-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e19736af56947addedce4674c0971e5dceef1b5ec7d667fe86bcd2b07f8f9075"}, + {file = "mypy-0.950-cp37-cp37m-win_amd64.whl", hash = "sha256:ef7beb2a3582eb7a9f37beaf38a28acfd801988cde688760aea9e6cc4832b10b"}, + {file = "mypy-0.950-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0112752a6ff07230f9ec2f71b0d3d4e088a910fdce454fdb6553e83ed0eced7d"}, + {file = "mypy-0.950-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ee0a36edd332ed2c5208565ae6e3a7afc0eabb53f5327e281f2ef03a6bc7687a"}, + {file = "mypy-0.950-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77423570c04aca807508a492037abbd72b12a1fb25a385847d191cd50b2c9605"}, + {file = "mypy-0.950-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ce6a09042b6da16d773d2110e44f169683d8cc8687e79ec6d1181a72cb028d2"}, + {file = "mypy-0.950-cp38-cp38-win_amd64.whl", hash = "sha256:5b231afd6a6e951381b9ef09a1223b1feabe13625388db48a8690f8daa9b71ff"}, + {file = "mypy-0.950-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0384d9f3af49837baa92f559d3fa673e6d2652a16550a9ee07fc08c736f5e6f8"}, + {file = "mypy-0.950-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1fdeb0a0f64f2a874a4c1f5271f06e40e1e9779bf55f9567f149466fc7a55038"}, + {file = "mypy-0.950-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:61504b9a5ae166ba5ecfed9e93357fd51aa693d3d434b582a925338a2ff57fd2"}, + {file = "mypy-0.950-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a952b8bc0ae278fc6316e6384f67bb9a396eb30aced6ad034d3a76120ebcc519"}, + {file = "mypy-0.950-cp39-cp39-win_amd64.whl", hash = "sha256:eaea21d150fb26d7b4856766e7addcf929119dd19fc832b22e71d942835201ef"}, + {file = "mypy-0.950-py3-none-any.whl", hash = "sha256:a4d9898f46446bfb6405383b57b96737dcfd0a7f25b748e78ef3e8c576bba3cb"}, + {file = "mypy-0.950.tar.gz", hash = "sha256:1b333cfbca1762ff15808a0ef4f71b5d3eed8528b23ea1c3fb50543c867d68de"}, ] mypy-extensions = [ - {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, - {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] myst-parser = [ {file = "myst-parser-0.15.2.tar.gz", hash = "sha256:f7f3b2d62db7655cde658eb5d62b2ec2a4631308137bd8d10f296a40d57bbbeb"}, diff --git a/pyproject.toml b/pyproject.toml index 818a77fe4..862b18ace 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ pep8-naming = {version = "^0.12.1", optional = true} flake8-bugbear = {version = "^22.4.25", optional = true} flake8-comprehensions = {version = "^3.9.0", optional = true} flake8-docstrings = {version = "^1.6.0", optional = true} -mypy = {version = "^0.910", optional = true} +mypy = {version = "^0.950", optional = true} types-PyYAML = {version = "^6.0.0", optional = true} codespell = {version = "^2.0.0", optional = true} @@ -266,6 +266,11 @@ markers = [ "long_2" ] +[tool.coverage.report] +exclude_lines = [ + '^ +assert_never\(.*?\)$', +] + [tool.mypy] ignore_missing_imports = true strict_optional = true diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 9f593192f..ee87b200c 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -1,5 +1,4 @@ from copy import deepcopy -from datetime import datetime from typing import Dict from typing import List from unittest.mock import MagicMock @@ -14,7 +13,6 @@ from etna.metrics import MAE from etna.metrics import MSE from etna.metrics import SMAPE -from etna.metrics import Metric from etna.metrics import MetricAggregationMode from etna.metrics import Width from etna.models import CatBoostMultiSegmentModel @@ -30,6 +28,7 @@ from etna.models.base import PredictionIntervalContextRequiredAbstractModel from etna.pipeline import FoldMask from etna.pipeline import Pipeline +from etna.pipeline.base import CrossValidationMode from etna.transforms import AddConstTransform from etna.transforms import DateFlagsTransform from etna.transforms import DifferencingTransform @@ -281,106 +280,86 @@ def test_forecast_prediction_interval_noise(constant_ts, constant_noisy_ts): @pytest.mark.parametrize("n_folds", (0, -1)) def test_invalid_n_folds(catboost_pipeline: Pipeline, n_folds: int, example_tsdf: TSDataset): """Test Pipeline.backtest behavior in case of invalid n_folds.""" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Folds number should be a positive number"): _ = catboost_pipeline.backtest(ts=example_tsdf, metrics=DEFAULT_METRICS, n_folds=n_folds) -def test_validate_backtest_dataset(catboost_pipeline_big: Pipeline, imbalanced_tsdf: TSDataset): - """Test Pipeline.backtest behavior in case of small dataframe that - can't be divided to required number of splits. - """ - with pytest.raises(ValueError): - _ = catboost_pipeline_big.backtest(ts=imbalanced_tsdf, n_folds=3, metrics=DEFAULT_METRICS) - - -@pytest.mark.parametrize("metrics", ([], [MAE(mode=MetricAggregationMode.macro)])) -def test_invalid_backtest_metrics(catboost_pipeline: Pipeline, metrics: List[Metric], example_tsdf: TSDataset): - """Test Pipeline.backtest behavior in case of invalid metrics.""" - with pytest.raises(ValueError): - _ = catboost_pipeline.backtest(ts=example_tsdf, metrics=metrics, n_folds=2) +@pytest.mark.parametrize( + "min_size, n_folds, horizon, stride", + [ + (1, 10, 1, 1), + (9, 10, 1, 1), + (10, 10, 2, 1), + (19, 10, 2, 2), + (28, 10, 2, 3), + ], +) +def test_invalid_backtest_dataset_size(min_size, n_folds, horizon, stride): + """Test Pipeline.backtest behavior in case of too small dataframe for given number of folds.""" + df = generate_ar_df(start_time="2020-01-01", periods=100, n_segments=2, freq="D") + df_wide = TSDataset.to_dataset(df) + to_remove = len(df_wide) - min_size + df_wide.iloc[:to_remove, 0] = np.NaN + ts = TSDataset(df=df_wide, freq="D") + pipeline = Pipeline(model=NaiveModel(lag=horizon), horizon=horizon) + + with pytest.raises(ValueError, match="All the series from feature dataframe should contain at least .* timestamps"): + _ = pipeline.backtest(ts=ts, n_folds=n_folds, stride=stride, metrics=DEFAULT_METRICS) + + +def test_invalid_backtest_metrics_empty(catboost_pipeline: Pipeline, example_tsdf: TSDataset): + """Test Pipeline.backtest behavior in case of empty metrics.""" + with pytest.raises(ValueError, match="At least one metric required"): + _ = catboost_pipeline.backtest(ts=example_tsdf, metrics=[], n_folds=2) + + +def test_invalid_backtest_metrics_macro(catboost_pipeline: Pipeline, example_tsdf: TSDataset): + """Test Pipeline.backtest behavior in case of macro metrics.""" + with pytest.raises(ValueError, match="All the metrics should be in"): + _ = catboost_pipeline.backtest(ts=example_tsdf, metrics=[MAE(mode=MetricAggregationMode.macro)], n_folds=2) + + +def test_invalid_backtest_mode_set_on_fold_mask(catboost_pipeline: Pipeline, example_tsdf: TSDataset): + """Test Pipeline.backtest behavior on setting mode with fold masks.""" + masks = [ + FoldMask( + first_train_timestamp="2020-01-01", + last_train_timestamp="2020-04-03", + target_timestamps=["2020-04-04", "2020-04-05", "2020-04-06"], + ), + FoldMask( + first_train_timestamp="2020-01-01", + last_train_timestamp="2020-04-06", + target_timestamps=["2020-04-07", "2020-04-08", "2020-04-09"], + ), + ] + with pytest.raises(ValueError, match="Mode shouldn't be set if n_folds are fold masks"): + _ = catboost_pipeline.backtest(ts=example_tsdf, n_folds=masks, mode="expand", metrics=DEFAULT_METRICS) -def test_generate_expandable_timeranges_days(): - """Test train-test timeranges generation in expand mode with daily freq""" - df = pd.DataFrame({"timestamp": pd.date_range("2021-01-01", "2021-04-01")}) - df["segment"] = "seg" - df["target"] = 1 - df = df.pivot(index="timestamp", columns="segment").reorder_levels([1, 0], axis=1).sort_index(axis=1) - df.columns.names = ["segment", "feature"] - ts = TSDataset(df, freq="D") +def test_invalid_backtest_stride_set_on_fold_mask(catboost_pipeline: Pipeline, example_tsdf: TSDataset): + """Test Pipeline.backtest behavior on setting stride with fold masks.""" + masks = [ + FoldMask( + first_train_timestamp="2020-01-01", + last_train_timestamp="2020-04-03", + target_timestamps=["2020-04-04", "2020-04-05", "2020-04-06"], + ), + FoldMask( + first_train_timestamp="2020-01-01", + last_train_timestamp="2020-04-06", + target_timestamps=["2020-04-07", "2020-04-08", "2020-04-09"], + ), + ] + with pytest.raises(ValueError, match="Stride shouldn't be set if n_folds are fold masks"): + _ = catboost_pipeline.backtest(ts=example_tsdf, n_folds=masks, stride=2, metrics=DEFAULT_METRICS) - true_borders = ( - (("2021-01-01", "2021-02-24"), ("2021-02-25", "2021-03-08")), - (("2021-01-01", "2021-03-08"), ("2021-03-09", "2021-03-20")), - (("2021-01-01", "2021-03-20"), ("2021-03-21", "2021-04-01")), - ) - masks = Pipeline._generate_masks_from_n_folds(ts=ts, n_folds=3, horizon=12, mode="expand") - for i, stage_dfs in enumerate(Pipeline._generate_folds_datasets(ts, masks=masks, horizon=12)): - for stage_df, borders in zip(stage_dfs, true_borders[i]): - assert stage_df.index.min() == datetime.strptime(borders[0], "%Y-%m-%d").date() - assert stage_df.index.max() == datetime.strptime(borders[1], "%Y-%m-%d").date() - - -def test_generate_expandable_timeranges_hours(): - """Test train-test timeranges generation in expand mode with hour freq""" - df = pd.DataFrame({"timestamp": pd.date_range("2020-01-01", "2020-02-01", freq="H")}) - df["segment"] = "seg" - df["target"] = 1 - df = df.pivot(index="timestamp", columns="segment").reorder_levels([1, 0], axis=1).sort_index(axis=1) - df.columns.names = ["segment", "feature"] - ts = TSDataset(df, freq="H") - - true_borders = ( - (("2020-01-01 00:00:00", "2020-01-30 12:00:00"), ("2020-01-30 13:00:00", "2020-01-31 00:00:00")), - (("2020-01-01 00:00:00", "2020-01-31 00:00:00"), ("2020-01-31 01:00:00", "2020-01-31 12:00:00")), - (("2020-01-01 00:00:00", "2020-01-31 12:00:00"), ("2020-01-31 13:00:00", "2020-02-01 00:00:00")), - ) - masks = Pipeline._generate_masks_from_n_folds(ts=ts, n_folds=3, horizon=12, mode="expand") - for i, stage_dfs in enumerate(Pipeline._generate_folds_datasets(ts, horizon=12, masks=masks)): - for stage_df, borders in zip(stage_dfs, true_borders[i]): - assert stage_df.index.min() == datetime.strptime(borders[0], "%Y-%m-%d %H:%M:%S").date() - assert stage_df.index.max() == datetime.strptime(borders[1], "%Y-%m-%d %H:%M:%S").date() - - -def test_generate_constant_timeranges_days(): - """Test train-test timeranges generation with constant mode with daily freq""" - df = pd.DataFrame({"timestamp": pd.date_range("2021-01-01", "2021-04-01")}) - df["segment"] = "seg" - df["target"] = 1 - df = df.pivot(index="timestamp", columns="segment").reorder_levels([1, 0], axis=1).sort_index(axis=1) - df.columns.names = ["segment", "feature"] - ts = TSDataset(df, freq="D") - true_borders = ( - (("2021-01-01", "2021-02-24"), ("2021-02-25", "2021-03-08")), - (("2021-01-13", "2021-03-08"), ("2021-03-09", "2021-03-20")), - (("2021-01-25", "2021-03-20"), ("2021-03-21", "2021-04-01")), - ) - masks = Pipeline._generate_masks_from_n_folds(ts=ts, n_folds=3, horizon=12, mode="constant") - for i, stage_dfs in enumerate(Pipeline._generate_folds_datasets(ts, horizon=12, masks=masks)): - for stage_df, borders in zip(stage_dfs, true_borders[i]): - assert stage_df.index.min() == datetime.strptime(borders[0], "%Y-%m-%d").date() - assert stage_df.index.max() == datetime.strptime(borders[1], "%Y-%m-%d").date() - - -def test_generate_constant_timeranges_hours(): - """Test train-test timeranges generation with constant mode with hours freq""" - df = pd.DataFrame({"timestamp": pd.date_range("2020-01-01", "2020-02-01", freq="H")}) - df["segment"] = "seg" - df["target"] = 1 - df = df.pivot(index="timestamp", columns="segment").reorder_levels([1, 0], axis=1).sort_index(axis=1) - df.columns.names = ["segment", "feature"] - ts = TSDataset(df, freq="H") - true_borders = ( - (("2020-01-01 00:00:00", "2020-01-30 12:00:00"), ("2020-01-30 13:00:00", "2020-01-31 00:00:00")), - (("2020-01-01 12:00:00", "2020-01-31 00:00:00"), ("2020-01-31 01:00:00", "2020-01-31 12:00:00")), - (("2020-01-02 00:00:00", "2020-01-31 12:00:00"), ("2020-01-31 13:00:00", "2020-02-01 00:00:00")), - ) - masks = Pipeline._generate_masks_from_n_folds(ts=ts, n_folds=3, horizon=12, mode="constant") - for i, stage_dfs in enumerate(Pipeline._generate_folds_datasets(ts, horizon=12, masks=masks)): - for stage_df, borders in zip(stage_dfs, true_borders[i]): - assert stage_df.index.min() == datetime.strptime(borders[0], "%Y-%m-%d %H:%M:%S").date() - assert stage_df.index.max() == datetime.strptime(borders[1], "%Y-%m-%d %H:%M:%S").date() +@pytest.mark.parametrize("stride", [-1, 0]) +def test_invalid_backtest_stride_not_positive(stride, catboost_pipeline: Pipeline, example_tsdf: TSDataset): + """Test Pipeline.backtest behavior on setting not positive stride.""" + with pytest.raises(ValueError, match="Stride should be a positive number, .* given"): + _ = catboost_pipeline.backtest(ts=example_tsdf, n_folds=3, stride=stride, metrics=DEFAULT_METRICS) @pytest.mark.parametrize( @@ -396,7 +375,7 @@ def test_generate_constant_timeranges_hours(): ), ), ) -def test_get_metrics_interface( +def test_backtest_metrics_interface( catboost_pipeline: Pipeline, aggregate_metrics: bool, expected_columns: List[str], big_daily_example_tsdf: TSDataset ): """Check that Pipeline.backtest returns metrics in correct format.""" @@ -408,84 +387,235 @@ def test_get_metrics_interface( assert sorted(expected_columns) == sorted(metrics_df.columns) -def test_get_forecasts_interface_daily(catboost_pipeline: Pipeline, big_daily_example_tsdf: TSDataset): +@pytest.mark.parametrize( + "ts_fixture", + [ + "big_daily_example_tsdf", + "example_tsdf", + ], +) +def test_backtest_forecasts_columns(ts_fixture, catboost_pipeline, request): """Check that Pipeline.backtest returns forecasts in correct format.""" - _, forecast, _ = catboost_pipeline.backtest(ts=big_daily_example_tsdf, metrics=DEFAULT_METRICS) + ts = request.getfixturevalue(ts_fixture) + _, forecast, _ = catboost_pipeline.backtest(ts=ts, metrics=DEFAULT_METRICS) expected_columns = sorted( ["regressor_lag_feature_10", "regressor_lag_feature_11", "regressor_lag_feature_12", "fold_number", "target"] ) assert expected_columns == sorted(set(forecast.columns.get_level_values("feature"))) -def test_get_forecasts_interface_hours(catboost_pipeline: Pipeline, example_tsdf: TSDataset): - """Check that Pipeline.backtest returns forecasts in correct format with non-daily seasonality.""" - _, forecast, _ = catboost_pipeline.backtest(ts=example_tsdf, metrics=DEFAULT_METRICS) - expected_columns = sorted( - ["regressor_lag_feature_10", "regressor_lag_feature_11", "regressor_lag_feature_12", "fold_number", "target"] - ) - assert expected_columns == sorted(set(forecast.columns.get_level_values("feature"))) +@pytest.mark.parametrize( + "n_folds, horizon, expected_timestamps", + [ + (2, 3, [-6, -5, -4, -3, -2, -1]), + (2, 5, [-10, -9, -8, -7, -6, -5, -4, -3, -2, -1]), + ( + [ + FoldMask( + first_train_timestamp=pd.Timestamp("2020-01-01"), + last_train_timestamp=pd.Timestamp("2020-01-31 14:00"), + target_timestamps=[pd.Timestamp("2020-01-31 17:00")], + ), + FoldMask( + first_train_timestamp=pd.Timestamp("2020-01-01"), + last_train_timestamp=pd.Timestamp("2020-01-31 19:00"), + target_timestamps=[pd.Timestamp("2020-01-31 22:00")], + ), + ], + 5, + [-8, -3], + ), + ], +) +def test_backtest_forecasts_timestamps(n_folds, horizon, expected_timestamps, example_tsdf): + """Check that Pipeline.backtest returns forecasts with expected timestamps.""" + pipeline = Pipeline(model=NaiveModel(lag=horizon), horizon=horizon) + _, forecast, _ = pipeline.backtest(ts=example_tsdf, metrics=DEFAULT_METRICS, n_folds=n_folds) + timestamp = example_tsdf.index + np.testing.assert_array_equal(forecast.index, timestamp[expected_timestamps]) -def test_get_fold_info_interface_daily(catboost_pipeline: Pipeline, big_daily_example_tsdf: TSDataset): - """Check that Pipeline.backtest returns info dataframe in correct format.""" - _, _, info_df = catboost_pipeline.backtest(ts=big_daily_example_tsdf, metrics=DEFAULT_METRICS) - expected_columns = ["fold_number", "test_end_time", "test_start_time", "train_end_time", "train_start_time"] - assert expected_columns == sorted(info_df.columns) +@pytest.mark.parametrize( + "n_folds, horizon, stride, expected_timestamps", + [ + (2, 3, 3, [-6, -5, -4, -3, -2, -1]), + (2, 3, 1, [-4, -3, -2, -3, -2, -1]), + (2, 3, 5, [-8, -7, -6, -3, -2, -1]), + ], +) +def test_backtest_forecasts_timestamps_with_stride(n_folds, horizon, stride, expected_timestamps, example_tsdf): + """Check that Pipeline.backtest with stride returns forecasts with expected timestamps.""" + pipeline = Pipeline(model=NaiveModel(lag=horizon), horizon=horizon) + _, forecast, _ = pipeline.backtest(ts=example_tsdf, metrics=DEFAULT_METRICS, n_folds=n_folds, stride=stride) + timestamp = example_tsdf.index -def test_get_fold_info_interface_hours(catboost_pipeline: Pipeline, example_tsdf: TSDataset): - """Check that Pipeline.backtest returns info dataframe in correct format with non-daily seasonality.""" - _, _, info_df = catboost_pipeline.backtest(ts=example_tsdf, metrics=DEFAULT_METRICS) - expected_columns = ["fold_number", "test_end_time", "test_start_time", "train_end_time", "train_start_time"] - assert expected_columns == sorted(info_df.columns) + np.testing.assert_array_equal(forecast.index, timestamp[expected_timestamps]) -def test_get_fold_info_refit_true(example_tsdf: TSDataset): - """Check that Pipeline.backtest returns info dataframe with correct train with regular refit.""" - n_folds = 5 +@pytest.mark.parametrize( + "ts_fixture, n_folds", + [ + ("big_daily_example_tsdf", 1), + ("big_daily_example_tsdf", 2), + ("example_tsdf", 1), + ("example_tsdf", 2), + ], +) +def test_backtest_fold_info_format(ts_fixture, n_folds, request): + """Check that Pipeline.backtest returns info dataframe in correct format.""" + ts = request.getfixturevalue(ts_fixture) pipeline = Pipeline(model=NaiveModel(lag=7), horizon=7) - _, _, info_df = pipeline.backtest(ts=example_tsdf, n_jobs=1, metrics=DEFAULT_METRICS, n_folds=n_folds, refit=True) - assert info_df["train_start_time"].nunique() == 1 - assert info_df["train_end_time"].nunique() == n_folds - assert info_df["test_start_time"].nunique() == n_folds - assert info_df["test_end_time"].nunique() == n_folds + _, _, info_df = pipeline.backtest(ts=ts, metrics=DEFAULT_METRICS, n_folds=n_folds) - -def test_get_fold_info_refit_false(example_tsdf: TSDataset): - """Check that Pipeline.backtest returns info dataframe with correct train with no refit.""" - n_folds = 5 - pipeline = Pipeline(model=NaiveModel(lag=7), horizon=7) - _, _, info_df = pipeline.backtest(ts=example_tsdf, n_jobs=1, metrics=DEFAULT_METRICS, n_folds=n_folds, refit=False) - assert info_df["train_start_time"].nunique() == 1 - assert info_df["train_end_time"].nunique() == 1 - assert info_df["test_start_time"].nunique() == n_folds - assert info_df["test_end_time"].nunique() == n_folds + expected_folds = pd.Series(np.arange(n_folds)) + pd.testing.assert_series_equal(info_df["fold_number"], expected_folds, check_names=False) + expected_columns = ["fold_number", "test_end_time", "test_start_time", "train_end_time", "train_start_time"] + assert expected_columns == sorted(info_df.columns) @pytest.mark.parametrize( - "n_folds, refit, expected_refits", + "mode, n_folds, refit, horizon, stride, expected_train_starts, expected_train_ends, expected_test_starts, expected_test_ends", [ - (1, 1, 1), - (1, 2, 1), - (3, 1, 3), - (3, 2, 2), - (3, 3, 1), - (3, 4, 1), - (4, 1, 4), - (4, 2, 2), - (4, 3, 2), - (4, 4, 1), - (4, 5, 1), + ("expand", 3, True, 7, None, [0, 0, 0], [-22, -15, -8], [-21, -14, -7], [-15, -8, -1]), + ("expand", 3, True, 7, 1, [0, 0, 0], [-10, -9, -8], [-9, -8, -7], [-3, -2, -1]), + ("expand", 3, True, 7, 10, [0, 0, 0], [-28, -18, -8], [-27, -17, -7], [-21, -11, -1]), + ("expand", 3, False, 7, None, [0, 0, 0], [-22, -22, -22], [-21, -14, -7], [-15, -8, -1]), + ("expand", 3, False, 7, 1, [0, 0, 0], [-10, -10, -10], [-9, -8, -7], [-3, -2, -1]), + ("expand", 3, False, 7, 10, [0, 0, 0], [-28, -28, -28], [-27, -17, -7], [-21, -11, -1]), + ("expand", 1, 1, 7, None, [0], [-8], [-7], [-1]), + ("expand", 1, 2, 7, None, [0], [-8], [-7], [-1]), + ("expand", 3, 1, 7, None, [0, 0, 0], [-22, -15, -8], [-21, -14, -7], [-15, -8, -1]), + ("expand", 3, 2, 7, None, [0, 0, 0], [-22, -22, -8], [-21, -14, -7], [-15, -8, -1]), + ("expand", 3, 3, 7, None, [0, 0, 0], [-22, -22, -22], [-21, -14, -7], [-15, -8, -1]), + ("expand", 3, 4, 7, None, [0, 0, 0], [-22, -22, -22], [-21, -14, -7], [-15, -8, -1]), + ("expand", 4, 1, 7, None, [0, 0, 0, 0], [-29, -22, -15, -8], [-28, -21, -14, -7], [-22, -15, -8, -1]), + ("expand", 4, 2, 7, None, [0, 0, 0, 0], [-29, -29, -15, -15], [-28, -21, -14, -7], [-22, -15, -8, -1]), + ("expand", 4, 2, 7, 1, [0, 0, 0, 0], [-11, -11, -9, -9], [-10, -9, -8, -7], [-4, -3, -2, -1]), + ("expand", 4, 2, 7, 10, [0, 0, 0, 0], [-38, -38, -18, -18], [-37, -27, -17, -7], [-31, -21, -11, -1]), + ("expand", 4, 3, 7, None, [0, 0, 0, 0], [-29, -29, -29, -8], [-28, -21, -14, -7], [-22, -15, -8, -1]), + ("expand", 4, 4, 7, None, [0, 0, 0, 0], [-29, -29, -29, -29], [-28, -21, -14, -7], [-22, -15, -8, -1]), + ("expand", 4, 5, 7, None, [0, 0, 0, 0], [-29, -29, -29, -29], [-28, -21, -14, -7], [-22, -15, -8, -1]), + ("constant", 3, True, 7, None, [0, 7, 14], [-22, -15, -8], [-21, -14, -7], [-15, -8, -1]), + ("constant", 3, True, 7, 1, [0, 1, 2], [-10, -9, -8], [-9, -8, -7], [-3, -2, -1]), + ("constant", 3, True, 7, 10, [0, 10, 20], [-28, -18, -8], [-27, -17, -7], [-21, -11, -1]), + ("constant", 3, False, 7, None, [0, 0, 0], [-22, -22, -22], [-21, -14, -7], [-15, -8, -1]), + ("constant", 3, False, 7, 1, [0, 0, 0], [-10, -10, -10], [-9, -8, -7], [-3, -2, -1]), + ("constant", 3, False, 7, 10, [0, 0, 0], [-28, -28, -28], [-27, -17, -7], [-21, -11, -1]), + ("constant", 1, 1, 7, None, [0], [-8], [-7], [-1]), + ("constant", 1, 2, 7, None, [0], [-8], [-7], [-1]), + ("constant", 3, 1, 7, None, [0, 7, 14], [-22, -15, -8], [-21, -14, -7], [-15, -8, -1]), + ("constant", 3, 2, 7, None, [0, 0, 14], [-22, -22, -8], [-21, -14, -7], [-15, -8, -1]), + ("constant", 3, 3, 7, None, [0, 0, 0], [-22, -22, -22], [-21, -14, -7], [-15, -8, -1]), + ("constant", 3, 4, 7, None, [0, 0, 0], [-22, -22, -22], [-21, -14, -7], [-15, -8, -1]), + ("constant", 4, 1, 7, None, [0, 7, 14, 21], [-29, -22, -15, -8], [-28, -21, -14, -7], [-22, -15, -8, -1]), + ("constant", 4, 2, 7, None, [0, 0, 14, 14], [-29, -29, -15, -15], [-28, -21, -14, -7], [-22, -15, -8, -1]), + ("constant", 4, 2, 7, 1, [0, 0, 2, 2], [-11, -11, -9, -9], [-10, -9, -8, -7], [-4, -3, -2, -1]), + ("constant", 4, 2, 7, 10, [0, 0, 20, 20], [-38, -38, -18, -18], [-37, -27, -17, -7], [-31, -21, -11, -1]), + ("constant", 4, 3, 7, None, [0, 0, 0, 21], [-29, -29, -29, -8], [-28, -21, -14, -7], [-22, -15, -8, -1]), + ("constant", 4, 4, 7, None, [0, 0, 0, 0], [-29, -29, -29, -29], [-28, -21, -14, -7], [-22, -15, -8, -1]), + ("constant", 4, 5, 7, None, [0, 0, 0, 0], [-29, -29, -29, -29], [-28, -21, -14, -7], [-22, -15, -8, -1]), + ( + None, + [ + FoldMask( + first_train_timestamp=None, + last_train_timestamp=pd.Timestamp("2020-01-31 10:00"), + target_timestamps=[pd.Timestamp("2020-01-31 14:00")], + ), + FoldMask( + first_train_timestamp=None, + last_train_timestamp=pd.Timestamp("2020-01-31 17:00"), + target_timestamps=[pd.Timestamp("2020-01-31 21:00")], + ), + ], + True, + 7, + None, + [0, 0], + [-15, -8], + [-14, -7], + [-8, -1], + ), + ( + None, + [ + FoldMask( + first_train_timestamp=pd.Timestamp("2020-01-01 1:00"), + last_train_timestamp=pd.Timestamp("2020-01-31 10:00"), + target_timestamps=[pd.Timestamp("2020-01-31 14:00")], + ), + FoldMask( + first_train_timestamp=pd.Timestamp("2020-01-01 8:00"), + last_train_timestamp=pd.Timestamp("2020-01-31 17:00"), + target_timestamps=[pd.Timestamp("2020-01-31 21:00")], + ), + ], + True, + 7, + None, + [1, 8], + [-15, -8], + [-14, -7], + [-8, -1], + ), + ( + None, + [ + FoldMask( + first_train_timestamp=None, + last_train_timestamp=pd.Timestamp("2020-01-30 20:00"), + target_timestamps=[pd.Timestamp("2020-01-31 00:00")], + ), + FoldMask( + first_train_timestamp=None, + last_train_timestamp=pd.Timestamp("2020-01-31 03:00"), + target_timestamps=[pd.Timestamp("2020-01-31 07:00")], + ), + FoldMask( + first_train_timestamp=None, + last_train_timestamp=pd.Timestamp("2020-01-31 10:00"), + target_timestamps=[pd.Timestamp("2020-01-31 14:00")], + ), + FoldMask( + first_train_timestamp=None, + last_train_timestamp=pd.Timestamp("2020-01-31 17:00"), + target_timestamps=[pd.Timestamp("2020-01-31 21:00")], + ), + ], + 2, + 7, + None, + [0, 0, 0, 0], + [-29, -29, -15, -15], + [-28, -21, -14, -7], + [-22, -15, -8, -1], + ), ], ) -def test_get_fold_info_refit_int(n_folds, refit, expected_refits, example_tsdf: TSDataset): - """Check that Pipeline.backtest returns info dataframe with correct train with rare refit.""" - pipeline = Pipeline(model=NaiveModel(lag=7), horizon=7) - _, _, info_df = pipeline.backtest(ts=example_tsdf, n_jobs=1, metrics=DEFAULT_METRICS, n_folds=n_folds, refit=refit) - assert info_df["train_start_time"].nunique() == 1 - assert info_df["train_end_time"].nunique() == expected_refits - assert info_df["test_start_time"].nunique() == n_folds - assert info_df["test_end_time"].nunique() == n_folds +def test_backtest_fold_info_timestamps( + mode, + n_folds, + refit, + horizon, + stride, + expected_train_starts, + expected_train_ends, + expected_test_starts, + expected_test_ends, + example_tsdf, +): + """Check that Pipeline.backtest returns info dataframe with correct timestamps.""" + pipeline = Pipeline(model=NaiveModel(lag=horizon), horizon=horizon) + _, _, info_df = pipeline.backtest( + ts=example_tsdf, metrics=DEFAULT_METRICS, mode=mode, n_folds=n_folds, refit=refit, stride=stride + ) + timestamp = example_tsdf.index + + np.testing.assert_array_equal(info_df["train_start_time"], timestamp[expected_train_starts]) + np.testing.assert_array_equal(info_df["train_end_time"], timestamp[expected_train_ends]) + np.testing.assert_array_equal(info_df["test_start_time"], timestamp[expected_test_starts]) + np.testing.assert_array_equal(info_df["test_end_time"], timestamp[expected_test_ends]) def test_backtest_refit_success(catboost_pipeline: Pipeline, big_example_tsdf: TSDataset): @@ -543,11 +673,13 @@ def test_forecast_pipeline_with_nan_at_the_end(df_with_nans_in_tails): @pytest.mark.parametrize( - "n_folds, mode, expected_masks", + "n_folds, horizon, stride, mode, expected_masks", ( ( 2, - "expand", + 3, + 3, + CrossValidationMode.expand, [ FoldMask( first_train_timestamp="2020-01-01", @@ -563,7 +695,45 @@ def test_forecast_pipeline_with_nan_at_the_end(df_with_nans_in_tails): ), ( 2, - "constant", + 3, + 1, + CrossValidationMode.expand, + [ + FoldMask( + first_train_timestamp="2020-01-01", + last_train_timestamp="2020-04-05", + target_timestamps=["2020-04-06", "2020-04-07", "2020-04-08"], + ), + FoldMask( + first_train_timestamp="2020-01-01", + last_train_timestamp="2020-04-06", + target_timestamps=["2020-04-07", "2020-04-08", "2020-04-09"], + ), + ], + ), + ( + 2, + 3, + 5, + CrossValidationMode.expand, + [ + FoldMask( + first_train_timestamp="2020-01-01", + last_train_timestamp="2020-04-01", + target_timestamps=["2020-04-02", "2020-04-03", "2020-04-04"], + ), + FoldMask( + first_train_timestamp="2020-01-01", + last_train_timestamp="2020-04-06", + target_timestamps=["2020-04-07", "2020-04-08", "2020-04-09"], + ), + ], + ), + ( + 2, + 3, + 3, + CrossValidationMode.constant, [ FoldMask( first_train_timestamp="2020-01-01", @@ -577,10 +747,48 @@ def test_forecast_pipeline_with_nan_at_the_end(df_with_nans_in_tails): ), ], ), + ( + 2, + 3, + 1, + CrossValidationMode.constant, + [ + FoldMask( + first_train_timestamp="2020-01-01", + last_train_timestamp="2020-04-05", + target_timestamps=["2020-04-06", "2020-04-07", "2020-04-08"], + ), + FoldMask( + first_train_timestamp="2020-01-02", + last_train_timestamp="2020-04-06", + target_timestamps=["2020-04-07", "2020-04-08", "2020-04-09"], + ), + ], + ), + ( + 2, + 3, + 5, + CrossValidationMode.constant, + [ + FoldMask( + first_train_timestamp="2020-01-01", + last_train_timestamp="2020-04-01", + target_timestamps=["2020-04-02", "2020-04-03", "2020-04-04"], + ), + FoldMask( + first_train_timestamp="2020-01-06", + last_train_timestamp="2020-04-06", + target_timestamps=["2020-04-07", "2020-04-08", "2020-04-09"], + ), + ], + ), ), ) -def test_generate_masks_from_n_folds(example_tsds: TSDataset, n_folds, mode, expected_masks): - masks = Pipeline._generate_masks_from_n_folds(ts=example_tsds, n_folds=n_folds, horizon=3, mode=mode) +def test_generate_masks_from_n_folds(example_tsds: TSDataset, n_folds, horizon, stride, mode, expected_masks): + masks = Pipeline._generate_masks_from_n_folds( + ts=example_tsds, n_folds=n_folds, horizon=horizon, stride=stride, mode=mode + ) for mask, expected_mask in zip(masks, expected_masks): assert mask.first_train_timestamp == expected_mask.first_train_timestamp assert mask.last_train_timestamp == expected_mask.last_train_timestamp @@ -597,7 +805,7 @@ def test_generate_folds_datasets(ts_name, mask, request): """Check _generate_folds_datasets for correct work.""" ts = request.getfixturevalue(ts_name) pipeline = Pipeline(model=NaiveModel(lag=7)) - mask = pipeline._prepare_fold_masks(ts=ts, masks=[mask], mode="constant")[0] + mask = pipeline._prepare_fold_masks(ts=ts, masks=[mask], mode="constant", stride=-1)[0] train, test = list(pipeline._generate_folds_datasets(ts, [mask], 4))[0] assert train.index.min() == np.datetime64(mask.first_train_timestamp) assert train.index.max() == np.datetime64(mask.last_train_timestamp) @@ -615,7 +823,7 @@ def test_generate_folds_datasets_without_first_date(ts_name, mask, request): """Check _generate_folds_datasets for correct work without first date.""" ts = request.getfixturevalue(ts_name) pipeline = Pipeline(model=NaiveModel(lag=7)) - mask = pipeline._prepare_fold_masks(ts=ts, masks=[mask], mode="constant")[0] + mask = pipeline._prepare_fold_masks(ts=ts, masks=[mask], mode="constant", stride=-1)[0] train, test = list(pipeline._generate_folds_datasets(ts, [mask], 4))[0] assert train.index.min() == np.datetime64(ts.index.min()) assert train.index.max() == np.datetime64(mask.last_train_timestamp)