Skip to content

Commit

Permalink
[src&tests] Add MetricTracker (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente authored Feb 2, 2021
1 parent c041060 commit 24e7636
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 4 deletions.
91 changes: 88 additions & 3 deletions asteroid/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import warnings
import traceback
from collections import Counter
from typing import List

from collections import Counter
import pandas as pd
import numpy as np
from pb_bss_eval import InputMetrics, OutputMetrics
Expand Down Expand Up @@ -30,7 +30,7 @@ def get_metrics(
clean (np.array): reference array.
estimate (np.array): estimate array.
sample_rate (int): sampling rate of the audio clips.
metrics_list (Union [str, list]): List of metrics to compute.
metrics_list (Union[List[str], str): List of metrics to compute.
Defaults to 'all' (['si_sdr', 'sdr', 'sir', 'sar', 'stoi', 'pesq']).
average (bool): Return dict([float]) if True, else dict([array]).
compute_permutation (bool): Whether to compute the permutation on
Expand Down Expand Up @@ -115,6 +115,91 @@ def get_metrics(
return utt_metrics


class MetricTracker:
"""Metric tracker, subject to change.
Args:
sample_rate (int): sampling rate of the audio clips.
metrics_list (Union[List[str], str): List of metrics to compute.
Defaults to 'all' (['si_sdr', 'sdr', 'sir', 'sar', 'stoi', 'pesq']).
average (bool): Return dict([float]) if True, else dict([array]).
compute_permutation (bool): Whether to compute the permutation on
estimate sources for the output metrics (default False)
ignore_metrics_errors (bool): Whether to ignore errors that occur in
computing the metrics. A warning will be printed instead.
"""

def __init__(
self,
sample_rate,
metrics_list=tuple(ALL_METRICS),
average=True,
compute_permutation=False,
ignore_metrics_errors=False,
):
self.sample_rate = sample_rate
# TODO: support WER in metrics_list when merged.
self.metrics_list = metrics_list
self.average = average
self.compute_permutation = compute_permutation
self.ignore_metrics_errors = ignore_metrics_errors

self.series_list = []
self._len_last_saved = 0
self._all_metrics = pd.DataFrame()

def __call__(
self, *, mix: np.ndarray, clean: np.ndarray, estimate: np.ndarray, filename=None, **kwargs
):
"""Compute metrics for mix/clean/estimate and log it to the class.
Args:
mix (np.array): mixture array.
clean (np.array): reference array.
estimate (np.array): estimate array.
sample_rate (int): sampling rate of the audio clips.
filename (str, optional): If computing a metric fails, print this
filename along with the exception/warning message for debugging purposes.
**kwargs: Any key, value pair to log in the utterance metric (filename, speaker ID, etc...)
"""
utt_metrics = get_metrics(
mix,
clean,
estimate,
sample_rate=self.sample_rate,
metrics_list=self.metrics_list,
average=self.average,
compute_permutation=self.compute_permutation,
ignore_metrics_errors=self.ignore_metrics_errors,
filename=filename,
)
utt_metrics.update(kwargs)
self.series_list.append(pd.Series(utt_metrics))

def as_df(self):
"""Return dataframe containing the results (cached)."""
if self._len_last_saved == len(self.series_list):
return self._all_metrics
self._len_last_saved = len(self.series_list)
self._all_metrics = pd.DataFrame(self.series_list)
return pd.DataFrame(self.series_list)

def final_report(self, dump_path: str = None):
"""Return dict of average metrics. Dump to JSON if `dump_path` is not None."""
final_results = {}
metrics_df = self.as_df()
for metric_name in self.metrics_list:
input_metric_name = "input_" + metric_name
ldf = metrics_df[metric_name] - metrics_df[input_metric_name]
final_results[metric_name] = metrics_df[metric_name].mean()
final_results[metric_name + "_imp"] = ldf.mean()
if dump_path is not None:
dump_path = dump_path + ".json" if not dump_path.endswith(".json") else dump_path
with open(dump_path, "w") as f:
json.dump(final_results, f, indent=0)
return final_results


class MockWERTracker:
def __init__(self, *args, **kwargs):
pass
Expand Down
18 changes: 17 additions & 1 deletion tests/metrics_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest import mock
import numpy as np
import pytest
from asteroid.metrics import get_metrics
from asteroid.metrics import get_metrics, MetricTracker


@pytest.mark.parametrize("fs", [8000, 16000])
Expand Down Expand Up @@ -69,3 +69,19 @@ def test_ignore_errors(filename, average):
)
assert metrics_dict["si_sdr"] is None
assert metrics_dict["pesq"] is not None


def test_metric_tracker():
metric_tracker = MetricTracker(sample_rate=8000, metrics_list=["si_sdr", "stoi"])
for i in range(5):
mix = np.random.randn(1, 4000)
clean = np.random.randn(1, 4000)
est = np.random.randn(1, 4000)
metric_tracker(mix=mix, clean=clean, estimate=est, mix_path=f"path{i}")

# Test dump & final report
metric_tracker.final_report()
metric_tracker.final_report(dump_path="final_metrics.json")

# Check that kwargs are passed.
assert "mix_path" in metric_tracker.as_df()

0 comments on commit 24e7636

Please sign in to comment.