Skip to content

Commit

Permalink
Replace operators.aggregate() with functional.DelayedAggregation (#…
Browse files Browse the repository at this point in the history
…16)

* Move output window aggregation to functional module

* Add docstring to DelayedAggregation. Fix incorrect resolution in FrameWiseModel
  • Loading branch information
juanmc2005 authored Dec 17, 2021
1 parent 5894ad6 commit a96e7ae
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 55 deletions.
102 changes: 96 additions & 6 deletions src/diart/functional.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import numpy as np
from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature
from pyannote.core import Annotation, Segment, SlidingWindow, SlidingWindowFeature
from pyannote.audio.utils.signal import Binarize as PyanBinarize
from pyannote.audio.pipelines.utils import PipelineModel, get_model, get_devices
from typing import Union, Optional, List, Iterable, Tuple
from typing import Union, Optional, List, Literal, Iterable, Tuple
import warnings

from .mapping import SpeakerMap, SpeakerMapBuilder

Expand All @@ -21,11 +22,13 @@ def __call__(self, waveform: SlidingWindowFeature) -> SlidingWindowFeature:
wave = torch.from_numpy(waveform.data.T[np.newaxis])
output = self.model(wave.to(self.model.device)).cpu().numpy()[0]
# Temporal resolution of the output
resolution = self.model.introspection.frames
resolution = self.model.specifications.duration / output.shape[0]
# Temporal shift to keep track of current start time
resolution = SlidingWindow(start=waveform.sliding_window.start,
duration=resolution.duration,
step=resolution.step)
resolution = SlidingWindow(
start=waveform.sliding_window.start,
duration=resolution,
step=resolution
)
return SlidingWindowFeature(output, resolution)


Expand Down Expand Up @@ -84,6 +87,93 @@ def __call__(self, embeddings: torch.Tensor) -> torch.Tensor:
return norm_embs


class DelayedAggregation:
"""Aggregate aligned overlapping windows of the same duration
across sliding buffers with a specific step and latency.
Parameters
----------
step: float
Shift between two consecutive buffers, in seconds.
latency: float, optional
Desired latency, in seconds. Defaults to step.
The higher the latency, the more overlapping windows to aggregate.
strategy: ("mean", "hamming", "any"), optional
Specifies how to aggregate overlapping windows. Defaults to "hamming".
"mean": simple average
"hamming": average weighted by the Hamming window values (aligned to the buffer)
"any": no aggregation, pick the first overlapping window
Example
--------
>>> duration = 5
>>> frames = 500
>>> step = 0.5
>>> speakers = 2
>>> start_time = 10
>>> resolution = duration / frames
>>> dagg = DelayedAggregation(step=step, latency=2, strategy="mean")
>>> buffers = [
>>> SlidingWindowFeature(
>>> np.random.rand(frames, speakers),
>>> SlidingWindow(start=(i + start_time) * step, duration=resolution, step=resolution)
>>> )
>>> for i in range(dagg.num_overlapping_windows)
>>> ]
>>> dagg.num_overlapping_windows
... 4
>>> dagg(buffers).data.shape
... (51, 2) # Rounding errors are possible when cropping the buffers
"""

def __init__(
self,
step: float,
latency: Optional[float] = None,
strategy: Literal["mean", "hamming", "any"] = "hamming",
):
self.step = step
self.latency = latency
self.strategy = strategy

if self.latency is None:
self.latency = self.step

assert self.step <= self.latency, "Invalid latency requested"
assert self.strategy in ["mean", "hamming", "any"]

self.num_overlapping_windows = int(round(self.latency / self.step))

if self.strategy == "hamming":
warnings.warn("'hamming' aggregation is not supported yet, defaulting to 'mean'")

def __call__(self, buffers: List[SlidingWindowFeature]) -> SlidingWindowFeature:
# Determine overlapping region to aggregate
real_time = buffers[-1].extent.end
start_time = 0
if buffers[0].extent.start > 0:
start_time = real_time - self.latency
required = Segment(start_time, real_time - self.latency + self.step)
# Stack all overlapping regions
intersection = np.stack([
buffer.crop(required, fixed=required.duration)
for buffer in buffers
])
# Aggregate according to strategy
if self.strategy in ("mean", "hamming"):
aggregation = np.mean(intersection, axis=0)
else:
aggregation = intersection[0]
# Determine resolution
resolution = buffers[-1].sliding_window
resolution = SlidingWindow(
start=required.start,
duration=resolution.duration,
step=resolution.step
)
return SlidingWindowFeature(aggregation, resolution)


class OnlineSpeakerClustering:
def __init__(
self,
Expand Down
49 changes: 2 additions & 47 deletions src/diart/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from rx import operators as ops
from rx.core import Observable
from dataclasses import dataclass
from typing import Callable, Optional, List, Literal
from typing import Callable, Optional
import numpy as np
from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature
import warnings
from pyannote.core import SlidingWindow, SlidingWindowFeature


Operator = Callable[[Observable], Observable]
Expand Down Expand Up @@ -86,47 +85,3 @@ def accumulate(state: AudioBufferState, value: np.ndarray):
# Transform state into a SlidingWindowFeature containing the new chunk
ops.map(AudioBufferState.to_sliding_window(sample_rate))
)


def aggregate(
duration: float,
step: float,
latency: Optional[float] = None,
strategy: Literal["mean", "hamming", "any"] = "mean",
):
if latency is None:
latency = step
assert duration >= latency >= step
assert strategy in ["mean", "hamming", "any"]
if strategy == "hamming":
warnings.warn("'hamming' aggregation is not supported yet, defaulting to 'mean'")
num_overlapping = int(round(latency / step))

def apply(buffers: List[SlidingWindowFeature]) -> SlidingWindowFeature:
# Determine overlapping region to aggregate
real_time = buffers[-1].extent.end
start_time = 0
if buffers[0].extent.start > 0:
start_time = real_time - latency
required = Segment(start_time, real_time - latency + step)
# Stack all overlapping regions
intersection = np.stack([
buffer.crop(required, fixed=required.duration)
for buffer in buffers
])
# Aggregate according to strategy
if strategy in ("mean", "hamming"):
aggregation = np.mean(intersection, axis=0)
else:
aggregation = intersection[0]
# Determine resolution
resolution = buffers[-1].sliding_window
resolution = SlidingWindow(start=required.start, duration=resolution.duration, step=resolution.step)
return SlidingWindowFeature(aggregation, resolution)

return ops.pipe(
# Buffer 'num_overlapping' sliding chunks with a step of 1 chunk
ops.buffer_with_count(num_overlapping, 1),
# Aggregate buffered chunks
ops.map(apply)
)
11 changes: 9 additions & 2 deletions src/diart/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,22 @@ def from_source(self, source: AudioSource, output_waveform: bool = False) -> rx.
clustering = fn.OnlineSpeakerClustering(
self.tau_active, self.rho_update, self.delta_new, "cosine", self.max_speakers
)
aggregation = fn.DelayedAggregation(self.step, self.latency, strategy="mean")
pipeline = rx.zip(segmentation_stream, embedding_stream).pipe(
ops.starmap(clustering),
my_ops.aggregate(self.duration, self.step, self.latency, "mean"),
# Buffer 'num_overlapping' sliding chunks with a step of 1 chunk
ops.buffer_with_count(aggregation.num_overlapping_windows, 1),
# Aggregate overlapping output windows
ops.map(aggregation),
# Binarize output
ops.map(fn.Binarize(source.uri, self.tau_active)),
)
if output_waveform:
window_selector = fn.DelayedAggregation(self.step, self.latency, strategy="any")
pipeline = pipeline.pipe(
ops.zip(regular_stream.pipe(
my_ops.aggregate(self.duration, self.step, self.latency, "any")
ops.buffer_with_count(window_selector.num_overlapping_windows, 1),
ops.map(window_selector),
))
)
return pipeline

0 comments on commit a96e7ae

Please sign in to comment.