From 9bcbdf4235b4341cd0d4090400483d6eaefa3bee Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 8 Jan 2022 20:33:52 +0100 Subject: [PATCH] Refactor: SDR & SI_SDR (#711) * signal_distortion_ratio * scale_invariant_signal_distortion_ratio * SignalDistortionRatio * ScaleInvariantSignalDistortionRatio Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 7 ++ README.md | 2 +- docs/source/references/functional.rst | 12 +- docs/source/references/modules.rst | 12 +- tests/audio/test_pit.py | 10 +- tests/audio/test_sdr.py | 26 ++-- tests/audio/test_si_sdr.py | 18 +-- torchmetrics/__init__.py | 12 +- torchmetrics/audio/__init__.py | 2 +- torchmetrics/audio/sdr.py | 147 ++++++++++++++++++++-- torchmetrics/audio/si_sdr.py | 82 ++---------- torchmetrics/functional/__init__.py | 4 +- torchmetrics/functional/audio/__init__.py | 6 +- torchmetrics/functional/audio/pit.py | 8 +- torchmetrics/functional/audio/sdr.py | 82 +++++++++++- torchmetrics/functional/audio/si_sdr.py | 50 ++------ torchmetrics/functional/audio/si_snr.py | 4 +- 17 files changed, 313 insertions(+), 171 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f86a4e62a1a..303c6bf35f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,6 +64,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `SpearmanCorrcoef` -> `SpearmanCorrCoef` +- Renamed audio SDR metrics: ([#711](https://github.com/PyTorchLightning/metrics/pull/711)) + * `functional.sdr` -> `functional.signal_distortion_ratio` + * `functional.si_sdr` -> `functional.scale_invariant_signal_distortion_ratio` + * `SDR` -> `SignalDistortionRatio` + * `SI_SDR` -> `ScaleInvariantSignalDistortionRatio` + + ### Removed - Removed `embedding_similarity` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638)) diff --git a/README.md b/README.md index 9251c47d121..83a11254556 100644 --- a/README.md +++ b/README.md @@ -266,7 +266,7 @@ acc = torchmetrics.functional.accuracy(preds, target) We currently have implemented metrics within the following domains: - Audio ( - [SI_SDR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-sdr), + [ScaleInvariantSignalDistortionRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSignalDistortionRatio), [SI_SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-snr), [SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#snr) and [few more](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#audio-metrics) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 00f63adb5a4..96aae85d36c 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -24,17 +24,17 @@ pit [func] :noindex: -sdr [func] -~~~~~~~~~~ +signal_distortion_ratio [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.sdr +.. autofunction:: torchmetrics.functional.signal_distortion_ratio :noindex: -si_sdr [func] -~~~~~~~~~~~~~ +scale_invariant_signal_distortion_ratio [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.si_sdr +.. autofunction:: torchmetrics.functional.scale_invariant_signal_distortion_ratio :noindex: diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index c666456e9d6..8a660668a8d 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -85,16 +85,16 @@ PIT .. autoclass:: torchmetrics.PIT :noindex: -SDR -~~~ +SignalDistortionRatio +~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: torchmetrics.SDR +.. autoclass:: torchmetrics.SignalDistortionRatio :noindex: -SI_SDR -~~~~~~ +ScaleInvariantSignalDistortionRatio +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: torchmetrics.SI_SDR +.. autoclass:: torchmetrics.ScaleInvariantSignalDistortionRatio :noindex: SI_SNR diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index 22b229cdbb5..71406a361b2 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -24,7 +24,7 @@ from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester from torchmetrics.audio import PIT -from torchmetrics.functional import pit, si_sdr, snr +from torchmetrics.functional import pit, scale_invariant_signal_distortion_ratio, snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) @@ -97,16 +97,18 @@ def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Ten snr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=snr, eval_func="max") -si_sdr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=si_sdr, eval_func="max") +si_sdr_pit_scipy = partial( + naive_implementation_pit_scipy, metric_func=scale_invariant_signal_distortion_ratio, eval_func="max" +) @pytest.mark.parametrize( "preds, target, sk_metric, metric_func, eval_func", [ (inputs1.preds, inputs1.target, snr_pit_scipy, snr, "max"), - (inputs1.preds, inputs1.target, si_sdr_pit_scipy, si_sdr, "max"), + (inputs1.preds, inputs1.target, si_sdr_pit_scipy, scale_invariant_signal_distortion_ratio, "max"), (inputs2.preds, inputs2.target, snr_pit_scipy, snr, "max"), - (inputs2.preds, inputs2.target, si_sdr_pit_scipy, si_sdr, "max"), + (inputs2.preds, inputs2.target, si_sdr_pit_scipy, scale_invariant_signal_distortion_ratio, "max"), ], ) class TestPIT(MetricTester): diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index 2f28e8976f3..b760a6aaa29 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -24,8 +24,8 @@ from tests.helpers import seed_all from tests.helpers.testers import MetricTester -from torchmetrics.audio import SDR -from torchmetrics.functional import sdr +from torchmetrics.audio import SignalDistortionRatio +from torchmetrics.functional import signal_distortion_ratio from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8 seed_all(42) @@ -49,7 +49,7 @@ def sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool preds = preds.detach().cpu().numpy() mss = [] for b in range(preds.shape[0]): - sdr_val_np, sir_val_np, sar_val_np, perm = bss_eval_sources(target[b], preds[b], compute_permutation) + sdr_val_np, _, _, _ = bss_eval_sources(target[b], preds[b], compute_permutation) mss.append(sdr_val_np) return torch.tensor(mss) @@ -83,7 +83,7 @@ def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step): ddp, preds, target, - SDR, + SignalDistortionRatio, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, metric_args=dict(), @@ -93,7 +93,7 @@ def test_sdr_functional(self, preds, target, sk_metric): self.run_functional_metric_test( preds, target, - sdr, + signal_distortion_ratio, sk_metric, metric_args=dict(), ) @@ -103,8 +103,8 @@ def test_sdr_differentiability(self, preds, target, sk_metric): self.run_differentiability_test( preds=preds, target=target, - metric_module=SDR, - metric_functional=sdr, + metric_module=SignalDistortionRatio, + metric_functional=signal_distortion_ratio, metric_args=dict(), ) @@ -115,8 +115,8 @@ def test_sdr_half_cpu(self, preds, target, sk_metric): self.run_precision_test_cpu( preds=preds, target=target, - metric_module=SDR, - metric_functional=sdr, + metric_module=SignalDistortionRatio, + metric_functional=signal_distortion_ratio, metric_args=dict(), ) @@ -125,13 +125,13 @@ def test_sdr_half_gpu(self, preds, target, sk_metric): self.run_precision_test_gpu( preds=preds, target=target, - metric_module=SDR, - metric_functional=sdr, + metric_module=SignalDistortionRatio, + metric_functional=signal_distortion_ratio, metric_args=dict(), ) -def test_error_on_different_shape(metric_class=SDR): +def test_error_on_different_shape(metric_class=SignalDistortionRatio): metric = metric_class() with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) @@ -143,7 +143,7 @@ def test_on_real_audio(): rate, ref = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech.wav")) rate, deg = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech_bab_0dB.wav")) assert torch.allclose( - sdr(torch.from_numpy(deg), torch.from_numpy(ref)).float(), + signal_distortion_ratio(torch.from_numpy(deg), torch.from_numpy(ref)).float(), torch.tensor(0.2211), rtol=0.0001, atol=1e-4, diff --git a/tests/audio/test_si_sdr.py b/tests/audio/test_si_sdr.py index 479b2c09c6a..32911592459 100644 --- a/tests/audio/test_si_sdr.py +++ b/tests/audio/test_si_sdr.py @@ -21,8 +21,8 @@ from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.audio import SI_SDR -from torchmetrics.functional import si_sdr +from torchmetrics.audio import ScaleInvariantSignalDistortionRatio +from torchmetrics.functional import scale_invariant_signal_distortion_ratio from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) @@ -84,7 +84,7 @@ def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_ste ddp, preds, target, - SI_SDR, + ScaleInvariantSignalDistortionRatio, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, metric_args=dict(zero_mean=zero_mean), @@ -94,7 +94,7 @@ def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean): self.run_functional_metric_test( preds, target, - si_sdr, + scale_invariant_signal_distortion_ratio, sk_metric, metric_args=dict(zero_mean=zero_mean), ) @@ -103,8 +103,8 @@ def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean): self.run_differentiability_test( preds=preds, target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, + metric_module=ScaleInvariantSignalDistortionRatio, + metric_functional=scale_invariant_signal_distortion_ratio, metric_args={"zero_mean": zero_mean}, ) @@ -119,13 +119,13 @@ def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean): self.run_precision_test_gpu( preds=preds, target=target, - metric_module=SI_SDR, - metric_functional=si_sdr, + metric_module=ScaleInvariantSignalDistortionRatio, + metric_functional=scale_invariant_signal_distortion_ratio, metric_args={"zero_mean": zero_mean}, ) -def test_error_on_different_shape(metric_class=SI_SDR): +def test_error_on_different_shape(metric_class=ScaleInvariantSignalDistortionRatio): metric = metric_class() with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 95199124583..63b492b6d59 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -13,7 +13,15 @@ from torchmetrics import functional # noqa: E402 from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402 -from torchmetrics.audio import PIT, SDR, SI_SDR, SI_SNR, SNR # noqa: E402 +from torchmetrics.audio import ( # noqa: E402 + PIT, + SDR, + SI_SDR, + SI_SNR, + SNR, + ScaleInvariantSignalDistortionRatio, + SignalDistortionRatio, +) from torchmetrics.classification import ( # noqa: E402, F401 AUC, AUROC, @@ -142,6 +150,8 @@ "ROC", "SacreBLEUScore", "SDR", + "SignalDistortionRatio", + "ScaleInvariantSignalDistortionRatio", "SI_SDR", "SI_SNR", "SNR", diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py index 70552ffa05e..092cb1bd60b 100644 --- a/torchmetrics/audio/__init__.py +++ b/torchmetrics/audio/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.audio.pit import PIT # noqa: F401 -from torchmetrics.audio.sdr import SDR # noqa: F401 +from torchmetrics.audio.sdr import SDR, ScaleInvariantSignalDistortionRatio, SignalDistortionRatio # noqa: F401 from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401 from torchmetrics.audio.si_snr import SI_SNR # noqa: F401 from torchmetrics.audio.snr import SNR # noqa: F401 diff --git a/torchmetrics/audio/sdr.py b/torchmetrics/audio/sdr.py index f39bced9022..b440d9599a9 100644 --- a/torchmetrics/audio/sdr.py +++ b/torchmetrics/audio/sdr.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional +from warnings import warn from torch import Tensor, tensor -from torchmetrics.functional.audio.sdr import sdr +from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio from torchmetrics.metric import Metric -class SDR(Metric): +class SignalDistortionRatio(Metric): r"""Signal to Distortion Ratio (SDR) [1,2,3] Forward accepts @@ -60,20 +61,20 @@ class SDR(Metric): If ``fast-bss-eval`` package is not installed Example: - >>> from torchmetrics.audio import SDR + >>> from torchmetrics.audio import SignalDistortionRatio >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) - >>> sdr = SDR() + >>> sdr = SignalDistortionRatio() >>> sdr(preds, target) tensor(-12.0589) >>> # use with pit >>> from torchmetrics.audio import PIT - >>> from torchmetrics.functional.audio import sdr + >>> from torchmetrics.functional.audio import signal_distortion_ratio >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] >>> target = torch.randn(4, 2, 8000) - >>> pit = PIT(sdr, 'max') + >>> pit = PIT(signal_distortion_ratio, 'max') >>> pit(preds, target) tensor(-11.6051) @@ -134,7 +135,9 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds: Predictions from model target: Ground truth values """ - sdr_batch = sdr(preds, target, self.use_cg_iter, self.filter_length, self.zero_mean, self.load_diag) + sdr_batch = signal_distortion_ratio( + preds, target, self.use_cg_iter, self.filter_length, self.zero_mean, self.load_diag + ) self.sum_sdr += sdr_batch.sum() self.total += sdr_batch.numel() @@ -142,3 +145,133 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore def compute(self) -> Tensor: """Computes average SDR.""" return self.sum_sdr / self.total + + +class SDR(SignalDistortionRatio): + r"""Signal to Distortion Ratio (SDR) + + .. deprecated:: v0.7 + Use :class:`torchmetrics.SignalDistortionRatio`. Will be removed in v0.8. + + Example: + >>> import torch + >>> g = torch.manual_seed(1) + >>> sdr = SDR() + >>> sdr(torch.randn(8000), torch.randn(8000)) + tensor(-12.0589) + >>> # use with pit + >>> from torchmetrics.audio import PIT + >>> from torchmetrics.functional.audio import signal_distortion_ratio + >>> pit = PIT(signal_distortion_ratio, 'max') + >>> pit(torch.randn(4, 2, 8000), torch.randn(4, 2, 8000)) + tensor(-11.6051) + """ + + def __init__( + self, + use_cg_iter: Optional[int] = None, + filter_length: int = 512, + zero_mean: bool = False, + load_diag: Optional[float] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, + ) -> None: + warn("`SDR` was renamed to `SignalDistortionRatio` in v0.7 and it will be removed in v0.8", DeprecationWarning) + + super().__init__( + use_cg_iter, + filter_length, + zero_mean, + load_diag, + compute_on_step, + dist_sync_on_step, + process_group, + dist_sync_fn, + ) + + +class ScaleInvariantSignalDistortionRatio(Metric): + """Scale-invariant signal-to-distortion ratio (SI-SDR). The SI-SDR value is in general considered an overall + measure of how good a source sound. + + Forward accepts + + - ``preds``: ``shape [...,time]`` + - ``target``: ``shape [...,time]`` + + Args: + zero_mean: + if to zero mean target and preds or not + compute_on_step: + Forward only calls ``update()`` and returns None if this is set to False. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. + + Raises: + TypeError: + if target and preds have a different shape + + Returns: + average si-sdr value + + Example: + >>> import torch + >>> from torchmetrics import ScaleInvariantSignalDistortionRatio + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> si_sdr = ScaleInvariantSignalDistortionRatio() + >>> si_sdr(preds, target) + tensor(18.4030) + + References: + [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech + and Signal Processing (ICASSP) 2019. + """ + + is_differentiable = True + higher_is_better = True + sum_si_sdr: Tensor + total: Tensor + + def __init__( + self, + zero_mean: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, + ) -> None: + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + self.zero_mean = zero_mean + + self.add_state("sum_si_sdr", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + si_sdr_batch = scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=self.zero_mean) + + self.sum_si_sdr += si_sdr_batch.sum() + self.total += si_sdr_batch.numel() + + def compute(self) -> Tensor: + """Computes average SI-SDR.""" + return self.sum_si_sdr / self.total diff --git a/torchmetrics/audio/si_sdr.py b/torchmetrics/audio/si_sdr.py index 4e632ce78fe..fe40f3a90ce 100644 --- a/torchmetrics/audio/si_sdr.py +++ b/torchmetrics/audio/si_sdr.py @@ -12,63 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional +from warnings import warn -from torch import Tensor, tensor +from torch import Tensor -from torchmetrics.functional.audio.si_sdr import si_sdr -from torchmetrics.metric import Metric +from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio -class SI_SDR(Metric): - """Scale-invariant signal-to-distortion ratio (SI-SDR). The SI-SDR value is in general considered an overall - measure of how good a source sound. +class SI_SDR(ScaleInvariantSignalDistortionRatio): + """Scale-invariant signal-to-distortion ratio (SI-SDR). - Forward accepts - - - ``preds``: ``shape [...,time]`` - - ``target``: ``shape [...,time]`` - - Args: - zero_mean: - if to zero mean target and preds or not - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. - - Raises: - TypeError: - if target and preds have a different shape - - Returns: - average si-sdr value + .. deprecated:: v0.7 + Use :class:`torchmetrics.ScaleInvariantSignalDistortionRatio`. Will be removed in v0.8. Example: >>> import torch - >>> from torchmetrics import SI_SDR - >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) >>> si_sdr = SI_SDR() - >>> si_sdr_val = si_sdr(preds, target) - >>> si_sdr_val + >>> si_sdr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0])) tensor(18.4030) - - References: - [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech - and Signal Processing (ICASSP) 2019. """ - is_differentiable = True - higher_is_better = True - sum_si_sdr: Tensor - total: Tensor - def __init__( self, zero_mean: bool = False, @@ -77,29 +40,8 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, ) -> None: - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, + warn( + "`SI_SDR` was renamed to `ScaleInvariantSignalDistortionRatio` in v0.7 and it will be removed in v0.8", + DeprecationWarning, ) - self.zero_mean = zero_mean - - self.add_state("sum_si_sdr", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - si_sdr_batch = si_sdr(preds=preds, target=target, zero_mean=self.zero_mean) - - self.sum_si_sdr += si_sdr_batch.sum() - self.total += si_sdr_batch.numel() - - def compute(self) -> Tensor: - """Computes average SI-SDR.""" - return self.sum_si_sdr / self.total + super().__init__(zero_mean, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index a4c188f53a5..3ba382bbd9c 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.functional.audio.pit import pit, pit_permutate -from torchmetrics.functional.audio.sdr import sdr +from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, sdr, signal_distortion_ratio from torchmetrics.functional.audio.si_sdr import si_sdr from torchmetrics.functional.audio.si_snr import si_snr from torchmetrics.functional.audio.snr import snr @@ -129,7 +129,9 @@ "rouge_score", "sacre_bleu_score", "sdr", + "signal_distortion_ratio", "si_sdr", + "scale_invariant_signal_distortion_ratio", "si_snr", "snr", "spearman_corrcoef", diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py index c651c9a994b..db7f29ecbd3 100644 --- a/torchmetrics/functional/audio/__init__.py +++ b/torchmetrics/functional/audio/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.functional.audio.pit import pit, pit_permutate # noqa: F401 -from torchmetrics.functional.audio.sdr import sdr # noqa: F401 +from torchmetrics.functional.audio.sdr import ( # noqa: F401 + scale_invariant_signal_distortion_ratio, + sdr, + signal_distortion_ratio, +) from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 from torchmetrics.functional.audio.snr import snr # noqa: F401 diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 59384a3899e..2b48d7b4fc0 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -127,11 +127,11 @@ def pit( best_perm of shape [batch] Example: - >>> from torchmetrics.functional.audio import si_sdr + >>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio >>> # [batch, spk, time] >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) - >>> best_metric, best_perm = pit(preds, target, si_sdr, 'max') + >>> best_metric, best_perm = pit(preds, target, scale_invariant_signal_distortion_ratio, 'max') >>> best_metric tensor([-5.1091]) >>> best_perm @@ -189,11 +189,11 @@ def pit_permutate(preds: Tensor, perm: Tensor) -> Tensor: Tensor: the permutated version of estimate Example: - >>> from torchmetrics.functional.audio import si_sdr + >>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio >>> # [batch, spk, time] >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) - >>> best_metric, best_perm = pit(preds, target, si_sdr, 'max') + >>> best_metric, best_perm = pit(preds, target, scale_invariant_signal_distortion_ratio, 'max') >>> best_metric tensor([-5.1091]) >>> best_perm diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 258a6058739..567fc45fbb0 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional +from warnings import warn import torch @@ -46,7 +47,7 @@ from torchmetrics.utilities.checks import _check_same_shape -def sdr( +def signal_distortion_ratio( preds: Tensor, target: Tensor, use_cg_iter: Optional[int] = None, @@ -86,18 +87,18 @@ def sdr( sdr value of shape ``[...]`` Example: - >>> from torchmetrics.functional.audio import sdr + >>> from torchmetrics.functional.audio import signal_distortion_ratio >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) - >>> sdr(preds, target) + >>> signal_distortion_ratio(preds, target) tensor(-12.0589) >>> # use with pit >>> from torchmetrics.functional.audio import pit >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] >>> target = torch.randn(4, 2, 8000) - >>> best_metric, best_perm = pit(preds, target, sdr, 'max') + >>> best_metric, best_perm = pit(preds, target, signal_distortion_ratio, 'max') >>> best_metric tensor([-11.6375, -11.4358, -11.7148, -11.6325]) >>> best_perm @@ -190,5 +191,74 @@ def sdr( # transform to decibels ratio = coh / (1 - coh) - sdr_val = 10.0 * torch.log10(ratio) - return sdr_val + val = 10.0 * torch.log10(ratio) + return val + + +def sdr( + preds: Tensor, + target: Tensor, + use_cg_iter: Optional[int] = None, + filter_length: int = 512, + zero_mean: bool = False, + load_diag: Optional[float] = None, +) -> Tensor: + r"""Signal to Distortion Ratio (SDR) + + .. deprecated:: v0.7 + Use :func:`torchmetrics.functional.signal_distortion_ratio`. Will be removed in v0.8. + + Example: + >>> import torch + >>> g = torch.manual_seed(1) + >>> sdr(torch.randn(8000), torch.randn(8000)) + tensor(-12.0589) + """ + warn("`sdr` was renamed to `signal_distortion_ratio` in v0.7 and it will be removed in v0.8", DeprecationWarning) + return signal_distortion_ratio(preds, target, use_cg_iter, filter_length, zero_mean, load_diag) + + +def scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: + """Calculates Scale-invariant signal-to-distortion ratio (SI-SDR) metric. The SI-SDR value is in general + considered an overall measure of how good a source sound. + + Args: + preds: + shape ``[...,time]`` + target: + shape ``[...,time]`` + zero_mean: + If to zero mean target and preds or not + + Returns: + si-sdr value of shape [...] + + Example: + >>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> scale_invariant_signal_distortion_ratio(preds, target) + tensor(18.4030) + + References: + [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech + and Signal Processing (ICASSP) 2019. + """ + _check_same_shape(preds, target) + EPS = torch.finfo(preds.dtype).eps + + if zero_mean: + target = target - torch.mean(target, dim=-1, keepdim=True) + preds = preds - torch.mean(preds, dim=-1, keepdim=True) + + alpha = (torch.sum(preds * target, dim=-1, keepdim=True) + EPS) / ( + torch.sum(target ** 2, dim=-1, keepdim=True) + EPS + ) + target_scaled = alpha * target + + noise = target_scaled - preds + + val = (torch.sum(target_scaled ** 2, dim=-1) + EPS) / (torch.sum(noise ** 2, dim=-1) + EPS) + val = 10 * torch.log10(val) + + return val diff --git a/torchmetrics/functional/audio/si_sdr.py b/torchmetrics/functional/audio/si_sdr.py index ec50ef09b92..d67cb531521 100644 --- a/torchmetrics/functional/audio/si_sdr.py +++ b/torchmetrics/functional/audio/si_sdr.py @@ -11,54 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from warnings import warn + import torch from torch import Tensor -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio def si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: - """Calculates Scale-invariant signal-to-distortion ratio (SI-SDR) metric. The SI-SDR value is in general - considered an overall measure of how good a source sound. - - Args: - preds: - shape ``[...,time]`` - target: - shape ``[...,time]`` - zero_mean: - If to zero mean target and preds or not + """Calculates Scale-invariant signal-to-distortion ratio (SI-SDR) metric. - Returns: - si-sdr value of shape [...] + .. deprecated:: v0.7 + Use :func:`torchmetrics.functional.scale_invariant_signal_distortion_ratio`. Will be removed in v0.8. Example: - >>> from torchmetrics.functional.audio import si_sdr - >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> si_sdr_val = si_sdr(preds, target) - >>> si_sdr_val + >>> si_sdr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0])) tensor(18.4030) - - References: - [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE International Conference on Acoustics, Speech - and Signal Processing (ICASSP) 2019. """ - _check_same_shape(preds, target) - EPS = torch.finfo(preds.dtype).eps - - if zero_mean: - target = target - torch.mean(target, dim=-1, keepdim=True) - preds = preds - torch.mean(preds, dim=-1, keepdim=True) - - alpha = (torch.sum(preds * target, dim=-1, keepdim=True) + EPS) / ( - torch.sum(target ** 2, dim=-1, keepdim=True) + EPS + warn( + "`si_sdr` was renamed to `scale_invariant_signal_distortion_ratio` in v0.7 and it will be removed in v0.8", + DeprecationWarning, ) - target_scaled = alpha * target - - noise = target_scaled - preds - - si_sdr_value = (torch.sum(target_scaled ** 2, dim=-1) + EPS) / (torch.sum(noise ** 2, dim=-1) + EPS) - si_sdr_value = 10 * torch.log10(si_sdr_value) - - return si_sdr_value + return scale_invariant_signal_distortion_ratio(preds, target, zero_mean) diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index 7f92fa49c4f..a967ddc306c 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -13,7 +13,7 @@ # limitations under the License. from torch import Tensor -from torchmetrics.functional.audio.si_sdr import si_sdr +from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio def si_snr(preds: Tensor, target: Tensor) -> Tensor: @@ -43,4 +43,4 @@ def si_snr(preds: Tensor, target: Tensor) -> Tensor: 696-700, doi: 10.1109/ICASSP.2018.8462116. """ - return si_sdr(target=target, preds=preds, zero_mean=True) + return scale_invariant_signal_distortion_ratio(target=target, preds=preds, zero_mean=True)