Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiprocessing in PESQ #1227

Merged
merged 10 commits into from
Sep 16, 2022
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added argument `normalize` to `LPIPS` metric ([#1216](https://github.com/Lightning-AI/metrics/pull/1216))


- Added support for multiprocessing of batches in `PESQ` metric ([#1227](https://github.com/Lightning-AI/metrics/pull/1227))

### Changed

- Classification refactor (
Expand Down
13 changes: 9 additions & 4 deletions src/torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class PerceptualEvaluationSpeechQuality(Metric):
fs: sampling frequency, should be 16000 or 8000 (Hz)
mode: ``'wb'`` (wide-band) or ``'nb'`` (narrow-band)
keep_same_device: whether to move the pesq value to the device of preds

n_processes: integer specifiying the number of processes to run in parallel for the metric calculation.
Only applies to batches of data and if ``multiprocessing`` package is installed.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
self,
fs: int,
mode: str,
n_processes: int = 1,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -94,6 +96,9 @@ def __init__(
if mode not in ("wb", "nb"):
raise ValueError(f"Expected argument `mode` to either be 'wb' or 'nb' but got {mode}")
self.mode = mode
if not isinstance(n_processes, int) and n_processes <= 0:
raise ValueError(f"Expected argument `n_processes` to be an int larger than 0 but got {n_processes}")
self.n_processes = n_processes

self.add_state("sum_pesq", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
Expand All @@ -105,9 +110,9 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds: Predictions from model
target: Ground truth values
"""
pesq_batch = perceptual_evaluation_speech_quality(preds, target, self.fs, self.mode, False).to(
self.sum_pesq.device
)
pesq_batch = perceptual_evaluation_speech_quality(
preds, target, self.fs, self.mode, False, self.n_processes
).to(self.sum_pesq.device)

self.sum_pesq += pesq_batch.sum()
self.total += pesq_batch.numel()
Expand Down
22 changes: 17 additions & 5 deletions src/torchmetrics/functional/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _PESQ_AVAILABLE
from torchmetrics.utilities.imports import _MULTIPROCESSING_AVAILABLE, _PESQ_AVAILABLE

if _PESQ_AVAILABLE:
import pesq as pesq_backend
Expand All @@ -28,7 +28,12 @@


def perceptual_evaluation_speech_quality(
preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bool = False
preds: Tensor,
target: Tensor,
fs: int,
mode: str,
keep_same_device: bool = False,
n_processes: int = 1,
) -> Tensor:
r"""PESQ (Perceptual Evaluation of Speech Quality)

Expand All @@ -46,6 +51,8 @@ def perceptual_evaluation_speech_quality(
fs: sampling frequency, should be 16000 or 8000 (Hz)
mode: ``'wb'`` (wide-band) or ``'nb'`` (narrow-band)
keep_same_device: whether to move the pesq value to the device of preds
n_processes: integer specifiying the number of processes to run in parallel for the metric calculation.
Only applies to batches of data and if ``multiprocessing`` package is installed.

Returns:
pesq value of shape [...]
Expand Down Expand Up @@ -89,9 +96,14 @@ def perceptual_evaluation_speech_quality(
else:
preds_np = preds.reshape(-1, preds.shape[-1]).detach().cpu().numpy()
target_np = target.reshape(-1, preds.shape[-1]).detach().cpu().numpy()
pesq_val_np = np.empty(shape=(preds_np.shape[0]))
for b in range(preds_np.shape[0]):
pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode)

if _MULTIPROCESSING_AVAILABLE and n_processes != 1:
pesq_val_np = pesq_backend.pesq_batch(fs, target_np, preds_np, mode, n_processor=n_processes)
pesq_val_np = np.array(pesq_val_np)
else:
pesq_val_np = np.empty(shape=(preds_np.shape[0]))
for b in range(preds_np.shape[0]):
pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode)
pesq_val = torch.from_numpy(pesq_val_np)
pesq_val = pesq_val.reshape(preds.shape[:-1])

Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,4 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool]
_REGEX_AVAILABLE: bool = _package_available("regex")
_PYSTOI_AVAILABLE: bool = _package_available("pystoi")
_FAST_BSS_EVAL_AVAILABLE: bool = _package_available("fast_bss_eval")
_MULTIPROCESSING_AVAILABLE: bool = _package_available("multiprocessing")
12 changes: 8 additions & 4 deletions tests/unittests/audio/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,26 +77,30 @@ def average_metric(preds, target, metric_func):
class TestPESQ(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("n_processes", [1, 2])
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_pesq(self, preds, target, sk_metric, fs, mode, ddp, dist_sync_on_step):
def test_pesq(self, preds, target, sk_metric, fs, mode, n_processes, ddp, dist_sync_on_step):
if n_processes != 1 and ddp:
pytest.skip("Multiprocessing and ddp does not work together")
self.run_class_metric_test(
ddp,
preds,
target,
PerceptualEvaluationSpeechQuality,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(fs=fs, mode=mode),
metric_args=dict(fs=fs, mode=mode, n_processes=n_processes),
)

def test_pesq_functional(self, preds, target, sk_metric, fs, mode):
@pytest.mark.parametrize("n_processes", [1, 2])
def test_pesq_functional(self, preds, target, sk_metric, fs, mode, n_processes):
self.run_functional_metric_test(
preds,
target,
perceptual_evaluation_speech_quality,
sk_metric,
metric_args=dict(fs=fs, mode=mode),
metric_args=dict(fs=fs, mode=mode, n_processes=n_processes),
)

def test_pesq_differentiability(self, preds, target, sk_metric, fs, mode):
Expand Down