Skip to content

Commit

Permalink
New wrapper FeatureShare (#2120)
Browse files Browse the repository at this point in the history
* initial proposal
* tests
* add attribute to all nessesary metrics
* redo to use named features
* add tests
* docs
* changelog
* fix mypy
* fix docs
* Apply suggestions from code review

Co-authored-by: Jirka Borovec <[email protected]>

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
4 people authored Dec 22, 2023
1 parent 28025c1 commit 62adb40
Show file tree
Hide file tree
Showing 13 changed files with 292 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `aggregate` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220))


- Added `FeatureShare` wrapper to share submodules containing feature extractors between metrics ([#2120](https://github.com/Lightning-AI/torchmetrics/pull/2120))


- Added `SpatialDistortionIndex` metric to image domain ([#2260](https://github.com/Lightning-AI/torchmetrics/pull/2260))


Expand Down
16 changes: 16 additions & 0 deletions docs/source/wrappers/feature_share.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
.. customcarditem::
:header: Feature Sharing
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/graph_classification.svg
:tags: Wrappers

.. include:: ../links.rst

###############
Feature Sharing
###############

Module Interface
________________

.. autoclass:: torchmetrics.wrappers.FeatureShare
:exclude-members: update, compute
1 change: 1 addition & 0 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ class FrechetInceptionDistance(Metric):
fake_features_num_samples: Tensor

inception: Module
feature_network: str = "inception"

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/image/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class InceptionScore(Metric):

features: List
inception: Module
feature_network: str = "inception"

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/image/kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ class KernelInceptionDistance(Metric):

real_features: List[Tensor]
fake_features: List[Tensor]
inception: Module
feature_network: str = "inception"

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/image/lpip.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric):

sum_scores: Tensor
total: Tensor
feature_network: str = "net"

# due to the use of named tuple in the backbone the net variable cannot be scripted
__jit_ignored_attributes__: ClassVar[List[str]] = ["net"]
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/image/mifid.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class MemorizationInformedFrechetInceptionDistance(Metric):
fake_features: List[Tensor]

inception: Module
feature_network: str = "inception"

def __init__(
self,
Expand Down
3 changes: 3 additions & 0 deletions src/torchmetrics/image/perceptual_path_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ class PerceptualPathLength(Metric):
higher_is_better: Optional[bool] = True
full_state_update: bool = True

net: nn.Module
feature_network: str = "net"

def __init__(
self,
num_samples: int = 10_000,
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/multimodal/clip_iqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class CLIPImageQualityAssessment(Metric):

anchors: Tensor
probs_list: List[Tensor]
feature_network: str = "model"

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class CLIPScore(Metric):

score: Tensor
n_samples: Tensor
feature_network: str = "model"

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.wrappers.bootstrapping import BootStrapper
from torchmetrics.wrappers.classwise import ClasswiseWrapper
from torchmetrics.wrappers.feature_share import FeatureShare
from torchmetrics.wrappers.minmax import MinMaxMetric
from torchmetrics.wrappers.multioutput import MultioutputWrapper
from torchmetrics.wrappers.multitask import MultitaskWrapper
Expand All @@ -22,6 +23,7 @@
__all__ = [
"BootStrapper",
"ClasswiseWrapper",
"FeatureShare",
"MinMaxMetric",
"MultioutputWrapper",
"MultitaskWrapper",
Expand Down
126 changes: 126 additions & 0 deletions src/torchmetrics/wrappers/feature_share.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import lru_cache
from typing import Any, Dict, Optional, Sequence, Union

from torch.nn import Module

from torchmetrics.collections import MetricCollection
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn

__doctest_requires__ = {("FeatureShare",): ["torch_fidelity"]}


class NetworkCache(Module):
"""Create a cached version of a network to be shared between metrics.
Because the different metrics may invoke the same network multiple times, we can save time by caching the input-
output pairs of the network.
"""

def __init__(self, network: Module, max_size: int = 100) -> None:
super().__init__()
self.max_size = max_size
self.network = lru_cache(maxsize=self.max_size)(network)

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Call the network with the given arguments."""
return self.network(*args, **kwargs)


class FeatureShare(MetricCollection):
"""Specialized metric collection that facilitates sharing features between metrics.
Certain metrics rely on an underlying expensive neural network for feature extraction when computing the metric.
This wrapper allows to share the feature extraction between multiple metrics, which can save a lot of time and
memory. This is achieved by making a shared instance of the network between the metrics and secondly by caching
the input-output pairs of the network, such the subsequent calls to the network with the same input will be much
faster.
Args:
metrics: One of the following:
* list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name
as key for output dict. Therefore, two metrics of the same class cannot be chained this way.
* dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict.
Use this format if you want to chain together multiple of the same metric with different parameters.
Note that the keys in the output dict will be sorted alphabetically.
max_cache_size: maximum number of input-output pairs to cache per metric. By default, this is none which means
that the cache will be set to the number of metrics in the collection meaning that all features will be
cached and shared across all metrics per batch.
Example::
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.wrappers import FeatureShare
>>> from torchmetrics.image import FrechetInceptionDistance, KernelInceptionDistance
>>> # initialize the metrics
>>> fs = FeatureShare([FrechetInceptionDistance(), KernelInceptionDistance(subset_size=10, subsets=2)])
>>> # update metric
>>> fs.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=True)
>>> fs.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=False)
>>> # compute metric
>>> fs.compute()
{'FrechetInceptionDistance': tensor(15.1700), 'KernelInceptionDistance': (tensor(-0.0012), tensor(0.0014))}
"""

def __init__(
self,
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]],
max_cache_size: Optional[int] = None,
) -> None:
# disable compute groups because the feature sharing is more custom
super().__init__(metrics=metrics, compute_groups=False)

if max_cache_size is None:
max_cache_size = len(self)
if not isinstance(max_cache_size, int):
raise TypeError(f"max_cache_size should be an integer, but got {max_cache_size}")

try:
first_net = next(iter(self.values()))
network_to_share = getattr(first_net, first_net.feature_network)
except AttributeError as err:
raise AttributeError(
"Tried to extract the network to share from the first metric, but it did not have a `feature_network`"
" attribute. Please make sure that the metric has an attribute with that name,"
" else it cannot be shared."
) from err
cached_net = NetworkCache(network_to_share, max_size=max_cache_size)

# set the cached network to all metrics
for metric_name, metric in self.items():
if not hasattr(metric, "feature_network"):
raise AttributeError(
"Tried to set the cached network to all metrics, but one of the metrics did not have a"
" `feature_network` attribute. Please make sure that all metrics have a attribute with that name,"
f" else it cannot be shared. Failed on metric {metric_name}."
)

# check if its the same network as the first metric
if str(getattr(metric, metric.feature_network)) != str(network_to_share):
rank_zero_warn(
f"The network to share between the metrics is not the same for all metrics."
f" Metric {metric_name} has a different network than the first metric."
" This may lead to unexpected behavior.",
UserWarning,
)

setattr(metric, metric.feature_network, cached_net)
134 changes: 134 additions & 0 deletions tests/unittests/wrappers/test_feature_share.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from torchmetrics import MetricCollection
from torchmetrics.image import (
FrechetInceptionDistance,
InceptionScore,
KernelInceptionDistance,
LearnedPerceptualImagePatchSimilarity,
StructuralSimilarityIndexMeasure,
)
from torchmetrics.wrappers import FeatureShare


@pytest.mark.parametrize(
"metrics",
[
[FrechetInceptionDistance(), InceptionScore(), KernelInceptionDistance()],
{"fid": FrechetInceptionDistance(), "is": InceptionScore(), "kid": KernelInceptionDistance()},
],
)
def test_initialization(metrics):
"""Test that the feature share wrapper can be initialized."""
fs = FeatureShare(metrics)
assert isinstance(fs, MetricCollection)
assert len(fs) == 3


def test_error_on_missing_feature_network():
"""Test that an error is raised when the feature network is missing."""
with pytest.raises(AttributeError, match="Tried to extract the network to share from the first metric.*"):
FeatureShare([StructuralSimilarityIndexMeasure(), FrechetInceptionDistance()])

with pytest.raises(AttributeError, match="Tried to set the cached network to all metrics, but one of the.*"):
FeatureShare([FrechetInceptionDistance(), StructuralSimilarityIndexMeasure()])


def test_warning_on_mixing_networks():
"""Test that a warning is raised when the metrics use different networks."""
with pytest.warns(UserWarning, match="The network to share between the metrics is not.*"):
FeatureShare([FrechetInceptionDistance(), InceptionScore(), LearnedPerceptualImagePatchSimilarity()])


def test_feature_share_speed():
"""Test that the feature share wrapper is faster than the metric collection."""
mc = MetricCollection([FrechetInceptionDistance(), InceptionScore(), KernelInceptionDistance()])
fs = FeatureShare([FrechetInceptionDistance(), InceptionScore(), KernelInceptionDistance()])
x = torch.randint(255, (1, 3, 64, 64), dtype=torch.uint8)

import time

start = time.time()
for _ in range(10):
x = torch.randint(255, (1, 3, 64, 64), dtype=torch.uint8)
mc.update(x, real=True)
end = time.time()
mc_time = end - start

start = time.time()
for _ in range(10):
x = torch.randint(255, (1, 3, 64, 64), dtype=torch.uint8)
fs.update(x, real=True)
end = time.time()
fs_time = end - start

assert fs_time < mc_time, "The feature share wrapper should be faster than the metric collection."


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_memory():
"""Test that the feature share wrapper uses less memory than the metric collection."""
base_memory = torch.cuda.memory_allocated()

fid = FrechetInceptionDistance().cuda()
inception = InceptionScore().cuda()
kid = KernelInceptionDistance().cuda()

memory_before_fs = torch.cuda.memory_allocated()
assert memory_before_fs > base_memory, "The memory usage should be higher after initializing the metrics."

torch.cuda.empty_cache()

FeatureShare([fid, inception, kid]).cuda()
memory_after_fs = torch.cuda.memory_allocated()

assert (
memory_after_fs > base_memory
), "The memory usage should be higher after initializing the feature share wrapper."
assert (
memory_after_fs < memory_before_fs
), "The memory usage should be higher after initializing the feature share wrapper."


def test_same_result_as_individual():
"""Test that the feature share wrapper gives the same result as the individual metrics."""
fid = FrechetInceptionDistance(feature=768)
inception = InceptionScore(feature=768)
kid = KernelInceptionDistance(feature=768, subset_size=10, subsets=2)

fs = FeatureShare([fid, inception, kid])

x = torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8)
fs.update(x, real=True)
fid.update(x, real=True)
inception.update(x)
kid.update(x, real=True)
x = torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8)
fs.update(x, real=False)
fid.update(x, real=False)
inception.update(x)
kid.update(x, real=False)

fs_res = fs.compute()
fid_res = fid.compute()
inception_res = inception.compute()
kid_res = kid.compute()

assert fs_res["FrechetInceptionDistance"] == fid_res
assert fs_res["InceptionScore"][0] == inception_res[0]
assert fs_res["InceptionScore"][1] == inception_res[1]
assert fs_res["KernelInceptionDistance"][0] == kid_res[0]
assert fs_res["KernelInceptionDistance"][1] == kid_res[1]

0 comments on commit 62adb40

Please sign in to comment.