-
Notifications
You must be signed in to change notification settings - Fork 411
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial proposal * tests * add attribute to all nessesary metrics * redo to use named features * add tests * docs * changelog * fix mypy * fix docs * Apply suggestions from code review Co-authored-by: Jirka Borovec <[email protected]> --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka <[email protected]>
- Loading branch information
1 parent
28025c1
commit 62adb40
Showing
13 changed files
with
292 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
.. customcarditem:: | ||
:header: Feature Sharing | ||
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/graph_classification.svg | ||
:tags: Wrappers | ||
|
||
.. include:: ../links.rst | ||
|
||
############### | ||
Feature Sharing | ||
############### | ||
|
||
Module Interface | ||
________________ | ||
|
||
.. autoclass:: torchmetrics.wrappers.FeatureShare | ||
:exclude-members: update, compute |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Copyright The Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from functools import lru_cache | ||
from typing import Any, Dict, Optional, Sequence, Union | ||
|
||
from torch.nn import Module | ||
|
||
from torchmetrics.collections import MetricCollection | ||
from torchmetrics.metric import Metric | ||
from torchmetrics.utilities import rank_zero_warn | ||
|
||
__doctest_requires__ = {("FeatureShare",): ["torch_fidelity"]} | ||
|
||
|
||
class NetworkCache(Module): | ||
"""Create a cached version of a network to be shared between metrics. | ||
Because the different metrics may invoke the same network multiple times, we can save time by caching the input- | ||
output pairs of the network. | ||
""" | ||
|
||
def __init__(self, network: Module, max_size: int = 100) -> None: | ||
super().__init__() | ||
self.max_size = max_size | ||
self.network = lru_cache(maxsize=self.max_size)(network) | ||
|
||
def forward(self, *args: Any, **kwargs: Any) -> Any: | ||
"""Call the network with the given arguments.""" | ||
return self.network(*args, **kwargs) | ||
|
||
|
||
class FeatureShare(MetricCollection): | ||
"""Specialized metric collection that facilitates sharing features between metrics. | ||
Certain metrics rely on an underlying expensive neural network for feature extraction when computing the metric. | ||
This wrapper allows to share the feature extraction between multiple metrics, which can save a lot of time and | ||
memory. This is achieved by making a shared instance of the network between the metrics and secondly by caching | ||
the input-output pairs of the network, such the subsequent calls to the network with the same input will be much | ||
faster. | ||
Args: | ||
metrics: One of the following: | ||
* list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name | ||
as key for output dict. Therefore, two metrics of the same class cannot be chained this way. | ||
* dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict. | ||
Use this format if you want to chain together multiple of the same metric with different parameters. | ||
Note that the keys in the output dict will be sorted alphabetically. | ||
max_cache_size: maximum number of input-output pairs to cache per metric. By default, this is none which means | ||
that the cache will be set to the number of metrics in the collection meaning that all features will be | ||
cached and shared across all metrics per batch. | ||
Example:: | ||
>>> import torch | ||
>>> _ = torch.manual_seed(42) | ||
>>> from torchmetrics.wrappers import FeatureShare | ||
>>> from torchmetrics.image import FrechetInceptionDistance, KernelInceptionDistance | ||
>>> # initialize the metrics | ||
>>> fs = FeatureShare([FrechetInceptionDistance(), KernelInceptionDistance(subset_size=10, subsets=2)]) | ||
>>> # update metric | ||
>>> fs.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=True) | ||
>>> fs.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=False) | ||
>>> # compute metric | ||
>>> fs.compute() | ||
{'FrechetInceptionDistance': tensor(15.1700), 'KernelInceptionDistance': (tensor(-0.0012), tensor(0.0014))} | ||
""" | ||
|
||
def __init__( | ||
self, | ||
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], | ||
max_cache_size: Optional[int] = None, | ||
) -> None: | ||
# disable compute groups because the feature sharing is more custom | ||
super().__init__(metrics=metrics, compute_groups=False) | ||
|
||
if max_cache_size is None: | ||
max_cache_size = len(self) | ||
if not isinstance(max_cache_size, int): | ||
raise TypeError(f"max_cache_size should be an integer, but got {max_cache_size}") | ||
|
||
try: | ||
first_net = next(iter(self.values())) | ||
network_to_share = getattr(first_net, first_net.feature_network) | ||
except AttributeError as err: | ||
raise AttributeError( | ||
"Tried to extract the network to share from the first metric, but it did not have a `feature_network`" | ||
" attribute. Please make sure that the metric has an attribute with that name," | ||
" else it cannot be shared." | ||
) from err | ||
cached_net = NetworkCache(network_to_share, max_size=max_cache_size) | ||
|
||
# set the cached network to all metrics | ||
for metric_name, metric in self.items(): | ||
if not hasattr(metric, "feature_network"): | ||
raise AttributeError( | ||
"Tried to set the cached network to all metrics, but one of the metrics did not have a" | ||
" `feature_network` attribute. Please make sure that all metrics have a attribute with that name," | ||
f" else it cannot be shared. Failed on metric {metric_name}." | ||
) | ||
|
||
# check if its the same network as the first metric | ||
if str(getattr(metric, metric.feature_network)) != str(network_to_share): | ||
rank_zero_warn( | ||
f"The network to share between the metrics is not the same for all metrics." | ||
f" Metric {metric_name} has a different network than the first metric." | ||
" This may lead to unexpected behavior.", | ||
UserWarning, | ||
) | ||
|
||
setattr(metric, metric.feature_network, cached_net) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright The Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import pytest | ||
import torch | ||
from torchmetrics import MetricCollection | ||
from torchmetrics.image import ( | ||
FrechetInceptionDistance, | ||
InceptionScore, | ||
KernelInceptionDistance, | ||
LearnedPerceptualImagePatchSimilarity, | ||
StructuralSimilarityIndexMeasure, | ||
) | ||
from torchmetrics.wrappers import FeatureShare | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"metrics", | ||
[ | ||
[FrechetInceptionDistance(), InceptionScore(), KernelInceptionDistance()], | ||
{"fid": FrechetInceptionDistance(), "is": InceptionScore(), "kid": KernelInceptionDistance()}, | ||
], | ||
) | ||
def test_initialization(metrics): | ||
"""Test that the feature share wrapper can be initialized.""" | ||
fs = FeatureShare(metrics) | ||
assert isinstance(fs, MetricCollection) | ||
assert len(fs) == 3 | ||
|
||
|
||
def test_error_on_missing_feature_network(): | ||
"""Test that an error is raised when the feature network is missing.""" | ||
with pytest.raises(AttributeError, match="Tried to extract the network to share from the first metric.*"): | ||
FeatureShare([StructuralSimilarityIndexMeasure(), FrechetInceptionDistance()]) | ||
|
||
with pytest.raises(AttributeError, match="Tried to set the cached network to all metrics, but one of the.*"): | ||
FeatureShare([FrechetInceptionDistance(), StructuralSimilarityIndexMeasure()]) | ||
|
||
|
||
def test_warning_on_mixing_networks(): | ||
"""Test that a warning is raised when the metrics use different networks.""" | ||
with pytest.warns(UserWarning, match="The network to share between the metrics is not.*"): | ||
FeatureShare([FrechetInceptionDistance(), InceptionScore(), LearnedPerceptualImagePatchSimilarity()]) | ||
|
||
|
||
def test_feature_share_speed(): | ||
"""Test that the feature share wrapper is faster than the metric collection.""" | ||
mc = MetricCollection([FrechetInceptionDistance(), InceptionScore(), KernelInceptionDistance()]) | ||
fs = FeatureShare([FrechetInceptionDistance(), InceptionScore(), KernelInceptionDistance()]) | ||
x = torch.randint(255, (1, 3, 64, 64), dtype=torch.uint8) | ||
|
||
import time | ||
|
||
start = time.time() | ||
for _ in range(10): | ||
x = torch.randint(255, (1, 3, 64, 64), dtype=torch.uint8) | ||
mc.update(x, real=True) | ||
end = time.time() | ||
mc_time = end - start | ||
|
||
start = time.time() | ||
for _ in range(10): | ||
x = torch.randint(255, (1, 3, 64, 64), dtype=torch.uint8) | ||
fs.update(x, real=True) | ||
end = time.time() | ||
fs_time = end - start | ||
|
||
assert fs_time < mc_time, "The feature share wrapper should be faster than the metric collection." | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") | ||
def test_memory(): | ||
"""Test that the feature share wrapper uses less memory than the metric collection.""" | ||
base_memory = torch.cuda.memory_allocated() | ||
|
||
fid = FrechetInceptionDistance().cuda() | ||
inception = InceptionScore().cuda() | ||
kid = KernelInceptionDistance().cuda() | ||
|
||
memory_before_fs = torch.cuda.memory_allocated() | ||
assert memory_before_fs > base_memory, "The memory usage should be higher after initializing the metrics." | ||
|
||
torch.cuda.empty_cache() | ||
|
||
FeatureShare([fid, inception, kid]).cuda() | ||
memory_after_fs = torch.cuda.memory_allocated() | ||
|
||
assert ( | ||
memory_after_fs > base_memory | ||
), "The memory usage should be higher after initializing the feature share wrapper." | ||
assert ( | ||
memory_after_fs < memory_before_fs | ||
), "The memory usage should be higher after initializing the feature share wrapper." | ||
|
||
|
||
def test_same_result_as_individual(): | ||
"""Test that the feature share wrapper gives the same result as the individual metrics.""" | ||
fid = FrechetInceptionDistance(feature=768) | ||
inception = InceptionScore(feature=768) | ||
kid = KernelInceptionDistance(feature=768, subset_size=10, subsets=2) | ||
|
||
fs = FeatureShare([fid, inception, kid]) | ||
|
||
x = torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8) | ||
fs.update(x, real=True) | ||
fid.update(x, real=True) | ||
inception.update(x) | ||
kid.update(x, real=True) | ||
x = torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8) | ||
fs.update(x, real=False) | ||
fid.update(x, real=False) | ||
inception.update(x) | ||
kid.update(x, real=False) | ||
|
||
fs_res = fs.compute() | ||
fid_res = fid.compute() | ||
inception_res = inception.compute() | ||
kid_res = kid.compute() | ||
|
||
assert fs_res["FrechetInceptionDistance"] == fid_res | ||
assert fs_res["InceptionScore"][0] == inception_res[0] | ||
assert fs_res["InceptionScore"][1] == inception_res[1] | ||
assert fs_res["KernelInceptionDistance"][0] == kid_res[0] | ||
assert fs_res["KernelInceptionDistance"][1] == kid_res[1] |