Skip to content

Commit

Permalink
Improve CI stability for clustering examples (#2785)
Browse files Browse the repository at this point in the history
* Update src/torchmetrics/functional/clustering/calinski_harabasz_score.py
* Update src/torchmetrics/functional/clustering/davies_bouldin_score.py
  • Loading branch information
SkafteNicki authored Oct 15, 2024
1 parent b1f4db2 commit 801cec8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
10 changes: 5 additions & 5 deletions src/torchmetrics/clustering/calinski_harabasz_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class CalinskiHarabaszScore(Metric):
Example::
>>> from torch import randn, randint
>>> from torchmetrics.clustering import CalinskiHarabaszScore
>>> data = randn(10, 3)
>>> labels = randint(3, (10,))
>>> data = randn(20, 3)
>>> labels = randint(3, (20,))
>>> metric = CalinskiHarabaszScore()
>>> metric(data, labels)
tensor(3.0053)
tensor(2.2128)
"""

Expand Down Expand Up @@ -108,7 +108,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> import torch
>>> from torchmetrics.clustering import CalinskiHarabaszScore
>>> metric = CalinskiHarabaszScore()
>>> metric.update(torch.randn(10, 3), torch.randint(0, 2, (10,)))
>>> metric.update(torch.randn(20, 3), torch.randint(3, (20,)))
>>> fig_, ax_ = metric.plot(metric.compute())
.. plot::
Expand All @@ -120,7 +120,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> metric = CalinskiHarabaszScore()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randn(10, 3), torch.randint(0, 2, (10,))))
... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def calinski_harabasz_score(data: Tensor, labels: Tensor) -> Tensor:
Example:
>>> from torch import randn, randint
>>> from torchmetrics.functional.clustering import calinski_harabasz_score
>>> data = randn(10, 3)
>>> labels = randint(0, 2, (10,))
>>> data = randn(20, 3)
>>> labels = randint(0, 3, (20,))
>>> calinski_harabasz_score(data, labels)
tensor(3.4998)
tensor(2.2128)
"""
_validate_intrinsic_cluster_data(data, labels)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def davies_bouldin_score(data: Tensor, labels: Tensor) -> Tensor:
Example:
>>> from torch import randn, randint
>>> from torchmetrics.functional.clustering import davies_bouldin_score
>>> data = randn(10, 3)
>>> labels = randint(0, 2, (10,))
>>> data = randn(20, 3)
>>> labels = randint(0, 3, (20,))
>>> davies_bouldin_score(data, labels)
tensor(1.3249)
tensor(2.7418)
"""
_validate_intrinsic_cluster_data(data, labels)
Expand Down

0 comments on commit 801cec8

Please sign in to comment.