diff --git a/src/diart/demo.py b/src/diart/demo.py index 8e4344bf..8442e86e 100644 --- a/src/diart/demo.py +++ b/src/diart/demo.py @@ -1,9 +1,11 @@ import argparse from pathlib import Path +import diart.operators as dops import diart.sources as src +import rx.operators as ops from diart.pipelines import OnlineSpeakerDiarization -from diart.sinks import OutputBuilder +from diart.sinks import RealTimePlot, RTTMWriter # Define script arguments parser = argparse.ArgumentParser() @@ -37,7 +39,6 @@ ) # Manage audio source -uri = args.source if args.source != "microphone": args.source = Path(args.source).expanduser() uri = args.source.name.split(".")[0] @@ -45,25 +46,25 @@ audio_source = src.FileAudioSource( file=args.source, uri=uri, - sample_rate=args.sample_rate, reader=src.RegularAudioFileReader( - args.sample_rate, pipeline.duration, args.step + args.sample_rate, pipeline.duration, pipeline.step ), ) else: output_dir = Path("~/").expanduser() if args.output is None else Path(args.output) audio_source = src.MicrophoneAudioSource(args.sample_rate) -# Configure output builder to write an RTTM file and to draw in real time -output_builder = OutputBuilder( - duration=pipeline.duration, - step=args.step, - latency=args.latency, - output_path=output_dir / "output.rttm", - visualization="slide", -) -# Build pipeline from audio source and stream results to the output builder -pipeline.from_source(audio_source, output_waveform=True).subscribe(output_builder) +# Build pipeline from audio source and stream predictions to a real-time plot +pipeline.from_source(audio_source).pipe( + ops.do(RTTMWriter(path=output_dir / "output.rttm")), + dops.buffer_output( + duration=pipeline.duration, + step=pipeline.step, + latency=pipeline.latency, + sample_rate=audio_source.sample_rate + ), +).subscribe(RealTimePlot(pipeline.duration, pipeline.latency)) + # Read audio source as a stream if args.source == "microphone": print("Recording...") diff --git a/src/diart/operators.py b/src/diart/operators.py index 31de0eb8..5b3a6551 100644 --- a/src/diart/operators.py +++ b/src/diart/operators.py @@ -1,9 +1,9 @@ from dataclasses import dataclass -from typing import Callable, Optional, List, Any +from typing import Callable, Optional, List, Any, Tuple import numpy as np import rx -from pyannote.core import SlidingWindow, SlidingWindowFeature +from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature, Segment from rx import operators as ops from rx.core import Observable @@ -94,3 +94,189 @@ def accumulate(state: List[Any], value: Any) -> List[Any]: return new_state[1:] return new_state return rx.pipe(ops.scan(accumulate, [])) + + +@dataclass +class PredictionWithAudio: + prediction: Annotation + waveform: Optional[SlidingWindowFeature] = None + + @property + def has_audio(self) -> bool: + return self.waveform is not None + + +@dataclass +class OutputAccumulationState: + annotation: Optional[Annotation] + waveform: Optional[SlidingWindowFeature] + real_time: float + next_sample: Optional[int] + + @staticmethod + def initial() -> 'OutputAccumulationState': + return OutputAccumulationState(None, None, 0, 0) + + @property + def cropped_waveform(self) -> SlidingWindowFeature: + return SlidingWindowFeature( + self.waveform[:self.next_sample], + self.waveform.sliding_window, + ) + + def to_tuple(self) -> Tuple[Optional[Annotation], Optional[SlidingWindowFeature], float]: + return self.annotation, self.cropped_waveform, self.real_time + + +def accumulate_output( + duration: float, + step: float, + patch_collar: float = 0.05, +) -> Operator: + """Accumulate predictions and audio to infinity: O(N) space complexity. + Uses a pre-allocated buffer that doubles its size once full: O(logN) concat operations. + + Parameters + ---------- + duration: float + Buffer duration in seconds. + step: float + Duration of the chunks at each event in seconds. + The first chunk may be bigger given the latency. + patch_collar: float, optional + Collar to merge speaker turns of the same speaker, in seconds. + Defaults to 0.05 (i.e. 50ms). + Returns + ------- + A reactive x operator implementing this behavior. + """ + def accumulate( + state: OutputAccumulationState, + value: Tuple[Annotation, Optional[SlidingWindowFeature]] + ) -> OutputAccumulationState: + value = PredictionWithAudio(*value) + annotation, waveform = None, None + + # Determine the real time of the stream + real_time = duration if state.annotation is None else state.real_time + step + + # Update total annotation with current predictions + if state.annotation is None: + annotation = value.prediction + else: + annotation = state.annotation.update(value.prediction).support(patch_collar) + + # Update total waveform if there's audio in the input + new_next_sample = 0 + if value.has_audio: + num_new_samples = value.waveform.data.shape[0] + new_next_sample = state.next_sample + num_new_samples + sw_holder = state + if state.waveform is None: + # Initialize the audio buffer with 10 times the size of the first chunk + waveform, sw_holder = np.zeros((10 * num_new_samples, 1)), value + elif new_next_sample < state.waveform.data.shape[0]: + # The buffer still has enough space to accommodate the chunk + waveform = state.waveform.data + else: + # The buffer is full, double its size + waveform = np.concatenate( + (state.waveform.data, np.zeros_like(state.waveform.data)), axis=0 + ) + # Copy chunk into buffer + waveform[state.next_sample:new_next_sample] = value.waveform.data + waveform = SlidingWindowFeature(waveform, sw_holder.waveform.sliding_window) + + return OutputAccumulationState(annotation, waveform, real_time, new_next_sample) + + return rx.pipe( + ops.scan(accumulate, OutputAccumulationState.initial()), + ops.map(OutputAccumulationState.to_tuple), + ) + + +def buffer_output( + duration: float, + step: float, + latency: float, + sample_rate: int, + patch_collar: float = 0.05, +) -> Operator: + """Store last predictions and audio inside a fixed buffer. + Provides the best time/space complexity trade-off if the past data is not needed. + + Parameters + ---------- + duration: float + Buffer duration in seconds. + step: float + Duration of the chunks at each event in seconds. + The first chunk may be bigger given the latency. + latency: float + Latency of the system in seconds. + sample_rate: int + Sample rate of the audio source. + patch_collar: float, optional + Collar to merge speaker turns of the same speaker, in seconds. + Defaults to 0.05 (i.e. 50ms). + + Returns + ------- + A reactive x operator implementing this behavior. + """ + # Define some useful constants + num_samples = int(round(duration * sample_rate)) + num_step_samples = int(round(step * sample_rate)) + resolution = 1 / sample_rate + + def accumulate( + state: OutputAccumulationState, + value: Tuple[Annotation, Optional[SlidingWindowFeature]] + ) -> OutputAccumulationState: + value = PredictionWithAudio(*value) + annotation, waveform = None, None + + # Determine the real time of the stream and the start time of the buffer + real_time = duration if state.annotation is None else state.real_time + step + start_time = max(0., real_time - latency - duration) + + # Update annotation and constrain its bounds to the buffer + if state.annotation is None: + annotation = value.prediction + else: + annotation = state.annotation.update(value.prediction) \ + .support(patch_collar) \ + .extrude(Segment(0, start_time)) + + # Update the audio buffer if there's audio in the input + new_next_sample = state.next_sample + num_step_samples + if value.has_audio: + if state.waveform is None: + # Determine the size of the first chunk + expected_duration = duration + step - latency + expected_samples = int(round(expected_duration * sample_rate)) + # Shift indicator to start copying new audio in the buffer + new_next_sample = state.next_sample + expected_samples + # Buffer size is duration + step + waveform = np.zeros((num_samples + num_step_samples, 1)) + # Copy first chunk into buffer (slicing because of rounding errors) + waveform[:expected_samples] = value.waveform.data[:expected_samples] + elif state.next_sample <= num_samples: + # The buffer isn't full, copy into next free buffer chunk + waveform = state.waveform.data + waveform[state.next_sample:new_next_sample] = value.waveform.data + else: + # The buffer is full, shift values to the left and copy into last buffer chunk + waveform = np.roll(state.waveform.data, -num_step_samples, axis=0) + waveform[-num_step_samples:] = value.waveform.data + + # Wrap waveform in a sliding window feature to include timestamps + window = SlidingWindow(start=start_time, duration=resolution, step=resolution) + waveform = SlidingWindowFeature(waveform, window) + + return OutputAccumulationState(annotation, waveform, real_time, new_next_sample) + + return rx.pipe( + ops.scan(accumulate, OutputAccumulationState.initial()), + ops.map(OutputAccumulationState.to_tuple), + ) diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index 981c8c16..e31d4d84 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -5,8 +5,8 @@ from pyannote.audio.pipelines.utils import PipelineModel from . import functional as fn -from . import operators as my_ops -from .sources import AudioSource +from . import operators as dops +from . import sources as src class OnlineSpeakerDiarization: @@ -41,17 +41,23 @@ def __init__( self.beta = beta self.max_speakers = max_speakers - def get_end_time(self, source: AudioSource) -> Optional[float]: + @property + def sample_rate(self) -> int: + return self.segmentation.model.audio.sample_rate + + def get_end_time(self, source: src.AudioSource) -> Optional[float]: if source.duration is not None: return source.duration - source.duration % self.step return None - def from_source(self, source: AudioSource, output_waveform: bool = False) -> rx.Observable: + def from_source(self, source: src.AudioSource, output_waveform: bool = True) -> rx.Observable: + msg = f"Audio source has sample rate {source.sample_rate}, expected {self.sample_rate}" + assert source.sample_rate == self.sample_rate, msg # Regularize the stream to a specific chunk duration and step regular_stream = source.stream if not source.is_regular: regular_stream = source.stream.pipe( - my_ops.regularize_stream(self.duration, self.step, source.sample_rate) + dops.regularize_stream(self.duration, self.step, source.sample_rate) ) # Branch the stream to calculate chunk segmentation segmentation_stream = regular_stream.pipe( @@ -76,20 +82,21 @@ def from_source(self, source: AudioSource, output_waveform: bool = False) -> rx. pipeline = rx.zip(segmentation_stream, embedding_stream).pipe( ops.starmap(clustering), # Buffer 'num_overlapping' sliding chunks with a step of 1 chunk - my_ops.buffer_slide(aggregation.num_overlapping_windows), + dops.buffer_slide(aggregation.num_overlapping_windows), # Aggregate overlapping output windows ops.map(aggregation), # Binarize output ops.map(fn.Binarize(source.uri, self.tau_active)), ) + # Add corresponding waveform to the output if output_waveform: window_selector = fn.DelayedAggregation( self.step, self.latency, strategy="first", stream_end=end_time ) - pipeline = pipeline.pipe( - ops.zip(regular_stream.pipe( - my_ops.buffer_slide(window_selector.num_overlapping_windows), - ops.map(window_selector), - )) + waveform_stream = regular_stream.pipe( + dops.buffer_slide(window_selector.num_overlapping_windows), + ops.map(window_selector), ) - return pipeline + return rx.zip(pipeline, waveform_stream) + # No waveform needed, add None for consistency + return pipeline.pipe(ops.map(lambda ann: (ann, None))) diff --git a/src/diart/sinks.py b/src/diart/sinks.py index 98ba4e14..0f54f8a1 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -3,88 +3,111 @@ from typing import Literal, Union, Text, Optional, Tuple import matplotlib.pyplot as plt -import numpy as np from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook from pyannote.database.util import load_rttm from pyannote.metrics.diarization import DiarizationErrorRate from rx.core import Observer -class OutputBuilder(Observer): +class RTTMWriter(Observer): + def __init__(self, path: Union[Path, Text], patch_collar: float = 0.05): + super().__init__() + self.patch_collar = patch_collar + self.path = Path(path) + if self.path.exists(): + self.path.unlink() + + def patch_rttm(self): + """Stitch same-speaker turns that are close to each other""" + annotation = list(load_rttm(self.path).values())[0] + with open(self.path, 'w') as file: + annotation.support(self.patch_collar).write_rttm(file) + + def on_next(self, value: Tuple[Annotation, Optional[SlidingWindowFeature]]): + with open(self.path, 'a') as file: + value[0].write_rttm(file) + + def on_error(self, error: Exception): + try: + self.patch_rttm() + except Exception: + print("Error while patching RTTM file:") + print_exc() + exit(1) + + def on_completed(self): + self.patch_rttm() + + +class RealTimePlot(Observer): def __init__( self, duration: float, - step: float, latency: float, - output_path: Optional[Union[Path, Text]] = None, - merge_collar: float = 0.05, visualization: Literal["slide", "accumulate"] = "slide", reference: Optional[Union[Path, Text]] = None, ): super().__init__() assert visualization in ["slide", "accumulate"] - self.collar = merge_collar self.visualization = visualization - self.output_path = output_path self.reference = reference if self.reference is not None: - self.reference = load_rttm(reference) - uri = list(self.reference.keys())[0] - self.reference = self.reference[uri] - self.output: Optional[Annotation] = None - self.waveform: Optional[SlidingWindowFeature] = None - self.window_duration: float = duration - self.step = step + self.reference = list(load_rttm(reference).values())[0] + self.window_duration = duration self.latency = latency - self.real_time = 0 self.figure, self.axs, self.num_axs = None, None, -1 - def init_num_axs(self): + def _init_num_axs(self, waveform: Optional[SlidingWindowFeature]): if self.num_axs == -1: self.num_axs = 1 - if self.waveform is not None: + if waveform is not None: self.num_axs += 1 if self.reference is not None: self.num_axs += 1 - def init_figure(self): - self.init_num_axs() + def _init_figure(self, waveform: Optional[SlidingWindowFeature]): + self._init_num_axs(waveform) self.figure, self.axs = plt.subplots(self.num_axs, 1, figsize=(10, 2 * self.num_axs)) if self.num_axs == 1: self.axs = [self.axs] - def draw(self): - # Initialize figure if first call - if self.figure is None: - self.init_figure() - - # Clear all axs + def _clear_axs(self): for i in range(self.num_axs): self.axs[i].clear() - # Determine plot bounds + def get_plot_bounds(self, real_time: float) -> Segment: start_time = 0 - end_time = self.real_time - self.latency + end_time = real_time - self.latency if self.visualization == "slide": start_time = max(0., end_time - self.window_duration) - notebook.crop = Segment(start_time, end_time) + return Segment(start_time, end_time) - # Plot internal state + def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): + prediction, waveform, real_time = values + # Initialize figure if first call + if self.figure is None: + self._init_figure(waveform) + # Clear previous plots + self._clear_axs() + # Set plot bounds + notebook.crop = self.get_plot_bounds(real_time) + + # Plot current values if self.reference is not None: metric = DiarizationErrorRate() - mapping = metric.optimal_mapping(self.reference, self.output) - self.output.rename_labels(mapping=mapping, copy=False) - notebook.plot_annotation(self.output, self.axs[0]) + mapping = metric.optimal_mapping(self.reference, prediction) + prediction.rename_labels(mapping=mapping, copy=False) + notebook.plot_annotation(prediction, self.axs[0]) self.axs[0].set_title("Output") if self.num_axs == 2: - if self.waveform is not None: - notebook.plot_feature(self.waveform, self.axs[1]) + if waveform is not None: + notebook.plot_feature(waveform, self.axs[1]) self.axs[1].set_title("Audio") elif self.reference is not None: notebook.plot_annotation(self.reference, self.axs[1]) self.axs[1].set_title("Reference") elif self.num_axs == 3: - notebook.plot_feature(self.waveform, self.axs[1]) + notebook.plot_feature(waveform, self.axs[1]) self.axs[1].set_title("Audio") notebook.plot_annotation(self.reference, self.axs[2]) self.axs[2].set_title("Reference") @@ -95,40 +118,6 @@ def draw(self): self.figure.canvas.flush_events() plt.pause(0.05) - def on_next(self, value: Union[Annotation, Tuple[Annotation, SlidingWindowFeature]]): - if isinstance(value, Annotation): - annotation, waveform = value, None - else: - annotation, waveform = value - - # Update output annotation - if self.output is None: - self.output = annotation - self.real_time = self.window_duration - else: - self.output = self.output.update(annotation).support(self.collar) - self.real_time += self.step - - # Update waveform - if waveform is not None: - if self.waveform is None: - self.waveform = waveform - else: - # FIXME time complexity can be better with pre-allocation of a numpy array - new_samples = np.concatenate([self.waveform.data, waveform.data], axis=0) - self.waveform = SlidingWindowFeature(new_samples, self.waveform.sliding_window) - - # Draw new output - self.draw() - - # Save RTTM if possible - if self.output_path is not None: - with open(Path(self.output_path), 'w') as file: - self.output.write_rttm(file) - def on_error(self, error: Exception): print_exc() exit(1) - - def on_completed(self): - print("Stream completed") diff --git a/src/diart/sources.py b/src/diart/sources.py index e42be30a..3fa11b12 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -24,7 +24,6 @@ def __init__(self, uri: Text, sample_rate: int): self.uri = uri self.sample_rate = sample_rate self.stream = Subject() - self.resolution = 1 / sample_rate @property def is_regular(self) -> bool: @@ -163,8 +162,6 @@ class FileAudioSource(AudioSource): The file to stream. uri: Text Unique identifier of the audio source. - sample_rate: int - Sample rate of the audio source. reader: AudioFileReader Determines how the file will be read. """ @@ -172,10 +169,9 @@ def __init__( self, file: AudioFile, uri: Text, - sample_rate: int, reader: AudioFileReader ): - super().__init__(uri, sample_rate) + super().__init__(uri, reader.sample_rate) self.reader = reader self._duration = self.reader.get_duration(file) self.file = file