Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP : Add gamma #59

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ jobs:
- name: Install
run: |
python -m pip install --upgrade pip
pip install .
pip install -r docs/requirements.txt
pip install .[docs]
- name: Build documentation
run: |
make --directory=docs html
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [3.6, 3.7]
python-version: [3.7, 3.8, 3.9, "3.10"]

steps:
- uses: actions/checkout@v1
Expand All @@ -28,7 +28,7 @@ jobs:
- name: Install from source
run: |
python -m pip install --upgrade pip
pip install .
pip install .[tests]
- name: Lint with flake8
run: |
pip install flake8
Expand All @@ -38,5 +38,4 @@ jobs:
flake8 ./pyannote --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pip install pytest
pytest
3 changes: 0 additions & 3 deletions docs/requirements.txt

This file was deleted.

3 changes: 1 addition & 2 deletions pyannote/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
# AUTHORS
# Hervé BREDIN - http://herve.niderb.fr

from .base import f_measure

from ._version import get_versions
from .base import f_measure

__version__ = get_versions()["version"]
del get_versions
Expand Down
70 changes: 39 additions & 31 deletions pyannote/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@

# AUTHORS
# Hervé BREDIN - http://herve.niderb.fr
from typing import List, Union, Optional, Set, Tuple


import scipy.stats
import pandas as pd
import numpy as np
import pandas as pd
import scipy.stats
from pyannote.core import Annotation, Timeline

from pyannote.metrics.types import Details, MetricComponents

class BaseMetric(object):

class BaseMetric:
"""
:class:`BaseMetric` is the base class for most pyannote evaluation metrics.

Expand All @@ -43,44 +46,46 @@ class BaseMetric(object):
"""

@classmethod
def metric_name(cls):
def metric_name(cls) -> str:
raise NotImplementedError(
cls.__name__ + " is missing a 'metric_name' class method. "
"It should return the name of the metric as string."
"It should return the name of the metric as string."
)

@classmethod
def metric_components(cls):
def metric_components(cls) -> MetricComponents:
raise NotImplementedError(
cls.__name__ + " is missing a 'metric_components' class method. "
"It should return the list of names of metric components."
"It should return the list of names of metric components."
)

def __init__(self, **kwargs):
super(BaseMetric, self).__init__()
self.metric_name_ = self.__class__.metric_name()
self.components_ = set(self.__class__.metric_components())
self.components_: Set[str] = set(self.__class__.metric_components())
self.reset()

def init_components(self):
return {value: 0.0 for value in self.components_}

def reset(self):
"""Reset accumulated components and metric values"""
self.accumulated_ = dict()
self.results_ = list()
self.accumulated_: Details = dict()
self.results_: List = list()
for value in self.components_:
self.accumulated_[value] = 0.0

def __get_name(self):
return self.__class__.metric_name()

name = property(fget=__get_name, doc="Metric name.")
@property
def name(self):
"""Metric name."""
return self.metric_name()

# TODO: use joblib/locky to allow parallel processing?
# TODO: signature could be something like __call__(self, reference_iterator, hypothesis_iterator, ...)

def __call__(self, reference, hypothesis, detailed=False, uri=None, **kwargs):
def __call__(self, reference: Union[Timeline, Annotation],
hypothesis: Union[Timeline, Annotation],
detailed: bool = False, uri: Optional[str] = None, **kwargs):
"""Compute metric value and accumulate components

Parameters
Expand Down Expand Up @@ -123,7 +128,7 @@ def __call__(self, reference, hypothesis, detailed=False, uri=None, **kwargs):

return components[self.metric_name_]

def report(self, display=False):
def report(self, display: bool = False) -> pd.DataFrame:
"""Evaluation report

Parameters
Expand Down Expand Up @@ -217,7 +222,7 @@ def __abs__(self):
"""Compute metric value from accumulated components"""
return self.compute_metric(self.accumulated_)

def __getitem__(self, component):
def __getitem__(self, component: str) -> Union[float, Details]:
"""Get value of accumulated `component`.

Parameters
Expand All @@ -241,7 +246,10 @@ def __iter__(self):
for uri, component in self.results_:
yield uri, component

def compute_components(self, reference, hypothesis, **kwargs):
def compute_components(self,
reference: Union[Timeline, Annotation],
hypothesis: Union[Timeline, Annotation],
**kwargs) -> Details:
"""Compute metric components

Parameters
Expand All @@ -260,11 +268,11 @@ def compute_components(self, reference, hypothesis, **kwargs):
"""
raise NotImplementedError(
self.__class__.__name__ + " is missing a 'compute_components' method."
"It should return a dictionary where keys are component names "
"and values are component values."
"It should return a dictionary where keys are component names "
"and values are component values."
)

def compute_metric(self, components):
def compute_metric(self, components: Details):
"""Compute metric value from computed `components`

Parameters
Expand All @@ -280,11 +288,12 @@ def compute_metric(self, components):
"""
raise NotImplementedError(
self.__class__.__name__ + " is missing a 'compute_metric' method. "
"It should return the actual value of the metric based "
"on the precomputed component dictionary given as input."
"It should return the actual value of the metric based "
"on the precomputed component dictionary given as input."
)

def confidence_interval(self, alpha=0.9):
def confidence_interval(self, alpha: float = 0.9) \
-> Tuple[float, Tuple[float, float]]:
"""Compute confidence interval on accumulated metric values

Parameters
Expand Down Expand Up @@ -333,10 +342,10 @@ def metric_name(cls):
return PRECISION_NAME

@classmethod
def metric_components(cls):
def metric_components(cls) -> MetricComponents:
return [PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED]

def compute_metric(self, components):
def compute_metric(self, components: Details) -> float:
"""Compute precision from `components`"""
numerator = components[PRECISION_RELEVANT_RETRIEVED]
denominator = components[PRECISION_RETRIEVED]
Expand Down Expand Up @@ -371,10 +380,10 @@ def metric_name(cls):
return RECALL_NAME

@classmethod
def metric_components(cls):
def metric_components(cls) -> MetricComponents:
return [RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED]

def compute_metric(self, components):
def compute_metric(self, components: Details) -> float:
"""Compute recall from `components`"""
numerator = components[RECALL_RELEVANT_RETRIEVED]
denominator = components[RECALL_RELEVANT]
Expand All @@ -387,7 +396,7 @@ def compute_metric(self, components):
return numerator / denominator


def f_measure(precision, recall, beta=1.0):
def f_measure(precision: float, recall: float, beta=1.0) -> float:
"""Compute f-measure

f-measure is defined as follows:
Expand All @@ -398,4 +407,3 @@ def f_measure(precision, recall, beta=1.0):
if precision + recall == 0.0:
return 0
return (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)

33 changes: 21 additions & 12 deletions pyannote/metrics/binary_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,21 @@
# AUTHORS
# Hervé BREDIN - http://herve.niderb.fr

import numpy as np
from collections import Counter
from typing import Tuple

import numpy as np
import sklearn.metrics
from numpy.typing import ArrayLike
from sklearn.base import BaseEstimator
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection._split import _CVIterableWrapper

from .types import CalibrationMethod

def det_curve(y_true, scores, distances=False):

def det_curve(y_true: ArrayLike, scores: ArrayLike, distances: bool = False) \
-> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
"""DET curve

Parameters
Expand Down Expand Up @@ -71,13 +77,16 @@ def det_curve(y_true, scores, distances=False):

# estimate equal error rate
eer_index = np.where(fpr > fnr)[0][0]
eer = .25 * (fpr[eer_index-1] + fpr[eer_index] +
fnr[eer_index-1] + fnr[eer_index])
eer = .25 * (fpr[eer_index - 1] + fpr[eer_index] +
fnr[eer_index - 1] + fnr[eer_index])

return fpr, fnr, thresholds, eer


def precision_recall_curve(y_true, scores, distances=False):
def precision_recall_curve(y_true: ArrayLike,
scores: ArrayLike,
distances: bool = False) \
-> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
"""Precision-recall curve

Parameters
Expand Down Expand Up @@ -120,18 +129,18 @@ class _Passthrough(BaseEstimator):
"""Dummy binary classifier used by score Calibration class"""

def __init__(self):
super(_Passthrough, self).__init__()
super().__init__()
self.classes_ = np.array([False, True], dtype=np.bool)

def fit(self, scores, y_true):
return self

def decision_function(self, scores):
def decision_function(self, scores: ArrayLike):
"""Returns the input scores unchanged"""
return scores


class Calibration(object):
class Calibration:
"""Probability calibration for binary classification tasks

Parameters
Expand All @@ -154,12 +163,12 @@ class Calibration(object):

"""

def __init__(self, equal_priors=False, method='isotonic'):
super(Calibration, self).__init__()
def __init__(self, equal_priors: bool = False,
method: CalibrationMethod = 'isotonic'):
self.method = method
self.equal_priors = equal_priors

def fit(self, scores, y_true):
def fit(self, scores: ArrayLike, y_true: ArrayLike):
"""Train calibration

Parameters
Expand Down Expand Up @@ -209,7 +218,7 @@ def fit(self, scores, y_true):

return self

def transform(self, scores):
def transform(self, scores: ArrayLike):
"""Calibrate scores into probabilities

Parameters
Expand Down
40 changes: 16 additions & 24 deletions pyannote/metrics/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,52 +90,44 @@

"""

# command line parsing
from docopt import docopt

import sys
import functools
import json
import sys
import warnings
import functools

import numpy as np
import pandas as pd
from tabulate import tabulate

from pyannote.core import Timeline
# command line parsing
from docopt import docopt
from pyannote.core import Annotation
from pyannote.database.util import load_rttm

from pyannote.core import Timeline
# evaluation protocols
from pyannote.database import get_protocol
from pyannote.database.util import get_annotated
from pyannote.database.util import load_rttm
from tabulate import tabulate

from pyannote.metrics.detection import DetectionErrorRate
from pyannote.metrics.detection import DetectionAccuracy
from pyannote.metrics.detection import DetectionRecall
from pyannote.metrics.detection import DetectionErrorRate
from pyannote.metrics.detection import DetectionPrecision

from pyannote.metrics.segmentation import SegmentationPurity
from pyannote.metrics.segmentation import SegmentationCoverage
from pyannote.metrics.segmentation import SegmentationPrecision
from pyannote.metrics.segmentation import SegmentationRecall

from pyannote.metrics.diarization import GreedyDiarizationErrorRate
from pyannote.metrics.detection import DetectionRecall
from pyannote.metrics.diarization import DiarizationCoverage
from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.metrics.diarization import DiarizationPurity
from pyannote.metrics.diarization import DiarizationCoverage

from pyannote.metrics.diarization import GreedyDiarizationErrorRate
from pyannote.metrics.identification import IdentificationErrorRate
from pyannote.metrics.identification import IdentificationPrecision
from pyannote.metrics.identification import IdentificationRecall

from pyannote.metrics.segmentation import SegmentationCoverage
from pyannote.metrics.segmentation import SegmentationPrecision
from pyannote.metrics.segmentation import SegmentationPurity
from pyannote.metrics.segmentation import SegmentationRecall
from pyannote.metrics.spotting import LowLatencySpeakerSpotting

showwarning_orig = warnings.showwarning


def showwarning(message, category, *args, **kwargs):
import sys

print(category.__name__ + ":", str(message))


Expand Down
Loading