Skip to content

Commit

Permalink
Modular OutputBuilder + better demo performance (#20)
Browse files Browse the repository at this point in the history
* Add `RTTMWriter` observer to write system outputs

* Add `accumulate_output()` and `buffer_output()` operators to concatenate system outputs either by accumulating all past predictions or by keeping only the last ones

* Add `RealTimePlot` observer to draw system outputs in real time

* `OnlineSpeakerDiarization` pipeline yields `None` waveform if `output_waveform=False` for consistency

* Remove redundant sample rate parameter in `FileAudioSource`

* Make `accumulate_output()` more efficient by pre-allocating the audio buffer and hence reducing the number of costly concatenate operations
  • Loading branch information
juanmc2005 authored Jan 3, 2022
1 parent 35f60f1 commit 88ba337
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 102 deletions.
29 changes: 15 additions & 14 deletions src/diart/demo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
from pathlib import Path

import diart.operators as dops
import diart.sources as src
import rx.operators as ops
from diart.pipelines import OnlineSpeakerDiarization
from diart.sinks import OutputBuilder
from diart.sinks import RealTimePlot, RTTMWriter

# Define script arguments
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -37,33 +39,32 @@
)

# Manage audio source
uri = args.source
if args.source != "microphone":
args.source = Path(args.source).expanduser()
uri = args.source.name.split(".")[0]
output_dir = args.source.parent if args.output is None else Path(args.output)
audio_source = src.FileAudioSource(
file=args.source,
uri=uri,
sample_rate=args.sample_rate,
reader=src.RegularAudioFileReader(
args.sample_rate, pipeline.duration, args.step
args.sample_rate, pipeline.duration, pipeline.step
),
)
else:
output_dir = Path("~/").expanduser() if args.output is None else Path(args.output)
audio_source = src.MicrophoneAudioSource(args.sample_rate)

# Configure output builder to write an RTTM file and to draw in real time
output_builder = OutputBuilder(
duration=pipeline.duration,
step=args.step,
latency=args.latency,
output_path=output_dir / "output.rttm",
visualization="slide",
)
# Build pipeline from audio source and stream results to the output builder
pipeline.from_source(audio_source, output_waveform=True).subscribe(output_builder)
# Build pipeline from audio source and stream predictions to a real-time plot
pipeline.from_source(audio_source).pipe(
ops.do(RTTMWriter(path=output_dir / "output.rttm")),
dops.buffer_output(
duration=pipeline.duration,
step=pipeline.step,
latency=pipeline.latency,
sample_rate=audio_source.sample_rate
),
).subscribe(RealTimePlot(pipeline.duration, pipeline.latency))

# Read audio source as a stream
if args.source == "microphone":
print("Recording...")
Expand Down
190 changes: 188 additions & 2 deletions src/diart/operators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass
from typing import Callable, Optional, List, Any
from typing import Callable, Optional, List, Any, Tuple

import numpy as np
import rx
from pyannote.core import SlidingWindow, SlidingWindowFeature
from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature, Segment
from rx import operators as ops
from rx.core import Observable

Expand Down Expand Up @@ -94,3 +94,189 @@ def accumulate(state: List[Any], value: Any) -> List[Any]:
return new_state[1:]
return new_state
return rx.pipe(ops.scan(accumulate, []))


@dataclass
class PredictionWithAudio:
prediction: Annotation
waveform: Optional[SlidingWindowFeature] = None

@property
def has_audio(self) -> bool:
return self.waveform is not None


@dataclass
class OutputAccumulationState:
annotation: Optional[Annotation]
waveform: Optional[SlidingWindowFeature]
real_time: float
next_sample: Optional[int]

@staticmethod
def initial() -> 'OutputAccumulationState':
return OutputAccumulationState(None, None, 0, 0)

@property
def cropped_waveform(self) -> SlidingWindowFeature:
return SlidingWindowFeature(
self.waveform[:self.next_sample],
self.waveform.sliding_window,
)

def to_tuple(self) -> Tuple[Optional[Annotation], Optional[SlidingWindowFeature], float]:
return self.annotation, self.cropped_waveform, self.real_time


def accumulate_output(
duration: float,
step: float,
patch_collar: float = 0.05,
) -> Operator:
"""Accumulate predictions and audio to infinity: O(N) space complexity.
Uses a pre-allocated buffer that doubles its size once full: O(logN) concat operations.
Parameters
----------
duration: float
Buffer duration in seconds.
step: float
Duration of the chunks at each event in seconds.
The first chunk may be bigger given the latency.
patch_collar: float, optional
Collar to merge speaker turns of the same speaker, in seconds.
Defaults to 0.05 (i.e. 50ms).
Returns
-------
A reactive x operator implementing this behavior.
"""
def accumulate(
state: OutputAccumulationState,
value: Tuple[Annotation, Optional[SlidingWindowFeature]]
) -> OutputAccumulationState:
value = PredictionWithAudio(*value)
annotation, waveform = None, None

# Determine the real time of the stream
real_time = duration if state.annotation is None else state.real_time + step

# Update total annotation with current predictions
if state.annotation is None:
annotation = value.prediction
else:
annotation = state.annotation.update(value.prediction).support(patch_collar)

# Update total waveform if there's audio in the input
new_next_sample = 0
if value.has_audio:
num_new_samples = value.waveform.data.shape[0]
new_next_sample = state.next_sample + num_new_samples
sw_holder = state
if state.waveform is None:
# Initialize the audio buffer with 10 times the size of the first chunk
waveform, sw_holder = np.zeros((10 * num_new_samples, 1)), value
elif new_next_sample < state.waveform.data.shape[0]:
# The buffer still has enough space to accommodate the chunk
waveform = state.waveform.data
else:
# The buffer is full, double its size
waveform = np.concatenate(
(state.waveform.data, np.zeros_like(state.waveform.data)), axis=0
)
# Copy chunk into buffer
waveform[state.next_sample:new_next_sample] = value.waveform.data
waveform = SlidingWindowFeature(waveform, sw_holder.waveform.sliding_window)

return OutputAccumulationState(annotation, waveform, real_time, new_next_sample)

return rx.pipe(
ops.scan(accumulate, OutputAccumulationState.initial()),
ops.map(OutputAccumulationState.to_tuple),
)


def buffer_output(
duration: float,
step: float,
latency: float,
sample_rate: int,
patch_collar: float = 0.05,
) -> Operator:
"""Store last predictions and audio inside a fixed buffer.
Provides the best time/space complexity trade-off if the past data is not needed.
Parameters
----------
duration: float
Buffer duration in seconds.
step: float
Duration of the chunks at each event in seconds.
The first chunk may be bigger given the latency.
latency: float
Latency of the system in seconds.
sample_rate: int
Sample rate of the audio source.
patch_collar: float, optional
Collar to merge speaker turns of the same speaker, in seconds.
Defaults to 0.05 (i.e. 50ms).
Returns
-------
A reactive x operator implementing this behavior.
"""
# Define some useful constants
num_samples = int(round(duration * sample_rate))
num_step_samples = int(round(step * sample_rate))
resolution = 1 / sample_rate

def accumulate(
state: OutputAccumulationState,
value: Tuple[Annotation, Optional[SlidingWindowFeature]]
) -> OutputAccumulationState:
value = PredictionWithAudio(*value)
annotation, waveform = None, None

# Determine the real time of the stream and the start time of the buffer
real_time = duration if state.annotation is None else state.real_time + step
start_time = max(0., real_time - latency - duration)

# Update annotation and constrain its bounds to the buffer
if state.annotation is None:
annotation = value.prediction
else:
annotation = state.annotation.update(value.prediction) \
.support(patch_collar) \
.extrude(Segment(0, start_time))

# Update the audio buffer if there's audio in the input
new_next_sample = state.next_sample + num_step_samples
if value.has_audio:
if state.waveform is None:
# Determine the size of the first chunk
expected_duration = duration + step - latency
expected_samples = int(round(expected_duration * sample_rate))
# Shift indicator to start copying new audio in the buffer
new_next_sample = state.next_sample + expected_samples
# Buffer size is duration + step
waveform = np.zeros((num_samples + num_step_samples, 1))
# Copy first chunk into buffer (slicing because of rounding errors)
waveform[:expected_samples] = value.waveform.data[:expected_samples]
elif state.next_sample <= num_samples:
# The buffer isn't full, copy into next free buffer chunk
waveform = state.waveform.data
waveform[state.next_sample:new_next_sample] = value.waveform.data
else:
# The buffer is full, shift values to the left and copy into last buffer chunk
waveform = np.roll(state.waveform.data, -num_step_samples, axis=0)
waveform[-num_step_samples:] = value.waveform.data

# Wrap waveform in a sliding window feature to include timestamps
window = SlidingWindow(start=start_time, duration=resolution, step=resolution)
waveform = SlidingWindowFeature(waveform, window)

return OutputAccumulationState(annotation, waveform, real_time, new_next_sample)

return rx.pipe(
ops.scan(accumulate, OutputAccumulationState.initial()),
ops.map(OutputAccumulationState.to_tuple),
)
31 changes: 19 additions & 12 deletions src/diart/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from pyannote.audio.pipelines.utils import PipelineModel

from . import functional as fn
from . import operators as my_ops
from .sources import AudioSource
from . import operators as dops
from . import sources as src


class OnlineSpeakerDiarization:
Expand Down Expand Up @@ -41,17 +41,23 @@ def __init__(
self.beta = beta
self.max_speakers = max_speakers

def get_end_time(self, source: AudioSource) -> Optional[float]:
@property
def sample_rate(self) -> int:
return self.segmentation.model.audio.sample_rate

def get_end_time(self, source: src.AudioSource) -> Optional[float]:
if source.duration is not None:
return source.duration - source.duration % self.step
return None

def from_source(self, source: AudioSource, output_waveform: bool = False) -> rx.Observable:
def from_source(self, source: src.AudioSource, output_waveform: bool = True) -> rx.Observable:
msg = f"Audio source has sample rate {source.sample_rate}, expected {self.sample_rate}"
assert source.sample_rate == self.sample_rate, msg
# Regularize the stream to a specific chunk duration and step
regular_stream = source.stream
if not source.is_regular:
regular_stream = source.stream.pipe(
my_ops.regularize_stream(self.duration, self.step, source.sample_rate)
dops.regularize_stream(self.duration, self.step, source.sample_rate)
)
# Branch the stream to calculate chunk segmentation
segmentation_stream = regular_stream.pipe(
Expand All @@ -76,20 +82,21 @@ def from_source(self, source: AudioSource, output_waveform: bool = False) -> rx.
pipeline = rx.zip(segmentation_stream, embedding_stream).pipe(
ops.starmap(clustering),
# Buffer 'num_overlapping' sliding chunks with a step of 1 chunk
my_ops.buffer_slide(aggregation.num_overlapping_windows),
dops.buffer_slide(aggregation.num_overlapping_windows),
# Aggregate overlapping output windows
ops.map(aggregation),
# Binarize output
ops.map(fn.Binarize(source.uri, self.tau_active)),
)
# Add corresponding waveform to the output
if output_waveform:
window_selector = fn.DelayedAggregation(
self.step, self.latency, strategy="first", stream_end=end_time
)
pipeline = pipeline.pipe(
ops.zip(regular_stream.pipe(
my_ops.buffer_slide(window_selector.num_overlapping_windows),
ops.map(window_selector),
))
waveform_stream = regular_stream.pipe(
dops.buffer_slide(window_selector.num_overlapping_windows),
ops.map(window_selector),
)
return pipeline
return rx.zip(pipeline, waveform_stream)
# No waveform needed, add None for consistency
return pipeline.pipe(ops.map(lambda ann: (ann, None)))
Loading

0 comments on commit 88ba337

Please sign in to comment.