From 9495aec1357ab72107f6e520ccc87fa356463605 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Fri, 26 Mar 2021 23:33:17 +0100 Subject: [PATCH 1/7] Add beamforming modules: first draft --- asteroid/dsp/beamforming.py | 182 ++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 asteroid/dsp/beamforming.py diff --git a/asteroid/dsp/beamforming.py b/asteroid/dsp/beamforming.py new file mode 100644 index 000000000..3e71f4303 --- /dev/null +++ b/asteroid/dsp/beamforming.py @@ -0,0 +1,182 @@ +import torch +from torch import nn + + +class SCM(nn.Module): + def forward(self, x: torch.Tensor, mask: torch.Tensor = None, normalize: bool = True): + """Compute the spatial covariance matrix from a STFT signal x. + + Args: + x (torch.ComplexTensor): shape [batch, mics, freqs, frames] + mask (torch.Tensor): [batch, 1, freqs, frames] or [batch, 1, freqs, frames]. Optional + normalize (bool): Whether to normalize with the mask mean per bin. + + Returns: + torch.ComplexTensor, the SCM with shape (batch, mics, mics, freqs) + """ + batch, mics, freqs, frames = x.shape + if mask is None: + mask = torch.ones(batch, 1, freqs, frames) + if mask.ndim == 3: + mask = mask[:, None] + + psd = torch.einsum("bmft,bnft->bmnf", mask * x, x.conj()) + if normalize: + psd /= mask.sum(-1, keepdim=True).transpose(-1, -2) + return psd + + +class _BeamFormer(nn.Module): + @staticmethod + def apply_beamforming_vector(bf_vector: torch.Tensor, mix: torch.Tensor): + """Apply the beamforming vector to the mixture. Output (batch, freqs, frames). + + Args: + bf_vector: shape (batch, mics, freqs) + mix: shape (batch, mics, freqs, frames). + """ + return torch.einsum("...mf,...mft->...ft", bf_vector.conj(), mix) + + +class MvdrBeamformer(_BeamFormer): + def forward( + self, + mix: torch.Tensor, + target_scm: torch.Tensor, + noise_scm: torch.Tensor, + ): + """Compute and apply MVDR beamformer from the speech and noise SCM matrices + + Args: + mix (torch.ComplexTensor): shape (batch, mics, freqs, frames) + target_scm (torch.ComplexTensor): (batch, mics, mics, freqs) + noise_scm (torch.ComplexTensor): (batch, mics, mics, freqs) + + Returns: + Filtered mixture. torch.ComplexTensor (batch, freqs, frames) + """ + # Get Acoustic transfer function (1st PCA of Σss) + e_val, e_vec = torch.symeig(target_scm.permute(0, 3, 1, 2), eigenvectors=True) + atf_vect = e_vec[..., -1] # bfm + return self.from_atf_vect(mix=mix, atf_vec=atf_vect.transpose(-1, -2), noise_scm=noise_scm) + + def from_atf_vect( + self, + mix: torch.Tensor, + atf_vec: torch.Tensor, + noise_scm: torch.Tensor, + ): + """Compute and apply MVDR beamformer from the ATF vector and noise SCM matrix. + + Args: + mix (torch.ComplexTensor): shape (batch, mics, freqs, frames) + atf_vec (torch.ComplexTensor): (batch, mics, freqs) + noise_scm (torch.ComplexTensor): (batch, mics, mics, freqs) + + Returns: + Filtered mixture. torch.ComplexTensor (batch, freqs, frames) + """ + noise_scm_t = noise_scm.permute(0, 3, 1, 2) # -> bfmm + atf_vec_t = atf_vec.transpose(-1, -2).unsqueeze(-1) # -> bfm1 + + # numerator, _ = torch.solve(atf_vec_t, noise_scm_t) # -> bfm1 + numerator = stable_solve(atf_vec_t, noise_scm_t) # -> bfm1 + + denominator = torch.matmul(atf_vec_t.conj().transpose(-1, -2), numerator) # -> bf11 + bf_vect = (numerator / denominator).squeeze(-1).transpose(-1, -2) # -> bfm1 -> bmf + output = self.apply_beamforming_vector(bf_vect, mix=mix) # -> bft + return output + + +class SdwMwfBeamformer(_BeamFormer): + def __init__(self, mu=1.0): + super().__init__() + self.mu = mu + + def forward( + self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch.Tensor, ref_mic: int = 0 + ): + """Compute and apply MVDR beamformer. + + Args: + mix (torch.ComplexTensor): shape (batch, mics, freqs, frames) + target_scm (torch.ComplexTensor): (batch, mics, mics, freqs) + noise_scm (torch.ComplexTensor): (batch, mics, mics, freqs) + ref_mic (int): reference microphone. + + Returns: + Filtered mixture. torch.ComplexTensor (batch, freqs, frames) + """ + noise_scm_t = noise_scm.permute(0, 3, 1, 2) # -> bfmm + target_scm_t = target_scm.permute(0, 3, 1, 2) # -> bfmm + + denominator = target_scm_t + self.mu * noise_scm_t + bf_vect, _ = torch.solve(target_scm_t, denominator) + bf_vect = bf_vect[..., ref_mic].transpose(-1, -2) # -> bfm1 -> bmf + output = self.apply_beamforming_vector(bf_vect, mix=mix) # -> bft + return output + + +class GEVBeamformer(_BeamFormer): + def forward(self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch.Tensor): + """Compute and apply the GEV beamformer. + We compute the principal component of noise_scm^-1 @ target_scm by solving the GEV decomposition + + Args: + mix: shape (batch, mics, freqs, frames) + target_scm: (batch, mics, mics, freqs) + noise_scm: (batch, mics, mics, freqs) + + Returns: + Filtered mixture. torch.ComplexTensor (batch, freqs, frames) + """ + noise_scm_t = noise_scm.permute(0, 3, 1, 2) + noise_scm_t = condition_covariance(noise_scm_t, 1e-6) + e_val, e_vec = generalized_eigenvalue_decomposition( + target_scm.permute(0, 3, 1, 2), noise_scm_t + ) + bf_vect = e_vec[..., -1] + # Normalize + bf_vect /= torch.norm(bf_vect, dim=-1, keepdim=True) + bf_vect = bf_vect.squeeze(-1).transpose(-1, -2) # -> bft + output = self.apply_beamforming_vector(bf_vect, mix=mix) # -> bft + return output + + +def stable_solve(inp, mat): + """Return torch.solve in mat is non-singular, else regularize `mat` and torch.solve.""" + try: + return torch.solve(inp, mat)[0] + except RuntimeError: + mat = condition_covariance(mat, 1e-6) + return torch.solve(inp, mat)[0] + + +def condition_covariance(x, gamma, dim1=-2, dim2=-1): + """see https://stt.msu.edu/users/mauryaas/Ashwini_JPEN.pdf (2.3)""" + # Assume 4d with ...mm + if dim1 != -2 or dim2 != -1: + raise NotImplementedError + scale = gamma * batch_trace(x, dim1=dim1, dim2=dim2)[..., None, None] / x.shape[dim1] + scaled_eye = torch.eye(x.shape[dim1])[None, None] * scale + return (x + scaled_eye) / (1 + gamma) + + +def batch_trace(x, dim1=-2, dim2=-1): + """Compute the trace along dim1 and dim2 for a any matrix ndim>=2.""" + return torch.diagonal(x, dim1=dim1, dim2=dim2).sum(-1) + + +def generalized_eigenvalue_decomposition(a, b): + """Solves the generalized eigenvalue decomposition through cholesky decomposition. + Returns eigen values and eigen vectors. + """ + cholesky = torch.cholesky(b) + inv_cholesky = torch.inverse(cholesky) + # Compute C matrix L⁻1 A L^-T + cmat = inv_cholesky @ a @ inv_cholesky.conj().transpose(-1, -2) + # Performing the eigenvalue decomposition + e_val, e_vec = torch.symeig(cmat, eigenvectors=True) + # Collecting the eigenvectors + e_vec = torch.matmul(inv_cholesky.conj().transpose(-1, -2), e_vec) + return e_val, e_vec From fcd5ad66051fd4e00e7ae0aa9fcdb73c124781f6 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Fri, 26 Mar 2021 23:48:11 +0100 Subject: [PATCH 2/7] Edit name --- asteroid/dsp/beamforming.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/asteroid/dsp/beamforming.py b/asteroid/dsp/beamforming.py index 3e71f4303..81c37e303 100644 --- a/asteroid/dsp/beamforming.py +++ b/asteroid/dsp/beamforming.py @@ -20,10 +20,10 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor = None, normalize: bool = if mask.ndim == 3: mask = mask[:, None] - psd = torch.einsum("bmft,bnft->bmnf", mask * x, x.conj()) + scm = torch.einsum("bmft,bnft->bmnf", mask * x, x.conj()) if normalize: - psd /= mask.sum(-1, keepdim=True).transpose(-1, -2) - return psd + scm /= mask.sum(-1, keepdim=True).transpose(-1, -2) + return scm class _BeamFormer(nn.Module): From b11cf35bc970b203fa92ae1de9cc7fe9162f2dc8 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Sat, 27 Mar 2021 20:17:29 +0100 Subject: [PATCH 3/7] More docs --- asteroid/dsp/beamforming.py | 40 +++++++++++++++++---------- docs/source/package_reference/dsp.rst | 9 ++++++ 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/asteroid/dsp/beamforming.py b/asteroid/dsp/beamforming.py index 81c37e303..51a1bcc41 100644 --- a/asteroid/dsp/beamforming.py +++ b/asteroid/dsp/beamforming.py @@ -45,7 +45,10 @@ def forward( target_scm: torch.Tensor, noise_scm: torch.Tensor, ): - """Compute and apply MVDR beamformer from the speech and noise SCM matrices + r"""Compute and apply MVDR beamformer from the speech and noise SCM matrices. + + :math:`\mathbf{w} = \displaystyle \frac{\Sigma_{nn}^{-1} \mathbf{a}}{ + \mathbf{a}^H \Sigma_{nn}^{-1} \mathbf{a}}` where :math:`\mathbf{a}` is the ATF estimated from the target SCM. Args: mix (torch.ComplexTensor): shape (batch, mics, freqs, frames) @@ -55,7 +58,7 @@ def forward( Returns: Filtered mixture. torch.ComplexTensor (batch, freqs, frames) """ - # Get Acoustic transfer function (1st PCA of Σss) + # Get acoustic transfer function (1st PCA of Σss) e_val, e_vec = torch.symeig(target_scm.permute(0, 3, 1, 2), eigenvectors=True) atf_vect = e_vec[..., -1] # bfm return self.from_atf_vect(mix=mix, atf_vec=atf_vect.transpose(-1, -2), noise_scm=noise_scm) @@ -96,7 +99,9 @@ def __init__(self, mu=1.0): def forward( self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch.Tensor, ref_mic: int = 0 ): - """Compute and apply MVDR beamformer. + """Compute and apply SDW-MWF beamformer. + + :math:`\mathbf{w} = \displaystyle (\Sigma_{ss} + \mu \Sigma_{nn})^{-1} \Sigma_{ss}`. Args: mix (torch.ComplexTensor): shape (batch, mics, freqs, frames) @@ -120,7 +125,9 @@ def forward( class GEVBeamformer(_BeamFormer): def forward(self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch.Tensor): """Compute and apply the GEV beamformer. - We compute the principal component of noise_scm^-1 @ target_scm by solving the GEV decomposition + + :math:`\mathbf{w} = \displaystyle MaxEig\{ \Sigma_{nn}^{-1}\Sigma_{ss} \}`, where + MaxEig extracts the eigenvector corresponding to the maximum eigenvalue (using the GEV decomposition). Args: mix: shape (batch, mics, freqs, frames) @@ -131,7 +138,7 @@ def forward(self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch. Filtered mixture. torch.ComplexTensor (batch, freqs, frames) """ noise_scm_t = noise_scm.permute(0, 3, 1, 2) - noise_scm_t = condition_covariance(noise_scm_t, 1e-6) + noise_scm_t = condition_scm(noise_scm_t, 1e-6) e_val, e_vec = generalized_eigenvalue_decomposition( target_scm.permute(0, 3, 1, 2), noise_scm_t ) @@ -143,17 +150,20 @@ def forward(self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch. return output -def stable_solve(inp, mat): - """Return torch.solve in mat is non-singular, else regularize `mat` and torch.solve.""" +def stable_solve(b, a): + """Return torch.solve in matrix `a` is non-singular, else regularize `a` and return torch.solve.""" try: - return torch.solve(inp, mat)[0] + return torch.solve(b, a)[0] except RuntimeError: - mat = condition_covariance(mat, 1e-6) - return torch.solve(inp, mat)[0] + a = condition_scm(a, 1e-6) + return torch.solve(b, a)[0] -def condition_covariance(x, gamma, dim1=-2, dim2=-1): - """see https://stt.msu.edu/users/mauryaas/Ashwini_JPEN.pdf (2.3)""" +def condition_scm(x, gamma=1e-6, dim1=-2, dim2=-1): + """Condition input SCM with (x + gamma tr(x) I) / (1 + gamma) along `dim1` and `dim2`. + + See https://stt.msu.edu/users/mauryaas/Ashwini_JPEN.pdf (2.3). + """ # Assume 4d with ...mm if dim1 != -2 or dim2 != -1: raise NotImplementedError @@ -163,13 +173,13 @@ def condition_covariance(x, gamma, dim1=-2, dim2=-1): def batch_trace(x, dim1=-2, dim2=-1): - """Compute the trace along dim1 and dim2 for a any matrix ndim>=2.""" + """Compute the trace along `dim1` and `dim2` for a any matrix `ndim>=2`.""" return torch.diagonal(x, dim1=dim1, dim2=dim2).sum(-1) def generalized_eigenvalue_decomposition(a, b): - """Solves the generalized eigenvalue decomposition through cholesky decomposition. - Returns eigen values and eigen vectors. + """Solves the generalized eigenvalue decomposition through Cholesky decomposition. + Returns eigen values and eigen vectors (ascending order). """ cholesky = torch.cholesky(b) inv_cholesky = torch.inverse(cholesky) diff --git a/docs/source/package_reference/dsp.rst b/docs/source/package_reference/dsp.rst index a51791f58..131ac80f4 100644 --- a/docs/source/package_reference/dsp.rst +++ b/docs/source/package_reference/dsp.rst @@ -5,6 +5,15 @@ DSP Modules :class: hidden-section + +:hidden:`Beamforming` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: asteroid.dsp.beamforming.MvdrBeamformer +.. autoclass:: asteroid.dsp.beamforming.SdwMwfBeamformer +.. autoclass:: asteroid.dsp.beamforming.GEVBeamformer +.. autoclass:: asteroid.dsp.beamforming.SCM + + :hidden:`LambdaOverlapAdd` ~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: asteroid.dsp.LambdaOverlapAdd From 4745e75e39d704690b6775e6dfbbe80dadb45d1a Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Sat, 27 Mar 2021 20:29:43 +0100 Subject: [PATCH 4/7] Add simple tests --- tests/dsp/beamforming_test.py | 48 +++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/dsp/beamforming_test.py diff --git a/tests/dsp/beamforming_test.py b/tests/dsp/beamforming_test.py new file mode 100644 index 000000000..754de0a0d --- /dev/null +++ b/tests/dsp/beamforming_test.py @@ -0,0 +1,48 @@ +import torch +import pytest +from asteroid_filterbanks import make_enc_dec, transforms as tr + +from asteroid.dsp.beamforming import ( + _BeamFormer, + SCM, + MvdrBeamformer, + SdwMwfBeamformer, + GEVBeamformer, +) + +_stft, _istft = make_enc_dec("stft", kernel_size=512, n_filters=512, stride=128) +stft = lambda x: tr.to_torch_complex(_stft(x)) +istft = lambda x: _istft(tr.from_torch_complex(x)) + + +def _default_beamformer_test(beamformer: _BeamFormer, n_mics=4, *args, **kwargs): + scm = SCM() + + speech = torch.randn(1, n_mics, 16000 * 6) + noise = torch.randn(1, n_mics, 16000 * 6) + mix = speech + noise + # GeV Beamforming + mix_stft = stft(mix) + speech_stft = stft(speech) + noise_stft = stft(noise) + sigma_ss = scm(speech_stft) + sigma_nn = scm(noise_stft) + + Ys_gev = beamformer.forward(mix=mix_stft, target_scm=sigma_ss, noise_scm=sigma_nn) + ys_gev = istft(Ys_gev) + + +@pytest.mark.parametrize("n_mics", [2, 3, 4]) +def test_gev(n_mics): + _default_beamformer_test(GEVBeamformer(), n_mics=n_mics) + + +@pytest.mark.parametrize("n_mics", [2, 3, 4]) +def test_mvdr(n_mics): + _default_beamformer_test(MvdrBeamformer(), n_mics=n_mics) + + +@pytest.mark.parametrize("n_mics", [2, 3, 4]) +@pytest.mark.parametrize("mu", [1.0, 2.0, 0]) +def test_mwf(n_mics, mu): + _default_beamformer_test(SdwMwfBeamformer(mu=mu), n_mics=n_mics) From fbc8561a723489c0a351085b6ffe202652033ad8 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Fri, 9 Apr 2021 10:08:16 +0200 Subject: [PATCH 5/7] Skip test if torch version is under 1.8 --- tests/dsp/beamforming_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/dsp/beamforming_test.py b/tests/dsp/beamforming_test.py index 754de0a0d..acf4f78d4 100644 --- a/tests/dsp/beamforming_test.py +++ b/tests/dsp/beamforming_test.py @@ -9,12 +9,17 @@ SdwMwfBeamformer, GEVBeamformer, ) +from asteroid.utils.test_utils import torch_version_tuple + + +torch_has_complex_support = torch_version_tuple()[1] == 8 _stft, _istft = make_enc_dec("stft", kernel_size=512, n_filters=512, stride=128) stft = lambda x: tr.to_torch_complex(_stft(x)) istft = lambda x: _istft(tr.from_torch_complex(x)) +@pytest.mark.skipif(not torch_has_complex_support, "No complex support ") def _default_beamformer_test(beamformer: _BeamFormer, n_mics=4, *args, **kwargs): scm = SCM() @@ -32,16 +37,19 @@ def _default_beamformer_test(beamformer: _BeamFormer, n_mics=4, *args, **kwargs) ys_gev = istft(Ys_gev) +@pytest.mark.skipif(not torch_has_complex_support, "No complex support ") @pytest.mark.parametrize("n_mics", [2, 3, 4]) def test_gev(n_mics): _default_beamformer_test(GEVBeamformer(), n_mics=n_mics) +@pytest.mark.skipif(not torch_has_complex_support, "No complex support ") @pytest.mark.parametrize("n_mics", [2, 3, 4]) def test_mvdr(n_mics): _default_beamformer_test(MvdrBeamformer(), n_mics=n_mics) +@pytest.mark.skipif(not torch_has_complex_support, "No complex support ") @pytest.mark.parametrize("n_mics", [2, 3, 4]) @pytest.mark.parametrize("mu", [1.0, 2.0, 0]) def test_mwf(n_mics, mu): From 0f68727d3323a2cd63a1e7f894989e54ce93b296 Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Fri, 9 Apr 2021 20:38:10 +0200 Subject: [PATCH 6/7] Fix test --- tests/dsp/beamforming_test.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/dsp/beamforming_test.py b/tests/dsp/beamforming_test.py index acf4f78d4..9e4274ab4 100644 --- a/tests/dsp/beamforming_test.py +++ b/tests/dsp/beamforming_test.py @@ -9,10 +9,9 @@ SdwMwfBeamformer, GEVBeamformer, ) -from asteroid.utils.test_utils import torch_version_tuple -torch_has_complex_support = torch_version_tuple()[1] == 8 +torch_has_complex_support = tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 8) _stft, _istft = make_enc_dec("stft", kernel_size=512, n_filters=512, stride=128) stft = lambda x: tr.to_torch_complex(_stft(x)) @@ -37,19 +36,19 @@ def _default_beamformer_test(beamformer: _BeamFormer, n_mics=4, *args, **kwargs) ys_gev = istft(Ys_gev) -@pytest.mark.skipif(not torch_has_complex_support, "No complex support ") +@pytest.mark.skipif(not torch_has_complex_support, reason="No complex support ") @pytest.mark.parametrize("n_mics", [2, 3, 4]) def test_gev(n_mics): _default_beamformer_test(GEVBeamformer(), n_mics=n_mics) -@pytest.mark.skipif(not torch_has_complex_support, "No complex support ") +@pytest.mark.skipif(not torch_has_complex_support, reason="No complex support ") @pytest.mark.parametrize("n_mics", [2, 3, 4]) def test_mvdr(n_mics): _default_beamformer_test(MvdrBeamformer(), n_mics=n_mics) -@pytest.mark.skipif(not torch_has_complex_support, "No complex support ") +@pytest.mark.skipif(not torch_has_complex_support, reason="No complex support ") @pytest.mark.parametrize("n_mics", [2, 3, 4]) @pytest.mark.parametrize("mu", [1.0, 2.0, 0]) def test_mwf(n_mics, mu): From 6e12632b410860461251cc93807516bd04a5051e Mon Sep 17 00:00:00 2001 From: Manuel Pariente Date: Fri, 9 Apr 2021 20:43:40 +0200 Subject: [PATCH 7/7] Externalize compute_scm and _Beamformer -> Beamformer --- asteroid/dsp/beamforming.py | 53 +++++++++++++++++++---------------- tests/dsp/beamforming_test.py | 4 +-- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/asteroid/dsp/beamforming.py b/asteroid/dsp/beamforming.py index 51a1bcc41..090557be4 100644 --- a/asteroid/dsp/beamforming.py +++ b/asteroid/dsp/beamforming.py @@ -4,29 +4,11 @@ class SCM(nn.Module): def forward(self, x: torch.Tensor, mask: torch.Tensor = None, normalize: bool = True): - """Compute the spatial covariance matrix from a STFT signal x. - - Args: - x (torch.ComplexTensor): shape [batch, mics, freqs, frames] - mask (torch.Tensor): [batch, 1, freqs, frames] or [batch, 1, freqs, frames]. Optional - normalize (bool): Whether to normalize with the mask mean per bin. - - Returns: - torch.ComplexTensor, the SCM with shape (batch, mics, mics, freqs) - """ - batch, mics, freqs, frames = x.shape - if mask is None: - mask = torch.ones(batch, 1, freqs, frames) - if mask.ndim == 3: - mask = mask[:, None] - - scm = torch.einsum("bmft,bnft->bmnf", mask * x, x.conj()) - if normalize: - scm /= mask.sum(-1, keepdim=True).transpose(-1, -2) - return scm + """See :func:`compute_scm`.""" + return compute_scm(x, mask=mask, normalize=normalize) -class _BeamFormer(nn.Module): +class BeamFormer(nn.Module): @staticmethod def apply_beamforming_vector(bf_vector: torch.Tensor, mix: torch.Tensor): """Apply the beamforming vector to the mixture. Output (batch, freqs, frames). @@ -38,7 +20,7 @@ def apply_beamforming_vector(bf_vector: torch.Tensor, mix: torch.Tensor): return torch.einsum("...mf,...mft->...ft", bf_vector.conj(), mix) -class MvdrBeamformer(_BeamFormer): +class MvdrBeamformer(BeamFormer): def forward( self, mix: torch.Tensor, @@ -91,7 +73,7 @@ def from_atf_vect( return output -class SdwMwfBeamformer(_BeamFormer): +class SdwMwfBeamformer(BeamFormer): def __init__(self, mu=1.0): super().__init__() self.mu = mu @@ -122,7 +104,7 @@ def forward( return output -class GEVBeamformer(_BeamFormer): +class GEVBeamformer(BeamFormer): def forward(self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch.Tensor): """Compute and apply the GEV beamformer. @@ -150,6 +132,29 @@ def forward(self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch. return output +def compute_scm(x: torch.Tensor, mask: torch.Tensor = None, normalize: bool = True): + """Compute the spatial covariance matrix from a STFT signal x. + + Args: + x (torch.ComplexTensor): shape [batch, mics, freqs, frames] + mask (torch.Tensor): [batch, 1, freqs, frames] or [batch, 1, freqs, frames]. Optional + normalize (bool): Whether to normalize with the mask mean per bin. + + Returns: + torch.ComplexTensor, the SCM with shape (batch, mics, mics, freqs) + """ + batch, mics, freqs, frames = x.shape + if mask is None: + mask = torch.ones(batch, 1, freqs, frames) + if mask.ndim == 3: + mask = mask[:, None] + + scm = torch.einsum("bmft,bnft->bmnf", mask * x, x.conj()) + if normalize: + scm /= mask.sum(-1, keepdim=True).transpose(-1, -2) + return scm + + def stable_solve(b, a): """Return torch.solve in matrix `a` is non-singular, else regularize `a` and return torch.solve.""" try: diff --git a/tests/dsp/beamforming_test.py b/tests/dsp/beamforming_test.py index 9e4274ab4..b7ab987a8 100644 --- a/tests/dsp/beamforming_test.py +++ b/tests/dsp/beamforming_test.py @@ -3,7 +3,7 @@ from asteroid_filterbanks import make_enc_dec, transforms as tr from asteroid.dsp.beamforming import ( - _BeamFormer, + BeamFormer, SCM, MvdrBeamformer, SdwMwfBeamformer, @@ -19,7 +19,7 @@ @pytest.mark.skipif(not torch_has_complex_support, "No complex support ") -def _default_beamformer_test(beamformer: _BeamFormer, n_mics=4, *args, **kwargs): +def _default_beamformer_test(beamformer: BeamFormer, n_mics=4, *args, **kwargs): scm = SCM() speech = torch.randn(1, n_mics, 16000 * 6)