From 62adb4022b30e0fe5c47589d9ecd82d3ff041117 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 23 Dec 2023 00:45:49 +0100 Subject: [PATCH] New wrapper `FeatureShare` (#2120) * initial proposal * tests * add attribute to all nessesary metrics * redo to use named features * add tests * docs * changelog * fix mypy * fix docs * Apply suggestions from code review Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka --- CHANGELOG.md | 3 + docs/source/wrappers/feature_share.rst | 16 +++ src/torchmetrics/image/fid.py | 1 + src/torchmetrics/image/inception.py | 1 + src/torchmetrics/image/kid.py | 2 + src/torchmetrics/image/lpip.py | 1 + src/torchmetrics/image/mifid.py | 1 + .../image/perceptual_path_length.py | 3 + src/torchmetrics/multimodal/clip_iqa.py | 1 + src/torchmetrics/multimodal/clip_score.py | 1 + src/torchmetrics/wrappers/__init__.py | 2 + src/torchmetrics/wrappers/feature_share.py | 126 ++++++++++++++++ .../unittests/wrappers/test_feature_share.py | 134 ++++++++++++++++++ 13 files changed, 292 insertions(+) create mode 100644 docs/source/wrappers/feature_share.rst create mode 100644 src/torchmetrics/wrappers/feature_share.py create mode 100644 tests/unittests/wrappers/test_feature_share.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 99397ce93aa..afb32ac9dc7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `aggregate` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220)) +- Added `FeatureShare` wrapper to share submodules containing feature extractors between metrics ([#2120](https://github.com/Lightning-AI/torchmetrics/pull/2120)) + + - Added `SpatialDistortionIndex` metric to image domain ([#2260](https://github.com/Lightning-AI/torchmetrics/pull/2260)) diff --git a/docs/source/wrappers/feature_share.rst b/docs/source/wrappers/feature_share.rst new file mode 100644 index 00000000000..8220eb70825 --- /dev/null +++ b/docs/source/wrappers/feature_share.rst @@ -0,0 +1,16 @@ +.. customcarditem:: + :header: Feature Sharing + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/graph_classification.svg + :tags: Wrappers + +.. include:: ../links.rst + +############### +Feature Sharing +############### + +Module Interface +________________ + +.. autoclass:: torchmetrics.wrappers.FeatureShare + :exclude-members: update, compute diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 0b7a3fb0cb7..9e148bc94ad 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -277,6 +277,7 @@ class FrechetInceptionDistance(Metric): fake_features_num_samples: Tensor inception: Module + feature_network: str = "inception" def __init__( self, diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index cc40d91d71e..57d67c20c91 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -100,6 +100,7 @@ class InceptionScore(Metric): features: List inception: Module + feature_network: str = "inception" def __init__( self, diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index cafa1464b6e..58503530e97 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -163,6 +163,8 @@ class KernelInceptionDistance(Metric): real_features: List[Tensor] fake_features: List[Tensor] + inception: Module + feature_network: str = "inception" def __init__( self, diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index c40cc0abdc2..77b87f61344 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -98,6 +98,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric): sum_scores: Tensor total: Tensor + feature_network: str = "net" # due to the use of named tuple in the backbone the net variable cannot be scripted __jit_ignored_attributes__: ClassVar[List[str]] = ["net"] diff --git a/src/torchmetrics/image/mifid.py b/src/torchmetrics/image/mifid.py index c5c16aab666..0ba3f2f338b 100644 --- a/src/torchmetrics/image/mifid.py +++ b/src/torchmetrics/image/mifid.py @@ -150,6 +150,7 @@ class MemorizationInformedFrechetInceptionDistance(Metric): fake_features: List[Tensor] inception: Module + feature_network: str = "inception" def __init__( self, diff --git a/src/torchmetrics/image/perceptual_path_length.py b/src/torchmetrics/image/perceptual_path_length.py index bb7561f7c4b..58e5fa3ae19 100644 --- a/src/torchmetrics/image/perceptual_path_length.py +++ b/src/torchmetrics/image/perceptual_path_length.py @@ -122,6 +122,9 @@ class PerceptualPathLength(Metric): higher_is_better: Optional[bool] = True full_state_update: bool = True + net: nn.Module + feature_network: str = "net" + def __init__( self, num_samples: int = 10_000, diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py index 252d7ad5931..5d0bd4866e1 100644 --- a/src/torchmetrics/multimodal/clip_iqa.py +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -168,6 +168,7 @@ class CLIPImageQualityAssessment(Metric): anchors: Tensor probs_list: List[Tensor] + feature_network: str = "model" def __init__( self, diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index a5a8201893a..cd133c87e44 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -99,6 +99,7 @@ class CLIPScore(Metric): score: Tensor n_samples: Tensor + feature_network: str = "model" def __init__( self, diff --git a/src/torchmetrics/wrappers/__init__.py b/src/torchmetrics/wrappers/__init__.py index 37116cdfcbe..aa38ef577d0 100644 --- a/src/torchmetrics/wrappers/__init__.py +++ b/src/torchmetrics/wrappers/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from torchmetrics.wrappers.bootstrapping import BootStrapper from torchmetrics.wrappers.classwise import ClasswiseWrapper +from torchmetrics.wrappers.feature_share import FeatureShare from torchmetrics.wrappers.minmax import MinMaxMetric from torchmetrics.wrappers.multioutput import MultioutputWrapper from torchmetrics.wrappers.multitask import MultitaskWrapper @@ -22,6 +23,7 @@ __all__ = [ "BootStrapper", "ClasswiseWrapper", + "FeatureShare", "MinMaxMetric", "MultioutputWrapper", "MultitaskWrapper", diff --git a/src/torchmetrics/wrappers/feature_share.py b/src/torchmetrics/wrappers/feature_share.py new file mode 100644 index 00000000000..14a4c051c64 --- /dev/null +++ b/src/torchmetrics/wrappers/feature_share.py @@ -0,0 +1,126 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import lru_cache +from typing import Any, Dict, Optional, Sequence, Union + +from torch.nn import Module + +from torchmetrics.collections import MetricCollection +from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn + +__doctest_requires__ = {("FeatureShare",): ["torch_fidelity"]} + + +class NetworkCache(Module): + """Create a cached version of a network to be shared between metrics. + + Because the different metrics may invoke the same network multiple times, we can save time by caching the input- + output pairs of the network. + + """ + + def __init__(self, network: Module, max_size: int = 100) -> None: + super().__init__() + self.max_size = max_size + self.network = lru_cache(maxsize=self.max_size)(network) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Call the network with the given arguments.""" + return self.network(*args, **kwargs) + + +class FeatureShare(MetricCollection): + """Specialized metric collection that facilitates sharing features between metrics. + + Certain metrics rely on an underlying expensive neural network for feature extraction when computing the metric. + This wrapper allows to share the feature extraction between multiple metrics, which can save a lot of time and + memory. This is achieved by making a shared instance of the network between the metrics and secondly by caching + the input-output pairs of the network, such the subsequent calls to the network with the same input will be much + faster. + + Args: + metrics: One of the following: + + * list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name + as key for output dict. Therefore, two metrics of the same class cannot be chained this way. + + + * dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict. + Use this format if you want to chain together multiple of the same metric with different parameters. + Note that the keys in the output dict will be sorted alphabetically. + + max_cache_size: maximum number of input-output pairs to cache per metric. By default, this is none which means + that the cache will be set to the number of metrics in the collection meaning that all features will be + cached and shared across all metrics per batch. + + Example:: + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.wrappers import FeatureShare + >>> from torchmetrics.image import FrechetInceptionDistance, KernelInceptionDistance + >>> # initialize the metrics + >>> fs = FeatureShare([FrechetInceptionDistance(), KernelInceptionDistance(subset_size=10, subsets=2)]) + >>> # update metric + >>> fs.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=True) + >>> fs.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=False) + >>> # compute metric + >>> fs.compute() + {'FrechetInceptionDistance': tensor(15.1700), 'KernelInceptionDistance': (tensor(-0.0012), tensor(0.0014))} + + """ + + def __init__( + self, + metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], + max_cache_size: Optional[int] = None, + ) -> None: + # disable compute groups because the feature sharing is more custom + super().__init__(metrics=metrics, compute_groups=False) + + if max_cache_size is None: + max_cache_size = len(self) + if not isinstance(max_cache_size, int): + raise TypeError(f"max_cache_size should be an integer, but got {max_cache_size}") + + try: + first_net = next(iter(self.values())) + network_to_share = getattr(first_net, first_net.feature_network) + except AttributeError as err: + raise AttributeError( + "Tried to extract the network to share from the first metric, but it did not have a `feature_network`" + " attribute. Please make sure that the metric has an attribute with that name," + " else it cannot be shared." + ) from err + cached_net = NetworkCache(network_to_share, max_size=max_cache_size) + + # set the cached network to all metrics + for metric_name, metric in self.items(): + if not hasattr(metric, "feature_network"): + raise AttributeError( + "Tried to set the cached network to all metrics, but one of the metrics did not have a" + " `feature_network` attribute. Please make sure that all metrics have a attribute with that name," + f" else it cannot be shared. Failed on metric {metric_name}." + ) + + # check if its the same network as the first metric + if str(getattr(metric, metric.feature_network)) != str(network_to_share): + rank_zero_warn( + f"The network to share between the metrics is not the same for all metrics." + f" Metric {metric_name} has a different network than the first metric." + " This may lead to unexpected behavior.", + UserWarning, + ) + + setattr(metric, metric.feature_network, cached_net) diff --git a/tests/unittests/wrappers/test_feature_share.py b/tests/unittests/wrappers/test_feature_share.py new file mode 100644 index 00000000000..96ba7a67b1e --- /dev/null +++ b/tests/unittests/wrappers/test_feature_share.py @@ -0,0 +1,134 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from torchmetrics import MetricCollection +from torchmetrics.image import ( + FrechetInceptionDistance, + InceptionScore, + KernelInceptionDistance, + LearnedPerceptualImagePatchSimilarity, + StructuralSimilarityIndexMeasure, +) +from torchmetrics.wrappers import FeatureShare + + +@pytest.mark.parametrize( + "metrics", + [ + [FrechetInceptionDistance(), InceptionScore(), KernelInceptionDistance()], + {"fid": FrechetInceptionDistance(), "is": InceptionScore(), "kid": KernelInceptionDistance()}, + ], +) +def test_initialization(metrics): + """Test that the feature share wrapper can be initialized.""" + fs = FeatureShare(metrics) + assert isinstance(fs, MetricCollection) + assert len(fs) == 3 + + +def test_error_on_missing_feature_network(): + """Test that an error is raised when the feature network is missing.""" + with pytest.raises(AttributeError, match="Tried to extract the network to share from the first metric.*"): + FeatureShare([StructuralSimilarityIndexMeasure(), FrechetInceptionDistance()]) + + with pytest.raises(AttributeError, match="Tried to set the cached network to all metrics, but one of the.*"): + FeatureShare([FrechetInceptionDistance(), StructuralSimilarityIndexMeasure()]) + + +def test_warning_on_mixing_networks(): + """Test that a warning is raised when the metrics use different networks.""" + with pytest.warns(UserWarning, match="The network to share between the metrics is not.*"): + FeatureShare([FrechetInceptionDistance(), InceptionScore(), LearnedPerceptualImagePatchSimilarity()]) + + +def test_feature_share_speed(): + """Test that the feature share wrapper is faster than the metric collection.""" + mc = MetricCollection([FrechetInceptionDistance(), InceptionScore(), KernelInceptionDistance()]) + fs = FeatureShare([FrechetInceptionDistance(), InceptionScore(), KernelInceptionDistance()]) + x = torch.randint(255, (1, 3, 64, 64), dtype=torch.uint8) + + import time + + start = time.time() + for _ in range(10): + x = torch.randint(255, (1, 3, 64, 64), dtype=torch.uint8) + mc.update(x, real=True) + end = time.time() + mc_time = end - start + + start = time.time() + for _ in range(10): + x = torch.randint(255, (1, 3, 64, 64), dtype=torch.uint8) + fs.update(x, real=True) + end = time.time() + fs_time = end - start + + assert fs_time < mc_time, "The feature share wrapper should be faster than the metric collection." + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_memory(): + """Test that the feature share wrapper uses less memory than the metric collection.""" + base_memory = torch.cuda.memory_allocated() + + fid = FrechetInceptionDistance().cuda() + inception = InceptionScore().cuda() + kid = KernelInceptionDistance().cuda() + + memory_before_fs = torch.cuda.memory_allocated() + assert memory_before_fs > base_memory, "The memory usage should be higher after initializing the metrics." + + torch.cuda.empty_cache() + + FeatureShare([fid, inception, kid]).cuda() + memory_after_fs = torch.cuda.memory_allocated() + + assert ( + memory_after_fs > base_memory + ), "The memory usage should be higher after initializing the feature share wrapper." + assert ( + memory_after_fs < memory_before_fs + ), "The memory usage should be higher after initializing the feature share wrapper." + + +def test_same_result_as_individual(): + """Test that the feature share wrapper gives the same result as the individual metrics.""" + fid = FrechetInceptionDistance(feature=768) + inception = InceptionScore(feature=768) + kid = KernelInceptionDistance(feature=768, subset_size=10, subsets=2) + + fs = FeatureShare([fid, inception, kid]) + + x = torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8) + fs.update(x, real=True) + fid.update(x, real=True) + inception.update(x) + kid.update(x, real=True) + x = torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8) + fs.update(x, real=False) + fid.update(x, real=False) + inception.update(x) + kid.update(x, real=False) + + fs_res = fs.compute() + fid_res = fid.compute() + inception_res = inception.compute() + kid_res = kid.compute() + + assert fs_res["FrechetInceptionDistance"] == fid_res + assert fs_res["InceptionScore"][0] == inception_res[0] + assert fs_res["InceptionScore"][1] == inception_res[1] + assert fs_res["KernelInceptionDistance"][0] == kid_res[0] + assert fs_res["KernelInceptionDistance"][1] == kid_res[1]