From 3ac12b9a88feaf1c9e19ca0310699235ffdd6d45 Mon Sep 17 00:00:00 2001 From: anteju <108555623+anteju@users.noreply.github.com> Date: Tue, 23 Jul 2024 13:06:58 -0700 Subject: [PATCH] [Audio] Metric with Squim objective and MOS (#9751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Metric with Squim Objective and MOS Signed-off-by: Ante Jukić * Removed utility functions Signed-off-by: Ante Jukić --------- Signed-off-by: Ante Jukić --- examples/audio/audio_to_audio_eval.py | 22 +- .../asr/parts/utils/transcribe_utils.py | 2 +- nemo/collections/audio/metrics/__init__.py | 3 + nemo/collections/audio/metrics/audio.py | 3 + nemo/collections/audio/metrics/squim.py | 197 ++++++++++++++++++ tests/collections/audio/test_audio_metrics.py | 131 ++++++++++++ 6 files changed, 355 insertions(+), 3 deletions(-) create mode 100644 nemo/collections/audio/metrics/squim.py diff --git a/examples/audio/audio_to_audio_eval.py b/examples/audio/audio_to_audio_eval.py index 4e60b2ec2b528..c7b9db6efb800 100644 --- a/examples/audio/audio_to_audio_eval.py +++ b/examples/audio/audio_to_audio_eval.py @@ -75,7 +75,7 @@ from nemo.collections.audio.data import audio_to_audio_dataset from nemo.collections.audio.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset -from nemo.collections.audio.metrics.audio import AudioMetricWrapper +from nemo.collections.audio.metrics import AudioMetricWrapper, SquimMOSMetric, SquimObjectiveMetric from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.parts.preprocessing import manifest from nemo.core.config import hydra_runner @@ -128,7 +128,17 @@ def get_evaluation_dataloader(config): def get_metrics(cfg: AudioEvaluationConfig): """Prepare a dictionary with metrics.""" - available_metrics = ['sdr', 'sisdr', 'stoi', 'estoi', 'pesq'] + available_metrics = [ + 'sdr', + 'sisdr', + 'stoi', + 'estoi', + 'pesq', + 'squim_mos', + 'squim_stoi', + 'squim_pesq', + 'squim_si_sdr', + ] metrics = dict() for name in sorted(set(cfg.metrics)): @@ -143,6 +153,14 @@ def get_metrics(cfg: AudioEvaluationConfig): metric = AudioMetricWrapper(metric=ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate, extended=True)) elif name == 'pesq': metric = AudioMetricWrapper(metric=PerceptualEvaluationSpeechQuality(fs=cfg.sample_rate, mode='wb')) + elif name == 'squim_mos': + metric = AudioMetricWrapper(metric=SquimMOSMetric(fs=cfg.sample_rate)) + elif name == 'squim_stoi': + metric = AudioMetricWrapper(metric=SquimObjectiveMetric(metric='stoi', fs=cfg.sample_rate)) + elif name == 'squim_pesq': + metric = AudioMetricWrapper(metric=SquimObjectiveMetric(metric='pesq', fs=cfg.sample_rate)) + elif name == 'squim_si_sdr': + metric = AudioMetricWrapper(metric=SquimObjectiveMetric(metric='si_sdr', fs=cfg.sample_rate)) else: raise ValueError(f'Unexpected metric: {name}. Currently available metrics: {available_metrics}') diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index c270e5c3a0f7b..c26fa6f4984dc 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -289,7 +289,7 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]: with open(cfg.dataset_manifest, "rt") as fh: for line in fh: item = json.loads(line) - item["audio_filepath"] = get_full_path(item["audio_filepath"], cfg.dataset_manifest) + item[audio_key] = get_full_path(item[audio_key], cfg.dataset_manifest) if item.get("duration") is None and cfg.presort_manifest: raise ValueError( f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field." diff --git a/nemo/collections/audio/metrics/__init__.py b/nemo/collections/audio/metrics/__init__.py index d9155f923f186..20c8fd2fa4e28 100644 --- a/nemo/collections/audio/metrics/__init__.py +++ b/nemo/collections/audio/metrics/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from nemo.collections.audio.metrics.audio import AudioMetricWrapper +from nemo.collections.audio.metrics.squim import SquimMOSMetric, SquimObjectiveMetric diff --git a/nemo/collections/audio/metrics/audio.py b/nemo/collections/audio/metrics/audio.py index 096700eff24a0..0f8b5bee0fd23 100644 --- a/nemo/collections/audio/metrics/audio.py +++ b/nemo/collections/audio/metrics/audio.py @@ -21,6 +21,7 @@ from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility +from nemo.collections.audio.metrics.squim import SquimMOSMetric, SquimObjectiveMetric from nemo.utils import logging @@ -34,6 +35,8 @@ SignalNoiseRatio, PerceptualEvaluationSpeechQuality, ShortTimeObjectiveIntelligibility, + SquimMOSMetric, + SquimObjectiveMetric, ] diff --git a/nemo/collections/audio/metrics/squim.py b/nemo/collections/audio/metrics/squim.py new file mode 100644 index 0000000000000..c20be43f79f8e --- /dev/null +++ b/nemo/collections/audio/metrics/squim.py @@ -0,0 +1,197 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +from torchmetrics import Metric +from nemo.utils import logging + +try: + import torchaudio + + HAVE_TORCHAUDIO = True +except ModuleNotFoundError: + HAVE_TORCHAUDIO = False + + +class SquimMOSMetric(Metric): + """A metric calculating the average Torchaudio Squim MOS. + + Args: + fs: sampling rate of the input signals + """ + + sample_rate: int = 16000 # sample rate of the model + mos_sum: torch.Tensor + num_examples: torch.Tensor + higher_is_better: bool = True + + def __init__(self, fs: int, **kwargs: Any): + super().__init__(**kwargs) + + if not HAVE_TORCHAUDIO: + raise ModuleNotFoundError(f"{self.__class__.__name__} metric needs `torchaudio`.") + + if fs != self.sample_rate: + # Resampler: kaiser_best + self._squim_mos_metric_resampler = torchaudio.transforms.Resample( + orig_freq=fs, + new_freq=self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + logging.warning('Input signals will be resampled from fs=%d to %d Hz', fs, self.sample_rate) + self.fs = fs + + # MOS model + self._squim_mos_metric_model = torchaudio.pipelines.SQUIM_SUBJECTIVE.get_model() + + self.add_state('mos_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('num_examples', default=torch.tensor(0), dist_reduce_fx='sum') + logging.debug('Setup metric %s with input fs=%s', self.__class__.__name__, self.fs) + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update the metric by calculating the MOS score for the current batch. + + Args: + preds: tensor with predictions, shape (B, T) + target: tensor with target signals, shape (B, T). Target can be a non-matching reference. + """ + if self.fs != self.sample_rate: + preds = self._squim_mos_metric_resampler(preds) + target = self._squim_mos_metric_resampler(target) + + if preds.ndim == 1: + # Unsqueeze batch dimension + preds = preds.unsqueeze(0) + target = target.unsqueeze(0) + elif preds.ndim > 2: + raise ValueError(f'Expected 1D or 2D signals, got {preds.ndim}D signals') + + mos_batch = self._squim_mos_metric_model(preds, target) + + self.mos_sum += mos_batch.sum() + self.num_examples += mos_batch.numel() + + def compute(self) -> torch.Tensor: + """Compute the underlying metric.""" + return self.mos_sum / self.num_examples + + def state_dict(self, *args, **kwargs): + """Do not save the MOS model and resampler in the state dict.""" + state_dict = super().state_dict(*args, **kwargs) + # Do not include resampler or mos_model in the state dict + remove_keys = [ + key + for key in state_dict.keys() + if '_squim_mos_metric_resampler' in key or '_squim_mos_metric_model' in key + ] + for key in remove_keys: + del state_dict[key] + return state_dict + + +class SquimObjectiveMetric(Metric): + """A metric calculating the average Torchaudio Squim objective metric. + + Args: + fs: sampling rate of the input signals + metric: the objective metric to calculate. One of 'stoi', 'pesq', 'si_sdr' + """ + + sample_rate: int = 16000 # sample rate of the model + metric_sum: torch.Tensor + num_examples: torch.Tensor + higher_is_better: bool = True + + def __init__(self, fs: int, metric: str, **kwargs: Any): + super().__init__(**kwargs) + + if not HAVE_TORCHAUDIO: + raise ModuleNotFoundError(f"{self.__class__.__name__} needs `torchaudio`.") + + if fs != self.sample_rate: + # Resampler: kaiser_best + self._squim_objective_metric_resampler = torchaudio.transforms.Resample( + orig_freq=fs, + new_freq=self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + logging.warning('Input signals will be resampled from fs=%d to %d Hz', fs, self.sample_rate) + self.fs = fs + + if metric not in ['stoi', 'pesq', 'si_sdr']: + raise ValueError(f'Unsupported metric {metric}. Supported metrics are "stoi", "pesq", "si_sdr".') + + self.metric = metric + + # Objective model + self._squim_objective_metric_model = torchaudio.pipelines.SQUIM_OBJECTIVE.get_model() + + self.add_state('metric_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('num_examples', default=torch.tensor(0), dist_reduce_fx='sum') + logging.debug('Setup %s with metric=%s, input fs=%s', self.__class__.__name__, self.metric, self.fs) + + def update(self, preds: torch.Tensor, target: Any = None) -> None: + """Update the metric by calculating the selected metric score for the current batch. + + Args: + preds: tensor with predictions, shape (B, T) + target: None, not used. Keeping for interfacfe compatibility with other metrics. + """ + if self.fs != self.sample_rate: + preds = self._squim_objective_metric_resampler(preds) + + if preds.ndim == 1: + # Unsqueeze batch dimension + preds = preds.unsqueeze(0) + elif preds.ndim > 2: + raise ValueError(f'Expected 1D or 2D signals, got {preds.ndim}D signals') + + stoi_batch, pesq_batch, si_sdr_batch = self._squim_objective_metric_model(preds) + + if self.metric == 'stoi': + metric_batch = stoi_batch + elif self.metric == 'pesq': + metric_batch = pesq_batch + elif self.metric == 'si_sdr': + metric_batch = si_sdr_batch + else: + raise ValueError(f'Unknown metric {self.metric}') + + self.metric_sum += metric_batch.sum() + self.num_examples += metric_batch.numel() + + def compute(self) -> torch.Tensor: + """Compute the underlying metric.""" + return self.metric_sum / self.num_examples + + def state_dict(self, *args, **kwargs): + """Do not save the MOS model and resampler in the state dict.""" + state_dict = super().state_dict(*args, **kwargs) + # Do not include resampler or mos_model in the state dict + remove_keys = [ + key + for key in state_dict.keys() + if '_squim_objective_metric_resampler' in key or '_squim_objective_metric_model' in key + ] + for key in remove_keys: + del state_dict[key] + return state_dict diff --git a/tests/collections/audio/test_audio_metrics.py b/tests/collections/audio/test_audio_metrics.py index 2d693bc4ab209..578b67fc24795 100644 --- a/tests/collections/audio/test_audio_metrics.py +++ b/tests/collections/audio/test_audio_metrics.py @@ -16,6 +16,14 @@ from torchmetrics.audio.snr import SignalNoiseRatio from nemo.collections.audio.metrics.audio import AudioMetricWrapper +from nemo.collections.audio.metrics.squim import SquimMOSMetric, SquimObjectiveMetric + +try: + import torchaudio + + HAVE_TORCHAUDIO = True +except ModuleNotFoundError: + HAVE_TORCHAUDIO = False class TestAudioMetricWrapper: @@ -140,3 +148,126 @@ def test_channel(self, channel): ref_metric.reset() wrapped_metric.reset() + + +class TestSquimMetrics: + @pytest.mark.unit + @pytest.mark.parametrize('fs', [16000, 24000]) + def test_squim_mos(self, fs: int): + """Test Squim MOS metric""" + if HAVE_TORCHAUDIO: + # Setup + num_batches = 4 + batch_size = 4 + atol = 1e-6 + + # UUT + squim_mos_metric = SquimMOSMetric(fs=fs) + + # Helper function + resampler = torchaudio.transforms.Resample( + orig_freq=fs, + new_freq=16000, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + squim_mos_model = torchaudio.pipelines.SQUIM_SUBJECTIVE.get_model() + + def calculate_squim_mos(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + if fs != 16000: + preds = resampler(preds) + target = resampler(target) + + # Calculate MOS + mos_batch = squim_mos_model(preds, target) + return mos_batch + + # Test + mos_sum = torch.tensor(0.0) + + for n in range(num_batches): + preds = torch.randn(batch_size, fs) + target = torch.randn(batch_size, fs) + + # UUT forward + squim_mos_metric.update(preds=preds, target=target) + + # Golden + mos_golden = calculate_squim_mos(preds=preds, target=target) + # Accumulate + mos_sum += mos_golden.sum() + + # Check the final value of the metric + mos_golden_final = mos_sum / (num_batches * batch_size) + assert torch.allclose(squim_mos_metric.compute(), mos_golden_final, atol=atol), f'Comparison failed' + + else: + with pytest.raises(ModuleNotFoundError): + SquimMOSMetric(fs=fs) + + @pytest.mark.unit + @pytest.mark.parametrize('metric', ['stoi', 'pesq', 'si_sdr']) + @pytest.mark.parametrize('fs', [16000, 24000]) + def test_squim_objective(self, metric: str, fs: int): + """Test Squim objective metric""" + if HAVE_TORCHAUDIO: + # Setup + num_batches = 4 + batch_size = 4 + atol = 1e-6 + + # UUT + squim_objective_metric = SquimObjectiveMetric(fs=fs, metric=metric) + + # Helper function + resampler = torchaudio.transforms.Resample( + orig_freq=fs, + new_freq=16000, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + squim_objective_model = torchaudio.pipelines.SQUIM_OBJECTIVE.get_model() + + def calculate_squim_objective(preds: torch.Tensor) -> torch.Tensor: + if fs != 16000: + preds = resampler(preds) + + # Calculate metric + stoi_batch, pesq_batch, si_sdr_batch = squim_objective_model(preds) + + if metric == 'stoi': + return stoi_batch + elif metric == 'pesq': + return pesq_batch + elif metric == 'si_sdr': + return si_sdr_batch + else: + raise ValueError(f'Unknown metric {metric}') + + # Test + metric_sum = torch.tensor(0.0) + + for n in range(num_batches): + preds = torch.randn(batch_size, fs) + + # UUT forward + squim_objective_metric.update(preds=preds, target=None) + + # Golden + metric_golden = calculate_squim_objective(preds=preds) + # Accumulate + metric_sum += metric_golden.sum() + + # Check the final value of the metric + metric_golden_final = metric_sum / (num_batches * batch_size) + assert torch.allclose( + squim_objective_metric.compute(), metric_golden_final, atol=atol + ), f'Comparison failed' + + else: + with pytest.raises(ModuleNotFoundError): + SquimObjectiveMetric(fs=fs, metric=metric)