Skip to content

Commit

Permalink
use overload bool-literal typing
Browse files Browse the repository at this point in the history
  • Loading branch information
simonottenhauskenbun committed Apr 17, 2024
1 parent f81a75a commit 230cf3b
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions pyannote/core/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
from collections import defaultdict
from typing import (
Hashable,
Literal,
Optional,
Dict,
Union,
Expand All @@ -123,6 +124,7 @@
Text,
TYPE_CHECKING,
NamedTuple,
overload,
)

import numpy as np
Expand Down Expand Up @@ -224,7 +226,7 @@ def _updateLabels(self):

# accumulate segments for updated labels
_segments = {label: [] for label in update}
for segment, track, label in self.itertracks_with_labels():
for segment, track, label in self.itertracks(yield_label=True):
if label in update:
_segments[label].append(segment)

Expand Down Expand Up @@ -270,6 +272,13 @@ def itersegments(self):
"""
return iter(self._tracks)

@overload
def itertracks(self, yield_label: Literal[False] = ...) -> Iterator[SegmentTrack]: ...
@overload
def itertracks(self, yield_label: Literal[True]) -> Iterator[SegmentTrackLabel]: ...
@overload
def itertracks(self, yield_label: bool) -> Iterator[Union[SegmentTrack, SegmentTrackLabel]]: ...

def itertracks(
self, yield_label: bool = False
) -> Iterator[Union[SegmentTrack, SegmentTrackLabel]]:
Expand All @@ -292,7 +301,7 @@ def itertracks(
>>> for segment, track in annotation.itertracks():
... # do something with the track
>>> for segment, track, label in annotation.itertracks_with_labels():
>>> for segment, track, label in annotation.itertracks(yield_label=True):
... # do something with the track and its label
"""

Expand All @@ -307,11 +316,11 @@ def itertracks(

def itertracks_with_labels(self) -> Iterator[SegmentTrackLabel]:
"""Typed version of :func:`itertracks`(yield_label=True)"""
return self.itertracks(yield_label=True) # type: ignore
return self.itertracks(yield_label=True)

def itertracks_without_labels(self) -> Iterator[SegmentTrack]:
"""Typed version of :func:`itertracks`(yield_label=False)"""
return self.itertracks(yield_label=False) # type: ignore
return self.itertracks(yield_label=False)

def _updateTimeline(self):
self._timeline = Timeline(segments=self._tracks, uri=self.uri)
Expand Down Expand Up @@ -358,14 +367,14 @@ def __eq__(self, other: "Annotation"):
labels are equal.
"""
pairOfTracks = itertools.zip_longest(
self.itertracks_with_labels(), other.itertracks_with_labels()
self.itertracks(yield_label=True), other.itertracks(yield_label=True)
)
return all(t1 == t2 for t1, t2 in pairOfTracks)

def __ne__(self, other: "Annotation"):
"""Inequality"""
pairOfTracks = itertools.zip_longest(
self.itertracks_with_labels(), other.itertracks_with_labels()
self.itertracks(yield_label=True), other.itertracks(yield_label=True)
)

return any(t1 != t2 for t1, t2 in pairOfTracks)
Expand Down Expand Up @@ -404,7 +413,7 @@ def _iter_rttm(self) -> Iterator[Text]:
f'containing spaces (got: "{uri}").'
)
raise ValueError(msg)
for segment, _, label in self.itertracks_with_labels():
for segment, _, label in self.itertracks(yield_label=True):
if isinstance(label, Text) and " " in label:
msg = (
f"Space-separated RTTM file format does not allow labels "
Expand Down Expand Up @@ -449,7 +458,7 @@ def _iter_lab(self) -> Iterator[Text]:
iterator: Iterator[str]
An iterator over LAB text lines
"""
for segment, _, label in self.itertracks_with_labels():
for segment, _, label in self.itertracks(yield_label=True):
if isinstance(label, Text) and " " in label:
msg = (
f"Space-separated LAB file format does not allow labels "
Expand Down Expand Up @@ -806,7 +815,7 @@ def __str__(self):
"""Human-friendly representation"""
# TODO: use pandas.DataFrame
return "\n".join(
["%s %s %s" % (s, t, l) for s, t, l in self.itertracks_with_labels()]
["%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True)]
)

def __delitem__(self, key: Key):
Expand Down Expand Up @@ -1051,7 +1060,7 @@ def update(self, annotation: "Annotation", copy: bool = False) -> "Annotation":
result = self.copy() if copy else self

# TODO speed things up by working directly with annotation internals
for segment, track, label in annotation.itertracks_with_labels():
for segment, track, label in annotation.itertracks(yield_label=True):
result[segment, track] = label

return result
Expand Down Expand Up @@ -1255,7 +1264,7 @@ def rename_tracks(self, generator: Union[LabelGenerator, Iterable[str], Iterable
raise ValueError("generator must be 'string', 'int', or iterable")

# TODO speed things up by working directly with annotation internals
for s, _, label in self.itertracks_with_labels():
for s, _, label in self.itertracks(yield_label=True):
renamed[s, next(generator_)] = label
return renamed

Expand Down Expand Up @@ -1338,7 +1347,7 @@ def relabel_tracks(self, generator: LabelGenerator = "string") -> "Annotation":
generator = int_generator()

relabeled = self.empty()
for s, t, _ in self.itertracks_with_labels():
for s, t, _ in self.itertracks(yield_label=True):
relabeled[s, t] = next(generator)

return relabeled
Expand Down

0 comments on commit 230cf3b

Please sign in to comment.