From 230cf3b70b82ef795763f126e0348d3ace5db296 Mon Sep 17 00:00:00 2001 From: Simon Ottenhaus Date: Wed, 17 Apr 2024 08:37:51 +0200 Subject: [PATCH] use overload bool-literal typing --- pyannote/core/annotation.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/pyannote/core/annotation.py b/pyannote/core/annotation.py index 7a9c382..2383140 100755 --- a/pyannote/core/annotation.py +++ b/pyannote/core/annotation.py @@ -111,6 +111,7 @@ from collections import defaultdict from typing import ( Hashable, + Literal, Optional, Dict, Union, @@ -123,6 +124,7 @@ Text, TYPE_CHECKING, NamedTuple, + overload, ) import numpy as np @@ -224,7 +226,7 @@ def _updateLabels(self): # accumulate segments for updated labels _segments = {label: [] for label in update} - for segment, track, label in self.itertracks_with_labels(): + for segment, track, label in self.itertracks(yield_label=True): if label in update: _segments[label].append(segment) @@ -270,6 +272,13 @@ def itersegments(self): """ return iter(self._tracks) + @overload + def itertracks(self, yield_label: Literal[False] = ...) -> Iterator[SegmentTrack]: ... + @overload + def itertracks(self, yield_label: Literal[True]) -> Iterator[SegmentTrackLabel]: ... + @overload + def itertracks(self, yield_label: bool) -> Iterator[Union[SegmentTrack, SegmentTrackLabel]]: ... + def itertracks( self, yield_label: bool = False ) -> Iterator[Union[SegmentTrack, SegmentTrackLabel]]: @@ -292,7 +301,7 @@ def itertracks( >>> for segment, track in annotation.itertracks(): ... # do something with the track - >>> for segment, track, label in annotation.itertracks_with_labels(): + >>> for segment, track, label in annotation.itertracks(yield_label=True): ... # do something with the track and its label """ @@ -307,11 +316,11 @@ def itertracks( def itertracks_with_labels(self) -> Iterator[SegmentTrackLabel]: """Typed version of :func:`itertracks`(yield_label=True)""" - return self.itertracks(yield_label=True) # type: ignore + return self.itertracks(yield_label=True) def itertracks_without_labels(self) -> Iterator[SegmentTrack]: """Typed version of :func:`itertracks`(yield_label=False)""" - return self.itertracks(yield_label=False) # type: ignore + return self.itertracks(yield_label=False) def _updateTimeline(self): self._timeline = Timeline(segments=self._tracks, uri=self.uri) @@ -358,14 +367,14 @@ def __eq__(self, other: "Annotation"): labels are equal. """ pairOfTracks = itertools.zip_longest( - self.itertracks_with_labels(), other.itertracks_with_labels() + self.itertracks(yield_label=True), other.itertracks(yield_label=True) ) return all(t1 == t2 for t1, t2 in pairOfTracks) def __ne__(self, other: "Annotation"): """Inequality""" pairOfTracks = itertools.zip_longest( - self.itertracks_with_labels(), other.itertracks_with_labels() + self.itertracks(yield_label=True), other.itertracks(yield_label=True) ) return any(t1 != t2 for t1, t2 in pairOfTracks) @@ -404,7 +413,7 @@ def _iter_rttm(self) -> Iterator[Text]: f'containing spaces (got: "{uri}").' ) raise ValueError(msg) - for segment, _, label in self.itertracks_with_labels(): + for segment, _, label in self.itertracks(yield_label=True): if isinstance(label, Text) and " " in label: msg = ( f"Space-separated RTTM file format does not allow labels " @@ -449,7 +458,7 @@ def _iter_lab(self) -> Iterator[Text]: iterator: Iterator[str] An iterator over LAB text lines """ - for segment, _, label in self.itertracks_with_labels(): + for segment, _, label in self.itertracks(yield_label=True): if isinstance(label, Text) and " " in label: msg = ( f"Space-separated LAB file format does not allow labels " @@ -806,7 +815,7 @@ def __str__(self): """Human-friendly representation""" # TODO: use pandas.DataFrame return "\n".join( - ["%s %s %s" % (s, t, l) for s, t, l in self.itertracks_with_labels()] + ["%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True)] ) def __delitem__(self, key: Key): @@ -1051,7 +1060,7 @@ def update(self, annotation: "Annotation", copy: bool = False) -> "Annotation": result = self.copy() if copy else self # TODO speed things up by working directly with annotation internals - for segment, track, label in annotation.itertracks_with_labels(): + for segment, track, label in annotation.itertracks(yield_label=True): result[segment, track] = label return result @@ -1255,7 +1264,7 @@ def rename_tracks(self, generator: Union[LabelGenerator, Iterable[str], Iterable raise ValueError("generator must be 'string', 'int', or iterable") # TODO speed things up by working directly with annotation internals - for s, _, label in self.itertracks_with_labels(): + for s, _, label in self.itertracks(yield_label=True): renamed[s, next(generator_)] = label return renamed @@ -1338,7 +1347,7 @@ def relabel_tracks(self, generator: LabelGenerator = "string") -> "Annotation": generator = int_generator() relabeled = self.empty() - for s, t, _ in self.itertracks_with_labels(): + for s, t, _ in self.itertracks(yield_label=True): relabeled[s, t] = next(generator) return relabeled