Skip to content

Commit

Permalink
ChangePointSegmentationTransform (#822)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ama16 authored Aug 19, 2022
1 parent 0198d11 commit ced2072
Show file tree
Hide file tree
Showing 10 changed files with 443 additions and 69 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Add `ChangePointSegmentationTransform`, `RupturesChangePointsModel` ([#821](https://github.com/tinkoff-ai/etna/issues/821))
-
-
### Changed
Expand Down
32 changes: 4 additions & 28 deletions etna/analysis/change_points_trend/search.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,12 @@
from typing import Dict
from typing import List

import numpy as np
import pandas as pd
from ruptures.base import BaseEstimator
from ruptures.costs import CostLinear

from etna.datasets import TSDataset


def _prepare_signal(series: pd.Series, model: BaseEstimator) -> np.ndarray:
"""Prepare series for change point model."""
signal = series.to_numpy()
if isinstance(model.cost, CostLinear):
signal = signal.reshape((-1, 1))
return signal


def _find_change_points_segment(
series: pd.Series, change_point_model: BaseEstimator, **model_predict_params
) -> List[pd.Timestamp]:
"""Find trend change points within one segment."""
signal = _prepare_signal(series=series, model=change_point_model)
timestamp = series.index
change_point_model.fit(signal=signal)
# last point in change points is the first index after the series
change_points_indices = change_point_model.predict(**model_predict_params)[:-1]
change_points = [timestamp[idx] for idx in change_points_indices]
return change_points


def find_change_points(
ts: TSDataset, in_column: str, change_point_model: BaseEstimator, **model_predict_params
) -> Dict[str, List[pd.Timestamp]]:
Expand All @@ -51,13 +28,12 @@ def find_change_points(
Dict[str, List[pd.Timestamp]]
dictionary with list of trend change points for each segment
"""
from etna.transforms.decomposition.base_change_points import RupturesChangePointsModel

result: Dict[str, List[pd.Timestamp]] = {}
df = ts.to_pandas()
ruptures = RupturesChangePointsModel(change_point_model, **model_predict_params)
for segment in ts.segments:
df_segment = df[segment]
raw_series = df_segment[in_column]
series = raw_series.loc[raw_series.first_valid_index() : raw_series.last_valid_index()]
result[segment] = _find_change_points_segment(
series=series, change_point_model=change_point_model, **model_predict_params
)
result[segment] = ruptures.get_change_points(df=df_segment, in_column=in_column)
return result
1 change: 1 addition & 0 deletions etna/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from etna.transforms.base import PerSegmentWrapper
from etna.transforms.base import Transform
from etna.transforms.decomposition import BinsegTrendTransform
from etna.transforms.decomposition import ChangePointsSegmentationTransform
from etna.transforms.decomposition import ChangePointsTrendTransform
from etna.transforms.decomposition import LinearTrendTransform
from etna.transforms.decomposition import STLTransform
Expand Down
2 changes: 2 additions & 0 deletions etna/transforms/decomposition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from etna.transforms.decomposition.base_change_points import RupturesChangePointsModel
from etna.transforms.decomposition.binseg import BinsegTrendTransform
from etna.transforms.decomposition.change_points_segmentation import ChangePointsSegmentationTransform
from etna.transforms.decomposition.change_points_trend import ChangePointsTrendTransform
from etna.transforms.decomposition.detrend import LinearTrendTransform
from etna.transforms.decomposition.detrend import TheilSenTrendTransform
Expand Down
108 changes: 108 additions & 0 deletions etna/transforms/decomposition/base_change_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from abc import ABC
from abc import abstractmethod
from typing import List
from typing import Tuple
from typing import Type

import pandas as pd
from ruptures.base import BaseEstimator
from ruptures.costs import CostLinear
from sklearn.base import RegressorMixin

TTimestampInterval = Tuple[pd.Timestamp, pd.Timestamp]
TDetrendModel = Type[RegressorMixin]


class BaseChangePointsModelAdapter(ABC):
"""BaseChangePointsModelAdapter is the base class for change point models adapters."""

@abstractmethod
def get_change_points(self, df: pd.DataFrame, in_column: str) -> List[pd.Timestamp]:
"""Find change points within one segment.
Parameters
----------
df:
dataframe indexed with timestamp
in_column:
name of column to get change points
Returns
-------
change points:
change point timestamps
"""
pass

@staticmethod
def _build_intervals(change_points: List[pd.Timestamp]) -> List[TTimestampInterval]:
"""Create list of stable intervals from list of change points."""
change_points.extend([pd.Timestamp.min, pd.Timestamp.max])
change_points = sorted(change_points)
intervals = list(zip(change_points[:-1], change_points[1:]))
return intervals

def get_change_points_intervals(self, df: pd.DataFrame, in_column: str) -> List[TTimestampInterval]:
"""Find change point intervals in given dataframe and column.
Parameters
----------
df:
dataframe indexed with timestamp
in_column:
name of column to get change points
Returns
-------
:
change points intervals
"""
change_points = self.get_change_points(df=df, in_column=in_column)
intervals = self._build_intervals(change_points=change_points)
return intervals


class RupturesChangePointsModel(BaseChangePointsModelAdapter):
"""RupturesChangePointsModel is ruptures change point models adapter."""

def __init__(self, change_point_model: BaseEstimator, **change_point_model_predict_params):
"""Init RupturesChangePointsModel.
Parameters
----------
change_point_model:
model to get change points
change_point_model_predict_params:
params for ``change_point_model.predict`` method
"""
self.change_point_model = change_point_model
self.model_predict_params = change_point_model_predict_params

def get_change_points(self, df: pd.DataFrame, in_column: str) -> List[pd.Timestamp]:
"""Find change points within one segment.
Parameters
----------
df:
dataframe indexed with timestamp
in_column:
name of column to get change points
Returns
-------
change points:
change point timestamps
"""
series = df.loc[df[in_column].first_valid_index() : df[in_column].last_valid_index(), in_column]
if series.isnull().values.any():
raise ValueError("The input column contains NaNs in the middle of the series! Try to use the imputer.")

signal = series.to_numpy()
if isinstance(self.change_point_model.cost, CostLinear):
signal = signal.reshape((-1, 1))
timestamp = series.index
self.change_point_model.fit(signal=signal)
# last point in change points is the first index after the series
change_points_indices = self.change_point_model.predict(**self.model_predict_params)[:-1]
change_points = [timestamp[idx] for idx in change_points_indices]
return change_points
121 changes: 121 additions & 0 deletions etna/transforms/decomposition/change_points_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import List
from typing import Optional

import pandas as pd

from etna.transforms.base import FutureMixin
from etna.transforms.base import PerSegmentWrapper
from etna.transforms.base import Transform
from etna.transforms.decomposition.base_change_points import BaseChangePointsModelAdapter
from etna.transforms.decomposition.base_change_points import TTimestampInterval


class _OneSegmentChangePointsSegmentationTransform(Transform):
"""_OneSegmentChangePointsSegmentationTransform make label encoder to change points."""

def __init__(self, in_column: str, out_column: str, change_point_model: BaseChangePointsModelAdapter):
"""Init _OneSegmentChangePointsSegmentationTransform.
Parameters
----------
in_column:
name of column to apply transform to
out_column:
result column name. If not given use ``self.__repr__()``
change_point_model:
model to get change points
"""
self.in_column = in_column
self.out_column = out_column
self.intervals: Optional[List[TTimestampInterval]] = None
self.change_point_model = change_point_model

def _fill_per_interval(self, series: pd.Series) -> pd.Series:
"""Fill values in resulting series."""
if self.intervals is None:
raise ValueError("Transform is not fitted! Fit the Transform before calling transform method.")
result_series = pd.Series(index=series.index)
for k, interval in enumerate(self.intervals):
tmp_series = series[interval[0] : interval[1]]
if tmp_series.empty:
continue
result_series[tmp_series.index] = k
return result_series.astype(int).astype("category")

def fit(self, df: pd.DataFrame) -> "_OneSegmentChangePointsSegmentationTransform":
"""Fit _OneSegmentChangePointsSegmentationTransform: find change points in ``df`` and build intervals.
Parameters
----------
df:
one segment dataframe indexed with timestamp
Returns
-------
:
instance with trained change points
Raises
------
ValueError
If series contains NaNs in the middle
"""
self.intervals = self.change_point_model.get_change_points_intervals(df=df, in_column=self.in_column)
return self

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Split df to intervals.
Parameters
----------
df:
one segment dataframe
Returns
-------
df:
df with new column
"""
series = df[self.in_column]
result_series = self._fill_per_interval(series=series)
df.loc[:, self.out_column] = result_series
return df


class ChangePointsSegmentationTransform(PerSegmentWrapper, FutureMixin):
"""ChangePointsSegmentationTransform make label encoder to change points.
Warning
-------
This transform can suffer from look-ahead bias. For transforming data at some timestamp
it uses information from the whole train part.
"""

def __init__(
self,
in_column: str,
change_point_model: BaseChangePointsModelAdapter,
out_column: Optional[str] = None,
):
"""Init ChangePointsSegmentationTransform.
Parameterss
----------
in_column:
name of column to fit change point model
out_column:
result column name. If not given use ``self.__repr__()``
change_point_model:
model to get change points
"""
self.in_column = in_column
self.out_column = out_column
self.change_point_model = change_point_model
if self.out_column is None:
self.out_column = repr(self)
super().__init__(
transform=_OneSegmentChangePointsSegmentationTransform(
in_column=self.in_column,
out_column=self.out_column,
change_point_model=self.change_point_model,
)
)
34 changes: 11 additions & 23 deletions etna/transforms/decomposition/change_points_trend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from ruptures.base import BaseEstimator
from sklearn.base import RegressorMixin

from etna.analysis.change_points_trend.search import _find_change_points_segment
from etna.transforms.base import PerSegmentWrapper
from etna.transforms.base import Transform
from etna.transforms.decomposition.base_change_points import RupturesChangePointsModel
from etna.transforms.decomposition.base_change_points import TTimestampInterval
from etna.transforms.utils import match_target_quantiles

TTimestampInterval = Tuple[pd.Timestamp, pd.Timestamp]
TDetrendModel = Type[RegressorMixin]


Expand All @@ -37,32 +37,23 @@ def __init__(
name of column to apply transform to
change_point_model:
model to get trend change points
TODO: replace this parameters with the instance of BaseChangePointsModelAdapter in ETNA 2.0
detrend_model:
model to get trend in data
change_point_model_predict_params:
params for ``change_point_model.predict`` method
"""
self.in_column = in_column
self.out_columns = in_column
self.change_point_model = change_point_model
self.ruptures_change_point_model = RupturesChangePointsModel(
change_point_model=change_point_model, **change_point_model_predict_params
)
self.detrend_model = detrend_model
self.per_interval_models: Optional[Dict[TTimestampInterval, TDetrendModel]] = None
self.intervals: Optional[List[TTimestampInterval]] = None
self.change_point_model = change_point_model
self.change_point_model_predict_params = change_point_model_predict_params

@staticmethod
def _build_trend_intervals(change_points: List[pd.Timestamp]) -> List[TTimestampInterval]:
"""Create list of stable trend intervals from list of change points."""
change_points = sorted(change_points)
left_border = pd.Timestamp.min
intervals = []
for point in change_points:
right_border = point
intervals.append((left_border, right_border))
left_border = right_border
intervals.append((left_border, pd.Timestamp.max))
return intervals

def _init_detrend_models(
self, intervals: List[TTimestampInterval]
) -> Dict[Tuple[pd.Timestamp, pd.Timestamp], TDetrendModel]:
Expand Down Expand Up @@ -112,14 +103,10 @@ def fit(self, df: pd.DataFrame) -> "_OneSegmentChangePointsTrendTransform":
-------
:
"""
series = df.loc[df[self.in_column].first_valid_index() : df[self.in_column].last_valid_index(), self.in_column]
if series.isnull().values.any():
raise ValueError("The input column contains NaNs in the middle of the series! Try to use the imputer.")
change_points = _find_change_points_segment(
series=series, change_point_model=self.change_point_model, **self.change_point_model_predict_params
)
self.intervals = self._build_trend_intervals(change_points=change_points)
self.intervals = self.ruptures_change_point_model.get_change_points_intervals(df=df, in_column=self.in_column)
self.per_interval_models = self._init_detrend_models(intervals=self.intervals)

series = df.loc[df[self.in_column].first_valid_index() : df[self.in_column].last_valid_index(), self.in_column]
self._fit_per_interval_model(series=series)
return self

Expand Down Expand Up @@ -190,6 +177,7 @@ def __init__(
name of column to apply transform to
change_point_model:
model to get trend change points
TODO: replace this parameters with the instance of BaseChangePointsModelAdapter in ETNA 2.0
detrend_model:
model to get trend in data
change_point_model_predict_params:
Expand Down
Loading

1 comment on commit ced2072

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.