-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ChangePointSegmentationTransform (#822)
- Loading branch information
Showing
10 changed files
with
443 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from abc import ABC | ||
from abc import abstractmethod | ||
from typing import List | ||
from typing import Tuple | ||
from typing import Type | ||
|
||
import pandas as pd | ||
from ruptures.base import BaseEstimator | ||
from ruptures.costs import CostLinear | ||
from sklearn.base import RegressorMixin | ||
|
||
TTimestampInterval = Tuple[pd.Timestamp, pd.Timestamp] | ||
TDetrendModel = Type[RegressorMixin] | ||
|
||
|
||
class BaseChangePointsModelAdapter(ABC): | ||
"""BaseChangePointsModelAdapter is the base class for change point models adapters.""" | ||
|
||
@abstractmethod | ||
def get_change_points(self, df: pd.DataFrame, in_column: str) -> List[pd.Timestamp]: | ||
"""Find change points within one segment. | ||
Parameters | ||
---------- | ||
df: | ||
dataframe indexed with timestamp | ||
in_column: | ||
name of column to get change points | ||
Returns | ||
------- | ||
change points: | ||
change point timestamps | ||
""" | ||
pass | ||
|
||
@staticmethod | ||
def _build_intervals(change_points: List[pd.Timestamp]) -> List[TTimestampInterval]: | ||
"""Create list of stable intervals from list of change points.""" | ||
change_points.extend([pd.Timestamp.min, pd.Timestamp.max]) | ||
change_points = sorted(change_points) | ||
intervals = list(zip(change_points[:-1], change_points[1:])) | ||
return intervals | ||
|
||
def get_change_points_intervals(self, df: pd.DataFrame, in_column: str) -> List[TTimestampInterval]: | ||
"""Find change point intervals in given dataframe and column. | ||
Parameters | ||
---------- | ||
df: | ||
dataframe indexed with timestamp | ||
in_column: | ||
name of column to get change points | ||
Returns | ||
------- | ||
: | ||
change points intervals | ||
""" | ||
change_points = self.get_change_points(df=df, in_column=in_column) | ||
intervals = self._build_intervals(change_points=change_points) | ||
return intervals | ||
|
||
|
||
class RupturesChangePointsModel(BaseChangePointsModelAdapter): | ||
"""RupturesChangePointsModel is ruptures change point models adapter.""" | ||
|
||
def __init__(self, change_point_model: BaseEstimator, **change_point_model_predict_params): | ||
"""Init RupturesChangePointsModel. | ||
Parameters | ||
---------- | ||
change_point_model: | ||
model to get change points | ||
change_point_model_predict_params: | ||
params for ``change_point_model.predict`` method | ||
""" | ||
self.change_point_model = change_point_model | ||
self.model_predict_params = change_point_model_predict_params | ||
|
||
def get_change_points(self, df: pd.DataFrame, in_column: str) -> List[pd.Timestamp]: | ||
"""Find change points within one segment. | ||
Parameters | ||
---------- | ||
df: | ||
dataframe indexed with timestamp | ||
in_column: | ||
name of column to get change points | ||
Returns | ||
------- | ||
change points: | ||
change point timestamps | ||
""" | ||
series = df.loc[df[in_column].first_valid_index() : df[in_column].last_valid_index(), in_column] | ||
if series.isnull().values.any(): | ||
raise ValueError("The input column contains NaNs in the middle of the series! Try to use the imputer.") | ||
|
||
signal = series.to_numpy() | ||
if isinstance(self.change_point_model.cost, CostLinear): | ||
signal = signal.reshape((-1, 1)) | ||
timestamp = series.index | ||
self.change_point_model.fit(signal=signal) | ||
# last point in change points is the first index after the series | ||
change_points_indices = self.change_point_model.predict(**self.model_predict_params)[:-1] | ||
change_points = [timestamp[idx] for idx in change_points_indices] | ||
return change_points |
121 changes: 121 additions & 0 deletions
121
etna/transforms/decomposition/change_points_segmentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from typing import List | ||
from typing import Optional | ||
|
||
import pandas as pd | ||
|
||
from etna.transforms.base import FutureMixin | ||
from etna.transforms.base import PerSegmentWrapper | ||
from etna.transforms.base import Transform | ||
from etna.transforms.decomposition.base_change_points import BaseChangePointsModelAdapter | ||
from etna.transforms.decomposition.base_change_points import TTimestampInterval | ||
|
||
|
||
class _OneSegmentChangePointsSegmentationTransform(Transform): | ||
"""_OneSegmentChangePointsSegmentationTransform make label encoder to change points.""" | ||
|
||
def __init__(self, in_column: str, out_column: str, change_point_model: BaseChangePointsModelAdapter): | ||
"""Init _OneSegmentChangePointsSegmentationTransform. | ||
Parameters | ||
---------- | ||
in_column: | ||
name of column to apply transform to | ||
out_column: | ||
result column name. If not given use ``self.__repr__()`` | ||
change_point_model: | ||
model to get change points | ||
""" | ||
self.in_column = in_column | ||
self.out_column = out_column | ||
self.intervals: Optional[List[TTimestampInterval]] = None | ||
self.change_point_model = change_point_model | ||
|
||
def _fill_per_interval(self, series: pd.Series) -> pd.Series: | ||
"""Fill values in resulting series.""" | ||
if self.intervals is None: | ||
raise ValueError("Transform is not fitted! Fit the Transform before calling transform method.") | ||
result_series = pd.Series(index=series.index) | ||
for k, interval in enumerate(self.intervals): | ||
tmp_series = series[interval[0] : interval[1]] | ||
if tmp_series.empty: | ||
continue | ||
result_series[tmp_series.index] = k | ||
return result_series.astype(int).astype("category") | ||
|
||
def fit(self, df: pd.DataFrame) -> "_OneSegmentChangePointsSegmentationTransform": | ||
"""Fit _OneSegmentChangePointsSegmentationTransform: find change points in ``df`` and build intervals. | ||
Parameters | ||
---------- | ||
df: | ||
one segment dataframe indexed with timestamp | ||
Returns | ||
------- | ||
: | ||
instance with trained change points | ||
Raises | ||
------ | ||
ValueError | ||
If series contains NaNs in the middle | ||
""" | ||
self.intervals = self.change_point_model.get_change_points_intervals(df=df, in_column=self.in_column) | ||
return self | ||
|
||
def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
"""Split df to intervals. | ||
Parameters | ||
---------- | ||
df: | ||
one segment dataframe | ||
Returns | ||
------- | ||
df: | ||
df with new column | ||
""" | ||
series = df[self.in_column] | ||
result_series = self._fill_per_interval(series=series) | ||
df.loc[:, self.out_column] = result_series | ||
return df | ||
|
||
|
||
class ChangePointsSegmentationTransform(PerSegmentWrapper, FutureMixin): | ||
"""ChangePointsSegmentationTransform make label encoder to change points. | ||
Warning | ||
------- | ||
This transform can suffer from look-ahead bias. For transforming data at some timestamp | ||
it uses information from the whole train part. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_column: str, | ||
change_point_model: BaseChangePointsModelAdapter, | ||
out_column: Optional[str] = None, | ||
): | ||
"""Init ChangePointsSegmentationTransform. | ||
Parameterss | ||
---------- | ||
in_column: | ||
name of column to fit change point model | ||
out_column: | ||
result column name. If not given use ``self.__repr__()`` | ||
change_point_model: | ||
model to get change points | ||
""" | ||
self.in_column = in_column | ||
self.out_column = out_column | ||
self.change_point_model = change_point_model | ||
if self.out_column is None: | ||
self.out_column = repr(self) | ||
super().__init__( | ||
transform=_OneSegmentChangePointsSegmentationTransform( | ||
in_column=self.in_column, | ||
out_column=self.out_column, | ||
change_point_model=self.change_point_model, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
ced2072
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉 Published on https://etna-docs.netlify.app as production
🚀 Deployed on https://62ff5927eb8ded6e04070977--etna-docs.netlify.app