Skip to content

Commit

Permalink
Add special audio source for Apple devices (#182)
Browse files Browse the repository at this point in the history
* Blacken entire code base

* Add AppleDeviceAudioSource
  • Loading branch information
juanmc2005 authored Oct 11, 2023
1 parent 25d2196 commit 1d1d826
Show file tree
Hide file tree
Showing 23 changed files with 620 additions and 226 deletions.
1 change: 1 addition & 0 deletions src/diart/argdoc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
SEGMENTATION = "Segmentation model name from pyannote"
EMBEDDING = "Embedding model name from pyannote"
DURATION = "Chunk duration (in seconds)"
STEP = "Sliding window step (in seconds)"
LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION"
TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1"
Expand Down
56 changes: 37 additions & 19 deletions src/diart/blocks/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ class AggregationStrategy(ABC):
"""

def __init__(self, cropping_mode: Literal["strict", "loose", "center"] = "loose"):
assert cropping_mode in ["strict", "loose", "center"], f"Invalid cropping mode `{cropping_mode}`"
assert cropping_mode in [
"strict",
"loose",
"center",
], f"Invalid cropping mode `{cropping_mode}`"
self.cropping_mode = cropping_mode

@staticmethod
def build(
name: Literal["mean", "hamming", "first"],
cropping_mode: Literal["strict", "loose", "center"] = "loose"
) -> 'AggregationStrategy':
cropping_mode: Literal["strict", "loose", "center"] = "loose",
) -> "AggregationStrategy":
"""Build an AggregationStrategy instance based on its name"""
assert name in ("mean", "hamming", "first")
if name == "mean":
Expand All @@ -35,7 +39,9 @@ def build(
else:
return FirstOnlyStrategy(cropping_mode)

def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> SlidingWindowFeature:
def __call__(
self, buffers: List[SlidingWindowFeature], focus: Segment
) -> SlidingWindowFeature:
"""Aggregate chunks over a specific region.
Parameters
Expand All @@ -53,21 +59,23 @@ def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> Slidi
aggregation = self.aggregate(buffers, focus)
resolution = focus.duration / aggregation.shape[0]
resolution = SlidingWindow(
start=focus.start,
duration=resolution,
step=resolution
start=focus.start, duration=resolution, step=resolution
)
return SlidingWindowFeature(aggregation, resolution)

@abstractmethod
def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
def aggregate(
self, buffers: List[SlidingWindowFeature], focus: Segment
) -> np.ndarray:
pass


class HammingWeightedAverageStrategy(AggregationStrategy):
"""Compute the average weighted by the corresponding Hamming-window aligned to each buffer"""

def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
def aggregate(
self, buffers: List[SlidingWindowFeature], focus: Segment
) -> np.ndarray:
num_frames, num_speakers = buffers[0].data.shape
hamming, intersection = [], []
for buffer in buffers:
Expand All @@ -87,19 +95,25 @@ def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.n
class AverageStrategy(AggregationStrategy):
"""Compute a simple average over the focus region"""

def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
def aggregate(
self, buffers: List[SlidingWindowFeature], focus: Segment
) -> np.ndarray:
# Stack all overlapping regions
intersection = np.stack([
buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration)
for buffer in buffers
])
intersection = np.stack(
[
buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration)
for buffer in buffers
]
)
return np.mean(intersection, axis=0)


class FirstOnlyStrategy(AggregationStrategy):
"""Instead of aggregating, keep the first focus region in the buffer list"""

def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
def aggregate(
self, buffers: List[SlidingWindowFeature], focus: Segment
) -> np.ndarray:
return buffers[0].crop(focus, mode=self.cropping_mode, fixed=focus.duration)


Expand Down Expand Up @@ -151,12 +165,16 @@ def __init__(
step: float,
latency: Optional[float] = None,
strategy: Literal["mean", "hamming", "first"] = "hamming",
cropping_mode: Literal["strict", "loose", "center"] = "loose"
cropping_mode: Literal["strict", "loose", "center"] = "loose",
):
self.step = step
self.latency = latency
self.strategy = strategy
assert cropping_mode in ["strict", "loose", "center"], f"Invalid cropping mode `{cropping_mode}`"
assert cropping_mode in [
"strict",
"loose",
"center",
], f"Invalid cropping mode `{cropping_mode}`"
self.cropping_mode = cropping_mode

if self.latency is None:
Expand All @@ -171,7 +189,7 @@ def _prepend(
self,
output_window: SlidingWindowFeature,
output_region: Segment,
buffers: List[SlidingWindowFeature]
buffers: List[SlidingWindowFeature],
):
# FIXME instead of prepending the output of the first chunk,
# add padding of `chunk_duration - latency` seconds at the
Expand All @@ -189,7 +207,7 @@ def _prepend(
resolution = output_region.end / first_output.shape[0]
output_window = SlidingWindowFeature(
first_output,
SlidingWindow(start=0, duration=resolution, step=resolution)
SlidingWindow(start=0, duration=resolution, step=resolution),
)
return output_window

Expand Down
7 changes: 3 additions & 4 deletions src/diart/blocks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class HyperParameter:
high: float

@staticmethod
def from_name(name: Text) -> 'HyperParameter':
def from_name(name: Text) -> "HyperParameter":
if name == "tau_active":
return TauActive
if name == "rho_update":
Expand Down Expand Up @@ -55,7 +55,7 @@ def sample_rate(self) -> int:

@staticmethod
@abstractmethod
def from_dict(data: Any) -> 'PipelineConfig':
def from_dict(data: Any) -> "PipelineConfig":
pass

def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
Expand Down Expand Up @@ -96,7 +96,6 @@ def set_timestamp_shift(self, shift: float):

@abstractmethod
def __call__(
self,
waveforms: Sequence[SlidingWindowFeature]
self, waveforms: Sequence[SlidingWindowFeature]
) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
pass
34 changes: 18 additions & 16 deletions src/diart/blocks/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ class OnlineSpeakerClustering:
max_speakers: int
Maximum number of global speakers to track through a conversation. Defaults to 20.
"""

def __init__(
self,
tau_active: float,
rho_update: float,
delta_new: float,
metric: Optional[str] = "cosine",
max_speakers: int = 20
max_speakers: int = 20,
):
self.tau_active = tau_active
self.rho_update = rho_update
Expand Down Expand Up @@ -116,9 +117,7 @@ def add_center(self, embedding: np.ndarray) -> int:
return center

def identify(
self,
segmentation: SlidingWindowFeature,
embeddings: torch.Tensor
self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor
) -> SpeakerMap:
"""Identify the centroids to which the input speaker embeddings belong.
Expand All @@ -135,15 +134,18 @@ def identify(
A mapping from local speakers to global speakers.
"""
embeddings = embeddings.detach().cpu().numpy()
active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0]
long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0]
active_speakers = np.where(
np.max(segmentation.data, axis=0) >= self.tau_active
)[0]
long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[
0
]
num_local_speakers = segmentation.data.shape[1]

if self.centers is None:
self.init_centers(embeddings.shape[1])
assignments = [
(spk, self.add_center(embeddings[spk]))
for spk in active_speakers
(spk, self.add_center(embeddings[spk])) for spk in active_speakers
]
return SpeakerMapBuilder.hard_map(
shape=(num_local_speakers, self.max_speakers),
Expand All @@ -154,18 +156,16 @@ def identify(
# Obtain a mapping based on distances between embeddings and centers
dist_map = SpeakerMapBuilder.dist(embeddings, self.centers, self.metric)
# Remove any assignments containing invalid speakers
inactive_speakers = np.array([
spk for spk in range(num_local_speakers)
if spk not in active_speakers
])
inactive_speakers = np.array(
[spk for spk in range(num_local_speakers) if spk not in active_speakers]
)
dist_map = dist_map.unmap_speakers(inactive_speakers, self.inactive_centers)
# Keep assignments under the distance threshold
valid_map = dist_map.unmap_threshold(self.delta_new)

# Some speakers might be unidentified
missed_speakers = [
s for s in active_speakers
if not valid_map.is_source_speaker_mapped(s)
s for s in active_speakers if not valid_map.is_source_speaker_mapped(s)
]

# Add assignments to new centers if possible
Expand Down Expand Up @@ -205,8 +205,10 @@ def identify(

return valid_map

def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor) -> SlidingWindowFeature:
def __call__(
self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor
) -> SlidingWindowFeature:
return SlidingWindowFeature(
self.identify(segmentation, embeddings).apply(segmentation.data),
segmentation.sliding_window
segmentation.sliding_window,
)
31 changes: 22 additions & 9 deletions src/diart/blocks/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(
# Default segmentation model is pyannote/segmentation
self.segmentation = segmentation
if self.segmentation is None:
self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
self.segmentation = m.SegmentationModel.from_pyannote(
"pyannote/segmentation"
)

self._duration = duration
self._sample_rate: Optional[int] = None
Expand Down Expand Up @@ -67,7 +69,7 @@ def __init__(
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@staticmethod
def from_dict(data: Any) -> 'SpeakerDiarizationConfig':
def from_dict(data: Any) -> "SpeakerDiarizationConfig":
# Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
device = utils.get(data, "device", None)
if device is None:
Expand Down Expand Up @@ -136,9 +138,15 @@ def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
assert self._config.step <= self._config.latency <= self._config.duration, msg

self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device)
self.segmentation = SpeakerSegmentation(
self._config.segmentation, self._config.device
)
self.embedding = OverlapAwareSpeakerEmbedding(
self._config.embedding, self._config.gamma, self._config.beta, norm=1, device=self._config.device
self._config.embedding,
self._config.gamma,
self._config.beta,
norm=1,
device=self._config.device,
)
self.pred_aggregation = DelayedAggregation(
self._config.step,
Expand Down Expand Up @@ -191,8 +199,7 @@ def reset(self):
self.chunk_buffer, self.pred_buffer = [], []

def __call__(
self,
waveforms: Sequence[SlidingWindowFeature]
self, waveforms: Sequence[SlidingWindowFeature]
) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
batch_size = len(waveforms)
msg = "Pipeline expected at least 1 input"
Expand All @@ -201,13 +208,17 @@ def __call__(
# Create batch from chunk sequence, shape (batch, samples, channels)
batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])

expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
expected_num_samples = int(
np.rint(self.config.duration * self.config.sample_rate)
)
msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
assert batch.shape[1] == expected_num_samples, msg

# Extract segmentation and embeddings
segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
embeddings = self.embedding(batch, segmentations) # shape (batch, speakers, emb_dim)
embeddings = self.embedding(
batch, segmentations
) # shape (batch, speakers, emb_dim)

seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]

Expand Down Expand Up @@ -236,7 +247,9 @@ def __call__(
# Shift prediction timestamps if required
if self.timestamp_shift != 0:
shifted_agg_prediction = Annotation(agg_prediction.uri)
for segment, track, speaker in agg_prediction.itertracks(yield_label=True):
for segment, track, speaker in agg_prediction.itertracks(
yield_label=True
):
new_segment = Segment(
segment.start + self.timestamp_shift,
segment.end + self.timestamp_shift,
Expand Down
22 changes: 16 additions & 6 deletions src/diart/blocks/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None)
def from_pyannote(
model,
use_hf_token: Union[Text, bool, None] = True,
device: Optional[torch.device] = None
) -> 'SpeakerEmbedding':
device: Optional[torch.device] = None,
) -> "SpeakerEmbedding":
emb_model = EmbeddingModel.from_pyannote(model, use_hf_token)
return SpeakerEmbedding(emb_model, device)

def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor:
def __call__(
self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None
) -> torch.Tensor:
"""
Calculate speaker embeddings of input audio.
If weights are given, calculate many speaker embeddings from the same waveform.
Expand Down Expand Up @@ -58,7 +60,7 @@ def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeature
self.model(inputs, weights),
"(batch spk) feat -> batch spk feat",
batch=batch_size,
spk=num_speakers
spk=num_speakers,
)
else:
output = self.model(inputs)
Expand All @@ -76,6 +78,7 @@ class OverlappedSpeechPenalty:
Temperature parameter (actually 1/beta) to lower joint speaker activations.
Defaults to 10.
"""

def __init__(self, gamma: float = 3, beta: float = 10):
self.gamma = gamma
self.beta = beta
Expand Down Expand Up @@ -106,7 +109,11 @@ def __call__(self, embeddings: torch.Tensor) -> torch.Tensor:
batch_size2, num_speakers2, _ = embeddings.shape
assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2
with torch.no_grad():
norm_embs = self.norm * embeddings / torch.norm(embeddings, p=2, dim=-1, keepdim=True)
norm_embs = (
self.norm
* embeddings
/ torch.norm(embeddings, p=2, dim=-1, keepdim=True)
)
return norm_embs


Expand All @@ -131,6 +138,7 @@ class OverlapAwareSpeakerEmbedding:
The device on which to run the embedding model.
Defaults to GPU if available or CPU if not.
"""

def __init__(
self,
model: EmbeddingModel,
Expand All @@ -155,5 +163,7 @@ def from_pyannote(
model = EmbeddingModel.from_pyannote(model, use_hf_token)
return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, device)

def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor:
def __call__(
self, waveform: TemporalFeatures, segmentation: TemporalFeatures
) -> torch.Tensor:
return self.normalize(self.embedding(waveform, self.osp(segmentation)))
Loading

0 comments on commit 1d1d826

Please sign in to comment.