Skip to content

Commit

Permalink
Refactor: SDR & SI_SDR (#711)
Browse files Browse the repository at this point in the history
* signal_distortion_ratio
* scale_invariant_signal_distortion_ratio
* SignalDistortionRatio
* ScaleInvariantSignalDistortionRatio

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Jan 8, 2022
1 parent 2f32c2d commit fdf5b3f
Show file tree
Hide file tree
Showing 17 changed files with 313 additions and 171 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `SpearmanCorrcoef` -> `SpearmanCorrCoef`


- Renamed audio SDR metrics: ([#711](https://github.com/PyTorchLightning/metrics/pull/711))
* `functional.sdr` -> `functional.signal_distortion_ratio`
* `functional.si_sdr` -> `functional.scale_invariant_signal_distortion_ratio`
* `SDR` -> `SignalDistortionRatio`
* `SI_SDR` -> `ScaleInvariantSignalDistortionRatio`


### Removed

- Removed `embedding_similarity` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638))
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ acc = torchmetrics.functional.accuracy(preds, target)
We currently have implemented metrics within the following domains:

- Audio (
[SI_SDR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-sdr),
[ScaleInvariantSignalDistortionRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSignalDistortionRatio),
[SI_SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-snr),
[SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#snr)
and [few more](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#audio-metrics)
Expand Down
12 changes: 6 additions & 6 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ pit [func]
:noindex:


sdr [func]
~~~~~~~~~~
signal_distortion_ratio [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.sdr
.. autofunction:: torchmetrics.functional.signal_distortion_ratio
:noindex:


si_sdr [func]
~~~~~~~~~~~~~
scale_invariant_signal_distortion_ratio [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.si_sdr
.. autofunction:: torchmetrics.functional.scale_invariant_signal_distortion_ratio
:noindex:


Expand Down
12 changes: 6 additions & 6 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,16 @@ PIT
.. autoclass:: torchmetrics.PIT
:noindex:

SDR
~~~
SignalDistortionRatio
~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.SDR
.. autoclass:: torchmetrics.SignalDistortionRatio
:noindex:

SI_SDR
~~~~~~
ScaleInvariantSignalDistortionRatio
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.SI_SDR
.. autoclass:: torchmetrics.ScaleInvariantSignalDistortionRatio
:noindex:

SI_SNR
Expand Down
10 changes: 6 additions & 4 deletions tests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import PIT
from torchmetrics.functional import pit, si_sdr, snr
from torchmetrics.functional import pit, scale_invariant_signal_distortion_ratio, snr
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down Expand Up @@ -97,16 +97,18 @@ def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Ten


snr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=snr, eval_func="max")
si_sdr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=si_sdr, eval_func="max")
si_sdr_pit_scipy = partial(
naive_implementation_pit_scipy, metric_func=scale_invariant_signal_distortion_ratio, eval_func="max"
)


@pytest.mark.parametrize(
"preds, target, sk_metric, metric_func, eval_func",
[
(inputs1.preds, inputs1.target, snr_pit_scipy, snr, "max"),
(inputs1.preds, inputs1.target, si_sdr_pit_scipy, si_sdr, "max"),
(inputs1.preds, inputs1.target, si_sdr_pit_scipy, scale_invariant_signal_distortion_ratio, "max"),
(inputs2.preds, inputs2.target, snr_pit_scipy, snr, "max"),
(inputs2.preds, inputs2.target, si_sdr_pit_scipy, si_sdr, "max"),
(inputs2.preds, inputs2.target, si_sdr_pit_scipy, scale_invariant_signal_distortion_ratio, "max"),
],
)
class TestPIT(MetricTester):
Expand Down
26 changes: 13 additions & 13 deletions tests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics.audio import SDR
from torchmetrics.functional import sdr
from torchmetrics.audio import SignalDistortionRatio
from torchmetrics.functional import signal_distortion_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8

seed_all(42)
Expand All @@ -49,7 +49,7 @@ def sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool
preds = preds.detach().cpu().numpy()
mss = []
for b in range(preds.shape[0]):
sdr_val_np, sir_val_np, sar_val_np, perm = bss_eval_sources(target[b], preds[b], compute_permutation)
sdr_val_np, _, _, _ = bss_eval_sources(target[b], preds[b], compute_permutation)
mss.append(sdr_val_np)
return torch.tensor(mss)

Expand Down Expand Up @@ -83,7 +83,7 @@ def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step):
ddp,
preds,
target,
SDR,
SignalDistortionRatio,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(),
Expand All @@ -93,7 +93,7 @@ def test_sdr_functional(self, preds, target, sk_metric):
self.run_functional_metric_test(
preds,
target,
sdr,
signal_distortion_ratio,
sk_metric,
metric_args=dict(),
)
Expand All @@ -103,8 +103,8 @@ def test_sdr_differentiability(self, preds, target, sk_metric):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=SDR,
metric_functional=sdr,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
)

Expand All @@ -115,8 +115,8 @@ def test_sdr_half_cpu(self, preds, target, sk_metric):
self.run_precision_test_cpu(
preds=preds,
target=target,
metric_module=SDR,
metric_functional=sdr,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
)

Expand All @@ -125,13 +125,13 @@ def test_sdr_half_gpu(self, preds, target, sk_metric):
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=SDR,
metric_functional=sdr,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
)


def test_error_on_different_shape(metric_class=SDR):
def test_error_on_different_shape(metric_class=SignalDistortionRatio):
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
Expand All @@ -143,7 +143,7 @@ def test_on_real_audio():
rate, ref = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech.wav"))
rate, deg = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech_bab_0dB.wav"))
assert torch.allclose(
sdr(torch.from_numpy(deg), torch.from_numpy(ref)).float(),
signal_distortion_ratio(torch.from_numpy(deg), torch.from_numpy(ref)).float(),
torch.tensor(0.2211),
rtol=0.0001,
atol=1e-4,
Expand Down
18 changes: 9 additions & 9 deletions tests/audio/test_si_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import SI_SDR
from torchmetrics.functional import si_sdr
from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
from torchmetrics.functional import scale_invariant_signal_distortion_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_ste
ddp,
preds,
target,
SI_SDR,
ScaleInvariantSignalDistortionRatio,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(zero_mean=zero_mean),
Expand All @@ -94,7 +94,7 @@ def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean):
self.run_functional_metric_test(
preds,
target,
si_sdr,
scale_invariant_signal_distortion_ratio,
sk_metric,
metric_args=dict(zero_mean=zero_mean),
)
Expand All @@ -103,8 +103,8 @@ def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=SI_SDR,
metric_functional=si_sdr,
metric_module=ScaleInvariantSignalDistortionRatio,
metric_functional=scale_invariant_signal_distortion_ratio,
metric_args={"zero_mean": zero_mean},
)

Expand All @@ -119,13 +119,13 @@ def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean):
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=SI_SDR,
metric_functional=si_sdr,
metric_module=ScaleInvariantSignalDistortionRatio,
metric_functional=scale_invariant_signal_distortion_ratio,
metric_args={"zero_mean": zero_mean},
)


def test_error_on_different_shape(metric_class=SI_SDR):
def test_error_on_different_shape(metric_class=ScaleInvariantSignalDistortionRatio):
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
12 changes: 11 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@

from torchmetrics import functional # noqa: E402
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402
from torchmetrics.audio import PIT, SDR, SI_SDR, SI_SNR, SNR # noqa: E402
from torchmetrics.audio import ( # noqa: E402
PIT,
SDR,
SI_SDR,
SI_SNR,
SNR,
ScaleInvariantSignalDistortionRatio,
SignalDistortionRatio,
)
from torchmetrics.classification import ( # noqa: E402, F401
AUC,
AUROC,
Expand Down Expand Up @@ -142,6 +150,8 @@
"ROC",
"SacreBLEUScore",
"SDR",
"SignalDistortionRatio",
"ScaleInvariantSignalDistortionRatio",
"SI_SDR",
"SI_SNR",
"SNR",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.audio.pit import PIT # noqa: F401
from torchmetrics.audio.sdr import SDR # noqa: F401
from torchmetrics.audio.sdr import SDR, ScaleInvariantSignalDistortionRatio, SignalDistortionRatio # noqa: F401
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
from torchmetrics.audio.snr import SNR # noqa: F401
Loading

0 comments on commit fdf5b3f

Please sign in to comment.