Skip to content

Commit

Permalink
fix: use overloads to properly type itertracks
Browse files Browse the repository at this point in the history
Add overloads for itertracks so that its return type is properly reflected when
passed literal True or False (or nothing). Previously, when type checking

for segment, track_name, label in annotation.itertracks(yield_label=True): ...

the type checker would complain that the iterator could contain 2-tuples. A
similar issue existed when passing False.
  • Loading branch information
evfinkn committed Apr 3, 2024
1 parent 4872a68 commit d981be2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
25 changes: 21 additions & 4 deletions pyannote/core/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,13 @@
Tuple,
Iterator,
Text,
overload,
TYPE_CHECKING,
)

import numpy as np
from sortedcontainers import SortedDict
from typing_extensions import Literal

from . import (
PYANNOTE_SEGMENT,
Expand All @@ -136,7 +138,17 @@
from .timeline import Timeline
from .feature import SlidingWindowFeature
from .utils.generators import string_generator, int_generator
from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode
from .utils.types import (
Label,
Key,
Support,
LabelGenerator,
TrackName,
Track,
LabeledTrack,
TrackIterator,
CropMode,
)

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -259,9 +271,14 @@ def itersegments(self):
"""
return iter(self._tracks)

def itertracks(
self, yield_label: bool = False
) -> Iterator[Union[Tuple[Segment, TrackName], Tuple[Segment, TrackName, Label]]]:
@overload
def itertracks(self, yield_label: Literal[False] = ...) -> Iterator[Track]: ...
@overload
def itertracks(self, yield_label: Literal[True]) -> Iterator[LabeledTrack]: ...
@overload
def itertracks(self, yield_label: bool) -> TrackIterator: ...

def itertracks(self, yield_label: bool = False) -> TrackIterator:
"""Iterate over tracks (in chronological order)
Parameters
Expand Down
3 changes: 3 additions & 0 deletions pyannote/core/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
LabelGeneratorMode = Literal['int', 'string']
LabelGenerator = Union[LabelGeneratorMode, Iterator[Label]]
TrackName = Union[str, int]
Track = Tuple['Segment', TrackName]
LabeledTrack = Tuple['Segment', TrackName, Label]
TrackIterator = Union[Iterator[Track], Iterator[LabeledTrack]]
Key = Union['Segment', Tuple['Segment', TrackName]]
Resource = Union['Segment', 'Timeline', 'SlidingWindowFeature',
'Annotation']
Expand Down

0 comments on commit d981be2

Please sign in to comment.