Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename hinge to hinge_loss #734

Merged
merged 34 commits into from
Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
935b005
change hinge to hinge_loss
Jan 10, 2022
d7dc64f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2022
4564b04
indentation
Jan 10, 2022
687e1cb
Merge remote-tracking branch 'upstream/master'
Jan 10, 2022
1844263
indentation
Jan 10, 2022
ed72497
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2022
2ec1206
hingeLoss import
Jan 10, 2022
613caa3
Merge branch 'master' of https://github.com/getgaurav2/metrics
Jan 10, 2022
67c3f28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2022
b57cb39
hingeLoss import __init__
Jan 10, 2022
5c12358
hingeLoss correct function
Jan 10, 2022
8462fd8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2022
e0c0a5e
hingeLoss correct import
Jan 10, 2022
758f485
hingeLoss correct import
Jan 10, 2022
aa0bd85
correct import stmt
Jan 10, 2022
589a835
flake8
Jan 10, 2022
48bbf28
warn
Borda Jan 10, 2022
04f88c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2022
1081a8c
warn
Borda Jan 10, 2022
3604d1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2022
bbaacfe
docs
Borda Jan 10, 2022
cf620e7
chlog
Borda Jan 10, 2022
1c360f2
fix
Borda Jan 10, 2022
9288222
Add `MultiScaleStructuralSimilarityIndexMeasure` (#679)
stancld Jan 10, 2022
839d981
Merge branch 'master' into master
mergify[bot] Jan 10, 2022
1bc351c
Merge branch 'master' into master
mergify[bot] Jan 10, 2022
eb7f348
disable deepsource
Borda Jan 10, 2022
863adfb
Merge branch 'master' into master
Borda Jan 10, 2022
33a9a4d
unify si_ssim & ssim
Borda Jan 10, 2022
db80337
Merge branch 'master' into master
Borda Jan 10, 2022
e8efb78
Merge branch 'master' into master
mergify[bot] Jan 10, 2022
ff6e703
Merge branch 'master' into master
mergify[bot] Jan 10, 2022
18af9e2
Merge branch 'master' into master
mergify[bot] Jan 10, 2022
92f82c1
Merge branch 'master' into master
mergify[bot] Jan 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions tests/classification/test_hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from tests.classification.inputs import Input
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester
from torchmetrics import Hinge
from torchmetrics.functional import hinge
from torchmetrics import HingeLoss
from torchmetrics.functional import hinge_loss
from torchmetrics.functional.classification.hinge import MulticlassMode

torch.manual_seed(42)
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multi
ddp=ddp,
preds=preds,
target=target,
metric_class=Hinge,
metric_class=HingeLoss,
sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode),
dist_sync_on_step=dist_sync_on_step,
metric_args={
Expand All @@ -108,16 +108,16 @@ def test_hinge_fn(self, preds, target, squared, multiclass_mode):
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=partial(hinge, squared=squared, multiclass_mode=multiclass_mode),
metric_functional=partial(hinge_loss, squared=squared, multiclass_mode=multiclass_mode),
sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode),
)

def test_hinge_differentiability(self, preds, target, squared, multiclass_mode):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=Hinge,
metric_functional=partial(hinge, squared=squared, multiclass_mode=multiclass_mode),
metric_module=HingeLoss,
metric_functional=partial(hinge_loss, squared=squared, multiclass_mode=multiclass_mode),
)


Expand Down Expand Up @@ -148,9 +148,9 @@ def test_hinge_differentiability(self, preds, target, squared, multiclass_mode):
)
def test_bad_inputs_fn(preds, target, multiclass_mode):
with pytest.raises(ValueError):
_ = hinge(preds, target, multiclass_mode=multiclass_mode)
_ = hinge_loss(preds, target, multiclass_mode=multiclass_mode)


def test_bad_inputs_class():
with pytest.raises(ValueError):
Hinge(multiclass_mode="invalid_mode")
HingeLoss(multiclass_mode="invalid_mode")
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
FBeta,
HammingDistance,
Hinge,
HingeLoss,
IoU,
JaccardIndex,
KLDivergence,
Expand Down Expand Up @@ -119,6 +120,7 @@
"FBeta",
"HammingDistance",
"Hinge",
"HingeLoss",
"JaccardIndex",
"KLDivergence",
"MatthewsCorrcoef",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from torchmetrics.classification.f_beta import F1, F1Score, FBeta # noqa: F401
from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import Hinge # noqa: F401
from torchmetrics.classification.hinge import Hinge, HingeLoss # noqa: F401
from torchmetrics.classification.iou import IoU # noqa: F401
from torchmetrics.classification.jaccard import JaccardIndex # noqa: F401
from torchmetrics.classification.kl_divergence import KLDivergence # noqa: F401
Expand Down
61 changes: 56 additions & 5 deletions torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
from warnings import warn

from torch import Tensor, tensor

from torchmetrics.functional.classification.hinge import MulticlassMode, _hinge_compute, _hinge_update
from torchmetrics.metric import Metric


class Hinge(Metric):
class HingeLoss(Metric):
r"""
Computes the mean `Hinge loss`_, typically used for Support Vector
Machines (SVMs). In the binary case it is defined as:
Expand Down Expand Up @@ -62,24 +63,24 @@ class Hinge(Metric):

Example (binary case):
>>> import torch
>>> from torchmetrics import Hinge
>>> from torchmetrics import HingeLoss
>>> target = torch.tensor([0, 1, 1])
>>> preds = torch.tensor([-2.2, 2.4, 0.1])
>>> hinge = Hinge()
>>> hinge = HingeLoss()
>>> hinge(preds, target)
tensor(0.3000)

Example (default / multiclass case):
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge = Hinge()
>>> hinge = HingeLoss()
>>> hinge(preds, target)
tensor(2.9000)

Example (multiclass example, one vs all mode):
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge = Hinge(multiclass_mode="one-vs-all")
>>> hinge = HingeLoss(multiclass_mode="one-vs-all")
>>> hinge(preds, target)
tensor([2.2333, 1.5000, 1.2333])

Expand Down Expand Up @@ -126,3 +127,53 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore

def compute(self) -> Tensor:
return _hinge_compute(self.measure, self.total)


class Hinge(HingeLoss):
r"""
Computes the mean `Hinge loss`_, typically used for Support Vector Machines (SVMs).

.. deprecated:: v0.7
Use :class:`torchmetrics.HingeLoss`. Will be removed in v0.8.

Example (binary case):
>>> import torch
>>> from torchmetrics import Hinge
>>> target = torch.tensor([0, 1, 1])
>>> preds = torch.tensor([-2.2, 2.4, 0.1])
>>> hinge = Hinge()
>>> hinge(preds, target)
tensor(0.3000)

Example (default / multiclass case):
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge = Hinge()
>>> hinge(preds, target)
tensor(2.9000)

Example (multiclass example, one vs all mode):
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge = Hinge(multiclass_mode="one-vs-all")
>>> hinge(preds, target)
tensor([2.2333, 1.5000, 1.2333])

"""

def __init__(
self,
squared: bool = False,
multiclass_mode: Optional[Union[str, MulticlassMode]] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
) -> None:
warn("`Hinge` was renamed to `HingeLoss` in v0.7 and it will be removed in v0.8", DeprecationWarning)
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
3 changes: 2 additions & 1 deletion torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torchmetrics.functional.classification.dice import dice_score
from torchmetrics.functional.classification.f_beta import f1, f1_score, fbeta
from torchmetrics.functional.classification.hamming_distance import hamming_distance
from torchmetrics.functional.classification.hinge import hinge
from torchmetrics.functional.classification.hinge import hinge, hinge_loss
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.jaccard import jaccard_index
from torchmetrics.functional.classification.kl_divergence import kl_divergence
Expand Down Expand Up @@ -97,6 +97,7 @@
"fbeta",
"hamming_distance",
"hinge",
"hinge_loss",
"image_gradients",
"jaccard_index",
"kl_divergence",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchmetrics.functional.classification.dice import dice_score # noqa: F401
from torchmetrics.functional.classification.f_beta import f1, f1_score, fbeta # noqa: F401
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.hinge import hinge # noqa: F401
from torchmetrics.functional.classification.hinge import hinge, hinge_loss # noqa: F401
from torchmetrics.functional.classification.jaccard import jaccard_index # noqa: F401
from torchmetrics.functional.classification.kl_divergence import kl_divergence # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
Expand Down
82 changes: 80 additions & 2 deletions torchmetrics/functional/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from warnings import warn

import torch
from torch import Tensor, tensor
Expand Down Expand Up @@ -154,6 +155,83 @@ def _hinge_compute(measure: Tensor, total: Tensor) -> Tensor:
return measure / total


def hinge_loss(
preds: Tensor,
target: Tensor,
squared: bool = False,
multiclass_mode: Optional[Union[str, MulticlassMode]] = None,
) -> Tensor:
r"""
Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs).

In the binary case it is defined as:

.. math::
\text{Hinge loss} = \max(0, 1 - y \times \hat{y})

Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction.

In the multi-class case, when ``multiclass_mode=None`` (default), ``multiclass_mode=MulticlassMode.CRAMMER_SINGER``
or ``multiclass_mode="crammer-singer"``, this metric will compute the multi-class hinge loss defined by Crammer and
Singer as:

.. math::
\text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right)

Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes),
and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class.

In the multi-class case when ``multiclass_mode=MulticlassMode.ONE_VS_ALL`` or ``multiclass_mode='one-vs-all'``, this
metric will use a one-vs-all approach to compute the hinge loss, giving a vector of C outputs where each entry pits
that class against all remaining classes.

This metric can optionally output the mean of the squared hinge loss by setting ``squared=True``

Only accepts inputs with preds shape of (N) (binary) or (N, C) (multi-class) and target shape of (N).

Args:
preds: Predictions from model (as float outputs from decision function).
target: Ground truth labels.
squared:
If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss (default).
multiclass_mode:
Which approach to use for multi-class inputs (has no effect in the binary case). ``None`` (default),
``MulticlassMode.CRAMMER_SINGER`` or ``"crammer-singer"``, uses the Crammer Singer multi-class hinge loss.
``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"`` computes the hinge loss in a one-vs-all fashion.

Raises:
ValueError:
If preds shape is not of size (N) or (N, C).
ValueError:
If target shape is not of size (N).
ValueError:
If ``multiclass_mode`` is not: None, ``MulticlassMode.CRAMMER_SINGER``, ``"crammer-singer"``,
``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"``.

Example (binary case):
>>> import torch
>>> from torchmetrics.functional import hinge_loss
>>> target = torch.tensor([0, 1, 1])
>>> preds = torch.tensor([-2.2, 2.4, 0.1])
>>> hinge_loss(preds, target)
tensor(0.3000)

Example (default / multiclass case):
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge_loss(preds, target)
tensor(2.9000)

Example (multiclass example, one vs all mode):
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge_loss(preds, target, multiclass_mode="one-vs-all")
tensor([2.2333, 1.5000, 1.2333])
"""
measure, total = _hinge_update(preds, target, squared=squared, multiclass_mode=multiclass_mode)
return _hinge_compute(measure, total)


def hinge(
preds: Tensor,
target: Tensor,
Expand Down Expand Up @@ -227,5 +305,5 @@ def hinge(
>>> hinge(preds, target, multiclass_mode="one-vs-all")
tensor([2.2333, 1.5000, 1.2333])
"""
measure, total = _hinge_update(preds, target, squared=squared, multiclass_mode=multiclass_mode)
return _hinge_compute(measure, total)
warn("`hinge` was renamed to `hinge_loss` in v0.7 and it will be removed in v0.8", DeprecationWarning)
return hinge_loss(preds, target, squared, multiclass_mode)