diff --git a/CHANGELOG.md b/CHANGELOG.md index 24b6f526e..6249ff942 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - ### Changed +- Add columns and mode parameters in plot_correlation_matrix ([#726](https://github.com/tinkoff-ai/etna/pull/753)) - - - diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 11857b226..9c1fe5e9f 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -633,7 +633,10 @@ def plot_anomalies( def get_correlation_matrix( - ts: "TSDataset", segments: Optional[List[str]] = None, method: str = "pearson" + ts: "TSDataset", + columns: Optional[List[str]] = None, + segments: Optional[List[str]] = None, + method: str = "pearson", ) -> np.ndarray: """Compute pairwise correlation of timeseries for selected segments. @@ -641,6 +644,8 @@ def get_correlation_matrix( ---------- ts: TSDataset with timeseries data + columns: + Columns to use, if None use all columns segments: Segments to use method: @@ -659,16 +664,23 @@ def get_correlation_matrix( """ if method not in ["pearson", "kendall", "spearman"]: raise ValueError(f"'{method}' is not a valid method of correlation.") + if segments is None: segments = sorted(ts.segments) - correlation_matrix = ts[:, segments, :].corr(method=method).values + if columns is None: + columns = list(set(ts.df.columns.get_level_values("feature"))) + + correlation_matrix = ts[:, segments, columns].corr(method=method).values return correlation_matrix def plot_correlation_matrix( ts: "TSDataset", + columns: Optional[List[str]] = None, segments: Optional[List[str]] = None, method: str = "pearson", + mode: str = "macro", + columns_num: int = 2, figsize: Tuple[int, int] = (10, 10), **heatmap_kwargs, ): @@ -678,6 +690,8 @@ def plot_correlation_matrix( ---------- ts: TSDataset with timeseries data + columns: + Columns to use, if None use all columns segments: Segments to use method: @@ -689,23 +703,48 @@ def plot_correlation_matrix( * spearman: Spearman rank correlation + mode: 'macro' or 'per-segment' + Aggregation mode + columns_num: + Number of subplots columns figsize: size of the figure in inches """ if segments is None: segments = sorted(ts.segments) + if columns is None: + columns = list(set(ts.df.columns.get_level_values("feature"))) if "vmin" not in heatmap_kwargs: heatmap_kwargs["vmin"] = -1 if "vmax" not in heatmap_kwargs: heatmap_kwargs["vmax"] = 1 - correlation_matrix = get_correlation_matrix(ts, segments, method) - fig, ax = plt.subplots(figsize=figsize) - ax = sns.heatmap(correlation_matrix, annot=True, fmt=".1g", square=True, ax=ax, **heatmap_kwargs) - labels = list(ts[:, segments, :].columns.values) - ax.set_xticklabels(labels, rotation=45, horizontalalignment="right") - ax.set_yticklabels(labels, rotation=0, horizontalalignment="right") - ax.set_title("Correlation Heatmap") + if mode not in ["macro", "per-segment"]: + raise ValueError(f"'{mode}' is not a valid method of mode.") + + if mode == "macro": + fig, ax = plt.subplots(figsize=figsize) + correlation_matrix = get_correlation_matrix(ts, columns, segments, method) + labels = list(ts[:, segments, columns].columns.values) + ax = sns.heatmap(correlation_matrix, annot=True, fmt=".1g", square=True, ax=ax, **heatmap_kwargs) + ax.set_xticks(np.arange(len(labels)) + 0.5, labels=labels) + ax.set_yticks(np.arange(len(labels)) + 0.5, labels=labels) + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + plt.setp(ax.get_yticklabels(), rotation=0, ha="right", rotation_mode="anchor") + ax.set_title("Correlation Heatmap") + + if mode == "per-segment": + fig, ax = prepare_axes(len(segments), columns_num=columns_num, figsize=figsize) + + for i, segment in enumerate(segments): + correlation_matrix = get_correlation_matrix(ts, columns, [segment], method) + labels = list(ts[:, segment, columns].columns.values) + ax[i] = sns.heatmap(correlation_matrix, annot=True, fmt=".1g", square=True, ax=ax[i], **heatmap_kwargs) + ax[i].set_xticks(np.arange(len(labels)) + 0.5, labels=labels) + ax[i].set_yticks(np.arange(len(labels)) + 0.5, labels=labels) + plt.setp(ax[i].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + plt.setp(ax[i].get_yticklabels(), rotation=0, ha="right", rotation_mode="anchor") + ax[i].set_title("Correlation Heatmap" + " " + segment) def plot_anomalies_interactive(