Skip to content

Commit

Permalink
issue-726-fix (#753)
Browse files Browse the repository at this point in the history
  • Loading branch information
DBcreator authored Jun 16, 2022
1 parent 88f4766 commit 4b034be
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
### Changed
- Add columns and mode parameters in plot_correlation_matrix ([#726](https://github.com/tinkoff-ai/etna/pull/753))
-
-
-
Expand Down
57 changes: 48 additions & 9 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,14 +633,19 @@ def plot_anomalies(


def get_correlation_matrix(
ts: "TSDataset", segments: Optional[List[str]] = None, method: str = "pearson"
ts: "TSDataset",
columns: Optional[List[str]] = None,
segments: Optional[List[str]] = None,
method: str = "pearson",
) -> np.ndarray:
"""Compute pairwise correlation of timeseries for selected segments.
Parameters
----------
ts:
TSDataset with timeseries data
columns:
Columns to use, if None use all columns
segments:
Segments to use
method:
Expand All @@ -659,16 +664,23 @@ def get_correlation_matrix(
"""
if method not in ["pearson", "kendall", "spearman"]:
raise ValueError(f"'{method}' is not a valid method of correlation.")

if segments is None:
segments = sorted(ts.segments)
correlation_matrix = ts[:, segments, :].corr(method=method).values
if columns is None:
columns = list(set(ts.df.columns.get_level_values("feature")))

correlation_matrix = ts[:, segments, columns].corr(method=method).values
return correlation_matrix


def plot_correlation_matrix(
ts: "TSDataset",
columns: Optional[List[str]] = None,
segments: Optional[List[str]] = None,
method: str = "pearson",
mode: str = "macro",
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 10),
**heatmap_kwargs,
):
Expand All @@ -678,6 +690,8 @@ def plot_correlation_matrix(
----------
ts:
TSDataset with timeseries data
columns:
Columns to use, if None use all columns
segments:
Segments to use
method:
Expand All @@ -689,23 +703,48 @@ def plot_correlation_matrix(
* spearman: Spearman rank correlation
mode: 'macro' or 'per-segment'
Aggregation mode
columns_num:
Number of subplots columns
figsize:
size of the figure in inches
"""
if segments is None:
segments = sorted(ts.segments)
if columns is None:
columns = list(set(ts.df.columns.get_level_values("feature")))
if "vmin" not in heatmap_kwargs:
heatmap_kwargs["vmin"] = -1
if "vmax" not in heatmap_kwargs:
heatmap_kwargs["vmax"] = 1

correlation_matrix = get_correlation_matrix(ts, segments, method)
fig, ax = plt.subplots(figsize=figsize)
ax = sns.heatmap(correlation_matrix, annot=True, fmt=".1g", square=True, ax=ax, **heatmap_kwargs)
labels = list(ts[:, segments, :].columns.values)
ax.set_xticklabels(labels, rotation=45, horizontalalignment="right")
ax.set_yticklabels(labels, rotation=0, horizontalalignment="right")
ax.set_title("Correlation Heatmap")
if mode not in ["macro", "per-segment"]:
raise ValueError(f"'{mode}' is not a valid method of mode.")

if mode == "macro":
fig, ax = plt.subplots(figsize=figsize)
correlation_matrix = get_correlation_matrix(ts, columns, segments, method)
labels = list(ts[:, segments, columns].columns.values)
ax = sns.heatmap(correlation_matrix, annot=True, fmt=".1g", square=True, ax=ax, **heatmap_kwargs)
ax.set_xticks(np.arange(len(labels)) + 0.5, labels=labels)
ax.set_yticks(np.arange(len(labels)) + 0.5, labels=labels)
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.setp(ax.get_yticklabels(), rotation=0, ha="right", rotation_mode="anchor")
ax.set_title("Correlation Heatmap")

if mode == "per-segment":
fig, ax = prepare_axes(len(segments), columns_num=columns_num, figsize=figsize)

for i, segment in enumerate(segments):
correlation_matrix = get_correlation_matrix(ts, columns, [segment], method)
labels = list(ts[:, segment, columns].columns.values)
ax[i] = sns.heatmap(correlation_matrix, annot=True, fmt=".1g", square=True, ax=ax[i], **heatmap_kwargs)
ax[i].set_xticks(np.arange(len(labels)) + 0.5, labels=labels)
ax[i].set_yticks(np.arange(len(labels)) + 0.5, labels=labels)
plt.setp(ax[i].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.setp(ax[i].get_yticklabels(), rotation=0, ha="right", rotation_mode="anchor")
ax[i].set_title("Correlation Heatmap" + " " + segment)


def plot_anomalies_interactive(
Expand Down

1 comment on commit 4b034be

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.