diff --git a/CHANGELOG.md b/CHANGELOG.md index 15a97e82b2f..23d189e5e2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added argument `normalize` to `LPIPS` metric ([#1216](https://github.com/Lightning-AI/metrics/pull/1216)) +- Added support for multiprocessing of batches in `PESQ` metric ([#1227](https://github.com/Lightning-AI/metrics/pull/1227)) + ### Changed - Classification refactor ( diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index 8d7d8fa5043..89d94b89689 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -42,7 +42,8 @@ class PerceptualEvaluationSpeechQuality(Metric): fs: sampling frequency, should be 16000 or 8000 (Hz) mode: ``'wb'`` (wide-band) or ``'nb'`` (narrow-band) keep_same_device: whether to move the pesq value to the device of preds - + n_processes: integer specifiying the number of processes to run in parallel for the metric calculation. + Only applies to batches of data and if ``multiprocessing`` package is installed. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -80,6 +81,7 @@ def __init__( self, fs: int, mode: str, + n_processes: int = 1, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -94,6 +96,9 @@ def __init__( if mode not in ("wb", "nb"): raise ValueError(f"Expected argument `mode` to either be 'wb' or 'nb' but got {mode}") self.mode = mode + if not isinstance(n_processes, int) and n_processes <= 0: + raise ValueError(f"Expected argument `n_processes` to be an int larger than 0 but got {n_processes}") + self.n_processes = n_processes self.add_state("sum_pesq", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") @@ -105,9 +110,9 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds: Predictions from model target: Ground truth values """ - pesq_batch = perceptual_evaluation_speech_quality(preds, target, self.fs, self.mode, False).to( - self.sum_pesq.device - ) + pesq_batch = perceptual_evaluation_speech_quality( + preds, target, self.fs, self.mode, False, self.n_processes + ).to(self.sum_pesq.device) self.sum_pesq += pesq_batch.sum() self.total += pesq_batch.numel() diff --git a/src/torchmetrics/functional/audio/pesq.py b/src/torchmetrics/functional/audio/pesq.py index 7456edf9763..bf681843770 100644 --- a/src/torchmetrics/functional/audio/pesq.py +++ b/src/torchmetrics/functional/audio/pesq.py @@ -16,7 +16,7 @@ from torch import Tensor from torchmetrics.utilities.checks import _check_same_shape -from torchmetrics.utilities.imports import _PESQ_AVAILABLE +from torchmetrics.utilities.imports import _MULTIPROCESSING_AVAILABLE, _PESQ_AVAILABLE if _PESQ_AVAILABLE: import pesq as pesq_backend @@ -28,7 +28,12 @@ def perceptual_evaluation_speech_quality( - preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bool = False + preds: Tensor, + target: Tensor, + fs: int, + mode: str, + keep_same_device: bool = False, + n_processes: int = 1, ) -> Tensor: r"""PESQ (Perceptual Evaluation of Speech Quality) @@ -46,6 +51,8 @@ def perceptual_evaluation_speech_quality( fs: sampling frequency, should be 16000 or 8000 (Hz) mode: ``'wb'`` (wide-band) or ``'nb'`` (narrow-band) keep_same_device: whether to move the pesq value to the device of preds + n_processes: integer specifiying the number of processes to run in parallel for the metric calculation. + Only applies to batches of data and if ``multiprocessing`` package is installed. Returns: pesq value of shape [...] @@ -89,9 +96,14 @@ def perceptual_evaluation_speech_quality( else: preds_np = preds.reshape(-1, preds.shape[-1]).detach().cpu().numpy() target_np = target.reshape(-1, preds.shape[-1]).detach().cpu().numpy() - pesq_val_np = np.empty(shape=(preds_np.shape[0])) - for b in range(preds_np.shape[0]): - pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode) + + if _MULTIPROCESSING_AVAILABLE and n_processes != 1: + pesq_val_np = pesq_backend.pesq_batch(fs, target_np, preds_np, mode, n_processor=n_processes) + pesq_val_np = np.array(pesq_val_np) + else: + pesq_val_np = np.empty(shape=(preds_np.shape[0])) + for b in range(preds_np.shape[0]): + pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode) pesq_val = torch.from_numpy(pesq_val_np) pesq_val = pesq_val.reshape(preds.shape[:-1]) diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 834b3ff63db..75d0fa65bdf 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -124,3 +124,4 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _REGEX_AVAILABLE: bool = _package_available("regex") _PYSTOI_AVAILABLE: bool = _package_available("pystoi") _FAST_BSS_EVAL_AVAILABLE: bool = _package_available("fast_bss_eval") +_MULTIPROCESSING_AVAILABLE: bool = _package_available("multiprocessing") diff --git a/tests/unittests/audio/test_pesq.py b/tests/unittests/audio/test_pesq.py index f6bfe1def55..8a370761a76 100644 --- a/tests/unittests/audio/test_pesq.py +++ b/tests/unittests/audio/test_pesq.py @@ -77,9 +77,12 @@ def average_metric(preds, target, metric_func): class TestPESQ(MetricTester): atol = 1e-2 + @pytest.mark.parametrize("n_processes", [1, 2]) @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step): + def test_pesq(self, preds, target, sk_metric, fs, mode, n_processes, ddp, dist_sync_on_step): + if n_processes != 1 and ddp: + pytest.skip("Multiprocessing and ddp does not work together") self.run_class_metric_test( ddp, preds, @@ -87,16 +90,17 @@ def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step): PerceptualEvaluationSpeechQuality, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, - metric_args=dict(fs=fs, mode=mode), + metric_args=dict(fs=fs, mode=mode, n_processes=n_processes), ) - def test_pesq_functional(self, preds, target, sk_metric, fs, mode): + @pytest.mark.parametrize("n_processes", [1, 2]) + def test_pesq_functional(self, preds, target, sk_metric, fs, mode, n_processes): self.run_functional_metric_test( preds, target, perceptual_evaluation_speech_quality, sk_metric, - metric_args=dict(fs=fs, mode=mode), + metric_args=dict(fs=fs, mode=mode, n_processes=n_processes), ) def test_pesq_differentiability(self, preds, target, sk_metric, fs, mode):