Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A quickstart speech enhancement tutorial #6492

Merged
merged 2 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This configuration contains the default values for training a multichannel speech enhancement model.
# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer.
#
name: "multichannel_enhancement"
name: "beamforming"

model:
sample_rate: 16000
Expand Down Expand Up @@ -78,10 +78,10 @@ model:

optim:
name: adamw
lr: 1e-3
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 0
weight_decay: 1e-3

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
Expand Down
126 changes: 126 additions & 0 deletions examples/audio_tasks/conf/masking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer.
#
name: "masking"

model:
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1

train_ds:
manifest_filepath: ???
input_key: audio_filepath # key of the input signal path in the manifest
target_key: target_filepath # key of the target signal path in the manifest
target_channel_selector: 0 # target signal is the first channel from files in target_key
audio_duration: 4.0 # in seconds, audio segment duration for training
random_offset: true # if the file is longer than audio_duration, use random offset to select a subsegment
min_duration: ${model.train_ds.audio_duration}
batch_size: 64 # batch size may be increased based on the available memory
shuffle: true
num_workers: 8
pin_memory: true

validation_ds:
manifest_filepath: ???
input_key: audio_filepath # key of the input signal path in the manifest
target_key: target_filepath
target_channel_selector: 0 # target signal is the first channel from files in target_key
batch_size: 64 # batch size may be increased based on the available memory
shuffle: false
num_workers: 4
pin_memory: true

test_ds:
manifest_filepath: ???
input_key: audio_filepath # key of the input signal path in the manifest
target_key: target_filepath # key of the target signal path in the manifest
target_channel_selector: 0 # target signal is the first channel from files in target_key
batch_size: 1 # batch size may be increased based on the available memory
shuffle: false
num_workers: 4
pin_memory: true

encoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram
power: null

decoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram

mask_estimator:
_target_: nemo.collections.asr.modules.audio_modules.MaskEstimatorRNN
num_outputs: ${model.num_outputs}
num_subbands: 257 # Number of subbands of the input spectrogram
num_features: 256 # Number of features at RNN input
num_layers: 5 # Number of RNN layers
bidirectional: true # Use bi-directional RNN

mask_processor:
_target_: nemo.collections.asr.modules.audio_modules.MaskReferenceChannel # Apply mask on the reference channel
ref_channel: 0 # Reference channel for the output

loss:
_target_: nemo.collections.asr.losses.SDRLoss
scale_invariant: true # Use scale-invariant SDR

metrics:
val:
sdr: # output SDR
_target_: torchmetrics.audio.SignalDistortionRatio
test:
sdr_ch0: # SDR on output channel 0
_target_: torchmetrics.audio.SignalDistortionRatio
channel: 0

optim:
name: adamw
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: null
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually don't expose this arg cause exp manager overrides it

num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_loss"
mode: "min"
save_top_k: 5
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Training the model

Basic run (on CPU for 50 epochs):
python examples/asr/experimental/audio_to_audio/speech_enhancement.py \
python examples/audio_tasks/speech_enhancement.py \
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
model.train_ds.manifest_filepath="<path to manifest file>" \
model.validation_ds.manifest_filepath="<path to manifest file>" \
Expand All @@ -36,7 +36,7 @@
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="./conf", config_name="multichannel_enhancement")
@hydra_runner(config_path="./conf", config_name="masking")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg, resolve=True)}')

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/data/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def get_samples_synchronized(

if duration + fixed_offset > min_audio_duration:
# The shortest file is shorter than the requested duration
logging.warning(
logging.debug(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intentional ?

f'Shortest file ({min_audio_duration}s) is less than the desired duration {duration}s + fixed offset {fixed_offset}s. Returned signals will be shortened to {available_duration} seconds.'
)
offset = fixed_offset
Expand Down
15 changes: 12 additions & 3 deletions nemo/collections/asr/models/enhancement_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.mask_processor = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_processor)
self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder)

if 'mixture_consistency' in self._cfg:
self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency)
else:
self.mixture_consistency = None

# Future enhancement:
# If subclasses need to modify the config before calling super()
# Check ASRBPE* classes do with their mixin
Expand Down Expand Up @@ -316,7 +321,7 @@ def input_types(self) -> Dict[str, NeuralType]:
"input_signal": NeuralType(
('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)
), # multi-channel format, channel dimension can be 1 for single-channel audio
"input_length": NeuralType(tuple('B'), LengthsType()),
"input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}

@property
Expand All @@ -325,7 +330,7 @@ def output_types(self) -> Dict[str, NeuralType]:
"output_signal": NeuralType(
('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)
), # multi-channel format, channel dimension can be 1 for single-channel audio
"output_length": NeuralType(tuple('B'), LengthsType()),
"output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}

def match_batch_length(self, input: torch.Tensor, batch_length: int):
Expand All @@ -346,7 +351,7 @@ def match_batch_length(self, input: torch.Tensor, batch_length: int):
return torch.nn.functional.pad(input, pad, 'constant', 0)

@typecheck()
def forward(self, input_signal, input_length):
def forward(self, input_signal, input_length=None):
"""
Forward pass of the model.

Expand All @@ -370,6 +375,10 @@ def forward(self, input_signal, input_length):
# Mask-based processor in the encoded domain
processed, processed_length = self.mask_processor(input=encoded, input_length=encoded_length, mask=mask)

# Mixture consistency
if self.mixture_consistency is not None:
processed = self.mixture_consistency(mixture=encoded, estimate=processed)

# Decoder
processed, processed_length = self.decoder(input=processed, input_length=processed_length)

Expand Down
104 changes: 88 additions & 16 deletions nemo/collections/asr/modules/audio_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ def __init__(
else:
raise ValueError(f'Unknown rnn_type: {rnn_type}')

self.fc = torch.nn.Linear(
in_features=2 * num_features if bidirectional else num_features, out_features=num_features
)
self.norm = torch.nn.LayerNorm(num_features)

# Each output shares the RNN and has a separate projection
self.output_projections = torch.nn.ModuleList(
[
torch.nn.Linear(
in_features=2 * num_features if bidirectional else num_features, out_features=num_subbands
)
for _ in range(num_outputs)
]
[torch.nn.Linear(in_features=num_features, out_features=num_subbands) for _ in range(num_outputs)]
)
self.output_nonlinearity = torch.nn.Sigmoid()

Expand Down Expand Up @@ -310,33 +310,36 @@ def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torc
).to(input.device)
self.rnn.flatten_parameters()
input_packed, _ = self.rnn(input_packed)
input, input_length = torch.nn.utils.rnn.pad_packed_sequence(input_packed, batch_first=True)
input_length = input_length.to(input.device)
output, output_length = torch.nn.utils.rnn.pad_packed_sequence(input_packed, batch_first=True)
output_length = output_length.to(input.device)

# Layer normalization and skip connection
output = self.norm(self.fc(output)) + input

# Create `num_outputs` masks
output = []
masks = []
for output_projection in self.output_projections:
# Output projection
mask = output_projection(input)
mask = output_projection(output)
mask = self.output_nonlinearity(mask)

# Back to the original format
# (B, N, F) -> (B, F, N)
mask = mask.transpose(2, 1)

# Append to the output
output.append(mask)
masks.append(mask)

# Stack along channel dimension to get (B, M, F, N)
output = torch.stack(output, axis=1)
masks = torch.stack(masks, axis=1)

# Mask frames beyond input length
# Mask frames beyond output length
length_mask: torch.Tensor = make_seq_mask_like(
lengths=input_length, like=output, time_dim=-1, valid_ones=False
lengths=output_length, like=masks, time_dim=-1, valid_ones=False
)
output = output.masked_fill(length_mask, 0.0)
masks = masks.masked_fill(length_mask, 0.0)

return output, input_length
return masks, output_length


class MaskReferenceChannel(NeuralModule):
Expand Down Expand Up @@ -875,3 +878,72 @@ def forward(
output, output_length = self.filter(input=output, input_length=input_length, power=power)

return output.to(io_dtype), output_length


class MixtureConsistencyProjection(NeuralModule):
"""Ensure estimated sources are consistent with the input mixture.
Note that the input mixture is assume to be a single-channel signal.

Args:
weighting: Optional weighting mode for the consistency constraint.
If `None`, use uniform weighting. If `power`, use the power of the
estimated source as the weight.
eps: Small positive value for regularization

Reference:
Wisdom et al., Differentiable consistency constraints for improved deep speech enhancement, 2018
"""

def __init__(self, weighting: Optional[str] = None, eps: float = 1e-8):
super().__init__()
self.weighting = weighting
self.eps = eps

if self.weighting not in [None, 'power']:
raise NotImplementedError(f'Weighting mode {self.weighting} not implemented')

@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"mixture": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"estimate": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
}

@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
}

@typecheck()
def forward(self, mixture: torch.Tensor, estimate: torch.Tensor) -> torch.Tensor:
"""Enforce mixture consistency on the estimated sources.
Args:
mixture: Single-channel mixture, shape (B, 1, F, N)
estimate: M estimated sources, shape (B, M, F, N)

Returns:
Source estimates consistent with the mixture, shape (B, M, F, N)
"""
# number of sources
M = estimate.size(-3)
# estimated mixture based on the estimated sources
estimated_mixture = torch.sum(estimate, dim=-3, keepdim=True)

# weighting
if self.weighting is None:
weight = 1 / M
elif self.weighting == 'power':
weight = estimate.abs().pow(2)
weight = weight / (weight.sum(dim=-3, keepdim=True) + self.eps)
else:
raise NotImplementedError(f'Weighting mode {self.weighting} not implemented')

# consistent estimate
consistent_estimate = estimate + weight * (mixture - estimated_mixture)

return consistent_estimate
Loading