Skip to content

Commit

Permalink
Added consistency projection, addressed comments for the notebook
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Jul 24, 2023
1 parent c8203e6 commit 30687b6
Show file tree
Hide file tree
Showing 6 changed files with 1,393 additions and 1,161 deletions.
9 changes: 9 additions & 0 deletions nemo/collections/asr/models/enhancement_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.mask_processor = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_processor)
self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder)

if 'mixture_consistency' in self._cfg:
self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency)
else:
self.mixture_consistency = None

# Future enhancement:
# If subclasses need to modify the config before calling super()
# Check ASRBPE* classes do with their mixin
Expand Down Expand Up @@ -370,6 +375,10 @@ def forward(self, input_signal, input_length=None):
# Mask-based processor in the encoded domain
processed, processed_length = self.mask_processor(input=encoded, input_length=encoded_length, mask=mask)

# Mixture consistency
if self.mixture_consistency is not None:
processed = self.mixture_consistency(mixture=encoded, estimate=processed)

# Decoder
processed, processed_length = self.decoder(input=processed, input_length=processed_length)

Expand Down
69 changes: 69 additions & 0 deletions nemo/collections/asr/modules/audio_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,3 +878,72 @@ def forward(
output, output_length = self.filter(input=output, input_length=input_length, power=power)

return output.to(io_dtype), output_length


class MixtureConsistencyProjection(NeuralModule):
"""Ensure estimated sources are consistent with the input mixture.
Note that the input mixture is assume to be a single-channel signal.
Args:
weighting: Optional weighting mode for the consistency constraint.
If `None`, use uniform weighting. If `power`, use the power of the
estimated source as the weight.
eps: Small positive value for regularization
Reference:
Wisdom et al., Differentiable consistency constraints for improved deep speech enhancement, 2018
"""

def __init__(self, weighting: Optional[str] = None, eps: float = 1e-8):
super().__init__()
self.weighting = weighting
self.eps = eps

if self.weighting not in [None, 'power']:
raise NotImplementedError(f'Weighting mode {self.weighting} not implemented')

@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"mixture": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"estimate": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
}

@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports.
"""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
}

@typecheck()
def forward(self, mixture: torch.Tensor, estimate: torch.Tensor) -> torch.Tensor:
"""Enforce mixture consistency on the estimated sources.
Args:
mixture: Single-channel mixture, shape (B, 1, F, N)
estimate: M estimated sources, shape (B, M, F, N)
Returns:
Source estimates consistent with the mixture, shape (B, M, F, N)
"""
# number of sources
M = estimate.size(-3)
# estimated mixture based on the estimated sources
estimated_mixture = torch.sum(estimate, dim=-3, keepdim=True)

# weighting
if self.weighting is None:
weight = 1 / M
elif self.weighting == 'power':
weight = estimate.abs().pow(2)
weight = weight / (weight.sum(dim=-3, keepdim=True) + self.eps)
else:
raise NotImplementedError(f'Weighting mode {self.weighting} not implemented')

# consistent estimate
consistent_estimate = estimate + weight * (mixture - estimated_mixture)

return consistent_estimate
Loading

0 comments on commit 30687b6

Please sign in to comment.