Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: SDR & SI_SDR #711

Merged
merged 12 commits into from
Jan 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `SpearmanCorrcoef` -> `SpearmanCorrCoef`


- Renamed audio SDR metrics: ([#711](https://github.com/PyTorchLightning/metrics/pull/711))
* `functional.sdr` -> `functional.signal_distortion_ratio`
* `functional.si_sdr` -> `functional.scale_invariant_signal_distortion_ratio`
* `SDR` -> `SignalDistortionRatio`
* `SI_SDR` -> `ScaleInvariantSignalDistortionRatio`


### Removed

- Removed `embedding_similarity` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638))
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ acc = torchmetrics.functional.accuracy(preds, target)
We currently have implemented metrics within the following domains:

- Audio (
[SI_SDR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-sdr),
[ScaleInvariantSignalDistortionRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSignalDistortionRatio),
[SI_SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-snr),
[SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#snr)
and [few more](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#audio-metrics)
Expand Down
12 changes: 6 additions & 6 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ pit [func]
:noindex:


sdr [func]
~~~~~~~~~~
signal_distortion_ratio [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.sdr
.. autofunction:: torchmetrics.functional.signal_distortion_ratio
:noindex:


si_sdr [func]
~~~~~~~~~~~~~
scale_invariant_signal_distortion_ratio [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.si_sdr
.. autofunction:: torchmetrics.functional.scale_invariant_signal_distortion_ratio
:noindex:


Expand Down
12 changes: 6 additions & 6 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,16 @@ PIT
.. autoclass:: torchmetrics.PIT
:noindex:

SDR
~~~
SignalDistortionRatio
~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.SDR
.. autoclass:: torchmetrics.SignalDistortionRatio
:noindex:

SI_SDR
~~~~~~
ScaleInvariantSignalDistortionRatio
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.SI_SDR
.. autoclass:: torchmetrics.ScaleInvariantSignalDistortionRatio
:noindex:

SI_SNR
Expand Down
10 changes: 6 additions & 4 deletions tests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import PIT
from torchmetrics.functional import pit, si_sdr, snr
from torchmetrics.functional import pit, scale_invariant_signal_distortion_ratio, snr
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down Expand Up @@ -97,16 +97,18 @@ def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Ten


snr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=snr, eval_func="max")
si_sdr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=si_sdr, eval_func="max")
si_sdr_pit_scipy = partial(
naive_implementation_pit_scipy, metric_func=scale_invariant_signal_distortion_ratio, eval_func="max"
)


@pytest.mark.parametrize(
"preds, target, sk_metric, metric_func, eval_func",
[
(inputs1.preds, inputs1.target, snr_pit_scipy, snr, "max"),
(inputs1.preds, inputs1.target, si_sdr_pit_scipy, si_sdr, "max"),
(inputs1.preds, inputs1.target, si_sdr_pit_scipy, scale_invariant_signal_distortion_ratio, "max"),
(inputs2.preds, inputs2.target, snr_pit_scipy, snr, "max"),
(inputs2.preds, inputs2.target, si_sdr_pit_scipy, si_sdr, "max"),
(inputs2.preds, inputs2.target, si_sdr_pit_scipy, scale_invariant_signal_distortion_ratio, "max"),
],
)
class TestPIT(MetricTester):
Expand Down
26 changes: 13 additions & 13 deletions tests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics.audio import SDR
from torchmetrics.functional import sdr
from torchmetrics.audio import SignalDistortionRatio
from torchmetrics.functional import signal_distortion_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8

seed_all(42)
Expand All @@ -49,7 +49,7 @@ def sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool
preds = preds.detach().cpu().numpy()
mss = []
for b in range(preds.shape[0]):
sdr_val_np, sir_val_np, sar_val_np, perm = bss_eval_sources(target[b], preds[b], compute_permutation)
sdr_val_np, _, _, _ = bss_eval_sources(target[b], preds[b], compute_permutation)
mss.append(sdr_val_np)
return torch.tensor(mss)

Expand Down Expand Up @@ -83,7 +83,7 @@ def test_sdr(self, preds, target, sk_metric, ddp, dist_sync_on_step):
ddp,
preds,
target,
SDR,
SignalDistortionRatio,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(),
Expand All @@ -93,7 +93,7 @@ def test_sdr_functional(self, preds, target, sk_metric):
self.run_functional_metric_test(
preds,
target,
sdr,
signal_distortion_ratio,
sk_metric,
metric_args=dict(),
)
Expand All @@ -103,8 +103,8 @@ def test_sdr_differentiability(self, preds, target, sk_metric):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=SDR,
metric_functional=sdr,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
)

Expand All @@ -115,8 +115,8 @@ def test_sdr_half_cpu(self, preds, target, sk_metric):
self.run_precision_test_cpu(
preds=preds,
target=target,
metric_module=SDR,
metric_functional=sdr,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
)

Expand All @@ -125,13 +125,13 @@ def test_sdr_half_gpu(self, preds, target, sk_metric):
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=SDR,
metric_functional=sdr,
metric_module=SignalDistortionRatio,
metric_functional=signal_distortion_ratio,
metric_args=dict(),
)


def test_error_on_different_shape(metric_class=SDR):
def test_error_on_different_shape(metric_class=SignalDistortionRatio):
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
Expand All @@ -143,7 +143,7 @@ def test_on_real_audio():
rate, ref = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech.wav"))
rate, deg = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech_bab_0dB.wav"))
assert torch.allclose(
sdr(torch.from_numpy(deg), torch.from_numpy(ref)).float(),
signal_distortion_ratio(torch.from_numpy(deg), torch.from_numpy(ref)).float(),
torch.tensor(0.2211),
rtol=0.0001,
atol=1e-4,
Expand Down
18 changes: 9 additions & 9 deletions tests/audio/test_si_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import SI_SDR
from torchmetrics.functional import si_sdr
from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
from torchmetrics.functional import scale_invariant_signal_distortion_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_si_sdr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_ste
ddp,
preds,
target,
SI_SDR,
ScaleInvariantSignalDistortionRatio,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(zero_mean=zero_mean),
Expand All @@ -94,7 +94,7 @@ def test_si_sdr_functional(self, preds, target, sk_metric, zero_mean):
self.run_functional_metric_test(
preds,
target,
si_sdr,
scale_invariant_signal_distortion_ratio,
sk_metric,
metric_args=dict(zero_mean=zero_mean),
)
Expand All @@ -103,8 +103,8 @@ def test_si_sdr_differentiability(self, preds, target, sk_metric, zero_mean):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=SI_SDR,
metric_functional=si_sdr,
metric_module=ScaleInvariantSignalDistortionRatio,
metric_functional=scale_invariant_signal_distortion_ratio,
metric_args={"zero_mean": zero_mean},
)

Expand All @@ -119,13 +119,13 @@ def test_si_sdr_half_gpu(self, preds, target, sk_metric, zero_mean):
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=SI_SDR,
metric_functional=si_sdr,
metric_module=ScaleInvariantSignalDistortionRatio,
metric_functional=scale_invariant_signal_distortion_ratio,
metric_args={"zero_mean": zero_mean},
)


def test_error_on_different_shape(metric_class=SI_SDR):
def test_error_on_different_shape(metric_class=ScaleInvariantSignalDistortionRatio):
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
12 changes: 11 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@

from torchmetrics import functional # noqa: E402
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402
from torchmetrics.audio import PIT, SDR, SI_SDR, SI_SNR, SNR # noqa: E402
from torchmetrics.audio import ( # noqa: E402
PIT,
SDR,
SI_SDR,
SI_SNR,
SNR,
ScaleInvariantSignalDistortionRatio,
SignalDistortionRatio,
)
from torchmetrics.classification import ( # noqa: E402, F401
AUC,
AUROC,
Expand Down Expand Up @@ -142,6 +150,8 @@
"ROC",
"SacreBLEUScore",
"SDR",
"SignalDistortionRatio",
"ScaleInvariantSignalDistortionRatio",
"SI_SDR",
"SI_SNR",
"SNR",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.audio.pit import PIT # noqa: F401
from torchmetrics.audio.sdr import SDR # noqa: F401
from torchmetrics.audio.sdr import SDR, ScaleInvariantSignalDistortionRatio, SignalDistortionRatio # noqa: F401
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
from torchmetrics.audio.snr import SNR # noqa: F401
Loading