From 138027339c6d5158f456965d976929be0c9f9abd Mon Sep 17 00:00:00 2001 From: hadware Date: Fri, 13 Dec 2019 11:21:43 +0100 Subject: [PATCH] feat: add type hinting and python3-ize code base --- pyannote/core/annotation.py | 136 +++++++++++++++------------- pyannote/core/feature.py | 63 ++++++++----- pyannote/core/json.py | 21 +++-- pyannote/core/notebook.py | 97 +++++++++++--------- pyannote/core/scores.py | 99 ++++++++++++--------- pyannote/core/segment.py | 142 ++++++++++++++++++------------ pyannote/core/timeline.py | 109 ++++++++++++++--------- pyannote/core/utils/distance.py | 2 +- pyannote/core/utils/generators.py | 13 ++- pyannote/core/utils/helper.py | 2 +- pyannote/core/utils/types.py | 15 ++++ setup.py | 2 + 12 files changed, 419 insertions(+), 282 deletions(-) create mode 100644 pyannote/core/utils/types.py diff --git a/pyannote/core/annotation.py b/pyannote/core/annotation.py index b31564c..c1c62c1 100755 --- a/pyannote/core/annotation.py +++ b/pyannote/core/annotation.py @@ -108,19 +108,22 @@ """ import itertools +from typing import Optional, Dict, Union, Iterable, List, Set, TextIO, Tuple, Iterator + import numpy as np -from typing import TextIO +import pandas as pd +from sortedcontainers import SortedDict from . import PYANNOTE_URI, PYANNOTE_MODALITY, \ PYANNOTE_SEGMENT, PYANNOTE_TRACK, PYANNOTE_LABEL -from sortedcontainers import SortedDict +from .json import PYANNOTE_JSON, PYANNOTE_JSON_CONTENT from .segment import Segment from .timeline import Timeline -from .json import PYANNOTE_JSON, PYANNOTE_JSON_CONTENT from .utils.generators import string_generator, int_generator +from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode -class Annotation(object): +class Annotation: """Annotation Parameters @@ -138,7 +141,10 @@ class Annotation(object): """ @classmethod - def from_df(cls, df, uri=None, modality=None): + def from_df(cls, + df: pd.DataFrame, + uri: Optional[str] = None, + modality: Optional[str] = None) -> 'Annotation': df = df[[PYANNOTE_SEGMENT, PYANNOTE_TRACK, PYANNOTE_LABEL]] @@ -159,32 +165,32 @@ def from_df(cls, df, uri=None, modality=None): return annotation - def __init__(self, uri=None, modality=None): - - super(Annotation, self).__init__() + def __init__(self, uri: Optional[str] = None, modality: Optional[str] = None): - self._uri = uri - self.modality = modality + self._uri: Optional[str] = uri + self.modality: Optional[str] = modality # sorted dictionary # keys: annotated segments # values: {track: label} dictionary - self._tracks = SortedDict() + self._tracks: Dict[Segment, Dict[TrackName, Label]] = SortedDict() # dictionary # key: label # value: timeline - self._labels = {} - self._labelNeedsUpdate = {} + self._labels: Dict[Label, Timeline] = {} + self._labelNeedsUpdate: [Label, bool] = {} # timeline meant to store all annotated segments - self._timeline = None - self._timelineNeedsUpdate = True + self._timeline: Timeline = None + self._timelineNeedsUpdate: bool = True - def _get_uri(self): + @property + def uri(self): return self._uri - def _set_uri(self, uri): + @uri.setter + def uri(self, uri: str): # update uri for all internal timelines for label in self.labels(): timeline = self.label_timeline(label, copy=False) @@ -193,8 +199,6 @@ def _set_uri(self, uri): timeline.uri = uri self._uri = uri - uri = property(_get_uri, fset=_set_uri, doc="Resource identifier") - def _updateLabels(self): # list of labels that needs to be updated @@ -250,7 +254,11 @@ def itersegments(self): """ return iter(self._tracks) - def itertracks(self, yield_label=False): + def itertracks(self, yield_label: bool = False) \ + -> Iterator[Union[ + Tuple[Segment, TrackName], + Tuple[Segment, TrackName, Label] + ]]: """Iterate over tracks (in chronological order) Parameters @@ -282,7 +290,7 @@ def _updateTimeline(self): self._timeline = Timeline(segments=self._tracks, uri=self.uri) self._timelineNeedsUpdate = False - def get_timeline(self, copy=True): + def get_timeline(self, copy: bool = True) -> Timeline: """Get timeline made of all annotated segments Parameters @@ -309,7 +317,7 @@ def get_timeline(self, copy=True): return self._timeline.copy() return self._timeline - def __eq__(self, other): + def __eq__(self, other: 'Annotation'): """Equality >>> annotation == other @@ -322,7 +330,7 @@ def __eq__(self, other): other.itertracks(yield_label=True)) return all(t1 == t2 for t1, t2 in pairOfTracks) - def __ne__(self, other): + def __ne__(self, other: 'Annotation'): """Inequality""" pairOfTracks = itertools.zip_longest( self.itertracks(yield_label=True), @@ -330,7 +338,7 @@ def __ne__(self, other): return any(t1 != t2 for t1, t2 in pairOfTracks) - def __contains__(self, included): + def __contains__(self, included: Union[Segment, Timeline]): """Inclusion Check whether every segment of `included` does exist in annotation. @@ -361,17 +369,17 @@ def write_rttm(self, file: TextIO): >>> with open('file.rttm', 'w') as file: ... annotation.write_rttm(file) """ - + uri = self.uri if self.uri else "" - - for segment, _, label in self.itertracks(yield_label=True): + + for segment, _, label in self.itertracks(yield_label=True): line = ( f'SPEAKER {uri} 1 {segment.start:.3f} {segment.duration:.3f} ' f' {label} \n' ) file.write(line) - def crop(self, support, mode='intersection'): + def crop(self, support: Support, mode: CropMode = 'intersection'): """Crop annotation to new support Parameters @@ -414,8 +422,7 @@ def crop(self, support, mode='intersection'): _labels = set([]) for segment, _ in \ - self.get_timeline(copy=False).co_iter(support): - + self.get_timeline(copy=False).co_iter(support): tracks = dict(self._tracks[segment]) _tracks[segment] = tracks _labels.update(tracks.values()) @@ -471,8 +478,7 @@ def crop(self, support, mode='intersection'): else: raise NotImplementedError("unsupported mode: '%s'" % mode) - - def get_tracks(self, segment): + def get_tracks(self, segment: Segment) -> Set[TrackName]: """Query tracks by segment Parameters @@ -489,9 +495,9 @@ def get_tracks(self, segment): ---- This will return an empty set if segment does not exist. """ - return set(self._tracks.get(segment, {})) + return set(self._tracks.get(segment, {}).keys()) - def has_track(self, segment, track): + def has_track(self, segment: Segment, track: TrackName) -> bool: """Check whether a given track exists Parameters @@ -508,7 +514,7 @@ def has_track(self, segment, track): """ return track in self._tracks.get(segment, {}) - def copy(self): + def copy(self) -> 'Annotation': """Get a copy of the annotation Returns @@ -536,7 +542,9 @@ def copy(self): return copied - def new_track(self, segment, candidate=None, prefix=None): + def new_track(self, segment: Segment, + candidate: Optional[TrackName] = None, + prefix: Optional[str] = None) -> TrackName: """Generate a new track name for given segment Ensures that the returned track name does not already @@ -587,7 +595,7 @@ def __str__(self): return "\n".join(["%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True)]) - def __delitem__(self, key): + def __delitem__(self, key: Key): """Delete one track >>> del annotation[segment, track] @@ -639,7 +647,7 @@ def __delitem__(self, key): 'Deletion only works with Segment or (Segment, track) keys.') # label = annotation[segment, track] - def __getitem__(self, key): + def __getitem__(self, key: Key) -> Label: """Get track label >>> label = annotation[segment, track] @@ -656,7 +664,7 @@ def __getitem__(self, key): return self._tracks[key[0]][key[1]] # annotation[segment, track] = label - def __setitem__(self, key, label): + def __setitem__(self, key: Key, label: Label): """Add new or update existing track >>> annotation[segment, track] = label @@ -698,7 +706,7 @@ def __setitem__(self, key, label): self._tracks[segment][track] = label self._labelNeedsUpdate[label] = True - def empty(self): + def empty(self) -> 'Annotation': """Return an empty copy Returns @@ -709,7 +717,7 @@ def empty(self): """ return self.__class__(uri=self.uri, modality=self.modality) - def labels(self): + def labels(self) -> List[Label]: """Get sorted list of labels Returns @@ -721,7 +729,7 @@ def labels(self): self._updateLabels() return sorted(self._labels, key=str) - def get_labels(self, segment, unique=True): + def get_labels(self, segment: Segment, unique: bool = True) -> Set[Label]: """Query labels by segment Parameters @@ -757,7 +765,8 @@ def get_labels(self, segment, unique=True): return labels - def subset(self, labels, invert=False): + def subset(self, labels: Iterable[Label], invert: bool = False) \ + -> 'Annotation': """Filter annotation by labels Parameters @@ -785,7 +794,7 @@ def subset(self, labels, invert=False): _tracks, _labels = {}, set([]) for segment, tracks in self._tracks.items(): sub_tracks = {track: label for track, label in tracks.items() - if label in labels} + if label in labels} if sub_tracks: _tracks[segment] = sub_tracks _labels.update(sub_tracks.values()) @@ -800,7 +809,8 @@ def subset(self, labels, invert=False): return sub - def update(self, annotation, copy=False): + def update(self, annotation: 'Annotation', copy: bool = False) \ + -> 'Annotation': """Add every track of an existing annotation (in place) Parameters @@ -829,7 +839,7 @@ def update(self, annotation, copy=False): return result - def label_timeline(self, label, copy=True): + def label_timeline(self, label: Label, copy: bool = True) -> Timeline: """Query segments by label Parameters @@ -868,7 +878,7 @@ def label_timeline(self, label, copy=True): return self._labels[label] - def label_support(self, label): + def label_support(self, label: Label) -> Timeline: """Label support Equivalent to ``Annotation.label_timeline(label).support()`` @@ -891,7 +901,7 @@ def label_support(self, label): """ return self.label_timeline(label, copy=False).support() - def label_duration(self, label): + def label_duration(self, label: Label) -> float: """Label duration Equivalent to ``Annotation.label_timeline(label).duration()`` @@ -915,7 +925,7 @@ def label_duration(self, label): return self.label_timeline(label, copy=False).duration() - def chart(self, percent=False): + def chart(self, percent: bool = False) -> List[Tuple[Label, float]]: """Get labels chart (from longest to shortest duration) Parameters @@ -939,7 +949,7 @@ def chart(self, percent=False): return chart - def argmax(self, support=None): + def argmax(self, support: Optional[Support] = None) -> Optional[Label]: """Get label with longest duration Parameters @@ -977,7 +987,8 @@ def argmax(self, support=None): return max(((_, cropped.label_duration(_)) for _ in cropped.labels()), key=lambda x: x[1])[0] - def rename_tracks(self, generator='string'): + def rename_tracks(self, generator: LabelGenerator = 'string') \ + -> 'Annotation': """Rename all tracks Parameters @@ -1023,7 +1034,10 @@ def rename_tracks(self, generator='string'): renamed[s, next(generator)] = label return renamed - def rename_labels(self, mapping=None, generator='string', copy=True): + def rename_labels(self, + mapping: Optional[Dict] = None, + generator: LabelGenerator = 'string', + copy: bool = True) -> 'Annotation': """Rename labels Parameters @@ -1073,7 +1087,8 @@ def rename_labels(self, mapping=None, generator='string', copy=True): return renamed - def relabel_tracks(self, generator='string'): + def relabel_tracks(self, generator: LabelGenerator = 'string') \ + -> 'Annotation': """Relabel tracks Create a new annotation where each track has a unique label. @@ -1101,7 +1116,7 @@ def relabel_tracks(self, generator='string'): return relabeled - def support(self, collar=0.): + def support(self, collar: float = 0.) -> 'Annotation': """Annotation support The support of an annotation is an annotation where contiguous tracks @@ -1160,7 +1175,10 @@ def support(self, collar=0.): return support - def co_iter(self, other): + def co_iter(self, other: 'Annotation') \ + -> Iterator[Tuple[Tuple[Segment, TrackName], + Tuple[Segment, TrackName]] + ]: """Iterate over pairs of intersecting tracks Parameters @@ -1171,7 +1189,7 @@ def co_iter(self, other): Returns ------- iterable : (Segment, object), (Segment, object) iterable - Yields pairs of intersectins tracks, in chronological (then + Yields pairs of intersecting tracks, in chronological (then alphabetical) order. See also @@ -1187,7 +1205,7 @@ def co_iter(self, other): for t, T in itertools.product(tracks, other_tracks): yield (s, t), (S, T) - def __mul__(self, other): + def __mul__(self, other: 'Annotation') -> np.ndarray: """Cooccurrence (or confusion) matrix >>> matrix = annotation * other @@ -1226,7 +1244,7 @@ def __mul__(self, other): return matrix - def for_json(self): + def for_json(self) -> Dict: """Serialization See also @@ -1250,7 +1268,7 @@ def for_json(self): return data @classmethod - def from_json(cls, data): + def from_json(cls, data: Dict) -> 'Annotation': """Deserialization See also diff --git a/pyannote/core/feature.py b/pyannote/core/feature.py index 7477990..d008ada 100755 --- a/pyannote/core/feature.py +++ b/pyannote/core/feature.py @@ -34,16 +34,19 @@ See :class:`pyannote.core.SlidingWindowFeature` for the complete reference. """ +import numbers +import warnings +from typing import Tuple, Optional, Union, Iterator import numpy as np -import numbers + +from pyannote.core.utils.types import CropMode from .segment import Segment from .segment import SlidingWindow from .timeline import Timeline class SlidingWindowFeature(np.lib.mixins.NDArrayOperatorsMixin): - """Periodic feature vectors Parameters @@ -55,27 +58,40 @@ class SlidingWindowFeature(np.lib.mixins.NDArrayOperatorsMixin): """ - def __init__(self, data, sliding_window): - super(SlidingWindowFeature, self).__init__() - self.sliding_window = sliding_window + def __init__(self, data: np.ndarray, sliding_window: SlidingWindow): + self.sliding_window: SlidingWindow = sliding_window self.data = data - self.__i = -1 + self.__i: int = -1 def __len__(self): + """Number of feature vectors""" return self.data.shape[0] + @property + def extent(self): + return self.sliding_window.range_to_segment(0, len(self)) + + @property + def dimension(self): + """Dimension of feature vectors""" + return self.data.shape[1] + def getNumber(self): - """Number of feature vectors""" + warnings.warn("This is deprecated in favor of `__len__`", + DeprecationWarning) return self.data.shape[0] def getDimension(self): - """Dimension of feature vectors""" - return self.data.shape[1] + warnings.warn("This is deprecated in favor of `dimension` property", + DeprecationWarning) + return self.dimension def getExtent(self): - return self.sliding_window.rangeToSegment(0, self.getNumber()) + warnings.warn("This is deprecated in favor of `extent` property", + DeprecationWarning) + return self.extent - def __getitem__(self, i): + def __getitem__(self, i: int) -> np.ndarray: """Get ith feature vector""" return self.data[i] @@ -83,7 +99,7 @@ def __iter__(self): self.__i = -1 return self - def __next__(self): + def __next__(self) -> Tuple[Segment, np.ndarray]: self.__i += 1 try: return self.sliding_window[self.__i], self.data[self.__i] @@ -93,7 +109,8 @@ def __next__(self): def next(self): return self.__next__() - def iterfeatures(self, window=False): + def iterfeatures(self, window: Optional[bool] = False) \ + -> Iterator[Union[Tuple[np.ndarray, Segment], np.ndarray]]: """Feature vector iterator Parameters @@ -103,14 +120,19 @@ def iterfeatures(self, window=False): Default is to only yield feature vector """ - nSamples = self.data.shape[0] - for i in range(nSamples): + n_samples = self.data.shape[0] + for i in range(n_samples): if window: yield self.data[i], self.sliding_window[i] else: yield self.data[i] - def crop(self, focus, mode='loose', fixed=None, return_data=True): + def crop(self, + focus: Union[Segment, Timeline], + mode: CropMode = 'loose', + fixed: Optional[float] = None, + return_data: bool = True) \ + -> Union[np.ndarray, 'SlidingWindowFeature']: """Extract frames Parameters @@ -175,7 +197,7 @@ def crop(self, focus, mode='loose', fixed=None, return_data=True): [self.data[start: end, :] for start, end in clipped_ranges]) else: # if all ranges are out of bounds, just return empty data - shape = (0, ) + self.data.shape[1:] + shape = (0,) + self.data.shape[1:] data = np.empty(shape) # corner case when 'fixed' duration cropping is requested: @@ -183,11 +205,11 @@ def crop(self, focus, mode='loose', fixed=None, return_data=True): if fixed is not None: data = np.vstack([ # repeat first sample as many times as needed - np.tile(self.data[0], (repeat_first, ) + (1,) * n_dimensions), + np.tile(self.data[0], (repeat_first,) + (1,) * n_dimensions), data, # repeat last sample as many times as needed np.tile(self.data[n_samples - 1], - (repeat_last,) + (1, ) * n_dimensions)]) + (repeat_last,) + (1,) * n_dimensions)]) # return data if return_data: @@ -206,7 +228,7 @@ def _repr_png_(self): _HANDLED_TYPES = (np.ndarray, numbers.Number) - def __array__(self): + def __array__(self) -> np.ndarray: return self.data def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): @@ -242,4 +264,5 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if __name__ == "__main__": import doctest + doctest.testmod() diff --git a/pyannote/core/json.py b/pyannote/core/json.py index 5581b2b..05def99 100644 --- a/pyannote/core/json.py +++ b/pyannote/core/json.py @@ -25,8 +25,11 @@ # AUTHORS # Hervé BREDIN - http://herve.niderb.fr +from pathlib import Path +from typing import Union, TextIO import simplejson as json +from .utils.types import Resource PYANNOTE_JSON = 'pyannote' PYANNOTE_JSON_CONTENT = 'content' @@ -47,7 +50,7 @@ def object_hook(d): return d -def load(fp): +def load(fp: TextIO) -> Resource: """Deserialize Parameters @@ -63,7 +66,7 @@ def load(fp): return json.load(fp, encoding='utf-8', object_hook=object_hook) -def loads(s): +def loads(s: str) -> Resource: """Deserialize Parameters @@ -78,12 +81,12 @@ def loads(s): return json.loads(s, encoding='utf-8', object_hook=object_hook) -def load_from(path): +def load_from(path: Union[str, Path]) -> Resource: """Deserialize Parameters ---------- - path : string + path : string or Path Path to file containing serialized `pyannote.core` data structure Returns @@ -95,13 +98,13 @@ def load_from(path): return load(fp) -def dump(resource, fp): +def dump(resource: Resource, fp: TextIO): """Serialize Parameters ---------- resource : `pyannote.core` data structure - Resource to deserialize + Resource to serialize fp : file File in which `resource` serialization is written """ @@ -109,7 +112,7 @@ def dump(resource, fp): json.dump(resource, fp, encoding='utf-8', for_json=True) -def dumps(resource): +def dumps(resource: Resource) -> str: """Serialize to string Parameters @@ -124,13 +127,13 @@ def dumps(resource): return json.dumps(resource, encoding='utf-8', for_json=True) -def dump_to(resource, path): +def dump_to(resource: Resource, path: Union[str, Path]): """Serialize Parameters ---------- resource : `pyannote.core` data structure - Resource to deserialize + Resource to serialize path : string Path to file in which `resource` serialization is written """ diff --git a/pyannote/core/notebook.py b/pyannote/core/notebook.py index f255ae0..27ee264 100644 --- a/pyannote/core/notebook.py +++ b/pyannote/core/notebook.py @@ -31,6 +31,9 @@ Visualization ############# """ +from typing import Iterable, Union, Dict, Optional + +from .utils.types import Label, LabelStyle, Resource try: from IPython.core.pylabtools import print_figure @@ -46,10 +49,9 @@ from .feature import SlidingWindowFeature -class Notebook(object): +class Notebook: def __init__(self): - super(Notebook, self).__init__() self.reset() def reset(self): @@ -60,33 +62,39 @@ def reset(self): colors = [cm(1. * i / 8) for i in range(9)] self._style_generator = cycle(product(linestyle, linewidth, colors)) - self._style = {None: ('solid', 1, (0.0, 0.0, 0.0))} + self._style: Dict[Optional[Label], LabelStyle] = { + None: ('solid', 1, (0.0, 0.0, 0.0)) + } del self.crop del self.width - def crop(): - doc = "The crop property." - def fget(self): - return self._crop - def fset(self, segment): - self._crop = segment - def fdel(self): - self._crop = None - return locals() - crop = property(**crop()) - - def width(): - doc = "The width property." - def fget(self): - return self._width - def fset(self, value): - self._width = value - def fdel(self): - self._width = 20 - return locals() - width = property(**width()) - - def __getitem__(self, label): + @property + def crop(self): + """The crop property.""" + return self._crop + + @crop.setter + def crop(self, segment: Segment): + self._crop = segment + + @crop.deleter + def crop(self): + self._crop = None + + @property + def width(self): + """The width property""" + return self._width + + @width.setter + def width(self, value: int): + self._width = value + + @width.deleter + def width(self): + self._width = 20 + + def __getitem__(self, label: Label) -> LabelStyle: if label not in self._style: self._style[label] = next(self._style_generator) return self._style[label] @@ -104,7 +112,7 @@ def setup(self, ax=None, ylim=(0, 1), yaxis=False, time=True): ax.axes.get_yaxis().set_visible(yaxis) return ax - def draw_segment(self, ax, segment, y, label=None, boundaries=True): + def draw_segment(self, ax, segment: Segment, y, label=None, boundaries=True): # do nothing if segment is empty if not segment: @@ -114,7 +122,7 @@ def draw_segment(self, ax, segment, y, label=None, boundaries=True): # draw segment ax.hlines(y, segment.start, segment.end, color, - linewidth=linewidth, linestyle=linestyle, label=label) + linewidth=linewidth, linestyle=linestyle, label=label) if boundaries: ax.vlines(segment.start, y + 0.05, y - 0.05, color, linewidth=1, linestyle='solid') @@ -124,13 +132,13 @@ def draw_segment(self, ax, segment, y, label=None, boundaries=True): if label is None: return - def get_y(self, segments): + def get_y(self, segments: Iterable[Segment]) -> np.ndarray: """ Parameters ---------- - segments : iterator - `Segment` iterator (sorted) + segments : Iterable + `Segment` iterable (sorted) Returns ------- @@ -169,8 +177,9 @@ def get_y(self, segments): return y - - def __call__(self, resource, time=True, legend=True): + def __call__(self, resource: Resource, + time: bool = True, + legend: bool = True): if isinstance(resource, Segment): self.plot_segment(resource, time=time) @@ -184,6 +193,8 @@ def __call__(self, resource, time=True, legend=True): elif isinstance(resource, Scores): self.plot_scores(resource, time=time, legend=legend) + elif isinstance(resource, SlidingWindowFeature): + self.plot_feature(resource, time=time) def plot_segment(self, segment, ax=None, time=True): @@ -193,7 +204,7 @@ def plot_segment(self, segment, ax=None, time=True): ax = self.setup(ax=ax, time=time) self.draw_segment(ax, segment, 0.5) - def plot_timeline(self, timeline, ax=None, time=True): + def plot_timeline(self, timeline: Timeline, ax=None, time=True): if not self.crop and timeline: self.crop = timeline.extent() @@ -207,7 +218,7 @@ def plot_timeline(self, timeline, ax=None, time=True): # ax.set_aspect(3. / self.crop.duration) - def plot_annotation(self, annotation, ax=None, time=True, legend=True): + def plot_annotation(self, annotation: Annotation, ax=None, time=True, legend=True): if not self.crop: self.crop = annotation.get_timeline(copy=False).extent() @@ -233,7 +244,7 @@ def plot_annotation(self, annotation, ax=None, time=True, legend=True): ax.legend(H, L, bbox_to_anchor=(0, 1), loc=3, ncol=5, borderaxespad=0., frameon=False) - def plot_scores(self, scores, ax=None, time=True, legend=True): + def plot_scores(self, scores: Scores, ax=None, time=True, legend=True): if not self.crop: self.crop = scores.to_annotation().get_timeline(copy=False).extent() @@ -264,7 +275,8 @@ def plot_scores(self, scores, ax=None, time=True, legend=True): ax.legend(H, L, bbox_to_anchor=(0, 1), loc=3, ncol=5, borderaxespad=0., frameon=False) - def plot_feature(self, feature, ax=None, time=True, ylim=None): + def plot_feature(self, feature: SlidingWindowFeature, + ax=None, time=True, ylim=None): if not self.crop: self.crop = feature.getExtent() @@ -289,10 +301,11 @@ def plot_feature(self, feature, ax=None, time=True, ylim=None): ax.plot(t, data) ax.set_xlim(xlim) + notebook = Notebook() -def repr_segment(segment): +def repr_segment(segment: Segment): """Get `png` data for `segment`""" import matplotlib.pyplot as plt figsize = plt.rcParams['figure.figsize'] @@ -305,7 +318,7 @@ def repr_segment(segment): return data -def repr_timeline(timeline): +def repr_timeline(timeline: Timeline): """Get `png` data for `timeline`""" import matplotlib.pyplot as plt figsize = plt.rcParams['figure.figsize'] @@ -318,7 +331,7 @@ def repr_timeline(timeline): return data -def repr_annotation(annotation): +def repr_annotation(annotation: Annotation): """Get `png` data for `annotation`""" import matplotlib.pyplot as plt figsize = plt.rcParams['figure.figsize'] @@ -331,7 +344,7 @@ def repr_annotation(annotation): return data -def repr_scores(scores): +def repr_scores(scores: Scores): """Get `png` data for `scores`""" import matplotlib.pyplot as plt figsize = plt.rcParams['figure.figsize'] @@ -344,7 +357,7 @@ def repr_scores(scores): return data -def repr_feature(feature): +def repr_feature(feature: SlidingWindowFeature): """Get `png` data for `feature`""" import matplotlib.pyplot as plt figsize = plt.rcParams['figure.figsize'] diff --git a/pyannote/core/scores.py b/pyannote/core/scores.py index a8a071b..241558a 100644 --- a/pyannote/core/scores.py +++ b/pyannote/core/scores.py @@ -25,31 +25,34 @@ # AUTHORS # Hervé BREDIN - http://herve.niderb.fr +from typing import Optional, Callable, Iterable, Hashable, List, Set, Tuple import numpy as np +from dataclasses import fields, astuple from pandas import Index, MultiIndex, DataFrame, pivot_table from . import PYANNOTE_SEGMENT, PYANNOTE_TRACK, PYANNOTE_LABEL, PYANNOTE_SCORE from .annotation import Annotation from .segment import Segment from .timeline import Timeline +from .utils.types import Key, Label, LabelGenerator, Support, TrackName -class Unknown(object): +class Unknown: + # TODO : document this class nextID = 0 @classmethod def reset(cls): - cls.nextID = 0 + cls.nextID: int = 0 @classmethod - def getNewID(cls): + def getNewID(cls) -> int: cls.nextID += 1 return cls.nextID - def __init__(self, format='#{id:d}'): - super(Unknown, self).__init__() + def __init__(self, format: str = '#{id:d}'): self.ID = Unknown.getNewID() self._format = format @@ -78,8 +81,7 @@ def __gt__(self, other): return True - -class Scores(object): +class Scores: """ Parameters @@ -107,10 +109,13 @@ class Scores(object): >>> s[Segment(2,3), 's1', 'C'] = 0.3 """ + @classmethod def from_df( - cls, df, - uri=None, modality=None, aggfunc=np.mean + cls, df: DataFrame, + uri: Optional[str] = None, + modality: Optional[str] = None, + aggfunc: Callable = np.mean ): """ @@ -149,19 +154,21 @@ def from_df( annotation=annotation, labels=labels, values=dataframe.values) - def __init__(self, uri=None, modality=None, - annotation=None, labels=None, - values=None, dtype=None): + def __init__(self, + uri: Optional[str] = None, + modality: Optional[str] = None, + annotation: Optional[Annotation] = None, + labels: Iterable[Hashable] = None, + values: Optional[np.ndarray] = None, + dtype=None): # TODO maybe this should get removed - super(Scores, self).__init__() - - names = [PYANNOTE_SEGMENT + '_' + field - for field in Segment._fields] + [PYANNOTE_TRACK] + names = [PYANNOTE_SEGMENT + '_' + field.name + for field in fields(Segment)] + [PYANNOTE_TRACK] if annotation: annotation = annotation.copy() index = Index( - [s + (t, ) for s, t in annotation.itertracks()], + [s + (t,) for s, t in annotation.itertracks()], name=names) else: @@ -183,7 +190,7 @@ def __init__(self, uri=None, modality=None, self.modality = modality self.uri = uri - def copy(self): + def copy(self) -> 'Scores': self._reindexIfNeeded() copied = self.__class__(uri=self.uri, modality=self.modality) copied.dataframe_ = self.dataframe_.copy() @@ -194,7 +201,7 @@ def copy(self): # del scores[segment] # del scores[segment, :] # del scores[segment, track] - def __delitem__(self, key): + def __delitem__(self, key: Key): if isinstance(key, Segment): segment = key @@ -204,7 +211,7 @@ def __delitem__(self, key): elif isinstance(key, tuple) and len(key) == 2: segment, track = key - self.dataframe_.drop(tuple(segment) + (track, ), + self.dataframe_.drop(tuple(segment) + (track,), axis=0, inplace=True) del self.annotation_[segment, track] self.hasChanged_ = True @@ -219,7 +226,7 @@ def __getitem__(self, key): key = (key[0], '_', key[1]) segment, track, label = key - return self.dataframe_.at[tuple(segment) + (track, ), label] + return self.dataframe_.at[tuple(segment) + (track,), label] # scores[segment, track, label] = value # scores[segment, label] ==== scores[segment, '_', label] @@ -274,7 +281,7 @@ def __reversed__(self): def itersegments(self): return iter(self) - def tracks(self, segment): + def tracks(self, segment: Segment): """Set of tracks for query segment Parameters @@ -289,7 +296,7 @@ def tracks(self, segment): """ return self.annotation_.get_tracks(segment) - def has_track(self, segment, track): + def has_track(self, segment: Segment, track): """Check whether a given track exists Parameters @@ -306,7 +313,7 @@ def has_track(self, segment, track): """ return self.annotation_.has_track(segment, track) - def get_track_by_name(self, track): + def get_track_by_name(self, track: TrackName) -> List[Tuple[Segment]]: """Get all tracks with given name Parameters @@ -319,17 +326,20 @@ def get_track_by_name(self, track): tracks : list List of (segment, track) tuples """ + # WARNING: this doesn't call a valid class return self.annotation_.get_track_by_name(track) - def new_track(self, segment, candidate=None, prefix=None): + def new_track(self, + segment: Segment, + candidate: Optional[TrackName]=None, + prefix: Optional[str]=None): """Track name generator Parameters ---------- segment : Segment - prefix : str, optional candidate : any valid track name - + prefix : str, optional Returns ------- @@ -360,7 +370,7 @@ def itervalues(self): if not np.isnan(value): yield segment, track, label, value - def get_track_scores(self, segment, track): + def get_track_scores(self, segment: Segment, track): """Get all scores for a given track. Parameters @@ -374,9 +384,9 @@ def get_track_scores(self, segment, track): scores : dict {label: score} dictionary """ - return dict(self.dataframe_.xs(tuple(segment) + (track, ))) + return dict(self.dataframe_.xs(tuple(segment) + (track,))) - def labels(self): + def labels(self) -> List[Label]: """List of labels Returns @@ -395,11 +405,11 @@ def _reindexIfNeeded(self): if not self.hasChanged_: return - names = [PYANNOTE_SEGMENT + '_' + field - for field in Segment._fields] + [PYANNOTE_TRACK] + names = [PYANNOTE_SEGMENT + '_' + field.name + for field in fields(Segment)] + [PYANNOTE_TRACK] new_index = Index( - [s + (t, ) for s, t in self.annotation_.itertracks()], + [astuple(s) + (t,) for s, t in self.annotation_.itertracks()], name=names) self.dataframe_ = self.dataframe_.reindex(new_index) @@ -408,7 +418,7 @@ def _reindexIfNeeded(self): return - def rename_tracks(self, generator='int'): + def rename_tracks(self, generator: LabelGenerator = 'int'): """Rename tracks""" self._reindexIfNeeded() @@ -417,16 +427,16 @@ def rename_tracks(self, generator='int'): annotation = self.annotation_.rename_tracks(generator=generator) retracked.annotation_ = annotation - names = [PYANNOTE_SEGMENT + '_' + field - for field in Segment._fields] + [PYANNOTE_TRACK] + names = [PYANNOTE_SEGMENT + '_' + field.name + for field in fields(Segment)] + [PYANNOTE_TRACK] new_index = Index( - [s + (t, ) for s, t in annotation.itertracks()], + [astuple(s) + (t,) for s, t in annotation.itertracks()], name=names) retracked.dataframe_.index = new_index return retracked - def apply(self, func, axis=0): + def apply(self, func: Callable, axis=0): applied = self.copy() applied.dataframe_ = self.dataframe_.apply(func, axis=axis) @@ -434,7 +444,7 @@ def apply(self, func, axis=0): return applied - def rank(self, ascending=False): + def rank(self, ascending: bool = False): """ Parameters @@ -454,7 +464,7 @@ def rank(self, ascending=False): ranked.hasChanged_ = True return ranked - def nbest(self, n, ascending=False): + def nbest(self, n: int, ascending: bool = False): """ Parameters @@ -478,7 +488,7 @@ def nbest(self, n, ascending=False): filtered.hasChanged_ = True return filtered - def subset(self, labels, invert=False): + def subset(self, labels: Set[Label], invert: bool = False): """Scores subset Extract scores subset based on labels @@ -512,7 +522,7 @@ def subset(self, labels, invert=False): return subset - def to_annotation(self, threshold=-np.inf, posterior=False): + def to_annotation(self, threshold: float = -np.inf, posterior: bool = False): """ Parameters @@ -558,7 +568,7 @@ def to_annotation(self, threshold=-np.inf, posterior=False): return annotation - def map(self, func): + def map(self, func: Callable): """Apply function to all values""" mapped = self.copy() @@ -566,7 +576,7 @@ def map(self, func): mapped.hasChanged_ = True return mapped - def crop(self, focus, mode='strict'): + def crop(self, focus: Support, mode: str = 'strict') -> Support: """Crop on focus Parameters @@ -653,4 +663,5 @@ def _repr_png_(self): if __name__ == "__main__": import doctest + doctest.testmod() diff --git a/pyannote/core/segment.py b/pyannote/core/segment.py index f9e717c..2bf5e1d 100755 --- a/pyannote/core/segment.py +++ b/pyannote/core/segment.py @@ -66,14 +66,19 @@ See :class:`pyannote.core.Segment` for the complete reference. """ -from collections import namedtuple +import warnings +from typing import Union, Optional, Tuple, List, Iterator +from .utils.types import Alignment + import numpy as np +from dataclasses import dataclass # 1 μs (one microsecond) SEGMENT_PRECISION = 1e-6 - -class Segment(namedtuple('Segment', ['start', 'end'])): +# setting 'frozen' to True makes it hashable and immutable +@dataclass(frozen=True, order=True) +class Segment: """ Time interval @@ -106,9 +111,8 @@ class Segment(namedtuple('Segment', ['start', 'end'])): - `segment.start == other_segment.start` and `segment.end < other_segment.end` """ - - def __new__(cls, start=0., end=0.): - return super(Segment, cls).__new__(cls, float(start), float(end)) + start: float = 0.0 + end: float = 0.0 def __bool__(self): """Emptiness @@ -125,17 +129,17 @@ def __bool__(self): """ return (self.end - self.start) > SEGMENT_PRECISION - def _get_duration(self): + @property + def duration(self) -> float: + """Segment duration (read-only)""" return self.end - self.start if self else 0. - duration = property(fget=_get_duration) - """Segment duration (read-only)""" - def _get_middle(self): + @property + def middle(self) -> float: + """Segment mid-time (read-only)""" return .5 * (self.start + self.end) - middle = property(fget=_get_middle) - """Segment mid-time (read-only)""" - def __iter__(self): + def __iter__(self) -> Iterator[float]: """Unpack segment boundaries >>> segment = Segment(start, end) >>> start, end = segment @@ -143,7 +147,7 @@ def __iter__(self): yield self.start yield self.end - def copy(self): + def copy(self) -> 'Segment': """Get a copy of the segment Returns @@ -157,7 +161,7 @@ def copy(self): # Inclusion (in), intersection (&), union (|) and gap (^) # # ------------------------------------------------------- # - def __contains__(self, other): + def __contains__(self, other: 'Segment'): """Inclusion >>> segment = Segment(start=0, end=10) @@ -190,7 +194,7 @@ def __and__(self, other): end = min(self.end, other.end) return Segment(start=start, end=end) - def intersects(self, other): + def intersects(self, other: 'Segment') -> bool: """Check whether two segments intersect each other Parameters @@ -210,7 +214,7 @@ def intersects(self, other): self.start < other.end - SEGMENT_PRECISION) or \ (self.start == other.start) - def overlaps(self, t): + def overlaps(self, t: float) -> bool: """Check if segment overlaps a given time Parameters @@ -225,7 +229,7 @@ def overlaps(self, t): """ return self.start <= t and self.end >= t - def __or__(self, other): + def __or__(self, other: 'Segment') -> 'Segment': """Union >>> segment = Segment(0, 10) @@ -255,7 +259,7 @@ def __or__(self, other): end = max(self.end, other.end) return Segment(start=start, end=end) - def __xor__(self, other): + def __xor__(self, other: 'Segment') -> 'Segment': """Gap >>> segment = Segment(0, 10) @@ -283,7 +287,7 @@ def __xor__(self, other): end = max(self.start, other.start) return Segment(start=start, end=end) - def _str_helper(self, seconds): + def _str_helper(self, seconds: float) -> str: from datetime import timedelta negative = seconds < 0 seconds = abs(seconds) @@ -349,7 +353,7 @@ def _repr_png_(self): return repr_segment(self) -class SlidingWindow(object): +class SlidingWindow: """Sliding window Parameters @@ -383,7 +387,6 @@ class SlidingWindow(object): """ def __init__(self, duration=0.030, step=0.010, start=0.000, end=None): - super(SlidingWindow, self).__init__() # duration must be a float > 0 if duration <= 0: @@ -393,44 +396,44 @@ def __init__(self, duration=0.030, step=0.010, start=0.000, end=None): # step must be a float > 0 if step <= 0: raise ValueError("'step' must be a float > 0.") - self.__step = step + self.__step: float = step # start must be a float. - self.__start = start + self.__start: float = start # if end is not provided, set it to infinity if end is None: - self.__end = np.inf + self.__end: float = np.inf else: # end must be greater than start if end <= start: raise ValueError("'end' must be greater than 'start'.") - self.__end = end + self.__end: float = end # current index of iterator - self.__i = -1 + self.__i: int = -1 - def __get_start(self): + @property + def start(self) -> float: + """Sliding window start time in seconds.""" return self.__start - start = property(fget=__get_start) - """Sliding window start time in seconds.""" - def __get_end(self): + @property + def end(self) -> float: + """Sliding window end time in seconds.""" return self.__end - end = property(fget=__get_end) - """Sliding window end time in seconds.""" - def __get_step(self): + @property + def step(self) -> float: + """Sliding window step in seconds.""" return self.__step - step = property(fget=__get_step) - """Sliding window step in seconds.""" - def __get_duration(self): + @property + def duration(self) -> float: + """Sliding window duration in seconds.""" return self.__duration - duration = property(fget=__get_duration) - """Sliding window duration in seconds.""" - def closest_frame(self, t): + def closest_frame(self, t: float) -> int: """Closest frame to timestamp. Parameters @@ -448,7 +451,7 @@ def closest_frame(self, t): (t - self.__start - .5 * self.__duration) / self.__step )) - def samples(self, from_duration, mode='strict'): + def samples(self, from_duration: float, mode: Alignment = 'strict') -> int: """Number of frames Parameters @@ -475,7 +478,11 @@ def samples(self, from_duration, mode='strict'): elif mode == 'center': return int(np.rint((from_duration / self.step))) - def crop(self, focus, mode='loose', fixed=None, return_ranges=False): + def crop(self, focus: Union[Segment, 'Timeline'], + mode: Alignment = 'loose', + fixed: Optional[float] = None, + return_ranges: Optional[bool] = False) -> \ + Union[np.ndarray, List[List[int]]]: """Crop sliding window Parameters @@ -598,7 +605,12 @@ def crop(self, focus, mode='loose', fixed=None, return_ranges=False): return np.array(range(*rng), dtype=np.int64) - def segmentToRange(self, segment): + def segmentToRange(self, segment: Segment) -> Tuple[int, int]: + warnings.warn("Deprecated in favor of `segment_to_range`", + DeprecationWarning) + return self.segment_to_range(segment) + + def segment_to_range(self, segment: Segment) -> Tuple[int, int]: """Convert segment to 0-indexed frame range Parameters @@ -616,7 +628,7 @@ def segmentToRange(self, segment): -------- >>> window = SlidingWindow() - >>> print window.segmentToRange(Segment(10, 15)) + >>> print window.segment_to_range(Segment(10, 15)) i0, n """ @@ -628,7 +640,12 @@ def segmentToRange(self, segment): return i0, n - def rangeToSegment(self, i0, n): + def rangeToSegment(self, i0: int, n: int) -> Segment: + warnings.warn("This is deprecated in favor of `range_to_segment`", + DeprecationWarning) + return self.range_to_segment(i0, n) + + def range_to_segment(self, i0: int, n: int) -> Segment: """Convert 0-indexed frame range to segment Each frame represents a unique segment of duration 'step', centered on @@ -652,7 +669,7 @@ def rangeToSegment(self, i0, n): -------- >>> window = SlidingWindow() - >>> print window.rangeToSegment(3, 2) + >>> print window.range_to_segment(3, 2) [ --> ] """ @@ -673,16 +690,25 @@ def rangeToSegment(self, i0, n): return Segment(start, end) - def samplesToDuration(self, nSamples): + def samplesToDuration(self, nSamples: int) -> float: + warnings.warn("This is deprecated in favor of `samples_to_duration`", + DeprecationWarning) + return self.samples_to_duration(nSamples) + + def samples_to_duration(self, n_samples: int) -> float: """Returns duration of samples""" - return self.rangeToSegment(0, nSamples).duration + return self.range_to_segment(0, n_samples).duration + def durationToSamples(self, duration: float) -> int: + warnings.warn("This is deprecated in favor of `duration_to_samples`", + DeprecationWarning) + return self.duration_to_samples(duration) - def durationToSamples(self, duration): + def duration_to_samples(self, duration: float) -> int: """Returns samples in duration""" - return self.segmentToRange(Segment(0, duration))[1] + return self.segment_to_range(Segment(0, duration))[1] - def __getitem__(self, i): + def __getitem__(self, i: int) -> Segment: """ Parameters ---------- @@ -706,10 +732,10 @@ def __getitem__(self, i): return Segment(start=start, end=start + self.__duration) - def next(self): + def next(self) -> Segment: return self.__next__() - def __next__(self): + def __next__(self) -> Segment: self.__i += 1 window = self[self.__i] @@ -718,7 +744,7 @@ def __next__(self): else: raise StopIteration() - def __iter__(self): + def __iter__(self) -> 'SlidingWindow': """Sliding window iterator Use expression 'for segment in sliding_window' @@ -745,7 +771,7 @@ def __iter__(self): self.__i = -1 return self - def __len__(self): + def __len__(self) -> int: """Number of positions Equivalent to len([segment for segment in window]) @@ -764,19 +790,19 @@ def __len__(self): # based on frame closest to the end i = self.closest_frame(self.__end) - while(self[i]): + while (self[i]): i += 1 length = i return length - def copy(self): + def copy(self) -> 'SlidingWindow': """Duplicate sliding window""" duration = self.duration step = self.step start = self.start end = self.end - sliding_window = SlidingWindow( + sliding_window = self.__class__( duration=duration, step=step, start=start, end=end ) return sliding_window diff --git a/pyannote/core/timeline.py b/pyannote/core/timeline.py index 862aa3a..313f527 100755 --- a/pyannote/core/timeline.py +++ b/pyannote/core/timeline.py @@ -88,12 +88,22 @@ See :class:`pyannote.core.Timeline` for the complete reference. """ +from typing import (Optional, Iterable, List, Union, Callable, + TextIO, Tuple, TYPE_CHECKING, Iterator, Dict) -from typing import TextIO -from .segment import Segment +import pandas as pd from sortedcontainers import SortedList + from . import PYANNOTE_URI, PYANNOTE_SEGMENT from .json import PYANNOTE_JSON, PYANNOTE_JSON_CONTENT +from .segment import Segment +from .utils.types import Support, Label, CropMode + +#  this is a moderately ugly way to import `Annotation` to the namespace +# without causing some circular imports : +# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports +if TYPE_CHECKING: + from .annotation import Annotation # ===================================================================== @@ -101,7 +111,7 @@ # ===================================================================== -class Timeline(object): +class Timeline: """ Ordered set of segments. @@ -123,15 +133,14 @@ class Timeline(object): """ @classmethod - def from_df(cls, df, uri=None): + def from_df(cls, df: pd.DataFrame, uri: Optional[str] = None) -> 'Timeline': segments = list(df[PYANNOTE_SEGMENT]) timeline = cls(segments=segments, uri=uri) return timeline - def __init__(self, segments=None, uri=None): - - super(Timeline, self).__init__() - + def __init__(self, + segments: Optional[Iterable[Segment]] = None, + uri: str = None): if segments is None: segments = () @@ -151,7 +160,7 @@ def __init__(self, segments=None, uri=None): self.segments_boundaries_ = SortedList(boundaries) # path to (or any identifier of) segmented resource - self.uri = uri + self.uri: str = uri def __len__(self): """Number of segments @@ -174,7 +183,7 @@ def __bool__(self): """ return len(self.segments_set_) > 0 - def __iter__(self): + def __iter__(self) -> Iterable[Segment]: """Iterate over segments (in chronological order) >>> for segment in timeline: @@ -186,7 +195,7 @@ def __iter__(self): """ return iter(self.segments_list_) - def __getitem__(self, k): + def __getitem__(self, k: int) -> Segment: """Get segment by index (in chronological order) >>> first_segment = timeline[0] @@ -194,7 +203,7 @@ def __getitem__(self, k): """ return self.segments_list_[k] - def __eq__(self, other): + def __eq__(self, other: 'Timeline'): """Equality Two timelines are equal if and only if their segments are equal. @@ -209,11 +218,11 @@ def __eq__(self, other): """ return self.segments_set_ == other.segments_set_ - def __ne__(self, other): + def __ne__(self, other: 'Timeline'): """Inequality""" return self.segments_set_ != other.segments_set_ - def index(self, segment): + def index(self, segment: Segment) -> int: """Get index of (existing) segment Parameters @@ -232,7 +241,7 @@ def index(self, segment): """ return self.segments_list_.index(segment) - def add(self, segment): + def add(self, segment: Segment) -> 'Timeline': """Add a segment (in place) Parameters @@ -268,7 +277,7 @@ def add(self, segment): return self - def remove(self, segment): + def remove(self, segment: Segment) -> 'Timeline': """Remove a segment (in place) Parameters @@ -300,7 +309,7 @@ def remove(self, segment): return self - def discard(self, segment): + def discard(self, segment: Segment) -> 'Timeline': """Same as `remove` See also @@ -309,10 +318,10 @@ def discard(self, segment): """ return self.remove(segment) - def __ior__(self, timeline): + def __ior__(self, timeline: 'Timeline') -> 'Timeline': return self.update(timeline) - def update(self, timeline): + def update(self, timeline: Segment) -> 'Timeline': """Add every segments of an existing timeline (in place) Parameters @@ -345,10 +354,10 @@ def update(self, timeline): return self - def __or__(self, timeline): + def __or__(self, timeline: 'Timeline') -> 'Timeline': return self.union(timeline) - def union(self, timeline): + def union(self, timeline: 'Timeline') -> 'Timeline': """Create new timeline made of union of segments Parameters @@ -369,7 +378,7 @@ def union(self, timeline): segments = self.segments_set_ | timeline.segments_set_ return Timeline(segments=segments, uri=self.uri) - def co_iter(self, other): + def co_iter(self, other: 'Timeline') -> Iterator[Tuple[Segment, Segment]]: """Iterate over pairs of intersecting segments >>> timeline1 = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 4)]) @@ -399,7 +408,11 @@ def co_iter(self, other): if segment.intersects(other_segment): yield segment, other_segment - def crop_iter(self, support, mode='intersection', returns_mapping=False): + def crop_iter(self, + support: Support, + mode: CropMode = 'intersection', + returns_mapping: bool = False) \ + -> Iterator[Union[Tuple[Segment, Segment], Segment]]: """Like `crop` but returns a segment iterator instead See also @@ -450,7 +463,11 @@ def crop_iter(self, support, mode='intersection', returns_mapping=False): else: yield mapped_to - def crop(self, support, mode='intersection', returns_mapping=False): + def crop(self, + support: Support, + mode: CropMode = 'intersection', + returns_mapping: bool = False) \ + -> Union['Timeline', Tuple['Timeline', Dict[Segment, Segment]]]: """Crop timeline to new support Parameters @@ -507,7 +524,7 @@ def crop(self, support, mode='intersection', returns_mapping=False): return Timeline(segments=self.crop_iter(support, mode=mode), uri=self.uri) - def overlapping(self, t): + def overlapping(self, t: float) -> List[Segment]: """Get list of segments overlapping `t` Parameters @@ -522,7 +539,7 @@ def overlapping(self, t): """ return list(self.overlapping_iter(t)) - def overlapping_iter(self, t): + def overlapping_iter(self, t: float) -> Iterator[Segment]: """Like `overlapping` but returns a segment iterator instead See also @@ -530,7 +547,6 @@ def overlapping_iter(self, t): :func:`pyannote.core.Timeline.overlapping` """ segment = Segment(start=t, end=t) - iterable = self.segments_list_.irange(maximum=segment) for segment in self.segments_list_.irange(maximum=segment): if segment.overlaps(t): yield segment @@ -549,7 +565,7 @@ def __str__(self): string = "[" for i, segment in enumerate(self.segments_list_): string += str(segment) - string += "\n " if i+1 < n else "" + string += "\n " if i + 1 < n else "" string += "]" return string @@ -564,7 +580,7 @@ def __repr__(self): return "" % (self.uri, list(self.segments_list_)) - def __contains__(self, included): + def __contains__(self, included: Union[Segment, 'Timeline']): """Inclusion Check whether every segment of `included` does exist in timeline. @@ -596,14 +612,14 @@ def __contains__(self, included): return included in self.segments_set_ elif isinstance(included, Timeline): - return self.segments_set_.issuperset(included._segments) + return self.segments_set_.issuperset(included.segments_set_) else: raise TypeError( 'Checking for inclusion only supports Segment and ' 'Timeline instances') - def empty(self): + def empty(self) -> 'Timeline': """Return an empty copy Returns @@ -614,7 +630,8 @@ def empty(self): """ return Timeline(uri=self.uri) - def copy(self, segment_func=None): + def copy(self, segment_func: Optional[Callable[[Segment], Segment]] = None) \ + -> 'Timeline': """Get a copy of the timeline If `segment_func` is provided, it is applied to each segment first. @@ -642,7 +659,7 @@ def copy(self, segment_func=None): return Timeline(segments=[segment_func(s) for s in self.segments_list_], uri=self.uri) - def extent(self): + def extent(self) -> Segment: """Extent The extent of a timeline is the segment of minimum duration that @@ -679,8 +696,8 @@ def extent(self): import numpy as np return Segment(start=np.inf, end=-np.inf) - def support_iter(self): - """Like `support` but returns a segment iterator instead + def support_iter(self) -> Iterator[Segment]: + """Like `support` but returns a segment generator instead See also -------- @@ -721,7 +738,7 @@ def support_iter(self): # Add new segment to the timeline support yield new_segment - def support(self): + def support(self) -> 'Timeline': """Timeline support The support of a timeline is the timeline with the minimum number of @@ -745,7 +762,7 @@ def support(self): """ return Timeline(segments=self.support_iter(), uri=self.uri) - def duration(self): + def duration(self) -> float: """Timeline duration The timeline duration is the sum of the durations of the segments @@ -761,8 +778,8 @@ def duration(self): # of the segments in the timeline support. return sum(s.duration for s in self.support_iter()) - def gaps_iter(self, support=None): - """Like `gaps` but returns a segment iterator instead + def gaps_iter(self, support: Optional[Support] = None) -> Iterator[Segment]: + """Like `gaps` but returns a segment generator instead See also -------- @@ -809,7 +826,8 @@ def gaps_iter(self, support=None): for gap in self.gaps_iter(support=segment): yield gap - def gaps(self, support=None): + def gaps(self, support: Optional[Support] = None) \ + -> 'Timeline': """Gaps A picture is worth a thousand words:: @@ -840,7 +858,7 @@ def gaps(self, support=None): return Timeline(segments=self.gaps_iter(support=support), uri=self.uri) - def segmentation(self): + def segmentation(self) -> 'Timeline': """Segmentation Create the unique timeline with same support and same set of segment @@ -899,7 +917,10 @@ def segmentation(self): return Timeline(segments=segments, uri=self.uri) - def to_annotation(self, generator='string', modality=None): + def to_annotation(self, + generator: Union[str, Iterable[Label], None, None] = 'string', + modality: Optional[str] = None) \ + -> 'Annotation': """Turn timeline into an annotation Each segment is labeled by a unique label. @@ -931,7 +952,7 @@ def to_annotation(self, generator='string', modality=None): return annotation - def write_uem(self,file: TextIO): + def write_uem(self, file: TextIO): """Dump timeline to file using UEM format Parameters @@ -945,7 +966,7 @@ def write_uem(self,file: TextIO): """ uri = self.uri if self.uri else "" - + for segment in self: line = f"{uri} 1 {segment.start:.3f} {segment.end:.3f}\n" file.write(line) diff --git a/pyannote/core/utils/distance.py b/pyannote/core/utils/distance.py index 73aa48c..f3971fa 100644 --- a/pyannote/core/utils/distance.py +++ b/pyannote/core/utils/distance.py @@ -31,7 +31,7 @@ import scipy.cluster.hierarchy -def l2_normalize(X): +def l2_normalize(X: np.ndarray): """L2 normalize vectors Parameters diff --git a/pyannote/core/utils/generators.py b/pyannote/core/utils/generators.py index de8be59..00a9600 100644 --- a/pyannote/core/utils/generators.py +++ b/pyannote/core/utils/generators.py @@ -28,16 +28,18 @@ import itertools from string import ascii_uppercase +from typing import Iterable, Union, List, Set, Optional, Iterator -def pairwise(iterable): - "s -> (s0,s1), (s1,s2), (s2, s3), ..." +def pairwise(iterable: Iterable): + """s -> (s0,s1), (s1,s2), (s2, s3), ...""" a, b = itertools.tee(iterable) next(b, None) return zip(a, b) -def string_generator(skip=[]): +def string_generator(skip: Optional[Union[List, Set]] = None) \ + -> Iterator[str]: """Label generator Parameters @@ -61,6 +63,8 @@ def string_generator(skip=[]): next(t) -> 'AAA' # then 3-letters labels ... # (you get the idea) """ + if skip is None: + skip = list() # label length r = 1 @@ -77,7 +81,8 @@ def string_generator(skip=[]): # increment label length when all possibilities are exhausted r = r + 1 -def int_generator(): + +def int_generator() -> Iterator[int]: i = 0 while True: yield i diff --git a/pyannote/core/utils/helper.py b/pyannote/core/utils/helper.py index ed63cc5..ddd37c2 100644 --- a/pyannote/core/utils/helper.py +++ b/pyannote/core/utils/helper.py @@ -31,7 +31,7 @@ def get_class_by_name(class_name: str, - default_module_name: Optional[str] = None) -> type : + default_module_name: Optional[str] = None) -> type: """Load class by its name Parameters diff --git a/pyannote/core/utils/types.py b/pyannote/core/utils/types.py new file mode 100644 index 0000000..218e03e --- /dev/null +++ b/pyannote/core/utils/types.py @@ -0,0 +1,15 @@ +from typing import Hashable, Union, Tuple, Iterator + +from typing_extensions import Literal + +Label = Hashable +Support = Union['Segment', 'Timeline'] +LabelGeneratorMode = Literal['int', 'string'] +LabelGenerator = Union[LabelGeneratorMode, Iterator[Label]] +TrackName = Union[str, int] +Key = Union['Segment', Tuple['Segment', TrackName]] +Resource = Union['Segment', 'Timeline', 'Score', 'SlidingWindowFeature', + 'Annotation'] +CropMode = Literal['intersection', 'loose', 'strict'] +Alignment = Literal['center', 'loose', 'strict'] +LabelStyle = Tuple[str, int, Tuple[float, float, float]] diff --git a/setup.py b/setup.py index d2094f1..01e259c 100755 --- a/setup.py +++ b/setup.py @@ -42,6 +42,8 @@ 'pandas >= 0.17.1', 'simplejson >= 3.8.1', 'matplotlib >= 2.0.0', + "dataclasses >= 0.7; python_version <'3.7'", + 'typing-extensions >= 3.7.4.1' ], # versioneer version=versioneer.get_version(),