-
Notifications
You must be signed in to change notification settings - Fork 80
Merged
Regularization search #1001
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from etna.experimental.change_points.regularization_search import get_ruptures_regularization |
178 changes: 178 additions & 0 deletions
178
etna/experimental/change_points/regularization_search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
from enum import Enum | ||
from typing import Dict | ||
from typing import Tuple | ||
from typing import Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from ruptures.base import BaseEstimator | ||
from ruptures.costs import CostLinear | ||
|
||
from etna.datasets import TSDataset | ||
|
||
|
||
class OptimizationMode(str, Enum): | ||
"""Enum for different optimization modes.""" | ||
|
||
pen = "pen" | ||
epsilon = "epsilon" | ||
|
||
@classmethod | ||
def _missing_(cls, value): | ||
raise NotImplementedError( | ||
f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} modes allowed" | ||
) | ||
|
||
|
||
def _get_n_bkps(series: pd.Series, change_point_model: BaseEstimator, **model_predict_params) -> int: | ||
"""Get number of change points, detected with given params. | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Parameters | ||
---------- | ||
series: | ||
series to detect change points | ||
change_point_model: | ||
model to get trend change points | ||
|
||
Returns | ||
------- | ||
: | ||
number of change points | ||
""" | ||
signal = series.to_numpy() | ||
if isinstance(change_point_model.cost, CostLinear): | ||
signal = signal.reshape((-1, 1)) | ||
|
||
change_point_model.fit(signal=signal) | ||
|
||
change_points_indices = change_point_model.predict(**model_predict_params)[:-1] | ||
return len(change_points_indices) | ||
|
||
|
||
def _get_next_value( | ||
now_value: float, lower_bound: float, upper_bound: float, need_greater: bool | ||
) -> Tuple[float, float, float]: | ||
"""Give next value according to binary search. | ||
Parameters | ||
---------- | ||
now_value: | ||
current value | ||
lower_bound: | ||
lower bound for search | ||
upper_bound: | ||
upper bound for search | ||
need_greater: | ||
True if we need greater value for n_bkps than previous time | ||
|
||
Returns | ||
------- | ||
: | ||
next value and its bounds | ||
""" | ||
if need_greater: | ||
return np.mean([now_value, lower_bound]), lower_bound, now_value | ||
else: | ||
return np.mean([now_value, upper_bound]), now_value, upper_bound | ||
|
||
|
||
def bin_search( | ||
series: pd.Series, | ||
change_point_model: BaseEstimator, | ||
n_bkps: int, | ||
opt_param: str, | ||
max_value: float, | ||
max_iters: int = 200, | ||
) -> float: | ||
"""Run binary search for optimal regularizations. | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Parameters | ||
---------- | ||
series: | ||
series for search | ||
change_point_model: | ||
model to get trend change points | ||
n_bkps: | ||
target numbers of changepoints | ||
opt_param: | ||
parameter for optimization | ||
max_value: | ||
maximum possible value, the upper bound for search | ||
max_iters: | ||
maximum iterations; in case if the required number of points is unattainable, values will be selected after max_iters iterations | ||
|
||
Returns | ||
------- | ||
: | ||
regularization parameters value | ||
""" | ||
zero_param = _get_n_bkps(series, change_point_model, **{opt_param: 0}) | ||
max_param = _get_n_bkps(series, change_point_model, **{opt_param: max_value}) | ||
if zero_param < n_bkps: | ||
raise ValueError("Impossible number of changepoints. Please, decrease n_bkps value.") | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if n_bkps < max_param: | ||
raise ValueError("Impossible number of changepoints. Please, increase max_value or increase n_bkps value.") | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
lower_bound, upper_bound = 0.0, max_value | ||
now_value = np.mean([lower_bound, upper_bound]) | ||
now_n_bkps = _get_n_bkps(series, change_point_model, **{opt_param: now_value}) | ||
iters = 0 | ||
|
||
while now_n_bkps != n_bkps and iters < max_iters: | ||
need_greater = now_n_bkps < n_bkps | ||
now_value, lower_bound, upper_bound = _get_next_value(now_value, lower_bound, upper_bound, need_greater) | ||
now_n_bkps = _get_n_bkps(series, change_point_model, **{opt_param: now_value}) | ||
iters += 1 | ||
return now_value | ||
|
||
|
||
def get_ruptures_regularization( | ||
ts: TSDataset, | ||
in_column: str, | ||
change_point_model: BaseEstimator, | ||
n_bkps: Union[Dict[str, int], int], | ||
mode: OptimizationMode, | ||
max_value: float = 10000, | ||
max_iters: int = 200, | ||
) -> Dict[str, Dict[str, float]]: | ||
"""Get regularization parameter values for given number of changepoints. | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
It is assumed that as the regularization being selected increases, the number of change points decreases. | ||
|
||
Parameters | ||
---------- | ||
ts: | ||
Dataset with timeseries data | ||
in_column: | ||
name of processed column | ||
change_point_model: | ||
model to get trend change points | ||
n_bkps: | ||
target numbers of changepoints | ||
mode: | ||
optimization mode | ||
max_value: | ||
maximum possible value, the upper bound for search | ||
max_iters: | ||
maximum iterations; in case if the required number of points is unattainable, values will be selected after max_iters iterations | ||
|
||
Returns | ||
------- | ||
: | ||
regularization parameters values | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be we should somehow explain how this format is working. |
||
|
||
Raises | ||
______ | ||
ValueError: | ||
If max_value is so low for needed n_bkps | ||
""" | ||
mode = OptimizationMode(mode) | ||
df = ts.to_pandas() | ||
segments = df.columns.get_level_values(0).unique() | ||
|
||
if isinstance(n_bkps, int): | ||
n_bkps = dict(zip(segments, [n_bkps] * len(segments))) | ||
|
||
regulatization = {} | ||
for segment in segments: | ||
series = ts[:, segment, in_column] | ||
regulatization[segment] = { | ||
mode.value: bin_search(series, change_point_model, n_bkps[segment], mode, max_value, max_iters) | ||
} | ||
return regulatization |
Empty file.
88 changes: 88 additions & 0 deletions
88
tests/test_experimental/test_change_points/test_regularization_search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import pytest | ||
from ruptures import Binseg | ||
|
||
from etna.datasets import TSDataset | ||
from etna.datasets import generate_ar_df | ||
from etna.experimental.change_points import get_ruptures_regularization | ||
from etna.experimental.change_points.regularization_search import _get_n_bkps | ||
|
||
|
||
@pytest.fixture | ||
def simple_change_points_ts(): | ||
df = generate_ar_df(periods=125, start_time="2021-05-20", n_segments=3, freq="D", random_seed=42) | ||
df_ts_format = TSDataset.to_dataset(df) | ||
return TSDataset(df_ts_format, freq="D") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"segment,params,expected", | ||
( | ||
("segment_0", {"pen": 20}, 6), | ||
("segment_0", {"epsilon": 20}, 24), | ||
("segment_1", {"pen": 10}, 7), | ||
("segment_1", {"epsilon": 100}, 12), | ||
("segment_2", {"pen": 2}, 14), | ||
("segment_2", {"epsilon": 200}, 6), | ||
), | ||
) | ||
def test_get_n_bkps(segment, params, expected, simple_change_points_ts): | ||
series = simple_change_points_ts[:, segment, "target"] | ||
assert _get_n_bkps(series, Binseg(), **params) == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"n_bkps,mode", | ||
( | ||
({"segment_0": 3, "segment_1": 14, "segment_2": 19}, "pen"), | ||
({"segment_0": 5, "segment_1": 2, "segment_2": 8}, "epsilon"), | ||
({"segment_0": 11, "segment_1": 18, "segment_2": 4}, "pen"), | ||
({"segment_0": 18, "segment_1": 21, "segment_2": 7}, "epsilon"), | ||
), | ||
) | ||
def test_get_regularization(n_bkps, mode, simple_change_points_ts): | ||
in_column = "target" | ||
res = get_ruptures_regularization( | ||
simple_change_points_ts, in_column=in_column, change_point_model=Binseg(), n_bkps=n_bkps, mode=mode | ||
) | ||
assert sorted(res.keys()) == sorted(simple_change_points_ts.to_pandas().columns.get_level_values(0).unique()) | ||
for seg in res.keys(): | ||
series = simple_change_points_ts[:, seg, in_column] | ||
answer = _get_n_bkps(series, Binseg(), **{mode: res[seg][mode]}) | ||
assert answer == n_bkps[seg] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"n_bkps,mode", | ||
( | ||
({"segment_0": 3, "segment_1": 34, "segment_2": 19}, "pen"), | ||
({"segment_0": 45, "segment_1": 2, "segment_2": 8}, "epsilon"), | ||
), | ||
) | ||
def test_fail_get_regularization_high(n_bkps, mode, simple_change_points_ts): | ||
in_column = "target" | ||
with pytest.raises(ValueError, match="Impossible number of changepoints. Please, decrease n_bkps value."): | ||
_ = get_ruptures_regularization( | ||
simple_change_points_ts, in_column=in_column, change_point_model=Binseg(), n_bkps=n_bkps, mode=mode | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"n_bkps,mode", | ||
( | ||
({"segment_0": 3, "segment_1": 1, "segment_2": 19}, "pen"), | ||
({"segment_0": 1, "segment_1": 2, "segment_2": 8}, "epsilon"), | ||
), | ||
) | ||
def test_fail_get_regularization_low(n_bkps, mode, simple_change_points_ts): | ||
in_column = "target" | ||
with pytest.raises( | ||
ValueError, match="Impossible number of changepoints. Please, increase max_value or increase n_bkps value." | ||
): | ||
_ = get_ruptures_regularization( | ||
simple_change_points_ts, | ||
in_column=in_column, | ||
change_point_model=Binseg(), | ||
n_bkps=n_bkps, | ||
mode=mode, | ||
max_value=1, | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be "in" -> "into"?