diff --git a/CHANGELOG.md b/CHANGELOG.md index dfcc70232ed..0ed1a49c9a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed `MultitaskWrapper` not being able to be logged in lightning when using metric collections ([#2349](https://github.com/Lightning-AI/torchmetrics/pull/2349)) + + - Fixed high memory consumption in `Perplexity` metric ([#2346](https://github.com/Lightning-AI/torchmetrics/pull/2346)) diff --git a/src/torchmetrics/wrappers/multitask.py b/src/torchmetrics/wrappers/multitask.py index 0cfd1eb2f11..d1cf70bce81 100644 --- a/src/torchmetrics/wrappers/multitask.py +++ b/src/torchmetrics/wrappers/multitask.py @@ -103,17 +103,49 @@ def __init__( super().__init__() self.task_metrics = nn.ModuleDict(task_metrics) - def items(self) -> Iterable[Tuple[str, nn.Module]]: - """Iterate over task and task metrics.""" - return self.task_metrics.items() + def items(self, flatten: bool = True) -> Iterable[Tuple[str, nn.Module]]: + """Iterate over task and task metrics. - def keys(self) -> Iterable[str]: - """Iterate over task names.""" - return self.task_metrics.keys() + Args: + flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection. + If False, will iterate over the task names and the corresponding metrics. + + """ + for task_name, metric in self.task_metrics.items(): + if flatten and isinstance(metric, MetricCollection): + for sub_metric_name, sub_metric in metric.items(): + yield f"{task_name}_{sub_metric_name}", sub_metric + else: + yield task_name, metric + + def keys(self, flatten: bool = True) -> Iterable[str]: + """Iterate over task names. + + Args: + flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection. + If False, will iterate over the task names and the corresponding metrics. + + """ + for task_name, metric in self.task_metrics.items(): + if flatten and isinstance(metric, MetricCollection): + for sub_metric_name in metric: + yield f"{task_name}_{sub_metric_name}" + else: + yield task_name - def values(self) -> Iterable[nn.Module]: - """Iterate over task metrics.""" - return self.task_metrics.values() + def values(self, flatten: bool = True) -> Iterable[nn.Module]: + """Iterate over task metrics. + + Args: + flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection. + If False, will iterate over the task names and the corresponding metrics. + + """ + for metric in self.task_metrics.values(): + if flatten and isinstance(metric, MetricCollection): + yield from metric.values() + else: + yield metric @staticmethod def _check_task_metrics_type(task_metrics: Dict[str, Union[Metric, MetricCollection]]) -> None: diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index 6b5523b6de0..469c0dad800 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -28,7 +28,7 @@ from torchmetrics import MetricCollection from torchmetrics.aggregation import SumMetric from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision -from torchmetrics.regression import MeanSquaredError +from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError from torchmetrics.wrappers import MultitaskWrapper from integrations.helpers import no_warning_call @@ -366,22 +366,34 @@ def test_task_wrapper_lightning_logging(tmpdir): class TestModel(BoringModel): def __init__(self) -> None: super().__init__() - self.metric = MultitaskWrapper({"classification": BinaryAccuracy(), "regression": MeanSquaredError()}) + self.multitask = MultitaskWrapper({"classification": BinaryAccuracy(), "regression": MeanSquaredError()}) + self.multitask_collection = MultitaskWrapper( + { + "classification": MetricCollection([BinaryAccuracy(), BinaryAveragePrecision()]), + "regression": MetricCollection([MeanSquaredError(), MeanAbsoluteError()]), + } + ) + self.accuracy = BinaryAccuracy() self.mse = MeanSquaredError() def training_step(self, batch, batch_idx): preds = torch.rand(10) target = torch.rand(10) - self.metric( - {"classification": preds.round(), "regression": preds}, - {"classification": target.round(), "regression": target}, + self.multitask( + {"classification": preds, "regression": preds}, + {"classification": target.round().int(), "regression": target}, + ) + self.multitask_collection( + {"classification": preds, "regression": preds}, + {"classification": target.round().int(), "regression": target}, ) self.accuracy(preds.round(), target.round()) self.mse(preds, target) self.log("accuracy", self.accuracy, on_epoch=True) self.log("mse", self.mse, on_epoch=True) - self.log_dict(self.metric, on_epoch=True) + self.log_dict(self.multitask, on_epoch=True) + self.log_dict(self.multitask_collection, on_epoch=True) return self.step(batch) model = TestModel() @@ -404,6 +416,10 @@ def training_step(self, batch, batch_idx): assert torch.allclose(logged["accuracy_epoch"], logged["classification_epoch"]) assert torch.allclose(logged["mse_step"], logged["regression_step"]) assert torch.allclose(logged["mse_epoch"], logged["regression_epoch"]) + assert "regression_MeanAbsoluteError_epoch" in logged + assert "regression_MeanSquaredError_epoch" in logged + assert "classification_BinaryAccuracy_epoch" in logged + assert "classification_BinaryAveragePrecision_epoch" in logged def test_scriptable(tmpdir): diff --git a/tests/unittests/wrappers/test_multitask.py b/tests/unittests/wrappers/test_multitask.py index 57ff8f8efaa..43240255d96 100644 --- a/tests/unittests/wrappers/test_multitask.py +++ b/tests/unittests/wrappers/test_multitask.py @@ -209,6 +209,55 @@ def test_nested_multitask_wrapper(): assert _dict_results_same_as_individual_results(classification_results, regression_results, multitask_results) +@pytest.mark.parametrize("method", ["keys", "items", "values"]) +@pytest.mark.parametrize("flatten", [True, False]) +def test_key_value_items_method(method, flatten): + """Test the keys, items, and values methods of the MultitaskWrapper.""" + multitask = MultitaskWrapper( + { + "classification": MetricCollection([BinaryAccuracy(), BinaryF1Score()]), + "regression": MetricCollection([MeanSquaredError(), MeanAbsoluteError()]), + } + ) + if method == "keys": + output = list(multitask.keys(flatten=flatten)) + elif method == "items": + output = list(multitask.items(flatten=flatten)) + elif method == "values": + output = list(multitask.values(flatten=flatten)) + + if flatten: + assert len(output) == 4 + if method == "keys": + assert output == [ + "classification_BinaryAccuracy", + "classification_BinaryF1Score", + "regression_MeanSquaredError", + "regression_MeanAbsoluteError", + ] + elif method == "items": + assert output == [ + ("classification_BinaryAccuracy", BinaryAccuracy()), + ("classification_BinaryF1Score", BinaryF1Score()), + ("regression_MeanSquaredError", MeanSquaredError()), + ("regression_MeanAbsoluteError", MeanAbsoluteError()), + ] + elif method == "values": + assert output == [BinaryAccuracy(), BinaryF1Score(), MeanSquaredError(), MeanAbsoluteError()] + else: + assert len(output) == 2 + if method == "keys": + assert output == ["classification", "regression"] + elif method == "items": + assert output[0][0] == "classification" + assert output[1][0] == "regression" + assert isinstance(output[0][1], MetricCollection) + assert isinstance(output[1][1], MetricCollection) + elif method == "values": + assert isinstance(output[0], MetricCollection) + assert isinstance(output[1], MetricCollection) + + def test_clone_with_prefix_and_postfix(): """Check that the clone method works with prefix and postfix arguments.""" multitask_metrics = MultitaskWrapper({"Classification": BinaryAccuracy(), "Regression": MeanSquaredError()})