Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Oct 10, 2024
2 parents 51062d5 + 6bfb775 commit c32e25c
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 16 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed for Pearson changes inputs ([#2765](https://github.com/Lightning-AI/torchmetrics/pull/2765))


- Fixed bug in `PESQ` metric where `NoUtterancesError` prevented calculating on a batch of data ([#2753](https://github.com/Lightning-AI/torchmetrics/pull/2753))


## [1.4.2] - 2022-09-12

### Added
Expand Down
74 changes: 73 additions & 1 deletion docs/source/pages/lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ also manually log the output of the metrics.

class MyModule(LightningModule):

def __init__(self):
def __init__(self, num_classes):
...
self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
Expand Down Expand Up @@ -157,13 +157,85 @@ Additionally, we highly recommend that the two ways of logging are not mixed as
self.valid_acc.update(logits, y)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)

In general if you are logging multiple metrics we highly recommend that you combine them into a single metric object
using the :class:`~torchmetrics.MetricCollection` class and then replacing the ``self.log`` calls with ``self.log_dict``,
assuming that all metrics receive the same input.

.. testcode:: python

class MyModule(LightningModule):

def __init__(self):
...
self.train_metrics = torchmetrics.MetricCollection(
{
"accuracy": torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes),
"f1": torchmetrics.classification.F1(task="multiclass", num_classes=num_classes),
},
prefix="train_",
)
self.valid_metrics = self.train_metrics.clone(prefix="valid_")

def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
batch_value = self.train_metrics(preds, y)
self.log_dict(batch_value)

def on_train_epoch_end(self):
self.train_metrics.reset()

def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_metrics.update(logits, y)

def on_validation_epoch_end(self, outputs):
self.log_dict(self.valid_metrics.compute())
self.valid_metrics.reset()

***************
Common Pitfalls
***************

The following contains a list of pitfalls to be aware of:

* Logging a `MetricCollection` object directly using ``self.log_dict`` is only supported if all metrics in the
collection returns a scalar tensor. If any of the metrics in the collection returns a non-scalar tensor,
the logging will fail. This can especially happen when either nesting multiple ``MetricCollection`` objects or when
using wrapper metrics such as :class:`~torchmetrics.wrappers.ClasswiseWrapper`,
:class:`~torchmetrics.wrappers.MinMaxMetric` etc. inside a ``MetricCollection`` since all these wrappers return
dicts or lists of tensors. It is still possible to log such nested metrics manually because the ``MetricCollection``
object will try to flatten everything into a single dict. Example:

.. testcode:: python

class MyModule(LightningModule):

def __init__(self):
super().__init__()
self.train_metrics = MetricCollection(
{
"macro_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="macro")),
"weighted_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="weighted")),
},
prefix="train_",
)

def training_step(self, batch, batch_idx):
...
# logging the MetricCollection object directly will fail
self.log_dict(self.train_metrics(preds, target))

# manually computing the result and then logging will work
batch_values = self.train_metrics(preds, target)
self.log_dict(batch_values, on_step=True, on_epoch=False)
...

def on_train_epoch_end(self):
self.train_metrics.reset()

* Modular metrics contain internal states that should belong to only one DataLoader. In case you are using multiple DataLoaders,
it is recommended to initialize a separate modular metric instances for each DataLoader and use them separately. The same holds
for using separate metrics for training, validation and testing.
Expand Down
11 changes: 9 additions & 2 deletions src/torchmetrics/functional/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import numpy as np
import torch
from torch import Tensor
Expand Down Expand Up @@ -83,6 +85,11 @@ def perceptual_evaluation_speech_quality(
)
import pesq as pesq_backend

def _issubtype_number(x: Any) -> bool:
return np.issubdtype(type(x), np.number)

_filter_error_msg = np.vectorize(_issubtype_number)

if fs not in (8000, 16000):
raise ValueError(f"Expected argument `fs` to either be 8000 or 16000 but got {fs}")
if mode not in ("wb", "nb"):
Expand All @@ -103,8 +110,8 @@ def perceptual_evaluation_speech_quality(
pesq_val_np = np.empty(shape=(preds_np.shape[0]))
for b in range(preds_np.shape[0]):
pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode)
pesq_val = torch.from_numpy(pesq_val_np)
pesq_val = pesq_val.reshape(preds.shape[:-1])
pesq_val = torch.from_numpy(pesq_val_np[_filter_error_msg(pesq_val_np)].astype(np.float32))
pesq_val = pesq_val.reshape(len(pesq_val))

if keep_same_device:
return pesq_val.to(preds.device)
Expand Down
75 changes: 62 additions & 13 deletions tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, MulticlassAccuracy
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from torchmetrics.utilities.prints import rank_zero_only
from torchmetrics.wrappers import ClasswiseWrapper, MultitaskWrapper
from torchmetrics.wrappers import ClasswiseWrapper, MinMaxMetric, MultitaskWrapper

from integrations.lightning.boring_model import BoringModel

Expand Down Expand Up @@ -523,43 +523,34 @@ def __init__(self) -> None:
},
prefix="train_",
)

self.val_metrics = MetricCollection(
{
"macro_accuracy": MulticlassAccuracy(num_classes=5, average="macro"),
"classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=5, average=None)),
},
prefix="val_",
)
self.val_metrics = self.train_metrics.clone(prefix="val_")

def training_step(self, batch, batch_idx):
loss = self(batch).sum()

preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)

self.train_metrics.update(preds, target)
batch_values = self.train_metrics.compute()
self.log_dict(batch_values, on_step=True, on_epoch=False)

return {"loss": loss}

def validation_step(self, batch, batch_idx):
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)

self.val_metrics.update(preds, target)

def on_validation_epoch_end(self):
self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True)
self.val_metrics.reset()

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
max_epochs=2,
log_every_n_steps=1,
)
trainer.fit(model)
Expand All @@ -572,3 +563,61 @@ def on_validation_epoch_end(self):
for i in range(5):
assert f"train_multiclassaccuracy_{i}" in logged
assert f"val_multiclassaccuracy_{i}" in logged


def test_collection_minmax_lightning_integration(tmpdir):
"""Check the integration of MinMaxWrapper, MetricCollection and LightningModule.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2763
"""

class TestModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.train_metrics = MetricCollection(
{
"macro_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="macro")),
"weighted_accuracy": MinMaxMetric(MulticlassAccuracy(num_classes=5, average="weighted")),
},
prefix="train_",
)
self.val_metrics = self.train_metrics.clone(prefix="val_")

def training_step(self, batch, batch_idx):
loss = self(batch).sum()
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)

self.train_metrics.update(preds, target)
batch_values = self.train_metrics.compute()
self.log_dict(batch_values, on_step=True, on_epoch=False)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)
self.val_metrics.update(preds, target)

def on_validation_epoch_end(self):
self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True)
self.val_metrics.reset()

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
log_every_n_steps=1,
)
trainer.fit(model)

logged = trainer.logged_metrics

# check that all metrics are logged
for prefix in ["train_", "val_"]:
for metric in ["macro_accuracy", "weighted_accuracy"]:
for key in ["max", "min", "raw"]:
assert f"{prefix}{metric}_{key}" in logged

0 comments on commit c32e25c

Please sign in to comment.