From 95266f7be84950c31cf227588ae447ed08dca8fa Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 03:03:50 +0800 Subject: [PATCH 01/56] finish functional version --- torchmetrics/functional/__init__.py | 1 + torchmetrics/functional/audio/__init__.py | 1 + torchmetrics/functional/audio/pit.py | 149 ++++++++++++++++++++++ 3 files changed, 151 insertions(+) create mode 100644 torchmetrics/functional/audio/pit.py diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index df25a3e7269..60807af20f5 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.functional.audio.pit import pit, permutate # noqa: F401 from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 from torchmetrics.functional.audio.snr import snr # noqa: F401 diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py index d5bb919a914..1ad0c0c5386 100644 --- a/torchmetrics/functional/audio/__init__.py +++ b/torchmetrics/functional/audio/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.functional.audio.pit import pit, permutate # noqa: F401 from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 from torchmetrics.functional.audio.snr import snr # noqa: F401 diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py new file mode 100644 index 00000000000..24704481cc0 --- /dev/null +++ b/torchmetrics/functional/audio/pit.py @@ -0,0 +1,149 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.utilities.checks import _check_same_shape + +from itertools import permutations +from typing import Callable, List, Union +import torch +from torch.tensor import Tensor +from scipy.optimize import linear_sum_assignment + +_ps_dict: dict = {} # cache +_ps_idx_dict: dict = {} # cache + + +def _find_best_perm_by_hungarian_method(metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max]): + mmtx = metric_mtx.detach().cpu() + best_perm = torch.tensor([linear_sum_assignment(pwm, eval_func == torch.max)[1] for pwm in mmtx]).to(metric_mtx.device) + best_metric = torch.gather(metric_mtx, 2, best_perm[:, :, None]).mean([-1, -2]) + return best_metric, best_perm # shape [batch], shape [batch, spk] + + +def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max]): + # create/read/cache the permutations and its indexes + # reading from cache would be much faster than creating in CPU then moving to GPU + batch_size, spk_num = metric_mtx.shape[:2] + key = str(spk_num) + str(metric_mtx.device) + if key not in _ps_dict: + # all the permutations, shape [perm_num, spk_num] + ps = torch.tensor(list(permutations(range(spk_num))), device=metric_mtx.device) + # shape [perm_num * spk_num] + inc = torch.arange(0, spk_num * spk_num, step=spk_num, device=metric_mtx.device, dtype=ps.dtype).repeat(ps.shape[0]) + # the indexes for all permutations, shape [perm_num*spk_num] + ps_idx = ps.view(-1) + inc + # cache ps and ps_idx + _ps_idx_dict[key] = ps_idx + _ps_dict[key] = ps + else: + ps_idx = _ps_idx_dict[key] # the indexes for all permutations, shape [perm_num*spk_num] + ps = _ps_dict[key] # all the permutations, shape [perm_num, spk_num] + + # find the metric of each permutation + metric_of_ps_details = metric_mtx.view(batch_size, -1)[:, ps_idx].reshape(batch_size, -1, spk_num) # shape [batch_size, perm_num, spk_num] + metric_of_ps = metric_of_ps_details.mean(dim=2) # shape [batch_size, perm_num] + + # find the best metric and best permutation + best_metric, best_indexes = eval_func(metric_of_ps, dim=1) + best_indexes = best_indexes.detach() + best_perm = ps[best_indexes, :] + return best_metric, best_perm # shape [batch], shape [batch, spk] + + +def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, return_best_perm: bool = False, eval_func: Union[torch.min, torch.max] = torch.max, **kwargs) -> List: + """ Permutation invariant training metric + + Args: + target: + shape [batch, spk, ...] + preds: + shape [batch, spk, ...] + metric_func: + a metric function accept a batch of target and estimate, i.e. metric_func(target[:, i, ...], estimate[:, j, ...]), and returns a batch of metric tensors [batch] + return_best_perm: + whether to return the best permutation + eval_func: + the function to find the best permutation, can be torch.min or torch.max, i.e. the smaller the better or the larger the better. + kwargs: + additional args for metric_func + + Returns: + List: + best_metric of shape [batch], + best_perm of shape [batch] if return_best_perm == True + + Example: + >>> import torch + >>> from torchmetrics.functional.audio import si_snr, pit, permutate + >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] + >>> target = torch.randn(3, 2, 5) # [batch, spk, time] + >>> best_metric, best_perm = pit(preds, target, si_snr, True, torch.min) + >>> best_metric + tensor([-29.3482, -11.2862, -9.2508]) + >>> best_perm + tensor([[0, 1], + [1, 0], + [0, 1]]) + >>> preds_pmted = permutate(preds, best_perm) + + Reference: + [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154. + + """ + _check_same_shape(preds, target) + + if len(target.shape) < 2: + raise TypeError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") + + batch_size, spk_num = target.shape[0:2] + + # calculate the metric matrix + metric_mtx = torch.empty((batch_size, spk_num, spk_num), device=target.device) + for t in range(spk_num): + for e in range(spk_num): + metric_mtx[:, t, e] = metric_func(preds[:, e, ...], target[:, t, ...], **kwargs) + + if spk_num < 3: + best_metric, best_perm = _find_best_perm_by_exhuastive_method(metric_mtx, eval_func) + else: + best_metric, best_perm = _find_best_perm_by_hungarian_method(metric_mtx, eval_func) + + # returns + ret = [] + ret.append(best_metric) + if return_best_perm: + ret.append(best_perm) + return ret + + +def permutate(preds: Tensor, perm: Tensor) -> Tensor: + """ permutate estimate according to perm + + Args: + preds (Tensor): the estimates you want to permutate, shape [batch, spk, ...] + perm (Tensor): the permutation returned from pit, shape [batch, spk] + + Returns: + Tensor: the permutated version of estimate + + Example: + >>> import torch + >>> from torchmetrics.functional.audio import si_snr, pit, permutate + >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] + >>> target = torch.randn(3, 2, 5) # [batch, spk, time] + >>> best_metric, best_perm = pit(preds, target, si_snr, True, torch.min) + >>> preds_pmted = permutate(preds, best_perm) + + """ + preds_pmted = torch.stack([torch.index_select(pred, 0, p) for pred, p in zip(preds, perm)]) + return preds_pmted From 3eaf5060a5356dde8b422aedccbb7582d2e5a999 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 03:18:41 +0800 Subject: [PATCH 02/56] change eval_func to str type --- torchmetrics/functional/audio/pit.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 24704481cc0..5545655ac39 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -60,7 +60,7 @@ def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, eval_func: Un return best_metric, best_perm # shape [batch], shape [batch, spk] -def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, return_best_perm: bool = False, eval_func: Union[torch.min, torch.max] = torch.max, **kwargs) -> List: +def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_func: str = 'max', return_best_perm: bool = False, **kwargs) -> List: """ Permutation invariant training metric Args: @@ -70,10 +70,10 @@ def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, return shape [batch, spk, ...] metric_func: a metric function accept a batch of target and estimate, i.e. metric_func(target[:, i, ...], estimate[:, j, ...]), and returns a batch of metric tensors [batch] + eval_func: + the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better or the larger the better. return_best_perm: whether to return the best permutation - eval_func: - the function to find the best permutation, can be torch.min or torch.max, i.e. the smaller the better or the larger the better. kwargs: additional args for metric_func @@ -87,7 +87,7 @@ def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, return >>> from torchmetrics.functional.audio import si_snr, pit, permutate >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] - >>> best_metric, best_perm = pit(preds, target, si_snr, True, torch.min) + >>> best_metric, best_perm = pit(preds, target, si_snr, 'min', True) >>> best_metric tensor([-29.3482, -11.2862, -9.2508]) >>> best_perm @@ -98,10 +98,10 @@ def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, return Reference: [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154. - """ _check_same_shape(preds, target) - + assert (eval_func == '') + eval_func = torch.max if eval_func == 'max' else torch.min if len(target.shape) < 2: raise TypeError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") @@ -141,7 +141,7 @@ def permutate(preds: Tensor, perm: Tensor) -> Tensor: >>> from torchmetrics.functional.audio import si_snr, pit, permutate >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] - >>> best_metric, best_perm = pit(preds, target, si_snr, True, torch.min) + >>> best_metric, best_perm = pit(preds, target, si_snr, 'min', True) >>> preds_pmted = permutate(preds, best_perm) """ From f63fedb05a2af0aacaf8a53c0321c1539ce3aa7f Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 03:35:23 +0800 Subject: [PATCH 03/56] remove return_best_perm --- torchmetrics/functional/audio/pit.py | 33 ++++++++++------------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 5545655ac39..ba52021f371 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -14,7 +14,7 @@ from torchmetrics.utilities.checks import _check_same_shape from itertools import permutations -from typing import Callable, List, Union +from typing import Callable, List, Tuple, Union import torch from torch.tensor import Tensor from scipy.optimize import linear_sum_assignment @@ -60,7 +60,7 @@ def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, eval_func: Un return best_metric, best_perm # shape [batch], shape [batch, spk] -def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_func: str = 'max', return_best_perm: bool = False, **kwargs) -> List: +def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_func: str = 'max', **kwargs) -> Tuple[Tensor, Tensor]: """ Permutation invariant training metric Args: @@ -72,22 +72,19 @@ def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_f a metric function accept a batch of target and estimate, i.e. metric_func(target[:, i, ...], estimate[:, j, ...]), and returns a batch of metric tensors [batch] eval_func: the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better or the larger the better. - return_best_perm: - whether to return the best permutation kwargs: additional args for metric_func Returns: - List: - best_metric of shape [batch], - best_perm of shape [batch] if return_best_perm == True + best_metric of shape [batch], + best_perm of shape [batch] Example: >>> import torch >>> from torchmetrics.functional.audio import si_snr, pit, permutate >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] - >>> best_metric, best_perm = pit(preds, target, si_snr, 'min', True) + >>> best_metric, best_perm = pit(preds, target, si_snr, 'min') >>> best_metric tensor([-29.3482, -11.2862, -9.2508]) >>> best_perm @@ -100,30 +97,24 @@ def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_f [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154. """ _check_same_shape(preds, target) - assert (eval_func == '') - eval_func = torch.max if eval_func == 'max' else torch.min + assert (eval_func == 'max' or eval_func == 'min'), 'eval_func can only be "max" or "min"' if len(target.shape) < 2: raise TypeError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") - batch_size, spk_num = target.shape[0:2] - # calculate the metric matrix + batch_size, spk_num = target.shape[0:2] metric_mtx = torch.empty((batch_size, spk_num, spk_num), device=target.device) for t in range(spk_num): for e in range(spk_num): metric_mtx[:, t, e] = metric_func(preds[:, e, ...], target[:, t, ...], **kwargs) + # find best if spk_num < 3: - best_metric, best_perm = _find_best_perm_by_exhuastive_method(metric_mtx, eval_func) + best_metric, best_perm = _find_best_perm_by_exhuastive_method(metric_mtx, torch.max if eval_func == 'max' else torch.min) else: - best_metric, best_perm = _find_best_perm_by_hungarian_method(metric_mtx, eval_func) + best_metric, best_perm = _find_best_perm_by_hungarian_method(metric_mtx, torch.max if eval_func == 'max' else torch.min) - # returns - ret = [] - ret.append(best_metric) - if return_best_perm: - ret.append(best_perm) - return ret + return best_metric, best_perm def permutate(preds: Tensor, perm: Tensor) -> Tensor: @@ -141,7 +132,7 @@ def permutate(preds: Tensor, perm: Tensor) -> Tensor: >>> from torchmetrics.functional.audio import si_snr, pit, permutate >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] - >>> best_metric, best_perm = pit(preds, target, si_snr, 'min', True) + >>> best_metric, best_perm = pit(preds, target, si_snr, 'min') >>> preds_pmted = permutate(preds, best_perm) """ From 58dcf1d7ac20690aeca703c92ba57c458352a477 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 03:53:05 +0800 Subject: [PATCH 04/56] raise --- torchmetrics/functional/audio/pit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index ba52021f371..c53b53e5f84 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -97,7 +97,8 @@ def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_f [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154. """ _check_same_shape(preds, target) - assert (eval_func == 'max' or eval_func == 'min'), 'eval_func can only be "max" or "min"' + if eval_func not in ['max', 'min']: + raise TypeError('eval_func can only be "max" or "min"') if len(target.shape) < 2: raise TypeError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") From bb186d6811182f5804bcaf6976326c6e6732f0a4 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 03:54:37 +0800 Subject: [PATCH 05/56] PIT --- torchmetrics/audio/pit.py | 109 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 torchmetrics/audio/pit.py diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py new file mode 100644 index 00000000000..c40721222e6 --- /dev/null +++ b/torchmetrics/audio/pit.py @@ -0,0 +1,109 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Union, Callable, Optional, Dict + +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.audio.pit import pit, permutate +from torchmetrics.metric import Metric + + +class PIT(Metric): + """ Permutation invariant training metric + + Forward accepts + + - ``preds``: ``shape [..., time]`` + - ``target``: ``shape [..., time]`` + + Args: + metric_func: + a metric function accept a batch of target and estimate, i.e. metric_func(target[:, i, ...], estimate[:, j, ...]), and returns a batch of metric tensors [batch] + eval_func: + the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better or the larger the better. + compute_on_step: + Forward only calls ``update()`` and returns None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. + kwargs: + additional args for metric_func + + Returns: + average PIT metric + + Example: + >>> import torch + >>> from torchmetrics import PIT + >>> from torchmetrics.functional.audio import si_snr + >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] + >>> target = torch.randn(3, 2, 5) # [batch, spk, time] + >>> pit = PIT(si_snr, 'max') + >>> avg_pit_metric = pit(preds, target) + + Reference: + [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154. + """ + + def __init__( + self, + metric_func: Callable, + eval_func: str = 'max', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + self.metric_func = metric_func + self.eval_func = eval_func + self.kwargs = kwargs + + self.add_state("sum_pit_metric", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + pit_metric = pit(preds, target, self.metric_func, self.eval_func, **self.kwargs)[0] + + self.sum_pit_metric += pit_metric.sum() + self.total += pit_metric.numel() + + def compute(self) -> Tensor: + """ + Computes average PIT metric. + """ + return self.sum_pit_metric / self.total + + @property + def is_differentiable(self) -> bool: + return True From 13be4e4d61943aaa98e167166a92571653807295 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 03:57:12 +0800 Subject: [PATCH 06/56] max --- torchmetrics/functional/audio/pit.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index c53b53e5f84..3dd9729fedd 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -84,7 +84,7 @@ def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_f >>> from torchmetrics.functional.audio import si_snr, pit, permutate >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] - >>> best_metric, best_perm = pit(preds, target, si_snr, 'min') + >>> best_metric, best_perm = pit(preds, target, si_snr, 'max') >>> best_metric tensor([-29.3482, -11.2862, -9.2508]) >>> best_perm @@ -133,9 +133,8 @@ def permutate(preds: Tensor, perm: Tensor) -> Tensor: >>> from torchmetrics.functional.audio import si_snr, pit, permutate >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] - >>> best_metric, best_perm = pit(preds, target, si_snr, 'min') + >>> best_metric, best_perm = pit(preds, target, si_snr, 'max') >>> preds_pmted = permutate(preds, best_perm) - """ preds_pmted = torch.stack([torch.index_select(pred, 0, p) for pred, p in zip(preds, perm)]) return preds_pmted From 1c722de40012be261189ed20b3d20955a956c231 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 05:39:23 +0800 Subject: [PATCH 07/56] add test --- tests/audio/test_pit.py | 154 +++++++++++++++++++++++++++ torchmetrics/__init__.py | 2 +- torchmetrics/audio/__init__.py | 1 + torchmetrics/functional/audio/pit.py | 10 +- 4 files changed, 161 insertions(+), 6 deletions(-) create mode 100644 tests/audio/test_pit.py diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py new file mode 100644 index 00000000000..c188590dbc3 --- /dev/null +++ b/tests/audio/test_pit.py @@ -0,0 +1,154 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from enum import auto +from functools import partial +from torchmetrics.functional.audio.si_snr import si_snr +from typing import Callable + +import pytest +import torch +import numpy as np +from torch import Tensor + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.audio import PIT +from torchmetrics.functional import snr, si_sdr, pit +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from scipy.optimize import linear_sum_assignment + +seed_all(42) + +Time = 10 + +Input = namedtuple('Input', ["preds", "target"]) + +inputs1 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, Time), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, Time), +) +inputs2 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, Time), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, Time), +) + + +def scipy_version(preds: Tensor, target: Tensor, metric_func: Callable, eval_func: str): + batch_size, spk_num = target.shape[0:2] + metric_mtx = torch.empty((batch_size, spk_num, spk_num), device=target.device) + for t in range(spk_num): + for e in range(spk_num): + metric_mtx[:, t, e] = metric_func(preds[:, e, ...], target[:, t, ...]) + + # pit_r = PIT(metric_func, eval_func)(preds, target) + metric_mtx = metric_mtx.detach().cpu().numpy() + best_metrics = [] + best_perms = [] + for b in range(batch_size): + row_idx, col_idx = linear_sum_assignment(metric_mtx[b, ...], eval_func == 'max') + best_metrics.append(metric_mtx[b, row_idx, col_idx].mean()) + best_perms.append(col_idx) + return torch.from_numpy(np.stack(best_metrics)), torch.from_numpy(np.stack(best_perms)) + + +def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + return metric_func(preds, target)[0].mean() + + +snr_pit_scipy = partial(scipy_version, metric_func=snr, eval_func='max') +si_sdr_pit_scipy = partial(scipy_version, metric_func=si_sdr, eval_func='max') + + +@pytest.mark.parametrize( + "preds, target, sk_metric, metric_func, eval_func", + [ + (inputs1.preds, inputs1.target, snr_pit_scipy, snr, 'max'), + (inputs1.preds, inputs1.target, si_sdr_pit_scipy, si_sdr, 'max'), + (inputs2.preds, inputs2.target, snr_pit_scipy, snr, 'max'), + (inputs2.preds, inputs2.target, si_sdr_pit_scipy, si_sdr, 'max'), + ], +) +class TestPIT(MetricTester): + atol = 1e-2 + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_pit(self, preds, target, sk_metric, metric_func, eval_func, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + PIT, + sk_metric=partial(average_metric, metric_func=sk_metric), + dist_sync_on_step=dist_sync_on_step, + metric_args=dict(metric_func=metric_func, eval_func=eval_func), + ) + + def test_pit_functional(self, preds, target, sk_metric, metric_func, eval_func): + device = 'cuda' if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else 'cpu' + + # move to device + preds = preds.to(device) + target = target.to(device) + + for i in range(NUM_BATCHES): + best_metric, best_perm = pit(preds[i], target[i], metric_func, eval_func) + best_metric_sk, best_perm_sk = sk_metric(preds[i].cpu(), target[i].cpu()) + + # assert its the same + assert np.allclose(best_metric.detach().cpu().numpy(), best_metric_sk.detach().cpu().numpy(), atol=self.atol) + assert (best_perm.detach().cpu().numpy() == best_perm_sk.detach().cpu().numpy()).all() + + def test_pit_differentiability(self, preds, target, sk_metric, metric_func, eval_func): + + def pit_diff(preds, target, metric_func, eval_func): + return pit(preds, target, metric_func, eval_func)[0] + + self.run_differentiability_test(preds=preds, target=target, metric_module=PIT, metric_functional=pit_diff, metric_args={'metric_func': metric_func, 'eval_func': eval_func}) + + @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6') + def test_pit_half_cpu(self, preds, target, sk_metric, metric_func, eval_func): + pytest.xfail("PIT metric does not support cpu + half precision") + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') + def test_pit_half_gpu(self, preds, target, sk_metric, metric_func, eval_func): + self.run_precision_test_gpu(preds=preds, + target=target, + metric_module=PIT, + metric_functional=partial(pit, metric_func=metric_func, eval_func=eval_func), + metric_args={ + 'metric_func': metric_func, + 'eval_func': eval_func + }) + + +def test_error_on_different_shape() -> None: + metric = PIT(snr, 'max') + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(3, 3, 10), torch.randn(3, 2, 10)) + + +def test_error_on_wrong_eval_func() -> None: + metric = PIT(snr, 'xxx') + with pytest.raises(RuntimeError, match='eval_func can only be "max" or "min"'): + metric(torch.randn(3, 3, 10), torch.randn(3, 3, 10)) + + +def test_error_on_wrong_shape() -> None: + metric = PIT(snr, 'max') + with pytest.raises(RuntimeError, match='Inputs must be of shape *'): + metric(torch.randn(3), torch.randn(3)) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 430552fa961..8c81bec3fb7 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -11,7 +11,7 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -from torchmetrics.audio import SI_SDR, SI_SNR, SNR # noqa: F401 E402 +from torchmetrics.audio import PIT, SI_SDR, SI_SNR, SNR # noqa: F401 E402 from torchmetrics.average import AverageMeter # noqa: F401 E402 from torchmetrics.classification import ( # noqa: F401 E402 AUC, diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py index 1ca805ab15a..94d20384e0a 100644 --- a/torchmetrics/audio/__init__.py +++ b/torchmetrics/audio/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.audio.pit import PIT # noqa: F401 from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401 from torchmetrics.audio.si_snr import SI_SNR # noqa: F401 from torchmetrics.audio.snr import SNR # noqa: F401 diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 3dd9729fedd..0f723e04fb8 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -23,7 +23,7 @@ _ps_idx_dict: dict = {} # cache -def _find_best_perm_by_hungarian_method(metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max]): +def _find_best_perm_by_linear_sum_assignment(metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max]): mmtx = metric_mtx.detach().cpu() best_perm = torch.tensor([linear_sum_assignment(pwm, eval_func == torch.max)[1] for pwm in mmtx]).to(metric_mtx.device) best_metric = torch.gather(metric_mtx, 2, best_perm[:, :, None]).mean([-1, -2]) @@ -98,13 +98,13 @@ def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_f """ _check_same_shape(preds, target) if eval_func not in ['max', 'min']: - raise TypeError('eval_func can only be "max" or "min"') + raise RuntimeError('eval_func can only be "max" or "min"') if len(target.shape) < 2: - raise TypeError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") + raise RuntimeError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") # calculate the metric matrix batch_size, spk_num = target.shape[0:2] - metric_mtx = torch.empty((batch_size, spk_num, spk_num), device=target.device) + metric_mtx = torch.empty((batch_size, spk_num, spk_num), dtype=preds.dtype, device=target.device) for t in range(spk_num): for e in range(spk_num): metric_mtx[:, t, e] = metric_func(preds[:, e, ...], target[:, t, ...], **kwargs) @@ -113,7 +113,7 @@ def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_f if spk_num < 3: best_metric, best_perm = _find_best_perm_by_exhuastive_method(metric_mtx, torch.max if eval_func == 'max' else torch.min) else: - best_metric, best_perm = _find_best_perm_by_hungarian_method(metric_mtx, torch.max if eval_func == 'max' else torch.min) + best_metric, best_perm = _find_best_perm_by_linear_sum_assignment(metric_mtx, torch.max if eval_func == 'max' else torch.min) return best_metric, best_perm From 123a6869ab3d37b098b5a88b21bb97658c5d56cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 17 Jul 2021 21:46:01 +0000 Subject: [PATCH 08/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/audio/test_pit.py | 45 +++++++++++++++-------- torchmetrics/audio/pit.py | 4 +- torchmetrics/functional/__init__.py | 2 +- torchmetrics/functional/audio/__init__.py | 2 +- torchmetrics/functional/audio/pit.py | 37 +++++++++++++------ 5 files changed, 59 insertions(+), 31 deletions(-) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index c188590dbc3..f17713f4e1e 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -14,20 +14,20 @@ from collections import namedtuple from enum import auto from functools import partial -from torchmetrics.functional.audio.si_snr import si_snr from typing import Callable +import numpy as np import pytest import torch -import numpy as np +from scipy.optimize import linear_sum_assignment from torch import Tensor from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester from torchmetrics.audio import PIT -from torchmetrics.functional import snr, si_sdr, pit +from torchmetrics.functional import pit, si_sdr, snr +from torchmetrics.functional.audio.si_snr import si_snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 -from scipy.optimize import linear_sum_assignment seed_all(42) @@ -110,7 +110,9 @@ def test_pit_functional(self, preds, target, sk_metric, metric_func, eval_func): best_metric_sk, best_perm_sk = sk_metric(preds[i].cpu(), target[i].cpu()) # assert its the same - assert np.allclose(best_metric.detach().cpu().numpy(), best_metric_sk.detach().cpu().numpy(), atol=self.atol) + assert np.allclose( + best_metric.detach().cpu().numpy(), best_metric_sk.detach().cpu().numpy(), atol=self.atol + ) assert (best_perm.detach().cpu().numpy() == best_perm_sk.detach().cpu().numpy()).all() def test_pit_differentiability(self, preds, target, sk_metric, metric_func, eval_func): @@ -118,22 +120,35 @@ def test_pit_differentiability(self, preds, target, sk_metric, metric_func, eval def pit_diff(preds, target, metric_func, eval_func): return pit(preds, target, metric_func, eval_func)[0] - self.run_differentiability_test(preds=preds, target=target, metric_module=PIT, metric_functional=pit_diff, metric_args={'metric_func': metric_func, 'eval_func': eval_func}) + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=PIT, + metric_functional=pit_diff, + metric_args={ + 'metric_func': metric_func, + 'eval_func': eval_func + } + ) - @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6') + @pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' + ) def test_pit_half_cpu(self, preds, target, sk_metric, metric_func, eval_func): pytest.xfail("PIT metric does not support cpu + half precision") @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') def test_pit_half_gpu(self, preds, target, sk_metric, metric_func, eval_func): - self.run_precision_test_gpu(preds=preds, - target=target, - metric_module=PIT, - metric_functional=partial(pit, metric_func=metric_func, eval_func=eval_func), - metric_args={ - 'metric_func': metric_func, - 'eval_func': eval_func - }) + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=PIT, + metric_functional=partial(pit, metric_func=metric_func, eval_func=eval_func), + metric_args={ + 'metric_func': metric_func, + 'eval_func': eval_func + } + ) def test_error_on_different_shape() -> None: diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py index c40721222e6..0be58bece8d 100644 --- a/torchmetrics/audio/pit.py +++ b/torchmetrics/audio/pit.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union, Callable, Optional, Dict +from typing import Any, Callable, Dict, Optional, Union import torch from torch import Tensor, tensor -from torchmetrics.functional.audio.pit import pit, permutate +from torchmetrics.functional.audio.pit import permutate, pit from torchmetrics.metric import Metric diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 811ea834a5b..4c397a9beda 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.audio.pit import pit, permutate # noqa: F401 +from torchmetrics.functional.audio.pit import permutate, pit # noqa: F401 from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 from torchmetrics.functional.audio.snr import snr # noqa: F401 diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py index 1ad0c0c5386..d6f05c3b3c6 100644 --- a/torchmetrics/functional/audio/__init__.py +++ b/torchmetrics/functional/audio/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.audio.pit import pit, permutate # noqa: F401 +from torchmetrics.functional.audio.pit import permutate, pit # noqa: F401 from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 from torchmetrics.functional.audio.snr import snr # noqa: F401 diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 0f723e04fb8..eee9af8135e 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.utilities.checks import _check_same_shape - from itertools import permutations from typing import Callable, List, Tuple, Union + import torch -from torch.tensor import Tensor from scipy.optimize import linear_sum_assignment +from torch.tensor import Tensor + +from torchmetrics.utilities.checks import _check_same_shape _ps_dict: dict = {} # cache _ps_idx_dict: dict = {} # cache @@ -25,7 +26,8 @@ def _find_best_perm_by_linear_sum_assignment(metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max]): mmtx = metric_mtx.detach().cpu() - best_perm = torch.tensor([linear_sum_assignment(pwm, eval_func == torch.max)[1] for pwm in mmtx]).to(metric_mtx.device) + best_perm = torch.tensor([linear_sum_assignment(pwm, eval_func == torch.max)[1] + for pwm in mmtx]).to(metric_mtx.device) best_metric = torch.gather(metric_mtx, 2, best_perm[:, :, None]).mean([-1, -2]) return best_metric, best_perm # shape [batch], shape [batch, spk] @@ -39,7 +41,8 @@ def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, eval_func: Un # all the permutations, shape [perm_num, spk_num] ps = torch.tensor(list(permutations(range(spk_num))), device=metric_mtx.device) # shape [perm_num * spk_num] - inc = torch.arange(0, spk_num * spk_num, step=spk_num, device=metric_mtx.device, dtype=ps.dtype).repeat(ps.shape[0]) + inc = torch.arange(0, spk_num * spk_num, step=spk_num, device=metric_mtx.device, + dtype=ps.dtype).repeat(ps.shape[0]) # the indexes for all permutations, shape [perm_num*spk_num] ps_idx = ps.view(-1) + inc # cache ps and ps_idx @@ -50,7 +53,9 @@ def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, eval_func: Un ps = _ps_dict[key] # all the permutations, shape [perm_num, spk_num] # find the metric of each permutation - metric_of_ps_details = metric_mtx.view(batch_size, -1)[:, ps_idx].reshape(batch_size, -1, spk_num) # shape [batch_size, perm_num, spk_num] + metric_of_ps_details = metric_mtx.view(batch_size, -1)[:, ps_idx].reshape( + batch_size, -1, spk_num + ) # shape [batch_size, perm_num, spk_num] metric_of_ps = metric_of_ps_details.mean(dim=2) # shape [batch_size, perm_num] # find the best metric and best permutation @@ -60,7 +65,11 @@ def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, eval_func: Un return best_metric, best_perm # shape [batch], shape [batch, spk] -def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_func: str = 'max', **kwargs) -> Tuple[Tensor, Tensor]: +def pit(preds: torch.Tensor, + target: torch.Tensor, + metric_func: Callable, + eval_func: str = 'max', + **kwargs) -> Tuple[Tensor, Tensor]: """ Permutation invariant training metric Args: @@ -76,9 +85,9 @@ def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_f additional args for metric_func Returns: - best_metric of shape [batch], + best_metric of shape [batch], best_perm of shape [batch] - + Example: >>> import torch >>> from torchmetrics.functional.audio import si_snr, pit, permutate @@ -111,9 +120,13 @@ def pit(preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_f # find best if spk_num < 3: - best_metric, best_perm = _find_best_perm_by_exhuastive_method(metric_mtx, torch.max if eval_func == 'max' else torch.min) + best_metric, best_perm = _find_best_perm_by_exhuastive_method( + metric_mtx, torch.max if eval_func == 'max' else torch.min + ) else: - best_metric, best_perm = _find_best_perm_by_linear_sum_assignment(metric_mtx, torch.max if eval_func == 'max' else torch.min) + best_metric, best_perm = _find_best_perm_by_linear_sum_assignment( + metric_mtx, torch.max if eval_func == 'max' else torch.min + ) return best_metric, best_perm @@ -127,7 +140,7 @@ def permutate(preds: Tensor, perm: Tensor) -> Tensor: Returns: Tensor: the permutated version of estimate - + Example: >>> import torch >>> from torchmetrics.functional.audio import si_snr, pit, permutate From 5d2977c01ebb825a7eb11415d260cac03e9980bc Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 05:52:49 +0800 Subject: [PATCH 09/56] add --- CHANGELOG.md | 3 +++ docs/source/references/functional.rst | 7 +++++++ docs/source/references/modules.rst | 6 ++++++ 3 files changed, 16 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a8155dd581..12c009560a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Symmetric Mean Absolute Percentage error (SMAPE) ([#375](https://github.com/PyTorchLightning/metrics/issues/375)) +- Added Permutation Invariant Training metric (PIT) ([#294](https://github.com/PyTorchLightning/metrics/issues/294)) + + ### Changed - Moved `psnr` and `ssim` from `torchmetrics.functional.regression.*` to `torchmetrics.functional.image.*` ([#382](https://github.com/PyTorchLightning/metrics/pull/382)) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 92672a3c0d7..3c627e325cc 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -9,6 +9,13 @@ Functional metrics Audio Metrics ************* +pit [func] +~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.pit + :noindex: + + si_sdr [func] ~~~~~~~~~~~~~ diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 3b45600d7c8..af0a79d6aa8 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -40,6 +40,12 @@ the metric will be computed over the ``time`` dimension. >>> snr_val tensor(16.1805) +PIT +~~~~~~ + +.. autoclass:: torchmetrics.PIT + :noindex: + SI_SDR ~~~~~~ From 3b9172fea32f074823d3856282fcbbc4996c2b27 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 06:04:14 +0800 Subject: [PATCH 10/56] fix pep8 --- torchmetrics/audio/pit.py | 15 ++++++---- torchmetrics/functional/audio/pit.py | 45 ++++++++++++++++------------ 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py index 0be58bece8d..0914194250e 100644 --- a/torchmetrics/audio/pit.py +++ b/torchmetrics/audio/pit.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional -import torch from torch import Tensor, tensor -from torchmetrics.functional.audio.pit import permutate, pit +from torchmetrics.functional.audio.pit import pit from torchmetrics.metric import Metric @@ -30,9 +29,11 @@ class PIT(Metric): Args: metric_func: - a metric function accept a batch of target and estimate, i.e. metric_func(target[:, i, ...], estimate[:, j, ...]), and returns a batch of metric tensors [batch] + a metric function accept a batch of target and estimate, i.e. metric_func(target[:, i, ...], + estimate[:, j, ...]), and returns a batch of metric tensors [batch] eval_func: - the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better or the larger the better. + the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better + or the larger the better. compute_on_step: Forward only calls ``update()`` and returns None if this is set to False. default: True dist_sync_on_step: @@ -59,7 +60,9 @@ class PIT(Metric): >>> avg_pit_metric = pit(preds, target) Reference: - [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154. + [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for + speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech + Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154. """ def __init__( diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index eee9af8135e..b82d98330ea 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from itertools import permutations -from typing import Callable, List, Tuple, Union +from typing import Callable, Tuple, Union import torch from scipy.optimize import linear_sum_assignment @@ -26,8 +26,8 @@ def _find_best_perm_by_linear_sum_assignment(metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max]): mmtx = metric_mtx.detach().cpu() - best_perm = torch.tensor([linear_sum_assignment(pwm, eval_func == torch.max)[1] - for pwm in mmtx]).to(metric_mtx.device) + best_perm = torch.tensor([linear_sum_assignment(pwm, eval_func == torch.max)[1] for pwm in mmtx]) + best_perm = best_perm.to(metric_mtx.device) best_metric = torch.gather(metric_mtx, 2, best_perm[:, :, None]).mean([-1, -2]) return best_metric, best_perm # shape [batch], shape [batch, spk] @@ -41,8 +41,7 @@ def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, eval_func: Un # all the permutations, shape [perm_num, spk_num] ps = torch.tensor(list(permutations(range(spk_num))), device=metric_mtx.device) # shape [perm_num * spk_num] - inc = torch.arange(0, spk_num * spk_num, step=spk_num, device=metric_mtx.device, - dtype=ps.dtype).repeat(ps.shape[0]) + inc = torch.arange(0, spk_num**2, step=spk_num, device=metric_mtx.device, dtype=ps.dtype).repeat(ps.shape[0]) # the indexes for all permutations, shape [perm_num*spk_num] ps_idx = ps.view(-1) + inc # cache ps and ps_idx @@ -53,10 +52,10 @@ def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, eval_func: Un ps = _ps_dict[key] # all the permutations, shape [perm_num, spk_num] # find the metric of each permutation - metric_of_ps_details = metric_mtx.view(batch_size, -1)[:, ps_idx].reshape( - batch_size, -1, spk_num - ) # shape [batch_size, perm_num, spk_num] - metric_of_ps = metric_of_ps_details.mean(dim=2) # shape [batch_size, perm_num] + # shape [batch_size, perm_num, spk_num] + metric_of_ps_details = metric_mtx.view(batch_size, -1)[:, ps_idx].reshape(batch_size, -1, spk_num) + # shape [batch_size, perm_num] + metric_of_ps = metric_of_ps_details.mean(dim=2) # find the best metric and best permutation best_metric, best_indexes = eval_func(metric_of_ps, dim=1) @@ -65,11 +64,13 @@ def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, eval_func: Un return best_metric, best_perm # shape [batch], shape [batch, spk] -def pit(preds: torch.Tensor, - target: torch.Tensor, - metric_func: Callable, - eval_func: str = 'max', - **kwargs) -> Tuple[Tensor, Tensor]: +def pit( + preds: torch.Tensor, + target: torch.Tensor, + metric_func: Callable, + eval_func: str = 'max', + **kwargs, +) -> Tuple[Tensor, Tensor]: """ Permutation invariant training metric Args: @@ -78,9 +79,11 @@ def pit(preds: torch.Tensor, preds: shape [batch, spk, ...] metric_func: - a metric function accept a batch of target and estimate, i.e. metric_func(target[:, i, ...], estimate[:, j, ...]), and returns a batch of metric tensors [batch] + a metric function accept a batch of target and estimate, + i.e. metric_func(target[:, i, ...], estimate[:, j, ...]), and returns a batch of metric tensors [batch] eval_func: - the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better or the larger the better. + the function to find the best permutation, can be 'min' or 'max', + i.e. the smaller the better or the larger the better. kwargs: additional args for metric_func @@ -103,7 +106,9 @@ def pit(preds: torch.Tensor, >>> preds_pmted = permutate(preds, best_perm) Reference: - [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154. + [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for + speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech + Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154. """ _check_same_shape(preds, target) if eval_func not in ['max', 'min']: @@ -121,11 +126,13 @@ def pit(preds: torch.Tensor, # find best if spk_num < 3: best_metric, best_perm = _find_best_perm_by_exhuastive_method( - metric_mtx, torch.max if eval_func == 'max' else torch.min + metric_mtx, + torch.max if eval_func == 'max' else torch.min, ) else: best_metric, best_perm = _find_best_perm_by_linear_sum_assignment( - metric_mtx, torch.max if eval_func == 'max' else torch.min + metric_mtx, + torch.max if eval_func == 'max' else torch.min, ) return best_metric, best_perm From 2c01d18410f9169ed36f97241c7f280bd3ef3229 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 17 Jul 2021 22:04:49 +0000 Subject: [PATCH 11/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/audio/pit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py index 0914194250e..efcbf25cc12 100644 --- a/torchmetrics/audio/pit.py +++ b/torchmetrics/audio/pit.py @@ -29,7 +29,7 @@ class PIT(Metric): Args: metric_func: - a metric function accept a batch of target and estimate, i.e. metric_func(target[:, i, ...], + a metric function accept a batch of target and estimate, i.e. metric_func(target[:, i, ...], estimate[:, j, ...]), and returns a batch of metric tensors [batch] eval_func: the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better From 7c2078f5240c7825aa76d316cddae493d42861a6 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 06:06:19 +0800 Subject: [PATCH 12/56] fix pep8 --- tests/audio/test_pit.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index f17713f4e1e..4e6381a4db5 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple -from enum import auto from functools import partial from typing import Callable @@ -26,7 +25,6 @@ from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester from torchmetrics.audio import PIT from torchmetrics.functional import pit, si_sdr, snr -from torchmetrics.functional.audio.si_snr import si_snr from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) From 85daf5c751ba26a082f895286a2d8e52650c6d42 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 06:10:46 +0800 Subject: [PATCH 13/56] fix mypy --- torchmetrics/functional/audio/pit.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index b82d98330ea..d4c26b71892 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from itertools import permutations -from typing import Callable, Tuple, Union +from typing import Any, Callable, Dict, Tuple, Union import torch from scipy.optimize import linear_sum_assignment @@ -24,7 +24,9 @@ _ps_idx_dict: dict = {} # cache -def _find_best_perm_by_linear_sum_assignment(metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max]): +def _find_best_perm_by_linear_sum_assignment( + metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max], +) -> Tuple[Tensor, Tensor]: mmtx = metric_mtx.detach().cpu() best_perm = torch.tensor([linear_sum_assignment(pwm, eval_func == torch.max)[1] for pwm in mmtx]) best_perm = best_perm.to(metric_mtx.device) @@ -32,7 +34,9 @@ def _find_best_perm_by_linear_sum_assignment(metric_mtx: torch.Tensor, eval_func return best_metric, best_perm # shape [batch], shape [batch, spk] -def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max]): +def _find_best_perm_by_exhuastive_method( + metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max] +) -> Tuple[Tensor, Tensor]: # create/read/cache the permutations and its indexes # reading from cache would be much faster than creating in CPU then moving to GPU batch_size, spk_num = metric_mtx.shape[:2] @@ -65,11 +69,7 @@ def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, eval_func: Un def pit( - preds: torch.Tensor, - target: torch.Tensor, - metric_func: Callable, - eval_func: str = 'max', - **kwargs, + preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_func: str = 'max', **kwargs: Dict[str, Any] ) -> Tuple[Tensor, Tensor]: """ Permutation invariant training metric From e1298961f674653c1746b1814adb88258602c556 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 06:14:54 +0800 Subject: [PATCH 14/56] add scipy --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index b47541f6e94..a1d8b2a6f11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.17.2 torch>=1.3.1 packaging +scipy From bcf43a37a3dcf3720e4cfbef2f54fed8245cbc12 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 17 Jul 2021 22:15:51 +0000 Subject: [PATCH 15/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/pit.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index d4c26b71892..fce81dd3810 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -25,7 +25,8 @@ def _find_best_perm_by_linear_sum_assignment( - metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max], + metric_mtx: torch.Tensor, + eval_func: Union[torch.min, torch.max], ) -> Tuple[Tensor, Tensor]: mmtx = metric_mtx.detach().cpu() best_perm = torch.tensor([linear_sum_assignment(pwm, eval_func == torch.max)[1] for pwm in mmtx]) @@ -34,9 +35,8 @@ def _find_best_perm_by_linear_sum_assignment( return best_metric, best_perm # shape [batch], shape [batch, spk] -def _find_best_perm_by_exhuastive_method( - metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max] -) -> Tuple[Tensor, Tensor]: +def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, + eval_func: Union[torch.min, torch.max]) -> Tuple[Tensor, Tensor]: # create/read/cache the permutations and its indexes # reading from cache would be much faster than creating in CPU then moving to GPU batch_size, spk_num = metric_mtx.shape[:2] @@ -69,7 +69,11 @@ def _find_best_perm_by_exhuastive_method( def pit( - preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_func: str = 'max', **kwargs: Dict[str, Any] + preds: torch.Tensor, + target: torch.Tensor, + metric_func: Callable, + eval_func: str = 'max', + **kwargs: Dict[str, Any] ) -> Tuple[Tensor, Tensor]: """ Permutation invariant training metric From 801b80c43ac6a40334437260bd1a3f4d99b26707 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 06:20:52 +0800 Subject: [PATCH 16/56] fix --- torchmetrics/functional/audio/pit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index d4c26b71892..c2a344dedc4 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -15,8 +15,9 @@ from typing import Any, Callable, Dict, Tuple, Union import torch +from torch import Tensor + from scipy.optimize import linear_sum_assignment -from torch.tensor import Tensor from torchmetrics.utilities.checks import _check_same_shape From f9f770f3a7632587487f8da8b0c2b744689a3cc7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 17 Jul 2021 22:21:28 +0000 Subject: [PATCH 17/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/pit.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 0774d82f907..0615cf33832 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -15,9 +15,8 @@ from typing import Any, Callable, Dict, Tuple, Union import torch -from torch import Tensor - from scipy.optimize import linear_sum_assignment +from torch import Tensor from torchmetrics.utilities.checks import _check_same_shape From dee14ea71b8bd79714bcfcf42c22693070411ea0 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 06:38:24 +0800 Subject: [PATCH 18/56] fix doctest --- torchmetrics/functional/audio/pit.py | 34 ++++++++++++++++++---------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 0774d82f907..dcb5a6f3598 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -98,17 +98,19 @@ def pit( Example: >>> import torch - >>> from torchmetrics.functional.audio import si_snr, pit, permutate - >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] - >>> target = torch.randn(3, 2, 5) # [batch, spk, time] - >>> best_metric, best_perm = pit(preds, target, si_snr, 'max') + >>> from torchmetrics.functional.audio import si_sdr, pit, permutate + >>> # [batch, spk, time] + >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) + >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) + >>> best_metric, best_perm = pit(preds, target, si_sdr, 'max') >>> best_metric - tensor([-29.3482, -11.2862, -9.2508]) + tensor([-5.1089]) >>> best_perm - tensor([[0, 1], - [1, 0], - [0, 1]]) + tensor([[0, 1]]) >>> preds_pmted = permutate(preds, best_perm) + >>> preds_pmted + tensor([[[-0.0579, 0.3560, -0.9604], + [-0.1719, 0.3205, 0.2951]]]) Reference: [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for @@ -155,11 +157,19 @@ def permutate(preds: Tensor, perm: Tensor) -> Tensor: Example: >>> import torch - >>> from torchmetrics.functional.audio import si_snr, pit, permutate - >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] - >>> target = torch.randn(3, 2, 5) # [batch, spk, time] - >>> best_metric, best_perm = pit(preds, target, si_snr, 'max') + >>> from torchmetrics.functional.audio import si_sdr, pit, permutate + >>> # [batch, spk, time] + >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) + >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) + >>> best_metric, best_perm = pit(preds, target, si_sdr, 'max') + >>> best_metric + tensor([-5.1089]) + >>> best_perm + tensor([[0, 1]]) >>> preds_pmted = permutate(preds, best_perm) + >>> preds_pmted + tensor([[[-0.0579, 0.3560, -0.9604], + [-0.1719, 0.3205, 0.2951]]]) """ preds_pmted = torch.stack([torch.index_select(pred, 0, p) for pred, p in zip(preds, perm)]) return preds_pmted From de37c928f06d3d08b7e0a0fcfbb5dec89b930d62 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 06:43:56 +0800 Subject: [PATCH 19/56] fix doctest --- torchmetrics/functional/audio/pit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 72f63ff8f0d..90e4195d71d 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -103,7 +103,7 @@ def pit( >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) >>> best_metric, best_perm = pit(preds, target, si_sdr, 'max') >>> best_metric - tensor([-5.1089]) + tensor([-5.1091]) >>> best_perm tensor([[0, 1]]) >>> preds_pmted = permutate(preds, best_perm) @@ -162,7 +162,7 @@ def permutate(preds: Tensor, perm: Tensor) -> Tensor: >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) >>> best_metric, best_perm = pit(preds, target, si_sdr, 'max') >>> best_metric - tensor([-5.1089]) + tensor([-5.1091]) >>> best_perm tensor([[0, 1]]) >>> preds_pmted = permutate(preds, best_perm) From 05e061bf4cd2db98fdf574df44ce620d6a35ef33 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 06:50:03 +0800 Subject: [PATCH 20/56] fix doctest --- torchmetrics/functional/audio/pit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 90e4195d71d..0969fbcaa58 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -109,7 +109,7 @@ def pit( >>> preds_pmted = permutate(preds, best_perm) >>> preds_pmted tensor([[[-0.0579, 0.3560, -0.9604], - [-0.1719, 0.3205, 0.2951]]]) + [-0.1719, 0.3205, 0.2951]]]) Reference: [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for @@ -168,7 +168,7 @@ def permutate(preds: Tensor, perm: Tensor) -> Tensor: >>> preds_pmted = permutate(preds, best_perm) >>> preds_pmted tensor([[[-0.0579, 0.3560, -0.9604], - [-0.1719, 0.3205, 0.2951]]]) + [-0.1719, 0.3205, 0.2951]]]) """ preds_pmted = torch.stack([torch.index_select(pred, 0, p) for pred, p in zip(preds, perm)]) return preds_pmted From ae12aa0bb2090024798d33ea746ad29c0415f61b Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 20:27:45 +0800 Subject: [PATCH 21/56] remove scipy dep & add warning --- requirements.txt | 1 - torchmetrics/functional/audio/pit.py | 10 +++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index a1d8b2a6f11..b47541f6e94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ numpy>=1.17.2 torch>=1.3.1 packaging -scipy diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 0969fbcaa58..ca69ccfa817 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -18,11 +18,15 @@ from scipy.optimize import linear_sum_assignment from torch import Tensor +from importlib.util import find_spec + from torchmetrics.utilities.checks import _check_same_shape _ps_dict: dict = {} # cache _ps_idx_dict: dict = {} # cache +_has_scipy = find_spec("scipy") + def _find_best_perm_by_linear_sum_assignment( metric_mtx: torch.Tensor, @@ -130,7 +134,11 @@ def pit( metric_mtx[:, t, e] = metric_func(preds[:, e, ...], target[:, t, ...], **kwargs) # find best - if spk_num < 3: + if spk_num < 3 or _has_scipy is None: + if spk_num >= 3 and _has_scipy is None: + import warnings + warnings.warn(f"Speaker-num is {spk_num}>3, you'd better install scipy to improve pit's performance (we have better implementation based on scipy)!") + best_metric, best_perm = _find_best_perm_by_exhuastive_method( metric_mtx, torch.max if eval_func == 'max' else torch.min, From 112bc16458cdf20c6d1a987b48a5c9fd1805270f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Jul 2021 12:28:21 +0000 Subject: [PATCH 22/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/pit.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index ca69ccfa817..ba94ea30fac 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from importlib.util import find_spec from itertools import permutations from typing import Any, Callable, Dict, Tuple, Union @@ -18,8 +19,6 @@ from scipy.optimize import linear_sum_assignment from torch import Tensor -from importlib.util import find_spec - from torchmetrics.utilities.checks import _check_same_shape _ps_dict: dict = {} # cache @@ -137,7 +136,9 @@ def pit( if spk_num < 3 or _has_scipy is None: if spk_num >= 3 and _has_scipy is None: import warnings - warnings.warn(f"Speaker-num is {spk_num}>3, you'd better install scipy to improve pit's performance (we have better implementation based on scipy)!") + warnings.warn( + f"Speaker-num is {spk_num}>3, you'd better install scipy to improve pit's performance (we have better implementation based on scipy)!" + ) best_metric, best_perm = _find_best_perm_by_exhuastive_method( metric_mtx, From dc5e6892825d8b5dda42e1f43ea8094f4792b384 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 20:30:11 +0800 Subject: [PATCH 23/56] change warn --- torchmetrics/functional/audio/pit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index ca69ccfa817..aff01f284f8 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -137,7 +137,7 @@ def pit( if spk_num < 3 or _has_scipy is None: if spk_num >= 3 and _has_scipy is None: import warnings - warnings.warn(f"Speaker-num is {spk_num}>3, you'd better install scipy to improve pit's performance (we have better implementation based on scipy)!") + warnings.warn(f"Speaker-num is {spk_num}>3, you'd better install scipy to improve pit's performance!") best_metric, best_perm = _find_best_perm_by_exhuastive_method( metric_mtx, From 208b4791b85d865060f42214c7cc54b4957bb839 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 20:34:05 +0800 Subject: [PATCH 24/56] add --- torchmetrics/audio/pit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py index efcbf25cc12..d430117a980 100644 --- a/torchmetrics/audio/pit.py +++ b/torchmetrics/audio/pit.py @@ -64,6 +64,8 @@ class PIT(Metric): speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154. """ + sum_pit_metric: Tensor + total: Tensor def __init__( self, From 736e354e0c1bc58c7848665bc27f40f401151164 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Sun, 18 Jul 2021 20:39:34 +0800 Subject: [PATCH 25/56] move scipy import to inner function --- torchmetrics/functional/audio/pit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index fee9989cdaa..d378187ad4f 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -16,7 +16,6 @@ from typing import Any, Callable, Dict, Tuple, Union import torch -from scipy.optimize import linear_sum_assignment from torch import Tensor from torchmetrics.utilities.checks import _check_same_shape @@ -31,6 +30,7 @@ def _find_best_perm_by_linear_sum_assignment( metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max], ) -> Tuple[Tensor, Tensor]: + from scipy.optimize import linear_sum_assignment mmtx = metric_mtx.detach().cpu() best_perm = torch.tensor([linear_sum_assignment(pwm, eval_func == torch.max)[1] for pwm in mmtx]) best_perm = best_perm.to(metric_mtx.device) From 271d899fbdeb0e03147bfe30d44b53c25ac84846 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 20 Jul 2021 13:38:54 +0200 Subject: [PATCH 26/56] Apply suggestions from code review --- docs/source/references/functional.rst | 2 +- torchmetrics/functional/audio/pit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 8de95e6d394..888b50611e5 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -10,7 +10,7 @@ Audio Metrics ************* pit [func] -~~~~~~~~~~~~~ +~~~~~~~~~~ .. autofunction:: torchmetrics.functional.pit :noindex: diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index d378187ad4f..eea99d9fe0b 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -39,7 +39,7 @@ def _find_best_perm_by_linear_sum_assignment( def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, - eval_func: Union[torch.min, torch.max]) -> Tuple[Tensor, Tensor]: + eval_func: Union[torch.min, torch.max],) -> Tuple[Tensor, Tensor]: # create/read/cache the permutations and its indexes # reading from cache would be much faster than creating in CPU then moving to GPU batch_size, spk_num = metric_mtx.shape[:2] From 2e1ae3f3f4ca80e8cef7845dcf875e120ebe4be0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Jul 2021 11:39:19 +0000 Subject: [PATCH 27/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/pit.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index eea99d9fe0b..f5d0fb78663 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -38,8 +38,10 @@ def _find_best_perm_by_linear_sum_assignment( return best_metric, best_perm # shape [batch], shape [batch, spk] -def _find_best_perm_by_exhuastive_method(metric_mtx: torch.Tensor, - eval_func: Union[torch.min, torch.max],) -> Tuple[Tensor, Tensor]: +def _find_best_perm_by_exhuastive_method( + metric_mtx: torch.Tensor, + eval_func: Union[torch.min, torch.max], +) -> Tuple[Tensor, Tensor]: # create/read/cache the permutations and its indexes # reading from cache would be much faster than creating in CPU then moving to GPU batch_size, spk_num = metric_mtx.shape[:2] From 105957524666a2e73fc1bbdb7beaa9e4de67e7c4 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Wed, 21 Jul 2021 00:38:52 +0800 Subject: [PATCH 28/56] simplyfied --- torchmetrics/functional/audio/pit.py | 30 +++++++++++++--------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index f5d0fb78663..e1f18171c42 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -20,8 +20,9 @@ from torchmetrics.utilities.checks import _check_same_shape -_ps_dict: dict = {} # cache -_ps_idx_dict: dict = {} # cache +# _ps_dict: cache of permutations +# it's necessary to cache it, otherwise it will consume a large amount of time +_ps_dict: dict = {} # _ps_dict[str(spk_num)+str(device)] = permutations _has_scipy = find_spec("scipy") @@ -47,29 +48,26 @@ def _find_best_perm_by_exhuastive_method( batch_size, spk_num = metric_mtx.shape[:2] key = str(spk_num) + str(metric_mtx.device) if key not in _ps_dict: - # all the permutations, shape [perm_num, spk_num] - ps = torch.tensor(list(permutations(range(spk_num))), device=metric_mtx.device) - # shape [perm_num * spk_num] - inc = torch.arange(0, spk_num**2, step=spk_num, device=metric_mtx.device, dtype=ps.dtype).repeat(ps.shape[0]) - # the indexes for all permutations, shape [perm_num*spk_num] - ps_idx = ps.view(-1) + inc - # cache ps and ps_idx - _ps_idx_dict[key] = ps_idx + # ps: all the permutations, shape [spk_num, perm_num] + # ps: In i-th permutation, the predcition corresponds to the j-th target is ps[j,i] + ps = torch.tensor(list(permutations(range(spk_num))), device=metric_mtx.device).T _ps_dict[key] = ps else: - ps_idx = _ps_idx_dict[key] # the indexes for all permutations, shape [perm_num*spk_num] - ps = _ps_dict[key] # all the permutations, shape [perm_num, spk_num] + ps = _ps_dict[key] # all the permutations, shape [spk_num, perm_num] # find the metric of each permutation - # shape [batch_size, perm_num, spk_num] - metric_of_ps_details = metric_mtx.view(batch_size, -1)[:, ps_idx].reshape(batch_size, -1, spk_num) + perm_num = ps.shape[-1] + # shape [batch_size, spk_num, perm_num] + bps = ps[None, ...].expand(batch_size, spk_num, perm_num) + # shape [batch_size, spk_num, perm_num] + metric_of_ps_details = torch.gather(metric_mtx, 2, bps) # shape [batch_size, perm_num] - metric_of_ps = metric_of_ps_details.mean(dim=2) + metric_of_ps = metric_of_ps_details.mean(dim=1) # find the best metric and best permutation best_metric, best_indexes = eval_func(metric_of_ps, dim=1) best_indexes = best_indexes.detach() - best_perm = ps[best_indexes, :] + best_perm = ps.T[best_indexes, :] return best_metric, best_perm # shape [batch], shape [batch, spk] From 07ecb2691ece52af4fd5ff4fbfb58e0d1b26bcaf Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Wed, 21 Jul 2021 00:48:04 +0800 Subject: [PATCH 29/56] add test_consistency_of_two_implementations --- tests/audio/test_pit.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index 4e6381a4db5..f0a2329bbf3 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -165,3 +165,15 @@ def test_error_on_wrong_shape() -> None: metric = PIT(snr, 'max') with pytest.raises(RuntimeError, match='Inputs must be of shape *'): metric(torch.randn(3), torch.randn(3)) + + +def test_consistency_of_two_implementations() -> None: + from torchmetrics.functional.audio.pit import _find_best_perm_by_exhuastive_method + from torchmetrics.functional.audio.pit import _find_best_perm_by_linear_sum_assignment + shapes_test = [(5, 2, 2), (4, 3, 3), (4, 4, 4), (3, 5, 5)] + for shp in shapes_test: + metric_mtx = torch.randn(size=shp) + bm1, bp1 = _find_best_perm_by_linear_sum_assignment(metric_mtx, torch.max) + bm2, bp2 = _find_best_perm_by_exhuastive_method(metric_mtx, torch.max) + assert torch.allclose(bm1, bm2) + assert (bp1 == bp2).all() From 16d70f9385253fdf89f4a0acb0432105dc0080d6 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Wed, 21 Jul 2021 00:53:23 +0800 Subject: [PATCH 30/56] use _SCIPY_AVAILABLE --- torchmetrics/functional/audio/pit.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index e1f18171c42..f3db17dd322 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from importlib.util import find_spec from itertools import permutations from typing import Any, Callable, Dict, Tuple, Union - +from torchmetrics.utilities.imports import _SCIPY_AVAILABLE import torch from torch import Tensor @@ -24,8 +23,6 @@ # it's necessary to cache it, otherwise it will consume a large amount of time _ps_dict: dict = {} # _ps_dict[str(spk_num)+str(device)] = permutations -_has_scipy = find_spec("scipy") - def _find_best_perm_by_linear_sum_assignment( metric_mtx: torch.Tensor, @@ -133,8 +130,8 @@ def pit( metric_mtx[:, t, e] = metric_func(preds[:, e, ...], target[:, t, ...], **kwargs) # find best - if spk_num < 3 or _has_scipy is None: - if spk_num >= 3 and _has_scipy is None: + if spk_num < 3 or _SCIPY_AVAILABLE is None: + if spk_num >= 3 and _SCIPY_AVAILABLE is None: import warnings warnings.warn(f"Speaker-num is {spk_num}>3, you'd better install scipy to improve pit's performance!") From b452792349e84d5b4e3d1f2404fcc2bc5a35a64e Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Wed, 21 Jul 2021 01:04:51 +0800 Subject: [PATCH 31/56] add docstring --- torchmetrics/functional/audio/pit.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index f3db17dd322..d2f8073fcc0 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -40,6 +40,20 @@ def _find_best_perm_by_exhuastive_method( metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max], ) -> Tuple[Tensor, Tensor]: + """Find the best metric value and permutation of each prediction and target pair in one batch using a exhuastive method + + Args: + metric_mtx: + the metric matrix, shape [batch_size, spk_num, spk_num] + eval_func: + the function to reduce the permutations of one pair + + Returns: + best_metric: + shape [batch] + best_perm: + shape [batch, spk] + """ # create/read/cache the permutations and its indexes # reading from cache would be much faster than creating in CPU then moving to GPU batch_size, spk_num = metric_mtx.shape[:2] From 6367adc28b53b511cfd63b9958fd7156854adbf0 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Wed, 21 Jul 2021 01:08:13 +0800 Subject: [PATCH 32/56] to TIME --- tests/audio/test_pit.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index f0a2329bbf3..8cde57c98c2 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -29,17 +29,17 @@ seed_all(42) -Time = 10 +TIME = 10 Input = namedtuple('Input', ["preds", "target"]) inputs1 = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, Time), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, Time), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME), ) inputs2 = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, Time), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, Time), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME), ) From 29744ab59371fdef6fd1e3b7b803b2a72b6078a8 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Wed, 21 Jul 2021 01:21:45 +0800 Subject: [PATCH 33/56] a_naive_implementation_of_pit_based_on_scipy --- tests/audio/test_pit.py | 43 ++++++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index 8cde57c98c2..b83660cbc9f 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import namedtuple from functools import partial -from typing import Callable +from typing import Callable, Tuple import numpy as np import pytest @@ -33,17 +33,38 @@ Input = namedtuple('Input', ["preds", "target"]) +# three speaker examples to test _find_best_perm_by_linear_sum_assignment inputs1 = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME), target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME), ) +# two speaker examples to test _find_best_perm_by_exhuastive_method inputs2 = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME), target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME), ) -def scipy_version(preds: Tensor, target: Tensor, metric_func: Callable, eval_func: str): +def a_naive_implementation_of_pit_based_on_scipy( + preds: Tensor, + target: Tensor, + metric_func: Callable, + eval_func: str, +) -> Tuple[Tensor, Tensor]: + """a naive implementation of pit based on scipy + + Args: + preds: predictions, shape[batch, spk, time] + target: targets, shape[batch, spk, time] + metric_func: which metric + eval_func: min or max + + Returns: + best_metric: + shape [batch] + best_perm: + shape [batch, spk] + """ batch_size, spk_num = target.shape[0:2] metric_mtx = torch.empty((batch_size, spk_num, spk_num), device=target.device) for t in range(spk_num): @@ -61,14 +82,22 @@ def scipy_version(preds: Tensor, target: Tensor, metric_func: Callable, eval_fun return torch.from_numpy(np.stack(best_metrics)), torch.from_numpy(np.stack(best_perms)) -def average_metric(preds: Tensor, target: Tensor, metric_func: Callable): - # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] - # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] +def average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor: + """average the metric values + + Args: + preds: predictions, shape[batch, spk, time] + target: targets, shape[batch, spk, time] + metric_func: a function which return best_metric and best_perm + + Returns: + the average of best_metric + """ return metric_func(preds, target)[0].mean() -snr_pit_scipy = partial(scipy_version, metric_func=snr, eval_func='max') -si_sdr_pit_scipy = partial(scipy_version, metric_func=si_sdr, eval_func='max') +snr_pit_scipy = partial(a_naive_implementation_of_pit_based_on_scipy, metric_func=snr, eval_func='max') +si_sdr_pit_scipy = partial(a_naive_implementation_of_pit_based_on_scipy, metric_func=si_sdr, eval_func='max') @pytest.mark.parametrize( From 2070d3a97f00f4743f33aa74eaf4076e00085048 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Jul 2021 17:22:20 +0000 Subject: [PATCH 34/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/audio/test_pit.py | 6 ++++-- torchmetrics/functional/audio/pit.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index b83660cbc9f..7fc158de59f 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -197,8 +197,10 @@ def test_error_on_wrong_shape() -> None: def test_consistency_of_two_implementations() -> None: - from torchmetrics.functional.audio.pit import _find_best_perm_by_exhuastive_method - from torchmetrics.functional.audio.pit import _find_best_perm_by_linear_sum_assignment + from torchmetrics.functional.audio.pit import ( + _find_best_perm_by_exhuastive_method, + _find_best_perm_by_linear_sum_assignment, + ) shapes_test = [(5, 2, 2), (4, 3, 3), (4, 4, 4), (3, 5, 5)] for shp in shapes_test: metric_mtx = torch.randn(size=shp) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index d2f8073fcc0..b1ac7e0353d 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -13,11 +13,12 @@ # limitations under the License. from itertools import permutations from typing import Any, Callable, Dict, Tuple, Union -from torchmetrics.utilities.imports import _SCIPY_AVAILABLE + import torch from torch import Tensor from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.imports import _SCIPY_AVAILABLE # _ps_dict: cache of permutations # it's necessary to cache it, otherwise it will consume a large amount of time From 859e25e3218b8ca31b54a8f8a49bb95bb882117e Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Wed, 21 Jul 2021 01:55:09 +0800 Subject: [PATCH 35/56] add docstring for _find_best_perm_by_linear_sum_assignment --- torchmetrics/functional/audio/pit.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index d2f8073fcc0..46996ed34f2 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -28,6 +28,21 @@ def _find_best_perm_by_linear_sum_assignment( metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max], ) -> Tuple[Tensor, Tensor]: + """Solves the linear sum assignment problem using scipy, and returns the best metric values and the corresponding + permutations + + Args: + metric_mtx: + the metric matrix, shape [batch_size, spk_num, spk_num] + eval_func: + the function to reduce the metric values of different the permutations + + Returns: + best_metric: + shape [batch] + best_perm: + shape [batch, spk] + """ from scipy.optimize import linear_sum_assignment mmtx = metric_mtx.detach().cpu() best_perm = torch.tensor([linear_sum_assignment(pwm, eval_func == torch.max)[1] for pwm in mmtx]) @@ -40,13 +55,14 @@ def _find_best_perm_by_exhuastive_method( metric_mtx: torch.Tensor, eval_func: Union[torch.min, torch.max], ) -> Tuple[Tensor, Tensor]: - """Find the best metric value and permutation of each prediction and target pair in one batch using a exhuastive method + """Solves the linear sum assignment problem using exhuastive method, i.e. exhuastively calculates the metric values + of all possible permutations, and returns the best metric values and the corresponding permutations Args: metric_mtx: the metric matrix, shape [batch_size, spk_num, spk_num] eval_func: - the function to reduce the permutations of one pair + the function to reduce the metric values of different the permutations Returns: best_metric: From 70adaa33d47fbdd2e82efb13e63eeb671573dc28 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 22 Jul 2021 22:39:00 +0800 Subject: [PATCH 36/56] Update torchmetrics/functional/audio/pit.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/audio/pit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 68e6de6ba9d..636910d88b1 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -149,7 +149,7 @@ def pit( """ _check_same_shape(preds, target) if eval_func not in ['max', 'min']: - raise RuntimeError('eval_func can only be "max" or "min"') + raise ValueError(f'eval_func can only be "max" or "min" but got {eval_func}') if len(target.shape) < 2: raise RuntimeError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") From a17b357ea096a0502d7e693959aa91140f228cb0 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 22 Jul 2021 22:39:33 +0800 Subject: [PATCH 37/56] Update torchmetrics/functional/audio/pit.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/audio/pit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 636910d88b1..fff7af1d2ed 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -151,7 +151,7 @@ def pit( if eval_func not in ['max', 'min']: raise ValueError(f'eval_func can only be "max" or "min" but got {eval_func}') if len(target.shape) < 2: - raise RuntimeError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") + raise ValueError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") # calculate the metric matrix batch_size, spk_num = target.shape[0:2] From 39105aea7ed95d477ea9fc64897086a727d3dd2d Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 22 Jul 2021 22:39:44 +0800 Subject: [PATCH 38/56] Update torchmetrics/functional/audio/pit.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/audio/pit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index fff7af1d2ed..80f928d78ad 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -163,7 +163,6 @@ def pit( # find best if spk_num < 3 or _SCIPY_AVAILABLE is None: if spk_num >= 3 and _SCIPY_AVAILABLE is None: - import warnings warnings.warn(f"Speaker-num is {spk_num}>3, you'd better install scipy to improve pit's performance!") best_metric, best_perm = _find_best_perm_by_exhuastive_method( From dbb0e5a6837f008dd20601c0683cb033fdb046e7 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 22 Jul 2021 22:40:15 +0800 Subject: [PATCH 39/56] Update torchmetrics/functional/audio/pit.py Co-authored-by: Nicki Skafte --- torchmetrics/functional/audio/pit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 80f928d78ad..4c41302dad6 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -163,7 +163,7 @@ def pit( # find best if spk_num < 3 or _SCIPY_AVAILABLE is None: if spk_num >= 3 and _SCIPY_AVAILABLE is None: - warnings.warn(f"Speaker-num is {spk_num}>3, you'd better install scipy to improve pit's performance!") + warnings.warn(f"In pit metric for speaker-num {spk_num}>3, we recommend installing scipy for better performance") best_metric, best_perm = _find_best_perm_by_exhuastive_method( metric_mtx, From c308c026db08bc690bd1ceb9b448679f91ee92e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Jul 2021 14:40:57 +0000 Subject: [PATCH 40/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/pit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 4c41302dad6..48c7b01ecb5 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -163,7 +163,9 @@ def pit( # find best if spk_num < 3 or _SCIPY_AVAILABLE is None: if spk_num >= 3 and _SCIPY_AVAILABLE is None: - warnings.warn(f"In pit metric for speaker-num {spk_num}>3, we recommend installing scipy for better performance") + warnings.warn( + f"In pit metric for speaker-num {spk_num}>3, we recommend installing scipy for better performance" + ) best_metric, best_perm = _find_best_perm_by_exhuastive_method( metric_mtx, From a4e347a34793d97c824ebfba922e304af949e27d Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 22 Jul 2021 23:31:24 +0800 Subject: [PATCH 41/56] use self.run_functional_metric_test --- tests/audio/test_pit.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index 7fc158de59f..e2a30e46c7d 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -126,21 +126,13 @@ def test_pit(self, preds, target, sk_metric, metric_func, eval_func, ddp, dist_s ) def test_pit_functional(self, preds, target, sk_metric, metric_func, eval_func): - device = 'cuda' if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else 'cpu' - - # move to device - preds = preds.to(device) - target = target.to(device) - - for i in range(NUM_BATCHES): - best_metric, best_perm = pit(preds[i], target[i], metric_func, eval_func) - best_metric_sk, best_perm_sk = sk_metric(preds[i].cpu(), target[i].cpu()) - - # assert its the same - assert np.allclose( - best_metric.detach().cpu().numpy(), best_metric_sk.detach().cpu().numpy(), atol=self.atol - ) - assert (best_perm.detach().cpu().numpy() == best_perm_sk.detach().cpu().numpy()).all() + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=pit, + sk_metric=sk_metric, + metric_args=dict(metric_func=metric_func, eval_func=eval_func), + ) def test_pit_differentiability(self, preds, target, sk_metric, metric_func, eval_func): From 01d484edabd99d35c988c6615845edb00d5bd40c Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Thu, 22 Jul 2021 23:44:46 +0800 Subject: [PATCH 42/56] add more description --- torchmetrics/audio/pit.py | 7 ++++--- torchmetrics/functional/audio/pit.py | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py index d430117a980..7d7f786ecf5 100644 --- a/torchmetrics/audio/pit.py +++ b/torchmetrics/audio/pit.py @@ -20,12 +20,13 @@ class PIT(Metric): - """ Permutation invariant training metric + """ Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method in + speech separation field in order to calculate audio metrics in a permutation invariant way. Forward accepts - - ``preds``: ``shape [..., time]`` - - ``target``: ``shape [..., time]`` + - ``preds``: ``shape [batch, spk, ...]`` + - ``target``: ``shape [batch, spk, ...]`` Args: metric_func: diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 48c7b01ecb5..529d34b9359 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -106,8 +106,9 @@ def pit( eval_func: str = 'max', **kwargs: Dict[str, Any] ) -> Tuple[Tensor, Tensor]: - """ Permutation invariant training metric - + """ Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method in + speech separation field in order to calculate audio metrics in a permutation invariant way. + Args: target: shape [batch, spk, ...] From 5a974e913494b118a44b13859dfabb54bdd85933 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Jul 2021 15:45:27 +0000 Subject: [PATCH 43/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/pit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 529d34b9359..3a2cb535575 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -108,7 +108,7 @@ def pit( ) -> Tuple[Tensor, Tensor]: """ Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method in speech separation field in order to calculate audio metrics in a permutation invariant way. - + Args: target: shape [batch, spk, ...] From 400b32369a7ff68293ca91ed6737509522c92fc4 Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Fri, 23 Jul 2021 00:08:31 +0800 Subject: [PATCH 44/56] import warnings --- torchmetrics/functional/audio/pit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 529d34b9359..be2406ab8fa 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -13,6 +13,7 @@ # limitations under the License. from itertools import permutations from typing import Any, Callable, Dict, Tuple, Union +import warnings import torch from torch import Tensor @@ -108,7 +109,7 @@ def pit( ) -> Tuple[Tensor, Tensor]: """ Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method in speech separation field in order to calculate audio metrics in a permutation invariant way. - + Args: target: shape [batch, spk, ...] From 458b912075aa7333c792e4509931e0e9a3e79c2d Mon Sep 17 00:00:00 2001 From: quancs <1017241746@qq.com> Date: Fri, 23 Jul 2021 00:10:19 +0800 Subject: [PATCH 45/56] check ValueError --- tests/audio/test_pit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index e2a30e46c7d..a648b118799 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -178,13 +178,13 @@ def test_error_on_different_shape() -> None: def test_error_on_wrong_eval_func() -> None: metric = PIT(snr, 'xxx') - with pytest.raises(RuntimeError, match='eval_func can only be "max" or "min"'): + with pytest.raises(ValueError, match='eval_func can only be "max" or "min"'): metric(torch.randn(3, 3, 10), torch.randn(3, 3, 10)) def test_error_on_wrong_shape() -> None: metric = PIT(snr, 'max') - with pytest.raises(RuntimeError, match='Inputs must be of shape *'): + with pytest.raises(ValueError, match='Inputs must be of shape *'): metric(torch.randn(3), torch.randn(3)) From 5d709d85b64c7090d3094b0390a000227acd3f72 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Jul 2021 16:10:54 +0000 Subject: [PATCH 46/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/pit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index be2406ab8fa..d9326cc1591 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from itertools import permutations from typing import Any, Callable, Dict, Tuple, Union -import warnings import torch from torch import Tensor From 8d3ed756c2995e01dfecb3f1c707215c027de665 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Jul 2021 09:18:01 +0000 Subject: [PATCH 47/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index d4456119b9e..a3aa79fd52c 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -11,9 +11,9 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -from torchmetrics.audio import SI_SDR, SI_SNR, SNR # noqa: E402 F401 -from torchmetrics.average import AverageMeter # noqa: E402 F401 -from torchmetrics.classification import ( # noqa: E402 F401 +from torchmetrics.audio import SI_SDR, SI_SNR, SNR # noqa: E402, F401 +from torchmetrics.average import AverageMeter # noqa: E402, F401 +from torchmetrics.classification import ( # noqa: E402, F401 AUC, AUROC, F1, From ca3425daf92ff7e9218952f7eb06b662cc9ad228 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 28 Jul 2021 11:27:49 +0200 Subject: [PATCH 48/56] Apply suggestions from code review Co-authored-by: Nicki Skafte --- docs/source/references/modules.rst | 2 +- tests/audio/test_pit.py | 8 ++++---- torchmetrics/audio/pit.py | 2 +- torchmetrics/functional/audio/pit.py | 14 +++++--------- 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 8336d9041af..a2468761f59 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -41,7 +41,7 @@ the metric will be computed over the ``time`` dimension. tensor(16.1805) PIT -~~~~~~ +~~~ .. autoclass:: torchmetrics.PIT :noindex: diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index a648b118799..36ae067d0e0 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -45,7 +45,7 @@ ) -def a_naive_implementation_of_pit_based_on_scipy( +def naive_implementation_pit_scipy( preds: Tensor, target: Tensor, metric_func: Callable, @@ -82,7 +82,7 @@ def a_naive_implementation_of_pit_based_on_scipy( return torch.from_numpy(np.stack(best_metrics)), torch.from_numpy(np.stack(best_perms)) -def average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor: +def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor: """average the metric values Args: @@ -96,8 +96,8 @@ def average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tens return metric_func(preds, target)[0].mean() -snr_pit_scipy = partial(a_naive_implementation_of_pit_based_on_scipy, metric_func=snr, eval_func='max') -si_sdr_pit_scipy = partial(a_naive_implementation_of_pit_based_on_scipy, metric_func=si_sdr, eval_func='max') +snr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=snr, eval_func='max') +si_sdr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=si_sdr, eval_func='max') @pytest.mark.parametrize( diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py index 7d7f786ecf5..075f5e4119e 100644 --- a/torchmetrics/audio/pit.py +++ b/torchmetrics/audio/pit.py @@ -58,7 +58,7 @@ class PIT(Metric): >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> pit = PIT(si_snr, 'max') - >>> avg_pit_metric = pit(preds, target) + >>> pit(preds, target) Reference: [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index d9326cc1591..2334ace2d3e 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -129,8 +129,7 @@ def pit( best_perm of shape [batch] Example: - >>> import torch - >>> from torchmetrics.functional.audio import si_sdr, pit, permutate + >>> from torchmetrics.functional.audio import si_sdr >>> # [batch, spk, time] >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) @@ -139,8 +138,7 @@ def pit( tensor([-5.1091]) >>> best_perm tensor([[0, 1]]) - >>> preds_pmted = permutate(preds, best_perm) - >>> preds_pmted + >>> pit_permutate(preds, best_perm) tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) @@ -182,7 +180,7 @@ def pit( return best_metric, best_perm -def permutate(preds: Tensor, perm: Tensor) -> Tensor: +def pit_permutate(preds: Tensor, perm: Tensor) -> Tensor: """ permutate estimate according to perm Args: @@ -193,8 +191,7 @@ def permutate(preds: Tensor, perm: Tensor) -> Tensor: Tensor: the permutated version of estimate Example: - >>> import torch - >>> from torchmetrics.functional.audio import si_sdr, pit, permutate + >>> from torchmetrics.functional.audio import si_sdr >>> # [batch, spk, time] >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) @@ -203,8 +200,7 @@ def permutate(preds: Tensor, perm: Tensor) -> Tensor: tensor([-5.1091]) >>> best_perm tensor([[0, 1]]) - >>> preds_pmted = permutate(preds, best_perm) - >>> preds_pmted + >>> pit_permutate(preds, best_perm) tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) """ From 1bd7f5dabddfecbb2f6446d9a734c7a5d9809eec Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 28 Jul 2021 18:48:11 +0200 Subject: [PATCH 49/56] fix imports --- torchmetrics/functional/__init__.py | 2 +- torchmetrics/functional/audio/__init__.py | 2 +- torchmetrics/functional/audio/pit.py | 9 +++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 07a622b0d1f..cad5bfa4ec2 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.audio.pit import permutate, pit # noqa: F401 +from torchmetrics.functional.audio.pit import pit_permutate, pit # noqa: F401 from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 from torchmetrics.functional.audio.snr import snr # noqa: F401 diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py index d6f05c3b3c6..9607e4232cc 100644 --- a/torchmetrics/functional/audio/__init__.py +++ b/torchmetrics/functional/audio/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.audio.pit import permutate, pit # noqa: F401 +from torchmetrics.functional.audio.pit import pit_permutate, pit # noqa: F401 from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 from torchmetrics.functional.audio.snr import snr # noqa: F401 diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 2334ace2d3e..d3c5dde85e2 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -161,20 +161,21 @@ def pit( metric_mtx[:, t, e] = metric_func(preds[:, e, ...], target[:, t, ...], **kwargs) # find best - if spk_num < 3 or _SCIPY_AVAILABLE is None: - if spk_num >= 3 and _SCIPY_AVAILABLE is None: + op = torch.max if eval_func == 'max' else torch.min + if spk_num < 3 or not _SCIPY_AVAILABLE: + if spk_num >= 3 and not _SCIPY_AVAILABLE: warnings.warn( f"In pit metric for speaker-num {spk_num}>3, we recommend installing scipy for better performance" ) best_metric, best_perm = _find_best_perm_by_exhuastive_method( metric_mtx, - torch.max if eval_func == 'max' else torch.min, + op ) else: best_metric, best_perm = _find_best_perm_by_linear_sum_assignment( metric_mtx, - torch.max if eval_func == 'max' else torch.min, + op ) return best_metric, best_perm From 05cacbbcffdb7923ae9beec103d01f98338358d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Jul 2021 16:49:04 +0000 Subject: [PATCH 50/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/__init__.py | 2 +- torchmetrics/functional/audio/__init__.py | 2 +- torchmetrics/functional/audio/pit.py | 10 ++-------- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index cad5bfa4ec2..58f89b521e2 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.audio.pit import pit_permutate, pit # noqa: F401 +from torchmetrics.functional.audio.pit import pit, pit_permutate # noqa: F401 from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 from torchmetrics.functional.audio.snr import snr # noqa: F401 diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py index 9607e4232cc..f701dad2f11 100644 --- a/torchmetrics/functional/audio/__init__.py +++ b/torchmetrics/functional/audio/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.audio.pit import pit_permutate, pit # noqa: F401 +from torchmetrics.functional.audio.pit import pit, pit_permutate # noqa: F401 from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 from torchmetrics.functional.audio.snr import snr # noqa: F401 diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index d3c5dde85e2..7a6f90ac5d4 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -168,15 +168,9 @@ def pit( f"In pit metric for speaker-num {spk_num}>3, we recommend installing scipy for better performance" ) - best_metric, best_perm = _find_best_perm_by_exhuastive_method( - metric_mtx, - op - ) + best_metric, best_perm = _find_best_perm_by_exhuastive_method(metric_mtx, op) else: - best_metric, best_perm = _find_best_perm_by_linear_sum_assignment( - metric_mtx, - op - ) + best_metric, best_perm = _find_best_perm_by_linear_sum_assignment(metric_mtx, op) return best_metric, best_perm From 9db192f35c99b1669e13b03fba7ce63b17cec2d5 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 28 Jul 2021 18:49:43 +0200 Subject: [PATCH 51/56] fix test --- tests/audio/test_pit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index 36ae067d0e0..b63484ecda6 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -120,7 +120,7 @@ def test_pit(self, preds, target, sk_metric, metric_func, eval_func, ddp, dist_s preds, target, PIT, - sk_metric=partial(average_metric, metric_func=sk_metric), + sk_metric=partial(_average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, metric_args=dict(metric_func=metric_func, eval_func=eval_func), ) From 65a27ffba5c41b7efa730d275df8df299cdaa18a Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 28 Jul 2021 19:03:52 +0200 Subject: [PATCH 52/56] docs --- torchmetrics/__init__.py | 2 +- torchmetrics/audio/pit.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index a3aa79fd52c..0af9089a4bf 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -11,7 +11,7 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -from torchmetrics.audio import SI_SDR, SI_SNR, SNR # noqa: E402, F401 +from torchmetrics.audio import SI_SDR, SI_SNR, SNR, PIT # noqa: E402, F401 from torchmetrics.average import AverageMeter # noqa: E402, F401 from torchmetrics.classification import ( # noqa: E402, F401 AUC, diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py index 075f5e4119e..b3d9e8400e0 100644 --- a/torchmetrics/audio/pit.py +++ b/torchmetrics/audio/pit.py @@ -54,11 +54,13 @@ class PIT(Metric): Example: >>> import torch >>> from torchmetrics import PIT - >>> from torchmetrics.functional.audio import si_snr + >>> from torchmetrics.functional import si_snr + >>> _ = torch.manual_seed(42) >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> pit = PIT(si_snr, 'max') >>> pit(preds, target) + tensor(-2.1065) Reference: [1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for From a63297d7c23c8d94d1f0d6dc2ec36f1871fef46b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Jul 2021 17:05:00 +0000 Subject: [PATCH 53/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 0af9089a4bf..70b1b3f22dc 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -11,7 +11,7 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -from torchmetrics.audio import SI_SDR, SI_SNR, SNR, PIT # noqa: E402, F401 +from torchmetrics.audio import PIT, SI_SDR, SI_SNR, SNR # noqa: E402, F401 from torchmetrics.average import AverageMeter # noqa: E402, F401 from torchmetrics.classification import ( # noqa: E402, F401 AUC, From 8f3dc55036877c855be0b917f6eb60d09f21e032 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 29 Jul 2021 12:51:34 +0200 Subject: [PATCH 54/56] Apply suggestions from code review Co-authored-by: thomas chaton --- tests/audio/test_pit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index b63484ecda6..912108df1c9 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -51,7 +51,7 @@ def naive_implementation_pit_scipy( metric_func: Callable, eval_func: str, ) -> Tuple[Tensor, Tensor]: - """a naive implementation of pit based on scipy + """A naive implementation of `Permutation Invariant Training` based on Scipy Args: preds: predictions, shape[batch, spk, time] From 23ab5765165059368256e81fe95ce475c5a666f9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 29 Jul 2021 13:32:01 +0200 Subject: [PATCH 55/56] Apply suggestions from code review --- torchmetrics/audio/pit.py | 5 +++-- torchmetrics/functional/audio/pit.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py index b3d9e8400e0..eabc2ea9435 100644 --- a/torchmetrics/audio/pit.py +++ b/torchmetrics/audio/pit.py @@ -20,8 +20,9 @@ class PIT(Metric): - """ Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method in - speech separation field in order to calculate audio metrics in a permutation invariant way. + """ + Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method [1] + in speech separation field in order to calculate audio metrics in a permutation invariant way. Forward accepts diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 7a6f90ac5d4..91c08e83ae9 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -107,8 +107,9 @@ def pit( eval_func: str = 'max', **kwargs: Dict[str, Any] ) -> Tuple[Tensor, Tensor]: - """ Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method in - speech separation field in order to calculate audio metrics in a permutation invariant way. + """ + Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method [1] + in speech separation field in order to calculate audio metrics in a permutation invariant way. Args: target: @@ -150,7 +151,7 @@ def pit( _check_same_shape(preds, target) if eval_func not in ['max', 'min']: raise ValueError(f'eval_func can only be "max" or "min" but got {eval_func}') - if len(target.shape) < 2: + if target.ndim < 2: raise ValueError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") # calculate the metric matrix From 72eace59044e8aece6fffa65f6da26cde28b890a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jul 2021 11:32:33 +0000 Subject: [PATCH 56/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/audio/pit.py | 2 +- torchmetrics/functional/audio/pit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/audio/pit.py b/torchmetrics/audio/pit.py index eabc2ea9435..77c4579ac99 100644 --- a/torchmetrics/audio/pit.py +++ b/torchmetrics/audio/pit.py @@ -20,7 +20,7 @@ class PIT(Metric): - """ + """ Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method [1] in speech separation field in order to calculate audio metrics in a permutation invariant way. diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 91c08e83ae9..aafd31382c1 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -107,7 +107,7 @@ def pit( eval_func: str = 'max', **kwargs: Dict[str, Any] ) -> Tuple[Tensor, Tensor]: - """ + """ Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method [1] in speech separation field in order to calculate audio metrics in a permutation invariant way.