Skip to content

ChangePointSegmentationTransform #822

Merged
merged 28 commits into from
Aug 19, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
-
-
-
- Add ChangePointSegmentationTransform ([#821](https://github.com/tinkoff-ai/etna/issues/821))
-
-
-
Expand Down
1 change: 1 addition & 0 deletions etna/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from etna.transforms.base import PerSegmentWrapper
from etna.transforms.base import Transform
from etna.transforms.decomposition import BinsegTrendTransform
from etna.transforms.decomposition import ChangePointsSegmentationTransform
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
from etna.transforms.decomposition import ChangePointsTrendTransform
from etna.transforms.decomposition import LinearTrendTransform
from etna.transforms.decomposition import STLTransform
Expand Down
1 change: 1 addition & 0 deletions etna/transforms/decomposition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from etna.transforms.decomposition.binseg import BinsegTrendTransform
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
from etna.transforms.decomposition.change_points_segmentation import ChangePointsSegmentationTransform
from etna.transforms.decomposition.change_points_trend import ChangePointsTrendTransform
from etna.transforms.decomposition.detrend import LinearTrendTransform
from etna.transforms.decomposition.detrend import TheilSenTrendTransform
Expand Down
72 changes: 72 additions & 0 deletions etna/transforms/decomposition/change_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type

import pandas as pd
from ruptures.base import BaseEstimator
from sklearn.base import RegressorMixin

from etna.analysis.change_points_trend.search import _find_change_points_segment
from etna.transforms.base import Transform

TTimestampInterval = Tuple[pd.Timestamp, pd.Timestamp]
TDetrendModel = Type[RegressorMixin]


class ChangePointsTransform(Transform):
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
"""ChangePointsTransform is the base class for transforms with change points."""

def __init__(self, in_column: str, change_point_model: BaseEstimator, **change_point_model_predict_params):
"""Init ChangePointsTransform.

Parameters
----------
in_column:
name of column to apply transform to
change_point_model:
model to get change points
change_point_model_predict_params:
params for ``change_point_model.predict`` method
"""
self.in_column = in_column
self.out_columns = in_column
self.change_point_model = change_point_model
self.intervals: Optional[List[TTimestampInterval]] = None
self.change_point_model_predict_params = change_point_model_predict_params

@staticmethod
def _build_intervals(change_points: List[pd.Timestamp]) -> List[TTimestampInterval]:
"""Create list of stable intervals from list of change points."""
change_points = sorted(change_points)
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
left_border = pd.Timestamp.min
intervals = []
for point in change_points:
right_border = point
intervals.append((left_border, right_border))
left_border = right_border
intervals.append((left_border, pd.Timestamp.max))
return intervals

def fit(self, df: pd.DataFrame) -> "ChangePointsTransform":
"""Fit ChangePointsTransform: find change points in ``df`` and build intervals..
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
df:
one segment dataframe indexed with timestamp

Returns
-------
:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
"""
self.series = df.loc[
df[self.in_column].first_valid_index() : df[self.in_column].last_valid_index(), self.in_column
]
if self.series.isnull().values.any():
raise ValueError("The input column contains NaNs in the middle of the series! Try to use the imputer.")
change_points = _find_change_points_segment(
series=self.series, change_point_model=self.change_point_model, **self.change_point_model_predict_params
)
self.intervals = self._build_intervals(change_points=change_points)
return self
119 changes: 119 additions & 0 deletions etna/transforms/decomposition/change_points_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Optional

import pandas as pd
from ruptures.base import BaseEstimator

from etna.transforms.base import FutureMixin
from etna.transforms.base import PerSegmentWrapper
from etna.transforms.decomposition.change_points import ChangePointsTransform


class _OneSegmentChangePointsSegmentationTransform(ChangePointsTransform):
"""_OneSegmentChangePointsSegmentationTransform make label encoder to change points."""

def __init__(
self, in_column: str, change_point_model: BaseEstimator, out_column: str, **change_point_model_predict_params
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
):
"""Init _OneSegmentChangePointsSegmentationTransform.
Parameters
----------
in_column:
name of column to apply transform to
change_point_model:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
model to get change points
out_column: str, optional
result column name. If not given use ``self.__repr__()``
change_point_model_predict_params:
params for ``change_point_model.predict`` method
"""
super(_OneSegmentChangePointsSegmentationTransform, self).__init__(
in_column=in_column, change_point_model=change_point_model, **change_point_model_predict_params
)

self.out_column = out_column

def _fill_per_interval(self, series: pd.Series) -> pd.Series:
"""Fill values in resulting series."""
if self.intervals is None:
raise ValueError("Transform is not fitted! Fit the Transform before calling transform method.")
result_series = pd.Series(index=series.index)
for k, interval in enumerate(self.intervals):
tmp_series = series[interval[0] : interval[1]]
if tmp_series.empty:
continue
result_series[tmp_series.index] = k
return result_series.astype(int).astype("category")

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
"""Split df to intervals.

Parameters
----------
df:
one segment dataframe
Returns
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
-------
df: pd.DataFrame
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
df with new column
"""
series = df[self.in_column]
result_series = self._fill_per_interval(series=series)
df.loc[:, self.out_column] = result_series
return df

def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
"""Do nothing in this case.

Parameters
----------
df:
one segment dataframe
Returns
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
-------
df: pd.DataFrame
one segment dataframe
"""
return df


class ChangePointsSegmentationTransform(PerSegmentWrapper, FutureMixin):
"""ChangePointsSegmentationTransform make label encoder to change points.

Warning
-------
This transform can suffer from look-ahead bias. For transforming data at some timestamp
it uses information from the whole train part.
"""

def __init__(
self,
in_column: str,
change_point_model: BaseEstimator,
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
out_column: Optional[str] = None,
**change_point_model_predict_params,
):
"""Init ChangePointsSegmentationTransform.

Parameters
----------
in_column:
name of column to fit change point model
change_point_model:
model to get change points
out_column: str, optional
result column name. If not given use ``self.__repr__()``
change_point_model_predict_params:
params for ``change_point_model.predict`` method
"""
self.in_column = in_column
self.out_column = out_column
self.change_point_model = change_point_model
self.change_point_model_predict_params = change_point_model_predict_params
super().__init__(
transform=_OneSegmentChangePointsSegmentationTransform(
in_column=self.in_column,
out_column=self.out_column if self.out_column is not None else self.__repr__(),
change_point_model=self.change_point_model,
**self.change_point_model_predict_params,
)
)
48 changes: 12 additions & 36 deletions etna/transforms/decomposition/change_points_trend.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
from copy import deepcopy
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type

import numpy as np
import pandas as pd
from ruptures.base import BaseEstimator
from sklearn.base import RegressorMixin

from etna.analysis.change_points_trend.search import _find_change_points_segment
from etna.transforms.base import PerSegmentWrapper
from etna.transforms.base import Transform
from etna.transforms.decomposition.change_points import ChangePointsTransform
from etna.transforms.decomposition.change_points import TDetrendModel
from etna.transforms.decomposition.change_points import TTimestampInterval
from etna.transforms.utils import match_target_quantiles

TTimestampInterval = Tuple[pd.Timestamp, pd.Timestamp]
TDetrendModel = Type[RegressorMixin]


class _OneSegmentChangePointsTrendTransform(Transform):
class _OneSegmentChangePointsTrendTransform(ChangePointsTransform):
"""_OneSegmentChangePointsTransform subtracts multiple linear trend from series."""

def __init__(
Expand All @@ -42,26 +37,12 @@ def __init__(
change_point_model_predict_params:
params for ``change_point_model.predict`` method
"""
self.in_column = in_column
super(_OneSegmentChangePointsTrendTransform, self).__init__(
in_column=in_column, change_point_model=change_point_model, **change_point_model_predict_params
)

self.out_columns = in_column
self.change_point_model = change_point_model
self.detrend_model = detrend_model
self.per_interval_models: Optional[Dict[TTimestampInterval, TDetrendModel]] = None
self.intervals: Optional[List[TTimestampInterval]] = None
self.change_point_model_predict_params = change_point_model_predict_params
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _build_trend_intervals(change_points: List[pd.Timestamp]) -> List[TTimestampInterval]:
"""Create list of stable trend intervals from list of change points."""
change_points = sorted(change_points)
left_border = pd.Timestamp.min
intervals = []
for point in change_points:
right_border = point
intervals.append((left_border, right_border))
left_border = right_border
intervals.append((left_border, pd.Timestamp.max))
return intervals

def _init_detrend_models(
self, intervals: List[TTimestampInterval]
Expand Down Expand Up @@ -112,15 +93,10 @@ def fit(self, df: pd.DataFrame) -> "_OneSegmentChangePointsTrendTransform":
-------
:
"""
series = df.loc[df[self.in_column].first_valid_index() : df[self.in_column].last_valid_index(), self.in_column]
if series.isnull().values.any():
raise ValueError("The input column contains NaNs in the middle of the series! Try to use the imputer.")
change_points = _find_change_points_segment(
series=series, change_point_model=self.change_point_model, **self.change_point_model_predict_params
)
self.intervals = self._build_trend_intervals(change_points=change_points)
self.per_interval_models = self._init_detrend_models(intervals=self.intervals)
self._fit_per_interval_model(series=series)
super(_OneSegmentChangePointsTrendTransform, self).fit(df=df)

self.per_interval_models = self._init_detrend_models(intervals=self.intervals) # type: ignore
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
self._fit_per_interval_model(series=self.series)
julia-shenshina marked this conversation as resolved.
Show resolved Hide resolved
return self

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
Expand Down
Loading