Skip to content

Commit

Permalink
feat(utils): add "soft" option to Powerset.to_multilabel conversion (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored Oct 22, 2023
1 parent 03f8265 commit 0b45103
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- feat(pipeline): add support for list of hooks with `Hooks`
- BREAKING(pipeline): remove `logging_hook` (use `ArtifactHook` instead)
- fix(pipeline): add missing "embedding" hook call in `SpeakerDiarization`
- feat(utils): add `"soft"` option to `Powerset.to_multilabel`

## Version 3.0.1 (2023-09-28)

Expand Down
22 changes: 14 additions & 8 deletions pyannote/audio/utils/powerset.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,26 +84,32 @@ def build_cardinality(self) -> torch.Tensor:
powerset_k += 1
return cardinality

def to_multilabel(self, powerset: torch.Tensor) -> torch.Tensor:
"""Convert predictions from (soft) powerset to (hard) multi-label
def to_multilabel(self, powerset: torch.Tensor, soft: bool = False) -> torch.Tensor:
"""Convert predictions from powerset to multi-label
Parameter
---------
powerset : (batch_size, num_frames, num_powerset_classes) torch.Tensor
Soft predictions in "powerset" space.
soft : bool, optional
Return soft multi-label predictions. Defaults to False (i.e. hard predictions)
Assumes that `powerset` are "logits" (not "probabilities").
Returns
-------
multi_label : (batch_size, num_frames, num_classes) torch.Tensor
Hard predictions in "multi-label" space.
Predictions in "multi-label" space.
"""

hard_powerset = torch.nn.functional.one_hot(
torch.argmax(powerset, dim=-1),
self.num_powerset_classes,
).float()
if soft:
powerset_probs = torch.exp(powerset)
else:
powerset_probs = torch.nn.functional.one_hot(
torch.argmax(powerset, dim=-1),
self.num_powerset_classes,
).float()

return torch.matmul(hard_powerset, self.mapping)
return torch.matmul(powerset_probs, self.mapping)

def forward(self, powerset: torch.Tensor) -> torch.Tensor:
"""Alias for `to_multilabel`"""
Expand Down

0 comments on commit 0b45103

Please sign in to comment.