Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FasNet TAC integration #306

Merged
merged 66 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
95f9339
added transform average concat to masknn modules
popcornell Nov 3, 2020
b8eb5b1
applied black
popcornell Nov 3, 2020
69adeb2
added fasnet tac to models
popcornell Nov 3, 2020
c8ca4f6
added tac recipe
popcornell Nov 3, 2020
b15725d
fixed args
popcornell Nov 3, 2020
f75081f
added samplerate property
popcornell Nov 3, 2020
17d48e6
added licenses and get_infos for publishing
popcornell Nov 6, 2020
e06015e
added main to parse_data.py
popcornell Nov 6, 2020
1d9e731
fixed main.py
popcornell Nov 6, 2020
435b939
starting to do the eval script
popcornell Nov 6, 2020
afa95fa
included eval script into run.sh
popcornell Nov 6, 2020
a9803a4
Update asteroid/models/fasnet.py
mpariente Dec 8, 2020
5032ab1
added transform average concat to masknn modules
popcornell Nov 3, 2020
a65d410
applied black
popcornell Nov 3, 2020
8cd479f
added fasnet tac to models
popcornell Nov 3, 2020
4ebcf2c
added tac recipe
popcornell Nov 3, 2020
09fe5af
fixed args
popcornell Nov 3, 2020
054cddc
added samplerate property
popcornell Nov 3, 2020
ea0b837
added licenses and get_infos for publishing
popcornell Nov 6, 2020
6a0e03a
added main to parse_data.py
popcornell Nov 6, 2020
af09cd0
fixed main.py
popcornell Nov 6, 2020
b69fa41
starting to do the eval script
popcornell Nov 6, 2020
0c032f6
included eval script into run.sh
popcornell Nov 6, 2020
0f36e37
Update asteroid/models/fasnet.py
mpariente Dec 8, 2020
c2b0b34
added comments to train custom system class
popcornell Feb 2, 2021
f506ffc
added cross-correlation to DSP
popcornell Feb 2, 2021
4b8c4b5
Merge remote-tracking branch 'origin/fasnet_tac' into fasnet_tac
popcornell Feb 2, 2021
fec32d3
removed dataset configs --> using git to clone Luo's original repo in…
popcornell Feb 2, 2021
bfb129c
added docstring to fasnet_tac
popcornell Feb 2, 2021
4bd2966
added docstring + comments
popcornell Feb 2, 2021
6e46648
added utils symlink
popcornell Feb 2, 2021
0ee30b6
removed ipdb
popcornell Feb 3, 2021
9901794
downloading fixes
popcornell Feb 3, 2021
7a0ec07
sampl_rate argument
popcornell Feb 3, 2021
69740f6
various fixes
popcornell Feb 3, 2021
9df8d11
applied linters
popcornell Feb 3, 2021
cb4900e
removed samplerate property
popcornell Feb 3, 2021
bff3942
Merge branch 'master' into fasnet_tac
mpariente Feb 4, 2021
832c377
fixed conf_file and commented out data prep
popcornell Feb 5, 2021
77fc0c8
fixed conf_file and commented out data prep
popcornell Feb 5, 2021
7472cfc
added tests, modified xcorr shape and fasnet accordingly
popcornell Feb 5, 2021
cde746f
added docstring for forward
popcornell Feb 6, 2021
34b663b
polishing text
popcornell Feb 6, 2021
95c0d15
adding docstring in tac forward
popcornell Feb 6, 2021
2c202de
added docstring for TACDataset
popcornell Feb 6, 2021
02c555f
reverted to only one conf file
popcornell Feb 7, 2021
d0634b5
added mkdir logs
popcornell Feb 7, 2021
f10c421
Edit xcorr
mpariente Feb 8, 2021
8fa8a6a
Edit tac.py Docstrings
mpariente Feb 8, 2021
2642ac5
Edit fasnet.py
mpariente Feb 8, 2021
dc9bc5c
Edit tac.py again: caps
mpariente Feb 8, 2021
e607be5
samplerate to sample_rate
mpariente Feb 8, 2021
ce0e491
Update run.sh
mpariente Feb 8, 2021
c928f92
Update train.py
mpariente Feb 8, 2021
c70a3f2
Merge branch 'master' into fasnet_tac
mpariente Feb 8, 2021
220390d
Use Base class for sample_rate and in_channels
mpariente Feb 8, 2021
9db3b9f
Add default to valid_mics for inference
mpariente Feb 8, 2021
8f46433
Feature_dim issue
mpariente Feb 8, 2021
fdf12bd
Fix imports and docs
mpariente Feb 8, 2021
c83f59b
Improve _default_test_model
mpariente Feb 8, 2021
43cbb73
Add FasNet test for save/load
mpariente Feb 8, 2021
568eb00
solving feature dim confusion
popcornell Feb 8, 2021
1a5bb86
Merge remote-tracking branch 'origin/fasnet_tac' into fasnet_tac
popcornell Feb 8, 2021
10e4eb6
test for broadcasting
popcornell Feb 8, 2021
1332a88
New FasNetTAC error with enc_dim
mpariente Feb 8, 2021
6b601fc
fixed enc_dim issue
popcornell Feb 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions asteroid/dsp/spatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import torch.nn.functional as F


def xcorr(inp, ref, normalized=True, eps=1e-8):
r"""Multi-channel cross correlation.

The two signals can have different lengths but the input signal should be shorter than the reference signal.

.. note:: The cross correlation is computed between each pair of microphone channels and not
between all possible pairs e.g. if both input and ref have shape ``(1, 2, 100)``
the output will be ``(1, 2, 1)`` the first element is the xcorr between
the first mic channel of input and the first mic channel of ref.
If either input and ref have only one channel e.g. input: (1, 3, 100) and ref: ``(1, 1, 100)``
then output will be ``(1, 3, 1)`` as ref will be broadcasted to have same shape as input.

Args:
inp (:class:`torch.Tensor`): multi-channel input signal. Shape: :math:`(batch, mic\_channels, seq\_len)`.
ref (:class:`torch.Tensor`): multi-channel reference signal. Shape: :math:`(batch, mic\_channels, seq\_len)`.
normalized (bool, optional): whether to normalize the cross-correlation with the l2 norm of input signals.
eps (float, optional): machine epsilon used for numerical stabilization when normalization is used.

Returns:
out (:class:`torch.Tensor`): cross correlation between the two multi-channel signals.
Shape: :math:`(batch, mic\_channels, seq\_len\_ref - seq\_len\_input + 1)`.

"""
# inp: batch, nmics2, seq_len2 || ref: batch, nmics1, seq_len1
assert inp.size(0) == ref.size(0), "ref and inp signals should have same batch size."
assert inp.size(2) >= ref.size(2), "Input signal should be shorter than the ref signal."

inp = inp.permute(1, 0, 2).contiguous()
ref = ref.permute(1, 0, 2).contiguous()
bsz = inp.size(1)
inp_mics = inp.size(0)

if ref.size(0) > inp.size(0):
inp = inp.expand(ref.size(0), inp.size(1), inp.size(2)).contiguous() # nmic2, L, seg1
inp_mics = ref.size(0)
elif ref.size(0) < inp.size(0):
ref = ref.expand(inp.size(0), ref.size(1), ref.size(2)).contiguous() # nmic1, L, seg2
# cosine similarity
out = F.conv1d(
inp.view(1, -1, inp.size(2)), ref.view(-1, 1, ref.size(2)), groups=inp_mics * bsz
) # 1, inp_mics*L, seg1-seg2+1

# L2 norms
if normalized:
inp_norm = F.conv1d(
inp.view(1, -1, inp.size(2)).pow(2),
torch.ones(inp.size(0) * inp.size(1), 1, ref.size(2)).type(inp.type()),
groups=inp_mics * bsz,
) # 1, inp_mics*L, seg1-seg2+1
inp_norm = inp_norm.sqrt() + eps
ref_norm = ref.norm(2, dim=2).view(1, -1, 1) + eps # 1, inp_mics*L, 1
out = out / (inp_norm * ref_norm)
return out.view(inp_mics, bsz, -1).permute(1, 0, 2).contiguous()
94 changes: 94 additions & 0 deletions asteroid/masknn/tac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
import torch.nn as nn
from . import activations, norms


class TAC(nn.Module):
"""Transform-Average-Concatenate inter-microphone-channel permutation invariant communication block [1].

Args:
input_dim (int): Number of features of input representation.
hidden_dim (int, optional): size of hidden layers in TAC operations.
activation (str, optional): type of activation used. See asteroid.masknn.activations.
norm_type (str, optional): type of normalization layer used. See asteroid.masknn.norms.

.. note:: Supports inputs of shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`
as in FasNet-TAC. The operations are applied for each element in ``chunk_size`` and ``n_chunks``.
Output is of same shape as input.

References
[1] : Luo, Yi, et al. "End-to-end microphone permutation and number invariant multi-channel
speech separation." ICASSP 2020.
"""

def __init__(self, input_dim, hidden_dim=384, activation="prelu", norm_type="gLN"):
super().__init__()
self.hidden_dim = hidden_dim
self.input_tf = nn.Sequential(
nn.Linear(input_dim, hidden_dim), activations.get(activation)()
)
self.avg_tf = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), activations.get(activation)()
)
self.concat_tf = nn.Sequential(
nn.Linear(2 * hidden_dim, input_dim), activations.get(activation)()
)
self.norm = norms.get(norm_type)(input_dim)

def forward(self, x, valid_mics=None):
"""
Args:
x: (:class:`torch.Tensor`): Input multi-channel DPRNN features.
Shape: :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`.
valid_mics: (:class:`torch.LongTensor`): tensor containing effective number of microphones on each batch.
Batches can be composed of examples coming from arrays with a different
number of microphones and thus the ``mic_channels`` dimension is padded.
E.g. torch.tensor([4, 3]) means first example has 4 channels and the second 3.
Shape: :math`(batch)`.

Returns:
output (:class:`torch.Tensor`): features for each mic_channel after TAC inter-channel processing.
Shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`.
"""
# Input is 5D because it is multi-channel DPRNN. DPRNN single channel is 4D.
batch_size, nmics, channels, chunk_size, n_chunks = x.size()
if valid_mics is None:
valid_mics = torch.LongTensor([nmics] * batch_size)
# First operation: transform the input for each frame and independently on each mic channel.
output = self.input_tf(
x.permute(0, 3, 4, 1, 2).reshape(batch_size * nmics * chunk_size * n_chunks, channels)
).reshape(batch_size, chunk_size, n_chunks, nmics, self.hidden_dim)

# Mean pooling across channels
if valid_mics.max() == 0:
# Fixed geometry array
mics_mean = output.mean(1)
else:
# Only consider valid channels in each batch element: each example can have different number of microphones.
mics_mean = [
output[b, :, :, : valid_mics[b]].mean(2).unsqueeze(0) for b in range(batch_size)
] # 1, dim1*dim2, H
mics_mean = torch.cat(mics_mean, 0) # B*dim1*dim2, H

# The average is processed by a non-linear transform
mics_mean = self.avg_tf(
mics_mean.reshape(batch_size * chunk_size * n_chunks, self.hidden_dim)
)
mics_mean = (
mics_mean.reshape(batch_size, chunk_size, n_chunks, self.hidden_dim)
.unsqueeze(3)
.expand_as(output)
)

# Concatenate the transformed average in each channel with the original feats and
# project back to same number of features
output = torch.cat([output, mics_mean], -1)
output = self.concat_tf(
output.reshape(batch_size * chunk_size * n_chunks * nmics, -1)
).reshape(batch_size, chunk_size, n_chunks, nmics, -1)
output = self.norm(
output.permute(0, 3, 4, 1, 2).reshape(batch_size * nmics, -1, chunk_size, n_chunks)
).reshape(batch_size, nmics, -1, chunk_size, n_chunks)

output += x
return output
Loading