From d981be262d050fa13d19c4910fbae14c49a4d118 Mon Sep 17 00:00:00 2001 From: evfinkn Date: Wed, 3 Apr 2024 16:26:36 -0500 Subject: [PATCH] fix: use overloads to properly type itertracks Add overloads for itertracks so that its return type is properly reflected when passed literal True or False (or nothing). Previously, when type checking for segment, track_name, label in annotation.itertracks(yield_label=True): ... the type checker would complain that the iterator could contain 2-tuples. A similar issue existed when passing False. --- pyannote/core/annotation.py | 25 +++++++++++++++++++++---- pyannote/core/utils/types.py | 3 +++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/pyannote/core/annotation.py b/pyannote/core/annotation.py index fc2d17a..bb77ebc 100755 --- a/pyannote/core/annotation.py +++ b/pyannote/core/annotation.py @@ -121,11 +121,13 @@ Tuple, Iterator, Text, + overload, TYPE_CHECKING, ) import numpy as np from sortedcontainers import SortedDict +from typing_extensions import Literal from . import ( PYANNOTE_SEGMENT, @@ -136,7 +138,17 @@ from .timeline import Timeline from .feature import SlidingWindowFeature from .utils.generators import string_generator, int_generator -from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode +from .utils.types import ( + Label, + Key, + Support, + LabelGenerator, + TrackName, + Track, + LabeledTrack, + TrackIterator, + CropMode, +) if TYPE_CHECKING: import pandas as pd @@ -259,9 +271,14 @@ def itersegments(self): """ return iter(self._tracks) - def itertracks( - self, yield_label: bool = False - ) -> Iterator[Union[Tuple[Segment, TrackName], Tuple[Segment, TrackName, Label]]]: + @overload + def itertracks(self, yield_label: Literal[False] = ...) -> Iterator[Track]: ... + @overload + def itertracks(self, yield_label: Literal[True]) -> Iterator[LabeledTrack]: ... + @overload + def itertracks(self, yield_label: bool) -> TrackIterator: ... + + def itertracks(self, yield_label: bool = False) -> TrackIterator: """Iterate over tracks (in chronological order) Parameters diff --git a/pyannote/core/utils/types.py b/pyannote/core/utils/types.py index c4225a5..5de0a1a 100644 --- a/pyannote/core/utils/types.py +++ b/pyannote/core/utils/types.py @@ -13,6 +13,9 @@ LabelGeneratorMode = Literal['int', 'string'] LabelGenerator = Union[LabelGeneratorMode, Iterator[Label]] TrackName = Union[str, int] +Track = Tuple['Segment', TrackName] +LabeledTrack = Tuple['Segment', TrackName, Label] +TrackIterator = Union[Iterator[Track], Iterator[LabeledTrack]] Key = Union['Segment', Tuple['Segment', TrackName]] Resource = Union['Segment', 'Timeline', 'SlidingWindowFeature', 'Annotation']