Skip to content

Commit

Permalink
Fix type hinting and add named tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
simonottenhauskenbun committed Apr 5, 2024
1 parent 5d9c591 commit cc5b190
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 24 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ __pycache__/
MANIFEST
.Python
env/
venv/
bin/
build/
develop-eggs/
Expand Down Expand Up @@ -61,4 +62,4 @@ doc/.ipynb_checkpoints
# PyCharm
.idea/

.mypy_cache/
.mypy_cache/
81 changes: 58 additions & 23 deletions pyannote/core/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
Iterator,
Text,
TYPE_CHECKING,
NamedTuple,
)

import numpy as np
Expand All @@ -139,7 +140,17 @@
from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode

if TYPE_CHECKING:
import pandas as pd
import pandas as pd # type: ignore


class SegmentTrack(NamedTuple):
segment: Segment
track: TrackName

class SegmentTrackLabel(NamedTuple):
segment: Segment
track: TrackName
label: Label


class Annotation:
Expand Down Expand Up @@ -187,7 +198,7 @@ def __init__(self, uri: Optional[str] = None, modality: Optional[str] = None):
self._labelNeedsUpdate: Dict[Label, bool] = {}

# timeline meant to store all annotated segments
self._timeline: Timeline = None
self._timeline: Optional[Timeline] = None
self._timelineNeedsUpdate: bool = True

@property
Expand All @@ -213,7 +224,7 @@ def _updateLabels(self):

# accumulate segments for updated labels
_segments = {label: [] for label in update}
for segment, track, label in self.itertracks(yield_label=True):
for segment, track, label in self.itertracks_with_labels():
if label in update:
_segments[label].append(segment)

Expand Down Expand Up @@ -261,9 +272,13 @@ def itersegments(self):

def itertracks(
self, yield_label: bool = False
) -> Iterator[Union[Tuple[Segment, TrackName], Tuple[Segment, TrackName, Label]]]:
) -> Iterator[Union[SegmentTrack, SegmentTrackLabel]]:
"""Iterate over tracks (in chronological order)
Typed version of :func:`itertracks`:
- :func:`itertracks_without_labels` yields (segment, track) tuples (SegmentTrack)
- :func:`itertracks_with_labels` yields (segment, track, label) tuples (SegmentTrackLabel)
Parameters
----------
yield_label : bool, optional
Expand All @@ -277,7 +292,7 @@ def itertracks(
>>> for segment, track in annotation.itertracks():
... # do something with the track
>>> for segment, track, label in annotation.itertracks(yield_label=True):
>>> for segment, track, label in annotation.itertracks_with_labels():
... # do something with the track and its label
"""

Expand All @@ -286,9 +301,17 @@ def itertracks(
tracks.items(), key=lambda tl: (str(tl[0]), str(tl[1]))
):
if yield_label:
yield segment, track, lbl
yield SegmentTrackLabel(segment, track, lbl)
else:
yield segment, track
yield SegmentTrack(segment, track)

def itertracks_with_labels(self) -> Iterator[SegmentTrackLabel]:
"""Typed version of :func:`itertracks`(yield_label=True)"""
return self.itertracks_with_labels() # type: ignore

def itertracks_without_labels(self) -> Iterator[SegmentTrack]:
"""Typed version of :func:`itertracks`(yield_label=False)"""
return self.itertracks(yield_label=False) # type: ignore

def _updateTimeline(self):
self._timeline = Timeline(segments=self._tracks, uri=self.uri)
Expand Down Expand Up @@ -317,9 +340,14 @@ def get_timeline(self, copy: bool = True) -> Timeline:
"""
if self._timelineNeedsUpdate:
self._updateTimeline()

timeline_ = self._timeline
if timeline_ is None:
timeline_ = Timeline(uri=self.uri)

if copy:
return self._timeline.copy()
return self._timeline
return timeline_.copy()
return timeline_

def __eq__(self, other: "Annotation"):
"""Equality
Expand All @@ -330,14 +358,14 @@ def __eq__(self, other: "Annotation"):
labels are equal.
"""
pairOfTracks = itertools.zip_longest(
self.itertracks(yield_label=True), other.itertracks(yield_label=True)
self.itertracks_with_labels(), other.itertracks_with_labels()
)
return all(t1 == t2 for t1, t2 in pairOfTracks)

def __ne__(self, other: "Annotation"):
"""Inequality"""
pairOfTracks = itertools.zip_longest(
self.itertracks(yield_label=True), other.itertracks(yield_label=True)
self.itertracks_with_labels(), other.itertracks_with_labels()
)

return any(t1 != t2 for t1, t2 in pairOfTracks)
Expand Down Expand Up @@ -376,7 +404,7 @@ def _iter_rttm(self) -> Iterator[Text]:
f'containing spaces (got: "{uri}").'
)
raise ValueError(msg)
for segment, _, label in self.itertracks(yield_label=True):
for segment, _, label in self.itertracks_with_labels():
if isinstance(label, Text) and " " in label:
msg = (
f"Space-separated RTTM file format does not allow labels "
Expand Down Expand Up @@ -421,7 +449,7 @@ def _iter_lab(self) -> Iterator[Text]:
iterator: Iterator[str]
An iterator over LAB text lines
"""
for segment, _, label in self.itertracks(yield_label=True):
for segment, _, label in self.itertracks_with_labels():
if isinstance(label, Text) and " " in label:
msg = (
f"Space-separated LAB file format does not allow labels "
Expand Down Expand Up @@ -556,6 +584,9 @@ def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation
else:
raise NotImplementedError("unsupported mode: '%s'" % mode)

else:
raise TypeError("unsupported support type: '%s'" % type(support))

def extrude(
self, removed: Support, mode: CropMode = "intersection"
) -> "Annotation":
Expand Down Expand Up @@ -775,7 +806,7 @@ def __str__(self):
"""Human-friendly representation"""
# TODO: use pandas.DataFrame
return "\n".join(
["%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True)]
["%s %s %s" % (s, t, l) for s, t, l in self.itertracks_with_labels()]
)

def __delitem__(self, key: Key):
Expand Down Expand Up @@ -1020,7 +1051,7 @@ def update(self, annotation: "Annotation", copy: bool = False) -> "Annotation":
result = self.copy() if copy else self

# TODO speed things up by working directly with annotation internals
for segment, track, label in annotation.itertracks(yield_label=True):
for segment, track, label in annotation.itertracks_with_labels():
result[segment, track] = label

return result
Expand Down Expand Up @@ -1178,7 +1209,7 @@ def argmax(self, support: Optional[Support] = None) -> Optional[Label]:
key=lambda x: x[1],
)[0]

def rename_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
def rename_tracks(self, generator: Union[LabelGenerator, Iterable[str], Iterable[int]] = "string") -> "Annotation":
"""Rename all tracks
Parameters
Expand Down Expand Up @@ -1215,13 +1246,17 @@ def rename_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
renamed = self.__class__(uri=self.uri, modality=self.modality)

if generator == "string":
generator = string_generator()
generator_ = string_generator()
elif generator == "int":
generator = int_generator()
generator_ = int_generator()
elif isinstance(generator, Iterable):
generator_ = iter(generator)
else:
raise ValueError("generator must be 'string', 'int', or iterable")

# TODO speed things up by working directly with annotation internals
for s, _, label in self.itertracks(yield_label=True):
renamed[s, next(generator)] = label
for s, _, label in self.itertracks_with_labels():
renamed[s, next(generator_)] = label
return renamed

def rename_labels(
Expand Down Expand Up @@ -1303,7 +1338,7 @@ def relabel_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
generator = int_generator()

relabeled = self.empty()
for s, t, _ in self.itertracks(yield_label=True):
for s, t, _ in self.itertracks_with_labels():
relabeled[s, t] = next(generator)

return relabeled
Expand Down Expand Up @@ -1439,11 +1474,11 @@ def discretize(
duration: Optional[float] = None,
):
"""Discretize
Parameters
----------
support : Segment, optional
Part of annotation to discretize.
Part of annotation to discretize.
Defaults to annotation full extent.
resolution : float or SlidingWindow, optional
Defaults to 10ms frames.
Expand Down

0 comments on commit cc5b190

Please sign in to comment.