Skip to content

Commit

Permalink
Change Point Interactive (#988)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ama16 authored Oct 18, 2022
1 parent 1d06d22 commit 124353a
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Add `plot_change_points_interactive` ([#988](https://github.com/tinkoff-ai/etna/pull/988))
-
-
-
Expand Down
1 change: 1 addition & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from etna.analysis.plotters import plot_anomalies_interactive
from etna.analysis.plotters import plot_backtest
from etna.analysis.plotters import plot_backtest_interactive
from etna.analysis.plotters import plot_change_points_interactive
from etna.analysis.plotters import plot_clusters
from etna.analysis.plotters import plot_correlation_matrix
from etna.analysis.plotters import plot_feature_relevance
Expand Down
139 changes: 139 additions & 0 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import plotly.graph_objects as go
import seaborn as sns
from matplotlib.lines import Line2D
from ruptures.base import BaseCost
from ruptures.base import BaseEstimator
from ruptures.exceptions import BadSegmentationParameters
from scipy.signal import periodogram
from typing_extensions import Literal

Expand Down Expand Up @@ -1819,3 +1822,139 @@ def metric_per_segment_distribution_plot(

plt.title("Metric per-segment distribution plot")
plt.grid()


def plot_change_points_interactive(
ts,
change_point_model: BaseEstimator,
model: BaseCost,
params_bounds: Dict[str, Tuple[Union[int, float], Union[int, float], Union[int, float]]],
model_params: List[str],
predict_params: List[str],
in_column: str = "target",
segments: Optional[List[str]] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 5),
start: Optional[str] = None,
end: Optional[str] = None,
):
"""Plot a time series with indicated change points.
Change points are obtained using the specified method. The method parameters values
can be changed using the corresponding sliders.
Parameters
----------
ts:
TSDataset with timeseries data
change_point_model:
model to get trend change points
model:
binseg segment model, ["l1", "l2", "rbf",...]. Not used if 'custom_cost' is not None
params_bounds:
Parameters ranges of the change points detection. Bounds for the parameter are (min,max,step)
model_params:
List of iterable parameters for initialize the model
predict_params:
List of iterable parameters for predict method
in_column:
column to plot
segments:
segments to use
columns_num:
number of subplots columns
figsize:
size of the figure in inches
start:
start timestamp for plot
end:
end timestamp for plot
Notes
-----
Jupyter notebook might display the results incorrectly,
in this case try to use ``!jupyter nbextension enable --py widgetsnbextension``.
Examples
--------
>>> from etna.datasets import TSDataset
>>> from etna.datasets import generate_ar_df
>>> from etna.analysis import plot_change_points_interactive
>>> from ruptures.detection import Binseg
>>> classic_df = generate_ar_df(periods=1000, start_time="2021-08-01", n_segments=2)
>>> df = TSDataset.to_dataset(classic_df)
>>> ts = TSDataset(df, "D")
>>> params_bounds = {"n_bkps": [0, 5, 1], "min_size":[1,10,3]}
>>> plot_change_points_interactive(ts=ts, change_point_model=Binseg, model="l2", params_bounds=params_bounds, model_params=["min_size"], predict_params=["n_bkps"], figsize=(20, 10)) # doctest: +SKIP
"""
from ipywidgets import FloatSlider
from ipywidgets import IntSlider
from ipywidgets import interact

if segments is None:
segments = sorted(ts.segments)

cache = {}

sliders = dict()
style = {"description_width": "initial"}
for param, bounds in params_bounds.items():
min_, max_, step = bounds
if isinstance(min_, float) or isinstance(max_, float) or isinstance(step, float):
sliders[param] = FloatSlider(min=min_, max=max_, step=step, continuous_update=False, style=style)
else:
sliders[param] = IntSlider(min=min_, max=max_, step=step, continuous_update=False, style=style)

def update(**kwargs):
_, ax = prepare_axes(num_plots=len(segments), columns_num=columns_num, figsize=figsize)

key = "_".join([str(val) for val in kwargs.values()])

is_fitted = False

if key not in cache:
m_params = {x: kwargs[x] for x in model_params}
p_params = {x: kwargs[x] for x in predict_params}
cache[key] = {}
else:
is_fitted = True

for i, segment in enumerate(segments):
ax[i].cla()
segment_df = ts[start:end, segment, :][segment]
timestamp = segment_df.index.values
target = segment_df[in_column].values

if not is_fitted:
try:
algo = change_point_model(model=model, **m_params).fit(signal=target)
bkps = algo.predict(**p_params)
cache[key][segment] = bkps
cache[key][segment].insert(0, 1)
except BadSegmentationParameters:
cache[key][segment] = None

segment_bkps = cache[key][segment]

if segment_bkps is not None:
for idx in range(len(segment_bkps[:-1])):
bkp = segment_bkps[idx] - 1
start_time = timestamp[bkp]
end_time = timestamp[segment_bkps[idx + 1] - 1]
selected_indices = (timestamp >= start_time) & (timestamp <= end_time)
cur_timestamp = timestamp[selected_indices]
cur_target = target[selected_indices]
ax[i].plot(cur_timestamp, cur_target)
if bkp != 0:
ax[i].axvline(timestamp[bkp], linestyle="dashed", c="grey")

else:
box = {"facecolor": "grey", "edgecolor": "red", "boxstyle": "round"}
ax[i].text(
0.5, 0.4, "Parameters\nError", bbox=box, horizontalalignment="center", color="white", fontsize=50
)
ax[i].set_title(segment)
ax[i].tick_params("x", rotation=45)
plt.show()

interact(update, **sliders)
140 changes: 128 additions & 12 deletions examples/EDA.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ We have prepared a set of tutorials for an easy introduction:
- Outliers
- Median method
- Density method
- Change Points
- Change points plot
- Interactive change points plot
#### 04. [Outliers](https://github.com/tinkoff-ai/etna/tree/master/examples/outliers.ipynb)
- Point outliers
- Median method
Expand Down

1 comment on commit 124353a

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.