diff --git a/pyannote/core/annotation.py b/pyannote/core/annotation.py index fc2d17a..bb77ebc 100755 --- a/pyannote/core/annotation.py +++ b/pyannote/core/annotation.py @@ -121,11 +121,13 @@ Tuple, Iterator, Text, + overload, TYPE_CHECKING, ) import numpy as np from sortedcontainers import SortedDict +from typing_extensions import Literal from . import ( PYANNOTE_SEGMENT, @@ -136,7 +138,17 @@ from .timeline import Timeline from .feature import SlidingWindowFeature from .utils.generators import string_generator, int_generator -from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode +from .utils.types import ( + Label, + Key, + Support, + LabelGenerator, + TrackName, + Track, + LabeledTrack, + TrackIterator, + CropMode, +) if TYPE_CHECKING: import pandas as pd @@ -259,9 +271,14 @@ def itersegments(self): """ return iter(self._tracks) - def itertracks( - self, yield_label: bool = False - ) -> Iterator[Union[Tuple[Segment, TrackName], Tuple[Segment, TrackName, Label]]]: + @overload + def itertracks(self, yield_label: Literal[False] = ...) -> Iterator[Track]: ... + @overload + def itertracks(self, yield_label: Literal[True]) -> Iterator[LabeledTrack]: ... + @overload + def itertracks(self, yield_label: bool) -> TrackIterator: ... + + def itertracks(self, yield_label: bool = False) -> TrackIterator: """Iterate over tracks (in chronological order) Parameters diff --git a/pyannote/core/utils/types.py b/pyannote/core/utils/types.py index c4225a5..5de0a1a 100644 --- a/pyannote/core/utils/types.py +++ b/pyannote/core/utils/types.py @@ -13,6 +13,9 @@ LabelGeneratorMode = Literal['int', 'string'] LabelGenerator = Union[LabelGeneratorMode, Iterator[Label]] TrackName = Union[str, int] +Track = Tuple['Segment', TrackName] +LabeledTrack = Tuple['Segment', TrackName, Label] +TrackIterator = Union[Iterator[Track], Iterator[LabeledTrack]] Key = Union['Segment', Tuple['Segment', TrackName]] Resource = Union['Segment', 'Timeline', 'SlidingWindowFeature', 'Annotation']