Skip to content

Commit

Permalink
Add support for multiprocessing in PESQ (#1227)
Browse files Browse the repository at this point in the history
* multiprocessing

* tests

* changelog

* fix tests

* add docs

* Apply suggestions from code review
  • Loading branch information
SkafteNicki authored Sep 16, 2022
1 parent 5659805 commit fbe06ef
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added argument `normalize` to `LPIPS` metric ([#1216](https://github.com/Lightning-AI/metrics/pull/1216))


- Added support for multiprocessing of batches in `PESQ` metric ([#1227](https://github.com/Lightning-AI/metrics/pull/1227))

### Changed

- Classification refactor (
Expand Down
13 changes: 9 additions & 4 deletions src/torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class PerceptualEvaluationSpeechQuality(Metric):
fs: sampling frequency, should be 16000 or 8000 (Hz)
mode: ``'wb'`` (wide-band) or ``'nb'`` (narrow-band)
keep_same_device: whether to move the pesq value to the device of preds
n_processes: integer specifiying the number of processes to run in parallel for the metric calculation.
Only applies to batches of data and if ``multiprocessing`` package is installed.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
self,
fs: int,
mode: str,
n_processes: int = 1,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -94,6 +96,9 @@ def __init__(
if mode not in ("wb", "nb"):
raise ValueError(f"Expected argument `mode` to either be 'wb' or 'nb' but got {mode}")
self.mode = mode
if not isinstance(n_processes, int) and n_processes <= 0:
raise ValueError(f"Expected argument `n_processes` to be an int larger than 0 but got {n_processes}")
self.n_processes = n_processes

self.add_state("sum_pesq", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
Expand All @@ -105,9 +110,9 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds: Predictions from model
target: Ground truth values
"""
pesq_batch = perceptual_evaluation_speech_quality(preds, target, self.fs, self.mode, False).to(
self.sum_pesq.device
)
pesq_batch = perceptual_evaluation_speech_quality(
preds, target, self.fs, self.mode, False, self.n_processes
).to(self.sum_pesq.device)

self.sum_pesq += pesq_batch.sum()
self.total += pesq_batch.numel()
Expand Down
22 changes: 17 additions & 5 deletions src/torchmetrics/functional/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _PESQ_AVAILABLE
from torchmetrics.utilities.imports import _MULTIPROCESSING_AVAILABLE, _PESQ_AVAILABLE

if _PESQ_AVAILABLE:
import pesq as pesq_backend
Expand All @@ -28,7 +28,12 @@


def perceptual_evaluation_speech_quality(
preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bool = False
preds: Tensor,
target: Tensor,
fs: int,
mode: str,
keep_same_device: bool = False,
n_processes: int = 1,
) -> Tensor:
r"""PESQ (Perceptual Evaluation of Speech Quality)
Expand All @@ -46,6 +51,8 @@ def perceptual_evaluation_speech_quality(
fs: sampling frequency, should be 16000 or 8000 (Hz)
mode: ``'wb'`` (wide-band) or ``'nb'`` (narrow-band)
keep_same_device: whether to move the pesq value to the device of preds
n_processes: integer specifiying the number of processes to run in parallel for the metric calculation.
Only applies to batches of data and if ``multiprocessing`` package is installed.
Returns:
pesq value of shape [...]
Expand Down Expand Up @@ -89,9 +96,14 @@ def perceptual_evaluation_speech_quality(
else:
preds_np = preds.reshape(-1, preds.shape[-1]).detach().cpu().numpy()
target_np = target.reshape(-1, preds.shape[-1]).detach().cpu().numpy()
pesq_val_np = np.empty(shape=(preds_np.shape[0]))
for b in range(preds_np.shape[0]):
pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode)

if _MULTIPROCESSING_AVAILABLE and n_processes != 1:
pesq_val_np = pesq_backend.pesq_batch(fs, target_np, preds_np, mode, n_processor=n_processes)
pesq_val_np = np.array(pesq_val_np)
else:
pesq_val_np = np.empty(shape=(preds_np.shape[0]))
for b in range(preds_np.shape[0]):
pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode)
pesq_val = torch.from_numpy(pesq_val_np)
pesq_val = pesq_val.reshape(preds.shape[:-1])

Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,4 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool]
_REGEX_AVAILABLE: bool = _package_available("regex")
_PYSTOI_AVAILABLE: bool = _package_available("pystoi")
_FAST_BSS_EVAL_AVAILABLE: bool = _package_available("fast_bss_eval")
_MULTIPROCESSING_AVAILABLE: bool = _package_available("multiprocessing")
12 changes: 8 additions & 4 deletions tests/unittests/audio/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,26 +77,30 @@ def average_metric(preds, target, metric_func):
class TestPESQ(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("n_processes", [1, 2])
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step):
def test_pesq(self, preds, target, sk_metric, fs, mode, n_processes, ddp, dist_sync_on_step):
if n_processes != 1 and ddp:
pytest.skip("Multiprocessing and ddp does not work together")
self.run_class_metric_test(
ddp,
preds,
target,
PerceptualEvaluationSpeechQuality,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(fs=fs, mode=mode),
metric_args=dict(fs=fs, mode=mode, n_processes=n_processes),
)

def test_pesq_functional(self, preds, target, sk_metric, fs, mode):
@pytest.mark.parametrize("n_processes", [1, 2])
def test_pesq_functional(self, preds, target, sk_metric, fs, mode, n_processes):
self.run_functional_metric_test(
preds,
target,
perceptual_evaluation_speech_quality,
sk_metric,
metric_args=dict(fs=fs, mode=mode),
metric_args=dict(fs=fs, mode=mode, n_processes=n_processes),
)

def test_pesq_differentiability(self, preds, target, sk_metric, fs, mode):
Expand Down

0 comments on commit fbe06ef

Please sign in to comment.