Skip to content

Update ruptures version #141

Merged
merged 3 commits into from
Oct 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Update EDA notebook ([#96](https://github.com/tinkoff-ai/etna-ts/pull/96), [#114](https://github.com/tinkoff-ai/etna-ts/pull/114))
- Delete offset from WindowStatisticsTransform ([#111](https://github.com/tinkoff-ai/etna-ts/pull/111))
- Add Pipeline example in Get started notebook ([#115](https://github.com/tinkoff-ai/etna-ts/pull/115))
- Internal implementation of BinsegTrendTransform ([#141](https://github.com/tinkoff-ai/etna-ts/pull/141))

### Fixed
- Add more obvious Exception Error for forecasting with unfitted model ([#102](https://github.com/tinkoff-ai/etna-ts/pull/102))
Expand Down
44 changes: 2 additions & 42 deletions etna/transforms/binseg.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,15 @@
from functools import lru_cache
from typing import Any
from typing import Optional

from ruptures.base import BaseCost
from ruptures.costs import cost_factory
from ruptures.detection import Binseg
from sklearn.linear_model import LinearRegression

from etna.transforms.change_points_trend import ChangePointsTrendTransform
from etna.transforms.change_points_trend import TDetrendModel


class _Binseg(Binseg):
"""Binary segmentation with lru_cache."""

def __init__(
self,
model: str = "l2",
custom_cost: Optional[BaseCost] = None,
min_size: int = 2,
jump: int = 5,
params: Any = None,
):
"""Initialize a Binseg instance.

Args:
model (str, optional): segment model, ["l1", "l2", "rbf",...]. Not used if ``'custom_cost'`` is not None.
custom_cost (BaseCost, optional): custom cost function. Defaults to None.
min_size (int, optional): minimum segment length. Defaults to 2 samples.
jump (int, optional): subsample (one every *jump* points). Defaults to 5 samples.
params (dict, optional): a dictionary of parameters for the cost instance.
"""
if custom_cost is not None and isinstance(custom_cost, BaseCost):
self.cost = custom_cost
elif params is None:
self.cost = cost_factory(model=model)
else:
self.cost = cost_factory(model=model, **params)
self.min_size = max(min_size, self.cost.min_size)
self.jump = jump
self.n_samples = None
self.signal = None

@lru_cache(maxsize=None)
def single_bkp(self, start: int, end: int) -> Any:
"""Run _single_bkp with lru_cache decorator."""
return self._single_bkp(start=start, end=end)


class BinsegTrendTransform(ChangePointsTrendTransform):
"""BinsegTrendTransform uses _Binseg model as a change point detection model in ChangePointsTrendTransform transform."""
"""BinsegTrendTransform uses Binseg model as a change point detection model in ChangePointsTrendTransform transform."""

def __init__(
self,
Expand Down Expand Up @@ -95,7 +55,7 @@ def __init__(
self.epsilon = epsilon
super().__init__(
in_column=in_column,
change_point_model=_Binseg(
change_point_model=Binseg(
model=self.model, custom_cost=self.custom_cost, min_size=self.min_size, jump=self.jump
),
detrend_model=detrend_model,
Expand Down
4 changes: 1 addition & 3 deletions etna/transforms/change_points_trend.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ def fit(self, df: pd.DataFrame) -> "OneSegmentChangePointsTransform":
-------
self
"""
# we need copy here because Binseg with CostAR (model="ar") changes given signal inplace; if it is fixed
# @TODO: delete copy
series = df.loc[df[self.in_column].first_valid_index() :, self.in_column].copy(deep=True)
series = df.loc[df[self.in_column].first_valid_index() :, self.in_column]
change_points = self._get_change_points(series=series)
self.intervals = self._build_trend_intervals(change_points=change_points)
self.per_interval_models = self._init_detrend_models(intervals=self.intervals)
Expand Down
64 changes: 31 additions & 33 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ scikit-learn = "^0.24.1"
prophet = "^1.0"
pandas = "^1"
catboost = "^0.25"
ruptures = "1.1.3"
ruptures = "1.1.5"
torch = "1.8.*"
pytorch-forecasting = "0.8.5"
numba = "^0.53.1"
Expand Down
Loading