diff --git a/CHANGELOG.md b/CHANGELOG.md index f7d97420e0a..c7770a90667 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added warning to `ClipScore` if long captions are detected and truncate ([#2001](https://github.com/Lightning-AI/torchmetrics/pull/2001)) +- Added `CLIPImageQualityAssessment` to multimodal package ([#1931](https://github.com/Lightning-AI/torchmetrics/pull/1931)) + + ### Changed - diff --git a/docs/source/links.rst b/docs/source/links.rst index 842b2c9bfb8..4ca837ccd64 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -144,6 +144,8 @@ .. _Seamless Scene Segmentation paper: https://arxiv.org/abs/1905.01220 .. _Fleiss kappa: https://en.wikipedia.org/wiki/Fleiss%27_kappa .. _VIF: https://ieeexplore.ieee.org/abstract/document/1576816 +.. _CLIP-IQA: https://arxiv.org/abs/2207.12396 +.. _CLIP: https://arxiv.org/abs/2103.00020 .. _PPL : https://arxiv.org/pdf/1812.04948 .. _CIOU: https://arxiv.org/abs/2005.03572 .. _DIOU: https://arxiv.org/abs/1911.08287v1 diff --git a/docs/source/multimodal/clip_iqa.rst b/docs/source/multimodal/clip_iqa.rst new file mode 100644 index 00000000000..074f35a50bf --- /dev/null +++ b/docs/source/multimodal/clip_iqa.rst @@ -0,0 +1,24 @@ +.. customcarditem:: + :header: CLIP IQA + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg + :tags: Image + +.. include:: ../links.rst + +######################################## +CLIP Image Quality Assessment (CLIP-IQA) +######################################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.multimodal.CLIPImageQualityAssessment + :noindex: + :exclude-members: update, compute + + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.multimodal.clip_image_quality_assessment + :noindex: diff --git a/requirements/multimodal.txt b/requirements/multimodal.txt index 4cff71599c7..1448edaf8f5 100644 --- a/requirements/multimodal.txt +++ b/requirements/multimodal.txt @@ -2,3 +2,4 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment transformers >=4.10.0, <4.30.3 +piq <=0.8.0 diff --git a/src/torchmetrics/functional/multimodal/__init__.py b/src/torchmetrics/functional/multimodal/__init__.py index d4b26efa639..812f5e69805 100644 --- a/src/torchmetrics/functional/multimodal/__init__.py +++ b/src/torchmetrics/functional/multimodal/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10 if _TRANSFORMERS_GREATER_EQUAL_4_10: + from torchmetrics.functional.multimodal.clip_iqa import clip_image_quality_assessment from torchmetrics.functional.multimodal.clip_score import clip_score - __all__ = ["clip_score"] + __all__ = ["clip_score", "clip_image_quality_assessment"] diff --git a/src/torchmetrics/functional/multimodal/clip_iqa.py b/src/torchmetrics/functional/multimodal/clip_iqa.py new file mode 100644 index 00000000000..fcf1f4e7ae7 --- /dev/null +++ b/src/torchmetrics/functional/multimodal/clip_iqa.py @@ -0,0 +1,330 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Literal, Tuple, Union + +import torch +from torch import Tensor + +from torchmetrics.functional.multimodal.clip_score import _get_clip_model_and_processor +from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout +from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10 + +if _TRANSFORMERS_GREATER_EQUAL_4_10: + from transformers import CLIPModel as _CLIPModel + from transformers import CLIPProcessor as _CLIPProcessor + + def _download_clip() -> None: + _CLIPModel.from_pretrained("openai/clip-vit-base-patch16") + _CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") + + if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip): + __doctest_skip__ = ["clip_score"] + +else: + __doctest_skip__ = ["clip_image_quality_assessment"] + _CLIPModel = None + _CLIPProcessor = None + +if not _PIQ_GREATER_EQUAL_0_8: + __doctest_skip__ = ["clip_image_quality_assessment"] + +_PROMPTS: Dict[str, Tuple[str, str]] = { + "quality": ("Good photo.", "Bad photo."), + "brightness": ("Bright photo.", "Dark photo."), + "noisiness": ("Clean photo.", "Noisy photo."), + "colorfullness": ("Colorful photo.", "Dull photo."), + "sharpness": ("Sharp photo.", "Blurry photo."), + "contrast": ("High contrast photo.", "Low contrast photo."), + "complexity": ("Complex photo.", "Simple photo."), + "natural": ("Natural photo.", "Synthetic photo."), + "happy": ("Happy photo.", "Sad photo."), + "scary": ("Scary photo.", "Peaceful photo."), + "new": ("New photo.", "Old photo."), + "warm": ("Warm photo.", "Cold photo."), + "real": ("Real photo.", "Abstract photo."), + "beutiful": ("Beautiful photo.", "Ugly photo."), + "lonely": ("Lonely photo.", "Sociable photo."), + "relaxing": ("Relaxing photo.", "Stressful photo."), +} + + +def _get_clip_iqa_model_and_processor( + model_name_or_path: Literal[ + "clip_iqa", + "openai/clip-vit-base-patch16", + "openai/clip-vit-base-patch32", + "openai/clip-vit-large-patch14-336", + "openai/clip-vit-large-patch14", + ] +) -> Tuple[_CLIPModel, _CLIPProcessor]: + """Extract the CLIP model and processor from the model name or path.""" + if model_name_or_path == "clip_iqa": + if not _PIQ_GREATER_EQUAL_0_8: + raise ValueError( + "For metric `clip_iqa` to work with argument `model_name_or_path` set to default value `'clip_iqa'`" + ", package `piq` version v0.8.0 or later must be installed. Either install with `pip install piq` or" + "`pip install torchmetrics[multimodal]`" + ) + + import piq + + model = piq.clip_iqa.clip.load().eval() + # any model checkpoint can be used here because the tokenizer is the same for all + processor = _CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") + return model, processor + return _get_clip_model_and_processor(model_name_or_path) + + +def _clip_iqa_format_prompts(prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",)) -> Tuple[List[str], List[str]]: + """Converts the provided keywords into a list of prompts for the model to calculate the anchor vectors. + + Args: + prompts: A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one + of the availble prompts (see above). Else the input is expected to be a tuple, where each element can be one + of two things: either a string or a tuple of strings. If a string is provided, it must be one of the + availble prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a + positive prompt and the second string must be a negative prompt. + + Returns: + Tuple containing a list of prompts and a list of the names of the prompts. The first list is double the length + of the second list. + + Examples:: + + >>> # single prompt + >>> _clip_iqa_format_prompts(("quality",)) + (['Good photo.', 'Bad photo.'], ['quality']) + >>> # multiple prompts + >>> _clip_iqa_format_prompts(("quality", "brightness")) + (['Good photo.', 'Bad photo.', 'Bright photo.', 'Dark photo.'], ['quality', 'brightness']) + >>> # Custom prompts + >>> _clip_iqa_format_prompts(("quality", ("Super good photo.", "Super bad photo."))) + (['Good photo.', 'Bad photo.', 'Super good photo.', 'Super bad photo.'], ['quality', 'user_defined_0']) + + """ + if not isinstance(prompts, tuple): + raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings") + + prompts_names: List[str] = [] + prompts_list: List[str] = [] + count = 0 + for p in prompts: + if not isinstance(p, (str, tuple)): + raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings") + if isinstance(p, str): + if p not in _PROMPTS: + raise ValueError( + f"All elements of `prompts` must be one of {_PROMPTS.keys()} if not custom tuple promts, got {p}." + ) + prompts_names.append(p) + prompts_list.extend(_PROMPTS[p]) + if isinstance(p, tuple) and len(p) != 2: + raise ValueError("If a tuple is provided in argument `prompts`, it must be of length 2") + if isinstance(p, tuple): + prompts_names.append(f"user_defined_{count}") + prompts_list.extend(p) + count += 1 + + return prompts_list, prompts_names + + +def _clip_iqa_get_anchor_vectors( + model_name_or_path: str, + model: _CLIPModel, + processor: _CLIPProcessor, + prompts_list: List[str], + device: Union[str, torch.device], +) -> Tensor: + """Calculates the anchor vectors for the CLIP IQA metric. + + Args: + model_name_or_path: string indicating the version of the CLIP model to use. + model: The CLIP model + processor: The CLIP processor + prompts_list: A list of prompts + device: The device to use for the calculation + + """ + if model_name_or_path == "clip_iqa": + text_processed = processor(text=prompts_list) + anchors_text = torch.zeros( + len(prompts_list), processor.tokenizer.model_max_length, dtype=torch.long, device=device + ) + for i, tp in enumerate(text_processed["input_ids"]): + anchors_text[i, : len(tp)] = torch.tensor(tp, dtype=torch.long, device=device) + + anchors = model.encode_text(anchors_text).float() + else: + text_processed = processor(text=prompts_list, return_tensors="pt", padding=True) + anchors = model.get_text_features( + text_processed["input_ids"].to(device), text_processed["attention_mask"].to(device) + ) + return anchors / anchors.norm(p=2, dim=-1, keepdim=True) + + +def _clip_iqa_update( + model_name_or_path: str, + images: Tensor, + model: _CLIPModel, + processor: _CLIPProcessor, + data_range: Union[int, float], + device: Union[str, torch.device], +) -> Tensor: + images = images / float(data_range) + """Update function for CLIP IQA.""" + if model_name_or_path == "clip_iqa": + # default mean and std from clip paper, see: + # https://github.com/huggingface/transformers/blob/main/src/transformers/utils/constants.py + default_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1, 3, 1, 1) + default_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1, 3, 1, 1) + images = (images - default_mean) / default_std + img_features = model.encode_image(images.float(), pos_embedding=False).float() + else: + processed_input = processor(images=[i.cpu() for i in images], return_tensors="pt", padding=True) + img_features = model.get_image_features(processed_input["pixel_values"].to(device)) + return img_features / img_features.norm(p=2, dim=-1, keepdim=True) + + +def _clip_iqa_compute( + img_features: Tensor, + anchors: Tensor, + prompts_names: List[str], + format_as_dict: bool = True, +) -> Union[Tensor, Dict[str, Tensor]]: + """Final computation of CLIP IQA.""" + logits_per_image = 100 * img_features @ anchors.t() + probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(-1)[:, :, 0] + if len(prompts_names) == 1: + return probs.squeeze() + if format_as_dict: + return {p: probs[:, i] for i, p in enumerate(prompts_names)} + return probs + + +def clip_image_quality_assessment( + images: Tensor, + model_name_or_path: Literal[ + "clip_iqa", + "openai/clip-vit-base-patch16", + "openai/clip-vit-base-patch32", + "openai/clip-vit-large-patch14-336", + "openai/clip-vit-large-patch14", + ] = "clip_iqa", + data_range: Union[int, float] = 1.0, + prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",), +) -> Union[Tensor, Dict[str, Tensor]]: + """Calculates `CLIP-IQA`_, that can be used to measure the visual content of images. + + The metric is based on the `CLIP`_ model, which is a neural network trained on a variety of (image, text) pairs to + be able to generate a vector representation of the image and the text that is similar if the image and text are + semantically similar. + + The metric works by calculating the cosine similarity between user provided images and pre-defined promts. The + prompts always come in pairs of "positive" and "negative" such as "Good photo." and "Bad photo.". By calculating + the similartity between image embeddings and both the "positive" and "negative" prompt, the metric can determine + which prompt the image is more similar to. The metric then returns the probability that the image is more similar + to the first prompt than the second prompt. + + Build in promts are: + * quality: "Good photo." vs "Bad photo." + * brightness: "Bright photo." vs "Dark photo." + * noisiness: "Clean photo." vs "Noisy photo." + * colorfullness: "Colorful photo." vs "Dull photo." + * sharpness: "Sharp photo." vs "Blurry photo." + * contrast: "High contrast photo." vs "Low contrast photo." + * complexity: "Complex photo." vs "Simple photo." + * natural: "Natural photo." vs "Synthetic photo." + * happy: "Happy photo." vs "Sad photo." + * scary: "Scary photo." vs "Peaceful photo." + * new: "New photo." vs "Old photo." + * warm: "Warm photo." vs "Cold photo." + * real: "Real photo." vs "Abstract photo." + * beutiful: "Beautiful photo." vs "Ugly photo." + * lonely: "Lonely photo." vs "Sociable photo." + * relaxing: "Relaxing photo." vs "Stressful photo." + + Args: + images: Either a single ``[N, C, H, W]`` tensor or a list of ``[C, H, W]`` tensors + model_name_or_path: string indicating the version of the CLIP model to use. By default this argument is set to + ``clip_iqa`` which corresponds to the model used in the original paper. Other availble models are + `"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"` + and `"openai/clip-vit-large-patch14"` + data_range: The maximum value of the input tensor. For example, if the input images are in range [0, 255], + data_range should be 255. The images are normalized by this value. + prompts: A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one + of the availble prompts (see above). Else the input is expected to be a tuple, where each element can be one + of two things: either a string or a tuple of strings. If a string is provided, it must be one of the + availble prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a + positive prompt and the second string must be a negative prompt. + + .. note:: If using the default `clip_iqa` model, the package `piq` must be installed. Either install with + `pip install piq` or `pip install torchmetrics[multimodal]`. + + Returns: + A tensor of shape ``(N,)`` if a single promts is provided. If a list of promts is provided, a dictionary of + with the promts as keys and tensors of shape ``(N,)`` as values. + + Raises: + ModuleNotFoundError: + If transformers package is not installed or version is lower than 4.10.0 + ValueError: + If not all images have format [C, H, W] + ValueError: + If promts is a tuple and it is not of length 2 + ValueError: + If promts is a string and it is not one of the available promts + ValueError: + If promts is a list of strings and not all strings are one of the available promts + + Example:: + Single promt: + + >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment + >>> import torch + >>> _ = torch.manual_seed(42) + >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> clip_image_quality_assessment(imgs, prompts=("quality",)) + tensor([0.8894, 0.8902]) + + Example:: + Multiple promts: + + >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment + >>> import torch + >>> _ = torch.manual_seed(42) + >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> clip_image_quality_assessment(imgs, prompts=("quality", "brightness")) + {'quality': tensor([0.8894, 0.8902]), 'brightness': tensor([0.5507, 0.5208])} + + Example:: + Custom promts. Must always be a tuple of length 2, with a positive and negative prompt. + + >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment + >>> import torch + >>> _ = torch.manual_seed(42) + >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> clip_image_quality_assessment(imgs, prompts=(("Super good photo.", "Super bad photo."), "brightness")) + {'user_defined_0': tensor([0.9652, 0.9629]), 'brightness': tensor([0.5507, 0.5208])} + + """ + prompts_list, prompts_names = _clip_iqa_format_prompts(prompts) + + model, processor = _get_clip_iqa_model_and_processor(model_name_or_path) + device = images.device + model = model.to(device) + + with torch.inference_mode(): + anchors = _clip_iqa_get_anchor_vectors(model_name_or_path, model, processor, prompts_list, device) + img_features = _clip_iqa_update(model_name_or_path, images, model, processor, data_range, device) + return _clip_iqa_compute(img_features, anchors, prompts_names) diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index 72fd179c5c2..42a4aea4f5c 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -87,7 +87,7 @@ def _clip_score_update( return score, len(text) -def _get_model_and_processor( +def _get_clip_model_and_processor( model_name_or_path: Literal[ "openai/clip-vit-base-patch16", "openai/clip-vit-base-patch32", @@ -118,14 +118,14 @@ def clip_score( ) -> Tensor: r"""Calculate `CLIP Score`_ which is a text-to-image similarity metric. - CLIP is a reference free metric that can be used to evaluate the correlation between a generated caption for an - image and the actual content of the image. It has been found to be highly correlated with human judgement. The + CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for + an image and the actual content of the image. It has been found to be highly correlated with human judgement. The metric is defined as: .. math:: \text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0) - which corresponds to the cosine similarity between visual CLIP embedding :math:`E_i` for an image :math:`i` and + which corresponds to the cosine similarity between visual `CLIP`_ embedding :math:`E_i` for an image :math:`i` and textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer to 100 the better. @@ -155,7 +155,7 @@ def clip_score( tensor(24.4255) """ - model, processor = _get_model_and_processor(model_name_or_path) + model, processor = _get_clip_model_and_processor(model_name_or_path) device = images.device if isinstance(images, Tensor) else images[0].device score, _ = _clip_score_update(images, text, model.to(device), processor) score = score.mean(0) diff --git a/src/torchmetrics/image/__init__.py b/src/torchmetrics/image/__init__.py index ace4a50de22..1defa78bbf5 100644 --- a/src/torchmetrics/image/__init__.py +++ b/src/torchmetrics/image/__init__.py @@ -23,7 +23,10 @@ from torchmetrics.image.tv import TotalVariation from torchmetrics.image.uqi import UniversalImageQualityIndex from torchmetrics.image.vif import VisualInformationFidelity -from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE, _TORCHVISION_AVAILABLE +from torchmetrics.utilities.imports import ( + _TORCH_FIDELITY_AVAILABLE, + _TORCHVISION_AVAILABLE, +) __all__ = [ "SpectralDistortionIndex", diff --git a/src/torchmetrics/multimodal/__init__.py b/src/torchmetrics/multimodal/__init__.py index 641ea8a244c..4a4c77d8baa 100644 --- a/src/torchmetrics/multimodal/__init__.py +++ b/src/torchmetrics/multimodal/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10 if _TRANSFORMERS_GREATER_EQUAL_4_10: + from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment from torchmetrics.multimodal.clip_score import CLIPScore - __all__ = ["CLIPScore"] + __all__ = ["CLIPScore", "CLIPImageQualityAssessment"] diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py new file mode 100644 index 00000000000..86310b7892d --- /dev/null +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -0,0 +1,261 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union + +import torch +from torch import Tensor + +from torchmetrics.functional.multimodal.clip_iqa import ( + _clip_iqa_compute, + _clip_iqa_format_prompts, + _clip_iqa_get_anchor_vectors, + _clip_iqa_update, + _get_clip_iqa_model_and_processor, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import ( + _MATPLOTLIB_AVAILABLE, + _PIQ_GREATER_EQUAL_0_8, + _TRANSFORMERS_GREATER_EQUAL_4_10, +) +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _PIQ_GREATER_EQUAL_0_8: + __doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"] + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["CLIPImageQualityAssessment.plot"] + +if _TRANSFORMERS_GREATER_EQUAL_4_10: + from transformers import CLIPModel as _CLIPModel + from transformers import CLIPProcessor as _CLIPProcessor + + def _download_clip() -> None: + _CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + _CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + + if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip): + __doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"] +else: + __doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"] + + +class CLIPImageQualityAssessment(Metric): + """Calculates `CLIP-IQA`_, that can be used to measure the visual content of images. + + The metric is based on the `CLIP`_ model, which is a neural network trained on a variety of (image, text) pairs to + be able to generate a vector representation of the image and the text that is similar if the image and text are + semantically similar. + + The metric works by calculating the cosine similarity between user provided images and pre-defined promts. The + promts always comes in pairs of "positive" and "negative" such as "Good photo." and "Bad photo.". By calculating + the similartity between image embeddings and both the "positive" and "negative" prompt, the metric can determine + which prompt the image is more similar to. The metric then returns the probability that the image is more similar + to the first prompt than the second prompt. + + Build in promts are: + * quality: "Good photo." vs "Bad photo." + * brightness: "Bright photo." vs "Dark photo." + * noisiness: "Clean photo." vs "Noisy photo." + * colorfullness: "Colorful photo." vs "Dull photo." + * sharpness: "Sharp photo." vs "Blurry photo." + * contrast: "High contrast photo." vs "Low contrast photo." + * complexity: "Complex photo." vs "Simple photo." + * natural: "Natural photo." vs "Synthetic photo." + * happy: "Happy photo." vs "Sad photo." + * scary: "Scary photo." vs "Peaceful photo." + * new: "New photo." vs "Old photo." + * warm: "Warm photo." vs "Cold photo." + * real: "Real photo." vs "Abstract photo." + * beutiful: "Beautiful photo." vs "Ugly photo." + * lonely: "Lonely photo." vs "Sociable photo." + * relaxing: "Relaxing photo." vs "Stressful photo." + + As input to ``forward`` and ``update`` the metric accepts the following input + + - ``images`` (:class:`~torch.Tensor`): tensor with images feed to the feature extractor with shape ``(N,C,H,W)`` + + As output of `forward` and `compute` the metric returns the following output + + - ``clip_iqa`` (:class:`~torch.Tensor` or dict of tensors): tensor with the CLIP-IQA score. If a single prompt is + provided, a single tensor with shape ``(N,)`` is returned. If a list of prompts is provided, a dict of tensors + is returned with the prompt as key and the tensor with shape ``(N,)`` as value. + + Args: + model_name_or_path: string indicating the version of the CLIP model to use. Available models are: + + - `"clip_iqa"`, model corresponding to the CLIP-IQA paper. + - `"openai/clip-vit-base-patch16"` + - `"openai/clip-vit-base-patch32"` + - `"openai/clip-vit-large-patch14-336"` + - `"openai/clip-vit-large-patch14"` + + data_range: The maximum value of the input tensor. For example, if the input images are in range [0, 255], + data_range should be 255. The images are normalized by this value. + prompts: A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one + of the availble prompts (see above). Else the input is expected to be a tuple, where each element can be one + of two things: either a string or a tuple of strings. If a string is provided, it must be one of the + availble prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a + positive prompt and the second string must be a negative prompt. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + .. note:: If using the default `clip_iqa` model, the package `piq` must be installed. Either install with + `pip install piq` or `pip install torchmetrics[image]`. + + Raises: + ModuleNotFoundError: + If transformers package is not installed or version is lower than 4.10.0 + ValueError: + If `prompts` is a tuple and it is not of length 2 + ValueError: + If `prompts` is a string and it is not one of the available prompts + ValueError: + If `prompts` is a list of strings and not all strings are one of the available prompts + + Example:: + Single prompt: + + >>> from torchmetrics.multimodal import CLIPImageQualityAssessment + >>> import torch + >>> _ = torch.manual_seed(42) + >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> metric = CLIPImageQualityAssessment() + >>> metric(imgs) + tensor([0.8894, 0.8902]) + + Example:: + Multiple prompts: + + >>> from torchmetrics.multimodal import CLIPImageQualityAssessment + >>> import torch + >>> _ = torch.manual_seed(42) + >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> metric = CLIPImageQualityAssessment(prompts=("quality", "brightness")) + >>> metric(imgs) + {'quality': tensor([0.8894, 0.8902]), 'brightness': tensor([0.5507, 0.5208])} + + Example:: + Custom prompts. Must always be a tuple of length 2, with a positive and negative prompt. + + >>> from torchmetrics.multimodal import CLIPImageQualityAssessment + >>> import torch + >>> _ = torch.manual_seed(42) + >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() + >>> metric = CLIPImageQualityAssessment(prompts=(("Super good photo.", "Super bad photo."), "brightness")) + >>> metric(imgs) + {'user_defined_0': tensor([0.9652, 0.9629]), 'brightness': tensor([0.5507, 0.5208])} + + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = True + plot_lower_bound = 0.0 + plot_upper_bound = 100.0 + + anchors: Tensor + probs_list: List[Tensor] + + def __init__( + self, + model_name_or_path: Literal[ + "clip_iqa", + "openai/clip-vit-base-patch16", + "openai/clip-vit-base-patch32", + "openai/clip-vit-large-patch14-336", + "openai/clip-vit-large-patch14", + ] = "clip_iqa", + data_range: Union[int, float] = 1.0, + prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",), + **kwargs: Any + ) -> None: + super().__init__(**kwargs) + if not (isinstance(data_range, (int, float)) and data_range > 0): + raise ValueError("Argument `data_range` should be a positive number.") + self.data_range = data_range + + prompts_list, prompts_name = _clip_iqa_format_prompts(prompts) + self.prompts_list = prompts_list + self.prompts_name = prompts_name + + self.model, self.processor = _get_clip_iqa_model_and_processor(model_name_or_path) + self.model_name_or_path = model_name_or_path + + with torch.inference_mode(): + anchors = _clip_iqa_get_anchor_vectors( + model_name_or_path, self.model, self.processor, self.prompts_list, self.device + ) + self.register_buffer("anchors", anchors) + + self.add_state("probs_list", [], dist_reduce_fx="cat") + + def update(self, images: Tensor) -> None: + """Update metric state with new data.""" + with torch.inference_mode(): + img_features = _clip_iqa_update( + self.model_name_or_path, images, self.model, self.processor, self.data_range, self.device + ) + probs = _clip_iqa_compute(img_features, self.anchors, self.prompts_name, format_as_dict=False) + if not isinstance(probs, Tensor): + raise ValueError("Output probs should be a tensor") + self.probs_list.append(probs) + + def compute(self) -> Union[Tensor, Dict[str, Tensor]]: + """Compute metric.""" + probs = dim_zero_cat(self.probs_list) + if len(self.prompts_name) == 1: + return probs.squeeze() + return {p: probs[:, i] for i, p in enumerate(self.prompts_name)} + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment + >>> metric = CLIPImageQualityAssessment() + >>> metric.update(torch.rand(1, 3, 224, 224)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment + >>> metric = CLIPImageQualityAssessment() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(1, 3, 224, 224))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index 888dd6963dc..0daa853e374 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -18,7 +18,7 @@ from typing_extensions import Literal from torchmetrics import Metric -from torchmetrics.functional.multimodal.clip_score import _clip_score_update, _get_model_and_processor +from torchmetrics.functional.multimodal.clip_score import _clip_score_update, _get_clip_model_and_processor from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TRANSFORMERS_GREATER_EQUAL_4_10 from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -26,15 +26,13 @@ if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["CLIPScore.plot"] -_DEFAULT_MODEL: str = "openai/clip-vit-large-patch14" - if _TRANSFORMERS_GREATER_EQUAL_4_10: from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor def _download_clip() -> None: - _CLIPModel.from_pretrained(_DEFAULT_MODEL) - _CLIPProcessor.from_pretrained(_DEFAULT_MODEL) + _CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + _CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip): __doctest_skip__ = ["CLIPScore", "CLIPScore.plot"] @@ -45,19 +43,30 @@ def _download_clip() -> None: class CLIPScore(Metric): r"""Calculates `CLIP Score`_ which is a text-to-image similarity metric. - CLIP is a reference free metric that can be used to evaluate the correlation between a generated caption for an - image and the actual content of the image. It has been found to be highly correlated with human judgement. The + CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for + an image and the actual content of the image. It has been found to be highly correlated with human judgement. The metric is defined as: .. math:: \text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0) - which corresponds to the cosine similarity between visual CLIP embedding :math:`E_i` for an image :math:`i` and + which corresponds to the cosine similarity between visual `CLIP`_ embedding :math:`E_i` for an image :math:`i` and textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer to 100 the better. .. note:: Metric is not scriptable + As input to ``forward`` and ``update`` the metric accepts the following input + + - ``images`` (:class:`~torch.Tensor` or list of tensors): tensor with images feed to the feature extractor with. If + a single tensor it should have shape ``(N, C, H, W)``. If a list of tensors, each tensor should have shape + ``(C, H, W)``. ``C`` is the number of channels, ``H`` and ``W`` are the height and width of the image. + - ``text`` (:class:`~str` or :class:`~list` of :class:`~str`): text to compare with the images, one for each image. + + As output of `forward` and `compute` the metric returns the following output + + - ``clip_score`` (:class:`~torch.Tensor`): float scalar tensor with mean CLIP score over samples + Args: model_name_or_path: string indicating the version of the CLIP model to use. Available models are: @@ -87,10 +96,10 @@ class CLIPScore(Metric): higher_is_better: bool = True full_state_update: bool = True plot_lower_bound: float = 0.0 + plot_upper_bound = 100.0 score: Tensor n_samples: Tensor - plot_upper_bound = 100.0 def __init__( self, @@ -99,11 +108,11 @@ def __init__( "openai/clip-vit-base-patch32", "openai/clip-vit-large-patch14-336", "openai/clip-vit-large-patch14", - ] = _DEFAULT_MODEL, # type: ignore[assignment] + ] = "openai/clip-vit-large-patch14", **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.model, self.processor = _get_model_and_processor(model_name_or_path) + self.model, self.processor = _get_clip_model_and_processor(model_name_or_path) self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 43764e05820..ee1dc505477 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -56,5 +56,6 @@ _SCIENCEPLOT_AVAILABLE: bool = package_available("scienceplots") _MULTIPROCESSING_AVAILABLE: bool = package_available("multiprocessing") _XLA_AVAILABLE: bool = package_available("torch_xla") +_PIQ_GREATER_EQUAL_0_8: Optional[bool] = compare_version("piq", operator.ge, "0.8.0") _LATEX_AVAILABLE: bool = shutil.which("latex") is not None diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index 2bb78db9235..c9a37e48972 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -377,6 +377,7 @@ def run_class_metric_test( check_batch: bool = True, fragment_kwargs: bool = False, check_scriptable: bool = True, + check_state_dict: bool = True, atol: Optional[float] = None, **kwargs_update: Any, ): @@ -396,6 +397,7 @@ def run_class_metric_test( calculated across devices for each batch (and not just at the end) fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes check_scriptable: bool indicating if metric should also be tested if it can be scripted + check_state_dict: bool indicating if metric should be tested that its state_dict by default is empty atol: absolute tolerance used for comparison of results, if None will use self.atol kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. @@ -421,6 +423,7 @@ def run_class_metric_test( atol=atol, fragment_kwargs=fragment_kwargs, check_scriptable=check_scriptable, + check_state_dict=check_state_dict, **kwargs_update, ), [(rank, NUM_PROCESSES) for rank in range(NUM_PROCESSES)], @@ -443,6 +446,7 @@ def run_class_metric_test( device=device, fragment_kwargs=fragment_kwargs, check_scriptable=check_scriptable, + check_state_dict=check_state_dict, **kwargs_update, ) diff --git a/tests/unittests/image/__init__.py b/tests/unittests/image/__init__.py index e69de29bb2d..6b5cc79247c 100644 --- a/tests/unittests/image/__init__.py +++ b/tests/unittests/image/__init__.py @@ -0,0 +1,18 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from unittests import _PATH_ROOT + +_SAMPLE_IMAGE = os.path.join(_PATH_ROOT, "_data", "image", "i01_01_5.bmp") diff --git a/tests/unittests/multimodal/test_clip_iqa.py b/tests/unittests/multimodal/test_clip_iqa.py new file mode 100644 index 00000000000..9af3d169450 --- /dev/null +++ b/tests/unittests/multimodal/test_clip_iqa.py @@ -0,0 +1,211 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from functools import partial + +import matplotlib +import matplotlib.pyplot as plt +import piq +import pytest +import torch +from PIL import Image +from torch import Tensor +from torchmetrics.functional.multimodal.clip_iqa import clip_image_quality_assessment +from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment +from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10 +from torchvision.transforms import PILToTensor + +from unittests.helpers.testers import MetricTester +from unittests.image import _SAMPLE_IMAGE +from unittests.text.helpers import skip_on_connection_issues + + +@pytest.mark.parametrize( + ("prompts", "match"), + [ + ("quality", "Argument `prompts` must be a tuple containing strings or tuples of strings"), + (("quality", 1), "Argument `prompts` must be a tuple containing strings or tuples of strings"), + ((("quality", "quality", "quality"),), "If a tuple is provided in argument `prompts`, it must be of length 2"), + (("quality", "something"), "All elements of `prompts` must be one of.*"), + ], +) +def test_raises_error_on_wrong_prompts(prompts, match): + """Test that the function raises an error if the prompts argument are not valid.""" + img = torch.rand(1, 3, 256, 256) + + with pytest.raises(ValueError, match=match): + clip_image_quality_assessment(img, prompts=prompts) + + +class CLIPTesterClass(CLIPImageQualityAssessment): + """Tester class for `CLIPImageQualityAssessment` metric overriding its update method.""" + + def update(self, preds, target): + """Override the update method to support two input arguments.""" + super().update(preds) + + def compute(self): + """Override the compute method.""" + return super().compute().sum() + + +def _clip_iqa_tester(preds, target): + """Tester function for `clip_image_quality_assessment` that supports two input arguments.""" + return clip_image_quality_assessment(preds) + + +def _reference(preds, target, reduce=False): + """Reference implementation of `CLIPImageQualityAssessment` metric.""" + res = piq.CLIPIQA()(preds).squeeze() + return res.sum() if reduce else res + + +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") +class TestCLIPIQA(MetricTester): + """Test clip iqa metric.""" + + @skip_on_connection_issues() + @pytest.mark.parametrize("ddp", [False]) + def test_clip_iqa(self, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=torch.rand(2, 1, 3, 128, 128), + target=torch.rand(2, 1, 3, 128, 128), + metric_class=CLIPTesterClass, + reference_metric=partial(_reference, reduce=True), + check_scriptable=False, + check_state_dict=False, + ) + + @skip_on_connection_issues() + @pytest.mark.parametrize("shapes", [(2, 1, 3, 256, 256), (2, 2, 3, 256, 256), (2, 2, 3, 128, 128)]) + def test_clip_iqa_functional(self, shapes): + """Test functional implementation of metric.""" + img = torch.rand(shapes) + self.run_functional_metric_test( + preds=img, + target=img, + metric_functional=_clip_iqa_tester, + reference_metric=_reference, + ) + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") +@pytest.mark.skipif(not os.path.isfile(_SAMPLE_IMAGE), reason="test image not found") +def test_for_correctness_sample_images(): + """Compare the output of the function with the output of the reference implementation.""" + img = Image.open(_SAMPLE_IMAGE) + img = PILToTensor()(img) + img = img.float()[None] + + reference = piq.CLIPIQA(data_range=255) + reference_score = reference(img) + + result = clip_image_quality_assessment(img, data_range=255) + assert torch.allclose(reference_score, result) + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") +@pytest.mark.parametrize( + "model", + [ + "openai/clip-vit-base-patch16", + "openai/clip-vit-base-patch32", + "openai/clip-vit-large-patch14-336", + "openai/clip-vit-large-patch14", + ], +) +@pytest.mark.skipif(not os.path.isfile(_SAMPLE_IMAGE), reason="test image not found") +def test_other_models(model): + """Test that the function works with other models.""" + img = Image.open(_SAMPLE_IMAGE) + img = PILToTensor()(img) + img = img.float()[None] + + reference = piq.CLIPIQA(data_range=255) + reference_score = reference(img) + + result = clip_image_quality_assessment(img, data_range=255, model_name_or_path=model) + # allow large difference between scores due to different models, but still in the same ballpark + assert reference_score - 0.2 < result < reference_score + 0.2 + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") +@pytest.mark.parametrize( + "prompts", + [ + ("quality",), + ("brightness",), + ("noisiness",), + ("colorfullness",), + ("sharpness",), + ("contrast",), + ("complexity",), + ("natural",), + ("happy",), + ("scary",), + ("new",), + ("warm",), + ("real",), + ("beutiful",), + ("lonely",), + ("relaxing",), + # some random combinations + ("quality", "brightness"), + ("quality", "brightness", "noisiness"), + ("quality", "brightness", "noisiness", "colorfullness"), + # custom prompts + (("Photo of a cat", "Photo of a dog"),), + (("Photo of a cat", "Photo of a dog"), "quality"), + (("Photo of a cat", "Photo of a dog"), "quality", ("Colorful photo", "Black and white photo")), + ], +) +@pytest.mark.skipif(not os.path.isfile(_SAMPLE_IMAGE), reason="test image not found") +def test_prompt(prompts): + """Test that the function works with other prompts, and that output is as expected.""" + img = Image.open(_SAMPLE_IMAGE) + img = PILToTensor()(img) + img = img.float()[None] + + result = clip_image_quality_assessment(img, data_range=255, prompts=prompts) + if len(prompts) == 1: + assert isinstance(result, Tensor) + assert 0 < result < 1 + else: + assert isinstance(result, dict) + for i, (k, v) in enumerate(result.items()): + assert isinstance(k, str) + assert k == prompts[i] if isinstance(prompts[i], str) else "user_defined_" in k + assert isinstance(v, Tensor) + assert 0 < v < 1 + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") +def test_plot_method(): + """Test the plot method of CLIPScore seperately in this file due to the skipping conditions.""" + metric = CLIPImageQualityAssessment() + metric.update(torch.rand(1, 3, 256, 256)) + fig, ax = metric.plot() + assert isinstance(fig, plt.Figure) + assert isinstance(ax, matplotlib.axes.Axes)