Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP : python3-ization and type hinting #37

Merged
merged 11 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ jobs:
- name: Install
run: |
python -m pip install --upgrade pip
pip install .
pip install -r docs/requirements.txt
pip install .[docs]
- name: Build documentation
run: |
make --directory=docs html
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [3.6, 3.7]
python-version: [3.7, 3.8, 3.9, "3.10"]

steps:
- uses: actions/checkout@v1
Expand All @@ -28,7 +28,7 @@ jobs:
- name: Install from source
run: |
python -m pip install --upgrade pip
pip install .
pip install .[tests]
- name: Lint with flake8
run: |
pip install flake8
Expand All @@ -38,5 +38,4 @@ jobs:
flake8 ./pyannote --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pip install pytest
pytest
3 changes: 0 additions & 3 deletions docs/requirements.txt

This file was deleted.

3 changes: 1 addition & 2 deletions pyannote/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
# AUTHORS
# Hervé BREDIN - http://herve.niderb.fr

from .base import f_measure

from ._version import get_versions
from .base import f_measure

__version__ = get_versions()["version"]
del get_versions
Expand Down
45 changes: 23 additions & 22 deletions pyannote/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@

# AUTHORS
# Hervé BREDIN - http://herve.niderb.fr
from typing import List, Dict


import scipy.stats
import pandas as pd
import numpy as np
import pandas as pd
import scipy.stats

from pyannote.metrics.types import Details, MetricComponents


class BaseMetric(object):
Expand All @@ -43,17 +45,17 @@ class BaseMetric(object):
"""

@classmethod
def metric_name(cls):
def metric_name(cls) -> str:
raise NotImplementedError(
cls.__name__ + " is missing a 'metric_name' class method. "
"It should return the name of the metric as string."
"It should return the name of the metric as string."
)

@classmethod
def metric_components(cls):
def metric_components(cls) -> MetricComponents:
raise NotImplementedError(
cls.__name__ + " is missing a 'metric_components' class method. "
"It should return the list of names of metric components."
"It should return the list of names of metric components."
)

def __init__(self, **kwargs):
Expand All @@ -67,15 +69,15 @@ def init_components(self):

def reset(self):
"""Reset accumulated components and metric values"""
self.accumulated_ = dict()
self.results_ = list()
self.accumulated_: Dict[str, float] = dict()
hadware marked this conversation as resolved.
Show resolved Hide resolved
self.results_: List = list()
for value in self.components_:
self.accumulated_[value] = 0.0

def __get_name(self):
return self.__class__.metric_name()

name = property(fget=__get_name, doc="Metric name.")
@property
def name(self):
"""Metric name."""
return self.metric_name()

# TODO: use joblib/locky to allow parallel processing?
# TODO: signature could be something like __call__(self, reference_iterator, hypothesis_iterator, ...)
Expand Down Expand Up @@ -241,7 +243,7 @@ def __iter__(self):
for uri, component in self.results_:
yield uri, component

def compute_components(self, reference, hypothesis, **kwargs):
def compute_components(self, reference, hypothesis, **kwargs) -> Dict[str, float]:
"""Compute metric components

Parameters
Expand All @@ -260,11 +262,11 @@ def compute_components(self, reference, hypothesis, **kwargs):
"""
raise NotImplementedError(
self.__class__.__name__ + " is missing a 'compute_components' method."
"It should return a dictionary where keys are component names "
"and values are component values."
"It should return a dictionary where keys are component names "
"and values are component values."
)

def compute_metric(self, components):
def compute_metric(self, components: Details):
"""Compute metric value from computed `components`

Parameters
Expand All @@ -280,8 +282,8 @@ def compute_metric(self, components):
"""
raise NotImplementedError(
self.__class__.__name__ + " is missing a 'compute_metric' method. "
"It should return the actual value of the metric based "
"on the precomputed component dictionary given as input."
"It should return the actual value of the metric based "
"on the precomputed component dictionary given as input."
)

def confidence_interval(self, alpha=0.9):
Expand Down Expand Up @@ -374,7 +376,7 @@ def metric_name(cls):
def metric_components(cls):
return [RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED]

def compute_metric(self, components):
def compute_metric(self, components) -> float:
"""Compute recall from `components`"""
numerator = components[RECALL_RELEVANT_RETRIEVED]
denominator = components[RECALL_RELEVANT]
Expand All @@ -387,7 +389,7 @@ def compute_metric(self, components):
return numerator / denominator


def f_measure(precision, recall, beta=1.0):
def f_measure(precision: float, recall: float, beta=1.0) -> float:
"""Compute f-measure

f-measure is defined as follows:
Expand All @@ -398,4 +400,3 @@ def f_measure(precision, recall, beta=1.0):
if precision + recall == 0.0:
return 0
return (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)

33 changes: 21 additions & 12 deletions pyannote/metrics/binary_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,21 @@
# AUTHORS
# Hervé BREDIN - http://herve.niderb.fr

import numpy as np
from collections import Counter
from typing import Tuple

import numpy as np
import sklearn.metrics
from pandas._typing import ArrayLike
hadware marked this conversation as resolved.
Show resolved Hide resolved
from sklearn.base import BaseEstimator
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection._split import _CVIterableWrapper

from .types import CalibrationMethod

def det_curve(y_true, scores, distances=False):

def det_curve(y_true: ArrayLike, scores: ArrayLike, distances: bool = False) \
-> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
"""DET curve

Parameters
Expand Down Expand Up @@ -71,13 +77,16 @@ def det_curve(y_true, scores, distances=False):

# estimate equal error rate
eer_index = np.where(fpr > fnr)[0][0]
eer = .25 * (fpr[eer_index-1] + fpr[eer_index] +
fnr[eer_index-1] + fnr[eer_index])
eer = .25 * (fpr[eer_index - 1] + fpr[eer_index] +
fnr[eer_index - 1] + fnr[eer_index])

return fpr, fnr, thresholds, eer


def precision_recall_curve(y_true, scores, distances=False):
def precision_recall_curve(y_true: ArrayLike,
scores: ArrayLike,
distances: bool = False) \
-> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
"""Precision-recall curve

Parameters
Expand Down Expand Up @@ -120,18 +129,18 @@ class _Passthrough(BaseEstimator):
"""Dummy binary classifier used by score Calibration class"""

def __init__(self):
super(_Passthrough, self).__init__()
super().__init__()
self.classes_ = np.array([False, True], dtype=np.bool)

def fit(self, scores, y_true):
return self

def decision_function(self, scores):
def decision_function(self, scores: ArrayLike):
"""Returns the input scores unchanged"""
return scores


class Calibration(object):
class Calibration:
"""Probability calibration for binary classification tasks

Parameters
Expand All @@ -154,12 +163,12 @@ class Calibration(object):

"""

def __init__(self, equal_priors=False, method='isotonic'):
super(Calibration, self).__init__()
def __init__(self, equal_priors: bool = False,
method: CalibrationMethod = 'isotonic'):
self.method = method
self.equal_priors = equal_priors

def fit(self, scores, y_true):
def fit(self, scores: ArrayLike, y_true: ArrayLike):
"""Train calibration

Parameters
Expand Down Expand Up @@ -209,7 +218,7 @@ def fit(self, scores, y_true):

return self

def transform(self, scores):
def transform(self, scores: ArrayLike):
"""Calibrate scores into probabilities

Parameters
Expand Down
40 changes: 16 additions & 24 deletions pyannote/metrics/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,52 +90,44 @@

"""

# command line parsing
from docopt import docopt

import sys
import functools
import json
import sys
import warnings
import functools

import numpy as np
import pandas as pd
from tabulate import tabulate

from pyannote.core import Timeline
# command line parsing
from docopt import docopt
from pyannote.core import Annotation
from pyannote.database.util import load_rttm

from pyannote.core import Timeline
# evaluation protocols
from pyannote.database import get_protocol
from pyannote.database.util import get_annotated
from pyannote.database.util import load_rttm
from tabulate import tabulate

from pyannote.metrics.detection import DetectionErrorRate
from pyannote.metrics.detection import DetectionAccuracy
from pyannote.metrics.detection import DetectionRecall
from pyannote.metrics.detection import DetectionErrorRate
from pyannote.metrics.detection import DetectionPrecision

from pyannote.metrics.segmentation import SegmentationPurity
from pyannote.metrics.segmentation import SegmentationCoverage
from pyannote.metrics.segmentation import SegmentationPrecision
from pyannote.metrics.segmentation import SegmentationRecall

from pyannote.metrics.diarization import GreedyDiarizationErrorRate
from pyannote.metrics.detection import DetectionRecall
from pyannote.metrics.diarization import DiarizationCoverage
from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.metrics.diarization import DiarizationPurity
from pyannote.metrics.diarization import DiarizationCoverage

from pyannote.metrics.diarization import GreedyDiarizationErrorRate
from pyannote.metrics.identification import IdentificationErrorRate
from pyannote.metrics.identification import IdentificationPrecision
from pyannote.metrics.identification import IdentificationRecall

from pyannote.metrics.segmentation import SegmentationCoverage
from pyannote.metrics.segmentation import SegmentationPrecision
from pyannote.metrics.segmentation import SegmentationPurity
from pyannote.metrics.segmentation import SegmentationRecall
from pyannote.metrics.spotting import LowLatencySpeakerSpotting

showwarning_orig = warnings.showwarning


def showwarning(message, category, *args, **kwargs):
import sys

print(category.__name__ + ":", str(message))


Expand Down
Loading