diff --git a/asteroid/dsp/spatial.py b/asteroid/dsp/spatial.py new file mode 100644 index 000000000..af33c20f7 --- /dev/null +++ b/asteroid/dsp/spatial.py @@ -0,0 +1,57 @@ +import torch +import torch.nn.functional as F + + +def xcorr(inp, ref, normalized=True, eps=1e-8): + r"""Multi-channel cross correlation. + + The two signals can have different lengths but the input signal should be shorter than the reference signal. + + .. note:: The cross correlation is computed between each pair of microphone channels and not + between all possible pairs e.g. if both input and ref have shape ``(1, 2, 100)`` + the output will be ``(1, 2, 1)`` the first element is the xcorr between + the first mic channel of input and the first mic channel of ref. + If either input and ref have only one channel e.g. input: (1, 3, 100) and ref: ``(1, 1, 100)`` + then output will be ``(1, 3, 1)`` as ref will be broadcasted to have same shape as input. + + Args: + inp (:class:`torch.Tensor`): multi-channel input signal. Shape: :math:`(batch, mic\_channels, seq\_len)`. + ref (:class:`torch.Tensor`): multi-channel reference signal. Shape: :math:`(batch, mic\_channels, seq\_len)`. + normalized (bool, optional): whether to normalize the cross-correlation with the l2 norm of input signals. + eps (float, optional): machine epsilon used for numerical stabilization when normalization is used. + + Returns: + out (:class:`torch.Tensor`): cross correlation between the two multi-channel signals. + Shape: :math:`(batch, mic\_channels, seq\_len\_ref - seq\_len\_input + 1)`. + + """ + # inp: batch, nmics2, seq_len2 || ref: batch, nmics1, seq_len1 + assert inp.size(0) == ref.size(0), "ref and inp signals should have same batch size." + assert inp.size(2) >= ref.size(2), "Input signal should be shorter than the ref signal." + + inp = inp.permute(1, 0, 2).contiguous() + ref = ref.permute(1, 0, 2).contiguous() + bsz = inp.size(1) + inp_mics = inp.size(0) + + if ref.size(0) > inp.size(0): + inp = inp.expand(ref.size(0), inp.size(1), inp.size(2)).contiguous() # nmic2, L, seg1 + inp_mics = ref.size(0) + elif ref.size(0) < inp.size(0): + ref = ref.expand(inp.size(0), ref.size(1), ref.size(2)).contiguous() # nmic1, L, seg2 + # cosine similarity + out = F.conv1d( + inp.view(1, -1, inp.size(2)), ref.view(-1, 1, ref.size(2)), groups=inp_mics * bsz + ) # 1, inp_mics*L, seg1-seg2+1 + + # L2 norms + if normalized: + inp_norm = F.conv1d( + inp.view(1, -1, inp.size(2)).pow(2), + torch.ones(inp.size(0) * inp.size(1), 1, ref.size(2)).type(inp.type()), + groups=inp_mics * bsz, + ) # 1, inp_mics*L, seg1-seg2+1 + inp_norm = inp_norm.sqrt() + eps + ref_norm = ref.norm(2, dim=2).view(1, -1, 1) + eps # 1, inp_mics*L, 1 + out = out / (inp_norm * ref_norm) + return out.view(inp_mics, bsz, -1).permute(1, 0, 2).contiguous() diff --git a/asteroid/masknn/tac.py b/asteroid/masknn/tac.py new file mode 100644 index 000000000..5b86df4f9 --- /dev/null +++ b/asteroid/masknn/tac.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +from . import activations, norms + + +class TAC(nn.Module): + """Transform-Average-Concatenate inter-microphone-channel permutation invariant communication block [1]. + + Args: + input_dim (int): Number of features of input representation. + hidden_dim (int, optional): size of hidden layers in TAC operations. + activation (str, optional): type of activation used. See asteroid.masknn.activations. + norm_type (str, optional): type of normalization layer used. See asteroid.masknn.norms. + + .. note:: Supports inputs of shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)` + as in FasNet-TAC. The operations are applied for each element in ``chunk_size`` and ``n_chunks``. + Output is of same shape as input. + + References + [1] : Luo, Yi, et al. "End-to-end microphone permutation and number invariant multi-channel + speech separation." ICASSP 2020. + """ + + def __init__(self, input_dim, hidden_dim=384, activation="prelu", norm_type="gLN"): + super().__init__() + self.hidden_dim = hidden_dim + self.input_tf = nn.Sequential( + nn.Linear(input_dim, hidden_dim), activations.get(activation)() + ) + self.avg_tf = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), activations.get(activation)() + ) + self.concat_tf = nn.Sequential( + nn.Linear(2 * hidden_dim, input_dim), activations.get(activation)() + ) + self.norm = norms.get(norm_type)(input_dim) + + def forward(self, x, valid_mics=None): + """ + Args: + x: (:class:`torch.Tensor`): Input multi-channel DPRNN features. + Shape: :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`. + valid_mics: (:class:`torch.LongTensor`): tensor containing effective number of microphones on each batch. + Batches can be composed of examples coming from arrays with a different + number of microphones and thus the ``mic_channels`` dimension is padded. + E.g. torch.tensor([4, 3]) means first example has 4 channels and the second 3. + Shape: :math`(batch)`. + + Returns: + output (:class:`torch.Tensor`): features for each mic_channel after TAC inter-channel processing. + Shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`. + """ + # Input is 5D because it is multi-channel DPRNN. DPRNN single channel is 4D. + batch_size, nmics, channels, chunk_size, n_chunks = x.size() + if valid_mics is None: + valid_mics = torch.LongTensor([nmics] * batch_size) + # First operation: transform the input for each frame and independently on each mic channel. + output = self.input_tf( + x.permute(0, 3, 4, 1, 2).reshape(batch_size * nmics * chunk_size * n_chunks, channels) + ).reshape(batch_size, chunk_size, n_chunks, nmics, self.hidden_dim) + + # Mean pooling across channels + if valid_mics.max() == 0: + # Fixed geometry array + mics_mean = output.mean(1) + else: + # Only consider valid channels in each batch element: each example can have different number of microphones. + mics_mean = [ + output[b, :, :, : valid_mics[b]].mean(2).unsqueeze(0) for b in range(batch_size) + ] # 1, dim1*dim2, H + mics_mean = torch.cat(mics_mean, 0) # B*dim1*dim2, H + + # The average is processed by a non-linear transform + mics_mean = self.avg_tf( + mics_mean.reshape(batch_size * chunk_size * n_chunks, self.hidden_dim) + ) + mics_mean = ( + mics_mean.reshape(batch_size, chunk_size, n_chunks, self.hidden_dim) + .unsqueeze(3) + .expand_as(output) + ) + + # Concatenate the transformed average in each channel with the original feats and + # project back to same number of features + output = torch.cat([output, mics_mean], -1) + output = self.concat_tf( + output.reshape(batch_size * chunk_size * n_chunks * nmics, -1) + ).reshape(batch_size, chunk_size, n_chunks, nmics, -1) + output = self.norm( + output.permute(0, 3, 4, 1, 2).reshape(batch_size * nmics, -1, chunk_size, n_chunks) + ).reshape(batch_size, nmics, -1, chunk_size, n_chunks) + + output += x + return output diff --git a/asteroid/models/fasnet.py b/asteroid/models/fasnet.py new file mode 100644 index 000000000..d9b091f8f --- /dev/null +++ b/asteroid/models/fasnet.py @@ -0,0 +1,285 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from .base_models import BaseModel +from ..masknn.recurrent import DPRNNBlock +from ..masknn import norms +from ..masknn.tac import TAC +from ..dsp.spatial import xcorr + + +class FasNetTAC(BaseModel): + r"""FasNetTAC separation model with optional Transform-Average-Concatenate (TAC) module[1]. + + Args: + n_src (int): Maximum number of sources the model can separate. + enc_dim (int, optional): Length of analysis filter. Defaults to 64. + feature_dim (int, optional): Size of hidden representation in DPRNN blocks after bottleneck. + Defaults to 64. + hidden_dim (int, optional): Number of neurons in the RNNs cell state in DPRNN blocks. + Defaults to 128. + n_layers (int, optional): Number of DPRNN blocks. Default to 4. + window_ms (int, optional): Beamformer window_length in milliseconds. Defaults to 4. + stride (int, optional): Stride for Beamforming windows. Defaults to window_ms // 2. + context_ms (int, optional): Context for each Beamforming window. Defaults to 16. + Effective window is 2*context_ms+window_ms. + sample_rate (int, optional): Samplerate of input signal. + tac_hidden_dim (int, optional): Size for TAC module hidden dimensions. Default to 384 neurons. + norm_type (str, optional): Normalization layer used. Default is Layer Normalization. + chunk_size (int, optional): Chunk size used for dual-path processing in DPRNN blocks. + Default to 50 samples. + hop_size (int, optional): Hop-size used for dual-path processing in DPRNN blocks. + Default to `chunk_size // 2` (50% overlap). + bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN + (Intra-Chunk is always bidirectional). + rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``, ``'LSTM'`` and ``'GRU'``. + dropout (float, optional): Dropout ratio, must be in [0,1]. + use_tac (bool, optional): whether to use Transform-Average-Concatenate for inter-mic-channels + communication. Defaults to True. + + References + [1] Luo, Yi, et al. "End-to-end microphone permutation and number invariant multi-channel + speech separation." ICASSP 2020. + """ + + def __init__( + self, + n_src, + enc_dim=64, + feature_dim=64, + hidden_dim=128, + n_layers=4, + window_ms=4, + stride=None, + context_ms=16, + sample_rate=16000, + tac_hidden_dim=384, + norm_type="gLN", + chunk_size=50, + hop_size=25, + bidirectional=True, + rnn_type="LSTM", + dropout=0.0, + use_tac=True, + ): + super().__init__(sample_rate=sample_rate, in_channels=None) + + self.enc_dim = enc_dim + self.feature_dim = feature_dim + self.hidden_dim = hidden_dim + self.n_layers = n_layers + self.n_src = n_src + assert window_ms % 2 == 0, "Window length should be even" + # Parameters + self.window_ms = window_ms + self.context_ms = context_ms + self.window = int(self.sample_rate * window_ms / 1000) + self.context = int(self.sample_rate * context_ms / 1000) + if not stride: + self.stride = self.window // 2 + else: + self.stride = int(self.sample_rate * stride / 1000) + self.filter_dim = self.context * 2 + 1 + self.output_dim = self.context * 2 + 1 + self.tac_hidden_dim = tac_hidden_dim + self.norm_type = norm_type + self.chunk_size = chunk_size + self.hop_size = hop_size + self.bidirectional = bidirectional + self.rnn_type = rnn_type + self.dropout = dropout + self.use_tac = use_tac + + # waveform encoder + self.encoder = nn.Conv1d(1, self.enc_dim, self.context * 2 + self.window, bias=False) + self.enc_LN = norms.get(norm_type)(self.enc_dim) + + # DPRNN here + TAC at each layer + self.bottleneck = nn.Conv1d(self.filter_dim + self.enc_dim, self.feature_dim, 1, bias=False) + + self.DPRNN_TAC = nn.ModuleList([]) + for i in range(self.n_layers): + tmp = nn.ModuleList( + [ + DPRNNBlock( + self.feature_dim, + self.hidden_dim, + norm_type, + bidirectional, + rnn_type, + dropout=dropout, + ) + ] + ) + if self.use_tac: + tmp.append(TAC(self.feature_dim, tac_hidden_dim, norm_type=norm_type)) + self.DPRNN_TAC.append(tmp) + + # DPRNN output layers + self.conv_2D = nn.Sequential( + nn.PReLU(), nn.Conv2d(self.feature_dim, self.n_src * self.feature_dim, 1) + ) + self.tanh = nn.Sequential(nn.Conv1d(self.feature_dim, self.output_dim, 1), nn.Tanh()) + self.gate = nn.Sequential(nn.Conv1d(self.feature_dim, self.output_dim, 1), nn.Sigmoid()) + + @staticmethod + def windowing_with_context(x, window, context): + batch_size, nmic, nsample = x.shape + unfolded = F.unfold( + x.unsqueeze(-1), + kernel_size=(window + 2 * context, 1), + padding=(context + window, 0), + stride=(window // 2, 1), + ) + + n_chunks = unfolded.size(-1) + unfolded = unfolded.reshape(batch_size, nmic, window + 2 * context, n_chunks) + return ( + unfolded[:, :, context : context + window].transpose(2, -1), + unfolded.transpose(2, -1), + ) + + def forward(self, x, valid_mics=None): + r""" + Args: + x: (:class:`torch.Tensor`): multi-channel input signal. Shape: :math:`(batch, mic\_channels, samples)`. + valid_mics: (:class:`torch.LongTensor`): tensor containing effective number of microphones on each batch. + Batches can be composed of examples coming from arrays with a different + number of microphones and thus the ``mic_channels`` dimension is padded. + E.g. torch.tensor([4, 3]) means first example has 4 channels and the second 3. + Shape: :math`(batch)`. + + Returns: + bf_signal (:class:`torch.Tensor`): beamformed signal with shape :math:`(batch, n\_src, samples)`. + """ + if valid_mics is None: + valid_mics = torch.LongTensor([x.shape[1]] * x.shape[0]) + n_samples = x.size(-1) # Original number of samples of multichannel audio + all_seg, all_mic_context = self.windowing_with_context(x, self.window, self.context) + batch_size, n_mics, seq_length, feats = all_mic_context.size() + # All_seg contains only the central window, all_mic_context contains also the right and left context + + # Encoder applies a filter on each all_mic_context feats + enc_output = ( + self.encoder(all_mic_context.reshape(batch_size * n_mics * seq_length, 1, feats)) + .reshape(batch_size * n_mics, seq_length, self.enc_dim) + .transpose(1, 2) + .contiguous() + ) # B*n_mics, seq_len, enc_dim + enc_output = self.enc_LN(enc_output).reshape( + batch_size, n_mics, self.enc_dim, seq_length + ) # apply norm + + # For each context window cosine similarity is computed. The first channel is chosen as a reference + ref_seg = all_seg[:, 0].reshape(batch_size * seq_length, self.window).unsqueeze(1) + all_context = all_mic_context.transpose(1, 2).reshape( + batch_size * seq_length, n_mics, self.context * 2 + self.window + ) + + all_cos_sim = xcorr(all_context, ref_seg) + all_cos_sim = ( + all_cos_sim.view(n_mics, batch_size, seq_length, self.filter_dim) + .permute(1, 0, 3, 2) + .contiguous() + ) # B, nmic, 2*context + 1, seq_len + + # Encoder features and cosine similarity features are concatenated + input_feature = torch.cat([enc_output, all_cos_sim], 2) + # Apply bottleneck to reduce parameters and feed to DPRNN + input_feature = self.bottleneck(input_feature.reshape(batch_size * n_mics, -1, seq_length)) + # We unfold the features for dual path processing + unfolded = F.unfold( + input_feature.unsqueeze(-1), + kernel_size=(self.chunk_size, 1), + padding=(self.chunk_size, 0), + stride=(self.hop_size, 1), + ) + n_chunks = unfolded.size(-1) + unfolded = unfolded.reshape( + batch_size * n_mics, self.feature_dim, self.chunk_size, n_chunks + ) + + for i in range(self.n_layers): + # At each layer we apply DPRNN to process each mic independently and then TAC for inter-mic processing. + dprnn = self.DPRNN_TAC[i][0] + unfolded = dprnn(unfolded) + if self.use_tac: + b, ch, chunk_size, n_chunks = unfolded.size() + tac = self.DPRNN_TAC[i][1] + unfolded = unfolded.reshape(-1, n_mics, ch, chunk_size, n_chunks) + unfolded = tac(unfolded, valid_mics).reshape( + batch_size * n_mics, self.feature_dim, self.chunk_size, n_chunks + ) + # Output, 2D conv to get different feats for each source + unfolded = self.conv_2D(unfolded).reshape( + batch_size * n_mics * self.n_src, self.feature_dim * self.chunk_size, n_chunks + ) + # Dual path processing is done we fold back + folded = F.fold( + unfolded, + (seq_length, 1), + kernel_size=(self.chunk_size, 1), + padding=(self.chunk_size, 0), + stride=(self.hop_size, 1), + ) + # Dividing to assure perfect reconstruction + folded = folded.squeeze(-1) / (self.chunk_size / self.hop_size) + # apply gating to output and scaling to -1 and 1 + folded = self.tanh(folded) * self.gate(folded) + folded = folded.view(batch_size, n_mics, self.n_src, -1, seq_length) + + # Beamforming + # Convolving with all mic context --> Filter and Sum + all_mic_context = all_mic_context.unsqueeze(2).repeat(1, 1, self.n_src, 1, 1) + all_bf_output = F.conv1d( + all_mic_context.view(1, -1, self.context * 2 + self.window), + folded.transpose(3, -1).contiguous().view(-1, 1, self.filter_dim), + groups=batch_size * n_mics * self.n_src * seq_length, + ) + all_bf_output = all_bf_output.view(batch_size, n_mics, self.n_src, seq_length, self.window) + + # Fold back to obtain signal + all_bf_output = F.fold( + all_bf_output.reshape( + batch_size * n_mics * self.n_src, seq_length, self.window + ).transpose(1, -1), + (n_samples, 1), + kernel_size=(self.window, 1), + padding=(self.window, 0), + stride=(self.window // 2, 1), + ) + bf_signal = all_bf_output.reshape(batch_size, n_mics, self.n_src, n_samples) + + # We sum over mics after filtering (filters will realign the signals --> delay and sum) + if valid_mics.max() == 0: + bf_signal = bf_signal.mean(1) + else: + bf_signal = [ + bf_signal[b, : valid_mics[b]].mean(0).unsqueeze(0) for b in range(batch_size) + ] + bf_signal = torch.cat(bf_signal, 0) + + return bf_signal + + def get_model_args(self): + config = { + "n_src": self.n_src, + "enc_dim": self.enc_dim, + "feature_dim": self.feature_dim, + "hidden_dim": self.hidden_dim, + "n_layers": self.n_layers, + "window_ms": self.window_ms, + "stride": self.stride, + "context_ms": self.context_ms, + "sample_rate": self.sample_rate, + "tac_hidden_dim": self.tac_hidden_dim, + "norm_type": self.norm_type, + "chunk_size": self.chunk_size, + "hop_size": self.hop_size, + "bidirectional": self.bidirectional, + "rnn_type": self.rnn_type, + "dropout": self.dropout, + "use_tac": self.use_tac, + } + return config diff --git a/egs/TAC/eval.py b/egs/TAC/eval.py new file mode 100644 index 000000000..87c8a37be --- /dev/null +++ b/egs/TAC/eval.py @@ -0,0 +1,126 @@ +import os +import random +import soundfile as sf +import torch +import yaml +import json +import argparse +import pandas as pd +from tqdm import tqdm +from pprint import pprint + +from asteroid.models.fasnet import FasNetTAC +from asteroid.metrics import get_metrics +from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr +from local.tac_dataset import TACDataset +from asteroid.models import save_publishable +from asteroid.utils import tensors_to_device + +parser = argparse.ArgumentParser() + +parser.add_argument("--test_json", type=str, required=True, help="Test json file") +parser.add_argument( + "--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution" +) +parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root") +parser.add_argument( + "--n_save_ex", type=int, default=50, help="Number of audio examples to save, -1 means all" +) + +compute_metrics = ["si_sdr"] # , "sdr", "sir", "sar", "stoi"] + + +def main(conf): + model_path = os.path.join(conf["exp_dir"], "best_model.pth") + model = FasNetTAC.from_pretrained(model_path) + # Handle device placement + if conf["use_gpu"]: + model.cuda() + model_device = next(model.parameters()).device + test_set = TACDataset(args.test_json, train=False) + + # Used to reorder sources only + loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") + + # Randomly choose the indexes of sentences to save. + ex_save_dir = os.path.join(conf["exp_dir"], "examples/") + if conf["n_save_ex"] == -1: + conf["n_save_ex"] = len(test_set) + save_idx = random.sample(range(len(test_set)), conf["n_save_ex"]) + series_list = [] + torch.no_grad().__enter__() + for idx in tqdm(range(len(test_set))): + + # Forward the network on the mixture. + mix, sources, valid_mics = tensors_to_device(test_set[idx], device=model_device) + valid_mics = torch.tensor([valid_mics]).to(sources.device) + est_sources = model(mix[None], valid_mics[None]) + loss, reordered_sources = loss_func(est_sources, sources[None][:, 0], return_est=True) + mix_np = mix.cpu().data.numpy() + sources_np = sources[0].cpu().data.numpy() + est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy() + utt_metrics = get_metrics( + mix_np[0], + sources_np, + est_sources_np, + sample_rate=conf["sample_rate"], + metrics_list=compute_metrics, + ) + utt_metrics["mix_path"] = test_set.examples[idx]["1"]["mixture"] + series_list.append(pd.Series(utt_metrics)) + + # Save some examples in a folder. Wav files and metrics as text. + if idx in save_idx: + local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx)) + os.makedirs(local_save_dir, exist_ok=True) + sf.write(local_save_dir + "mixture.wav", mix_np[0], conf["sample_rate"]) + # Loop over the sources and estimates + for src_idx, src in enumerate(sources_np): + sf.write(local_save_dir + "s{}.wav".format(src_idx + 1), src, conf["sample_rate"]) + for src_idx, est_src in enumerate(est_sources_np): + sf.write( + local_save_dir + "s{}_estimate.wav".format(src_idx + 1), + est_src, + conf["sample_rate"], + ) + # Write local metrics to the example folder. + with open(local_save_dir + "metrics.json", "w") as f: + json.dump(utt_metrics, f, indent=0) + + # Save all metrics to the experiment folder. + all_metrics_df = pd.DataFrame(series_list) + all_metrics_df.to_csv(os.path.join(conf["exp_dir"], "all_metrics.csv")) + + # Print and save summary metrics + final_results = {} + for metric_name in compute_metrics: + input_metric_name = "input_" + metric_name + ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name] + final_results[metric_name] = all_metrics_df[metric_name].mean() + final_results[metric_name + "_imp"] = ldf.mean() + print("Overall metrics :") + pprint(final_results) + with open(os.path.join(conf["exp_dir"], "final_metrics.json"), "w") as f: + json.dump(final_results, f, indent=0) + model_dict = torch.load(model_path, map_location="cpu") + + publishable = save_publishable( + os.path.join(conf["exp_dir"], "publish_dir"), + model_dict, + metrics=final_results, + train_conf=train_conf, + ) + + +if __name__ == "__main__": + args = parser.parse_args() + arg_dic = dict(vars(args)) + + # Load training config + conf_path = os.path.join(args.exp_dir, "conf.yml") + with open(conf_path) as f: + train_conf = yaml.safe_load(f) + arg_dic["sample_rate"] = train_conf["data"]["sample_rate"] + arg_dic["train_conf"] = train_conf + + main(arg_dic) diff --git a/egs/TAC/local/conf.yml b/egs/TAC/local/conf.yml new file mode 100644 index 000000000..9cc4b5802 --- /dev/null +++ b/egs/TAC/local/conf.yml @@ -0,0 +1,30 @@ +data: + sample_rate: 16000 + segment: + train_json: ./data/train.json + dev_json: ./data/validation.json + test_json: ./data/test.json +net: + enc_dim: 64 + chunk_size: 50 + hop_size: 25 + feature_dim: 64 + hidden_dim: 128 + n_layers: 4 + n_src: 2 + window_ms: 4 + context_ms: 16 + sample_rate: 16000 +optim: + lr: 0.001 + weight_decay: !!float 1e-5 +training: + epochs: 200 + batch_size: 4 + gradient_clipping: 5 + accumulate_batches: 1 + save_top_k: 10 + num_workers: 8 + patience: 30 + half_lr: true + early_stop: true diff --git a/egs/TAC/local/parse_data.py b/egs/TAC/local/parse_data.py new file mode 100644 index 000000000..535b11f2b --- /dev/null +++ b/egs/TAC/local/parse_data.py @@ -0,0 +1,41 @@ +import os +import json +import soundfile as sf +import argparse +import glob +import re +from pathlib import Path + +parser = argparse.ArgumentParser("parsing tac dataset") +parser.add_argument("--in_dir", type=str) +parser.add_argument("--out_json", type=str) + + +def parse_dataset(in_dir, out_json): + + examples = [] + for n_mic_f in glob.glob(os.path.join(in_dir, "*")): + for sample_dir in glob.glob(os.path.join(n_mic_f, "*")): + c_ex = {} + for wav in glob.glob(os.path.join(sample_dir, "*.wav")): + + source_or_mix = Path(wav).stem.split("_")[0] + n_mic = int(re.findall("\d+", Path(wav).stem.split("_")[-1])[0]) + length = len(sf.SoundFile(wav)) + + if n_mic not in c_ex.keys(): + c_ex[n_mic] = {source_or_mix: wav, "length": length} + else: + assert c_ex[n_mic]["length"] == length + c_ex[n_mic][source_or_mix] = wav + examples.append(c_ex) + + os.makedirs(Path(out_json).parent, exist_ok=True) + + with open(out_json, "w") as f: + json.dump(examples, f, indent=4) + + +if __name__ == "__main__": + args = parser.parse_args() + parse_dataset(args.in_dir, args.out_json) diff --git a/egs/TAC/local/tac_dataset.py b/egs/TAC/local/tac_dataset.py new file mode 100644 index 000000000..c20eb8926 --- /dev/null +++ b/egs/TAC/local/tac_dataset.py @@ -0,0 +1,138 @@ +from torch.utils.data import Dataset +import json +import soundfile as sf +import torch +import numpy as np +from pathlib import Path +from asteroid.data.librimix_dataset import librispeech_license + + +class TACDataset(Dataset): + """Multi-channel Librispeech-derived dataset used in Transform Average Concatenate. + + Args: + json_file (str): Path to json file resulting from the data prep script which contains parsed examples. + segment (float, optional): Length of the segments used for training, in seconds. + If None, use full utterances (e.g. for test). + sample_rate (int, optional): The sampling rate of the wav files. + max_mics (int, optional): Maximum number of microphones for an array in the dataset. + train (bool, optional): If True randomly permutes the microphones on each example. + """ + + dataset_name = "TACDataset" + + def __init__(self, json_file, segment=None, sample_rate=16000, max_mics=6, train=True): + self.segment = segment + self.sample_rate = sample_rate + self.max_mics = max_mics + self.train = train + + with open(json_file, "r") as f: + examples = json.load(f) + + if self.segment: + target_len = int(segment * sample_rate) + self.examples = [] + for ex in examples: + if ex["1"]["length"] < target_len: + continue + self.examples.append(ex) + print( + "Discarded {} out of {} because too short".format( + len(examples) - len(self.examples), len(examples) + ) + ) + else: + self.examples = examples + if not train: + # sort examples based on number + self.examples = sorted( + self.examples, key=lambda x: str(Path(x["2"]["spk1"]).parent).strip("sample") + ) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, item): + """Returns mixtures, sources and the number of mics in the recording, padded to `max_mics`.""" + c_ex = self.examples[item] + # randomly select ref mic + mics = [x for x in c_ex.keys()] + if self.train: + np.random.shuffle(mics) # randomly permute during training to change ref mics + + mixtures = [] + sources = [] + for i in range(len(mics)): + c_mic = c_ex[mics[i]] + + if self.segment: + offset = 0 + if c_mic["length"] > int(self.segment * self.sample_rate): + offset = np.random.randint( + 0, c_mic["length"] - int(self.segment * self.sample_rate) + ) + + # we load mixture + mixture, fs = sf.read( + c_mic["mixture"], + start=offset, + stop=offset + int(self.segment * self.sample_rate), + dtype="float32", + ) + spk1, fs = sf.read( + c_mic["spk1"], + start=offset, + stop=offset + int(self.segment * self.sample_rate), + dtype="float32", + ) + spk2, fs = sf.read( + c_mic["spk2"], + start=offset, + stop=offset + int(self.segment * self.sample_rate), + dtype="float32", + ) + else: + mixture, fs = sf.read(c_mic["mixture"], dtype="float32") # load all + spk1, fs = sf.read(c_mic["spk1"], dtype="float32") + spk2, fs = sf.read(c_mic["spk2"], dtype="float32") + + mixture = torch.from_numpy(mixture).unsqueeze(0) + spk1 = torch.from_numpy(spk1).unsqueeze(0) + spk2 = torch.from_numpy(spk2).unsqueeze(0) + + assert fs == self.sample_rate + mixtures.append(mixture) + sources.append(torch.cat((spk1, spk2), 0)) + + mixtures = torch.cat(mixtures, 0) + sources = torch.stack(sources) + # we pad till max_mic + valid_mics = mixtures.shape[0] + if mixtures.shape[0] < self.max_mics: + dummy = torch.zeros((self.max_mics - mixtures.shape[0], mixtures.shape[-1])) + mixtures = torch.cat((mixtures, dummy), 0) + sources = torch.cat((sources, dummy.unsqueeze(1).repeat(1, sources.shape[1], 1)), 0) + return mixtures, sources, valid_mics + + def get_infos(self): + """Get dataset infos (for publishing models). + + Returns: + dict, dataset infos with keys `dataset`, `task` and `licences`. + """ + infos = dict() + infos["dataset"] = self.dataset_name + infos["task"] = "separate_noisy" + infos["licenses"] = [librispeech_license, tac_license] + return infos + + +tac_license = dict( + title="End-to-end Microphone Permutation and Number Invariant Multi-channel Speech Separation", + title_link="https://arxiv.org/abs/1910.14104", + author="Yi Luo, Zhuo Chen, Nima Mesgarani, Takuya Yoshioka", + license="CC BY 4.0", + license_link="https://creativecommons.org/licenses/by/4.0/", + non_commercial=False, +) diff --git a/egs/TAC/run.sh b/egs/TAC/run.sh new file mode 100644 index 000000000..f668c9a04 --- /dev/null +++ b/egs/TAC/run.sh @@ -0,0 +1,119 @@ +#!/bin/bash + +# Exit on error +set -e +set -o pipefail + +# Main storage directory where dataset will be stored +storage_dir=$(readlink -m ./datasets) +librispeech_dir=$storage_dir/LibriSpeech +noise_dir=$storage_dir/Nonspeech +# After running the recipe a first time, you can run it from stage 3 directly to train new models. + +# Path to the python you'll use for the experiment. Defaults to the current python +# You can run ./utils/prepare_python_env.sh to create a suitable python environment, paste the output here. +python_path=python + +# Example usage +# ./run.sh --stage 3 --tag my_tag --id 0,1 + +# General +stage=0 # Controls from which stage to start +tag="" # Controls the directory name associated to the experiment +# You can ask for several GPUs using id (passed to CUDA_VISIBLE_DEVICES) +id=0 +eval_use_gpu=1 + +# Dataset option +dataset_type=adhoc + +. utils/parse_options.sh + +dumpdir=data/$suffix # directory to put generated json file + +# check if gpuRIR installed +if ! ( pip list | grep -F gpuRIR ); then + echo 'This recipe requires gpuRIR. Please install gpuRIR.' + exit +fi + +if [[ $stage -le 0 ]]; then + echo "Stage 0: Downloading required Datasets" + + if ! test -e $librispeech_dir/train-clean-100; then + echo "Downloading LibriSpeech/train-clean-100 into $storage_dir" + wget -c --tries=0 --read-timeout=20 http://www.openslr.org/resources/12/train-clean-100.tar.gz -P $storage_dir + tar -xzf $storage_dir/train-clean-100.tar.gz -C $storage_dir + rm -rf $storage_dir/train-clean-100.tar.gz + fi + + if ! test -e $librispeech_dir/dev-clean; then + echo "Downloading LibriSpeech/dev-clean into $storage_dir" + wget -c --tries=0 --read-timeout=20 http://www.openslr.org/resources/12/dev-clean.tar.gz -P $storage_dir + tar -xzf $storage_dir/dev-clean.tar.gz -C $storage_dir + rm -rf $storage_dir/dev-clean.tar.gz + fi + + if ! test -e $librispeech_dir/test-clean; then + echo "Downloading LibriSpeech/test-clean into $storage_dir" + wget -c --tries=0 --read-timeout=20 http://www.openslr.org/resources/12/test-clean.tar.gz -P $storage_dir + tar -xzf $storage_dir/test-clean.tar.gz -C $storage_dir + rm -rf $storage_dir/test-clean.tar.gz + fi + + if ! test -e $storage_dir/Nonspeech; then + echo "Downloading Noises into $storage_dir" + wget -c --tries=0 --read-timeout=20 http://web.cse.ohio-state.edu/pnl/corpus/HuNonspeech/Nonspeech.zip -P $storage_dir + unzip $storage_dir/Nonspeech.zip -d $storage_dir + rm -rf $storage_dir/Nonspeech.zip + fi + +fi + +if [[ $stage -le 1 ]]; then + echo "Stage 1: Creating Synthetic Datasets" + git clone https://github.com/yluo42/TAC ./local/TAC + cd local/TAC/data + $python_path create_dataset.py \ + --output-path=$storage_dir \ + --dataset=$dataset_type \ + --libri-path=$librispeech_dir \ + --noise-path=$noise_dir + cd ../../../ +fi + +if [[ $stage -le 2 ]]; then + echo "Parsing dataset to json to speed up subsequent experiments" + for split in train validation test; do + $python_path ./local/parse_data.py --in_dir $storage_dir/MC_Libri_${dataset_type}/$split --out_json $dumpdir/${split}.json + done +fi + +# Generate a random ID for the run if no tag is specified +uuid=$($python_path -c 'import uuid, sys; print(str(uuid.uuid4())[:8])') +if [[ -z ${tag} ]]; then + tag=${uuid} +fi +expdir=exp/train_TAC_${tag} +mkdir -p $expdir && echo $uuid >> $expdir/run_uuid.txt +echo "Results from the following experiment will be stored in $expdir" + +if [[ $stage -le 3 ]]; then + echo "Stage 3: Training" + mkdir -p logs + CUDA_VISIBLE_DEVICES=$id $python_path train.py --exp_dir ${expdir} | tee logs/train_${tag}.log + cp logs/train_${tag}.log $expdir/train.log + + # Get ready to publish + mkdir -p $expdir/publish_dir + echo "TAC/TAC" > $expdir/publish_dir/recipe_name.txt +fi + + +if [[ $stage -le 4 ]]; then + echo "Stage 4 : Evaluation" + CUDA_VISIBLE_DEVICES=$id $python_path eval.py --test_json $dumpdir/test.json \ + --use_gpu $eval_use_gpu \ + --exp_dir ${expdir} | tee logs/eval_${tag}.log + cp logs/eval_${tag}.log $expdir/eval.log +fi diff --git a/egs/TAC/train.py b/egs/TAC/train.py new file mode 100644 index 000000000..05c318c9d --- /dev/null +++ b/egs/TAC/train.py @@ -0,0 +1,158 @@ +import os +import argparse +import json + +import torch +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import DataLoader +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping + +from asteroid.models import save_publishable +from local.tac_dataset import TACDataset +from asteroid.engine.optimizers import make_optimizer +from asteroid.engine.system import System +from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr +from asteroid.models.fasnet import FasNetTAC + +# Keys which are not in the conf.yml file can be added here. +# In the hierarchical dictionary created when parsing, the key `key` can be +# found at dic['main_args'][key] + +# By default train.py will use all available GPUs. The `id` option in run.sh +# will limit the number of available GPUs for train.py. +parser = argparse.ArgumentParser() +parser.add_argument("--exp_dir", default="exp/tmp", help="Full path to save best validation model") + + +class TACSystem(System): + def common_step(self, batch, batch_nb, train=True): + inputs, targets, valid_channels = batch + # valid_channels contains a list of valid microphone channels for each example. + # each example can have a varying number of microphone channels (can come from different arrays). + # e.g. [[2], [4], [1]] three examples with 2 mics 4 mics and 1 mics. + est_targets = self.model(inputs, valid_channels) + loss = self.loss_func(est_targets, targets[:, 0]).mean() # first channel is used as ref + + return loss + + +def main(conf): + + train_set = TACDataset(conf["data"]["train_json"], conf["data"]["segment"], train=True) + val_set = TACDataset(conf["data"]["dev_json"], conf["data"]["segment"], train=False) + + train_loader = DataLoader( + train_set, + shuffle=True, + batch_size=conf["training"]["batch_size"], + num_workers=conf["training"]["num_workers"], + drop_last=True, + ) + val_loader = DataLoader( + val_set, + shuffle=False, + batch_size=conf["training"]["batch_size"], + num_workers=conf["training"]["num_workers"], + drop_last=True, + ) + + model = FasNetTAC(**conf["net"]) + optimizer = make_optimizer(model.parameters(), **conf["optim"]) + # Define scheduler + if conf["training"]["half_lr"]: + scheduler = ReduceLROnPlateau( + optimizer=optimizer, factor=0.5, patience=conf["training"]["patience"] + ) + else: + scheduler = None + # Just after instantiating, save the args. Easy loading in the future. + exp_dir = conf["main_args"]["exp_dir"] + os.makedirs(exp_dir, exist_ok=True) + conf_path = os.path.join(exp_dir, "conf.yml") + with open(conf_path, "w") as outfile: + yaml.safe_dump(conf, outfile) + + # Define Loss function. + loss_func = PITLossWrapper(pairwise_neg_sisdr) + system = TACSystem( + model=model, + loss_func=loss_func, + optimizer=optimizer, + train_loader=train_loader, + val_loader=val_loader, + scheduler=scheduler, + config=conf, + ) + + # Define callbacks + # Define callbacks + callbacks = [] + checkpoint_dir = os.path.join(exp_dir, "checkpoints/") + checkpoint = ModelCheckpoint( + checkpoint_dir, + monitor="val_loss", + mode="min", + save_top_k=conf["training"]["save_top_k"], + verbose=True, + ) + callbacks.append(checkpoint) + if conf["training"]["early_stop"]: + callbacks.append( + EarlyStopping( + monitor="val_loss", mode="min", patience=conf["training"]["patience"], verbose=True + ) + ) + + # Don't ask GPU if they are not available. + gpus = -1 if torch.cuda.is_available() else None + trainer = pl.Trainer( + max_epochs=conf["training"]["epochs"], + callbacks=callbacks, + default_root_dir=exp_dir, + gpus=gpus, + distributed_backend="ddp", + gradient_clip_val=conf["training"]["gradient_clipping"], + ) + trainer.fit(system) + + best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} + with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: + json.dump(best_k, f, indent=0) + + state_dict = torch.load(checkpoint.best_model_path) + system.load_state_dict(state_dict=state_dict["state_dict"]) + system.cpu() + + to_save = system.model.serialize() + to_save.update(train_set.get_infos()) + torch.save(to_save, os.path.join(exp_dir, "best_model.pth")) + save_publishable( + os.path.join(exp_dir, "publish_dir"), + to_save, + metrics=dict(), + train_conf=conf, + recipe="asteroid/TAC", + ) + + +if __name__ == "__main__": + import yaml + from pprint import pprint as print + from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict + + # We start with opening the config file conf.yml as a dictionary from + # which we can create parsers. Each top level key in the dictionary defined + # by the YAML file creates a group in the parser. + with open("./local/conf.yml") as f: + def_conf = yaml.safe_load(f) + parser = prepare_parser_from_dict(def_conf, parser=parser) + # Arguments are then parsed into a hierarchical dictionary (instead of + # flat, as returned by argparse) to facilitate calls to the different + # asteroid methods (see in main). + # plain_args is the direct output of parser.parse_args() and contains all + # the attributes in an non-hierarchical structure. It can be useful to also + # have it so we included it here but it is not used. + arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) + print(arg_dic) + main(arg_dic) diff --git a/egs/TAC/utils b/egs/TAC/utils new file mode 120000 index 000000000..025d106f8 --- /dev/null +++ b/egs/TAC/utils @@ -0,0 +1 @@ +../wham/ConvTasNet/utils/ \ No newline at end of file diff --git a/tests/dsp/spatial_test.py b/tests/dsp/spatial_test.py new file mode 100644 index 000000000..f2e287181 --- /dev/null +++ b/tests/dsp/spatial_test.py @@ -0,0 +1,25 @@ +import torch +import pytest +import numpy as np +from asteroid.dsp.spatial import xcorr + + +@pytest.mark.parametrize("seq_len_input", [1390]) +@pytest.mark.parametrize("seq_len_ref", [1390, 1290]) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("n_mics_input", [1]) +@pytest.mark.parametrize("n_mics_ref", [1, 2]) +@pytest.mark.parametrize("normalized", [False, True]) +def test_xcorr(seq_len_input, seq_len_ref, batch_size, n_mics_input, n_mics_ref, normalized): + target = torch.rand((batch_size, n_mics_input, seq_len_input)) + ref = torch.rand((batch_size, n_mics_ref, seq_len_ref)) + result = xcorr(target, ref, normalized) + assert result.shape[-1] == (seq_len_input - seq_len_ref) + 1 + + if normalized == False: + for b in range(batch_size): + for m in range(n_mics_input): + npy_result = np.correlate(target[b, m].numpy(), ref[b, m].numpy()) + np.testing.assert_array_almost_equal( + result[b, m, : len(npy_result)].numpy(), npy_result, decimal=2 + ) diff --git a/tests/models/fasnet_test.py b/tests/models/fasnet_test.py new file mode 100644 index 000000000..7345cf2b0 --- /dev/null +++ b/tests/models/fasnet_test.py @@ -0,0 +1,27 @@ +import torch +import pytest + +from asteroid.models.fasnet import FasNetTAC + + +@pytest.mark.parametrize("samples", [8372]) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("n_mics", [1, 2]) +@pytest.mark.parametrize("n_src", [1, 2, 3]) +@pytest.mark.parametrize("use_tac", [True, False]) +@pytest.mark.parametrize("enc_dim", [4]) +@pytest.mark.parametrize("feature_dim", [8]) +@pytest.mark.parametrize("window", [2]) +@pytest.mark.parametrize("context", [3]) +def test_fasnet(batch_size, n_mics, samples, n_src, use_tac, enc_dim, feature_dim, window, context): + mixture = torch.rand((batch_size, n_mics, samples)) + valid_mics = torch.tensor([n_mics for x in range(batch_size)]) + fasnet = FasNetTAC( + n_src, + use_tac=use_tac, + enc_dim=enc_dim, + feature_dim=feature_dim, + window_ms=window, + context_ms=context, + ) + fasnet(mixture, valid_mics) diff --git a/tests/models/models_test.py b/tests/models/models_test.py index 3baf3e44d..9225dba40 100644 --- a/tests/models/models_test.py +++ b/tests/models/models_test.py @@ -7,6 +7,7 @@ from asteroid import models from asteroid.filterbanks import make_enc_dec from asteroid.dsp import LambdaOverlapAdd +from asteroid.models.fasnet import FasNetTAC from asteroid.separate import separate from asteroid.models import ( ConvTasNet, @@ -194,6 +195,14 @@ def test_dptnet(fb, sample_rate): _default_test_model(DPTNet(2, ff_hid=10, chunk_size=4, n_repeats=2, sample_rate=sample_rate)) +@pytest.mark.parametrize("use_tac", [True, False]) +def test_fasnet(use_tac): + _default_test_model( + FasNetTAC(n_src=2, feature_dim=8, hidden_dim=10, n_layers=2, use_tac=use_tac), + test_input=torch.randn(3, 2, 8372), + ) + + def test_dcunet(): n_fft = 1024 _, istft = make_enc_dec( @@ -231,8 +240,9 @@ def test_dccrnet(): DCCRNet("mini").masker(torch.zeros((1, 42, 3), dtype=torch.complex64)) -def _default_test_model(model, input_samples=801): - test_input = torch.randn(1, input_samples) +def _default_test_model(model, input_samples=801, test_input=None): + if test_input is None: + test_input = torch.randn(1, input_samples) model_conf = model.serialize() reconstructed_model = model.__class__.from_pretrained(model_conf)