Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric: CLIP IQA #1931

Merged
merged 51 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
9197f0f
starting point
SkafteNicki Jul 20, 2023
0a70ff7
links
SkafteNicki Jul 20, 2023
dcd713f
small correction to clip score
SkafteNicki Jul 20, 2023
6256ba2
implementation update
SkafteNicki Jul 20, 2023
d94ab1e
some testing
SkafteNicki Jul 21, 2023
5294fba
fix implementation
SkafteNicki Jul 21, 2023
d8d09e2
update
SkafteNicki Jul 25, 2023
c39cc68
Merge branch 'master' into newmetric/clip_iqa
SkafteNicki Jul 28, 2023
e53da97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2023
6a59e89
changelog
SkafteNicki Jul 28, 2023
86af666
merge master
SkafteNicki Jul 28, 2023
7e60e4c
fix ruff
SkafteNicki Jul 28, 2023
d566614
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2023
b15ca90
working module
SkafteNicki Jul 28, 2023
f39c554
improve testing
SkafteNicki Jul 29, 2023
baf2b0a
merge master
SkafteNicki Jul 29, 2023
d1880b3
change transformers check
SkafteNicki Jul 31, 2023
04ef55a
fix doc mistake
SkafteNicki Jul 31, 2023
23e3025
add plotting functionality
SkafteNicki Jul 31, 2023
659607f
fix conditional imports
SkafteNicki Jul 31, 2023
06ad2b9
add requirement
SkafteNicki Jul 31, 2023
aa4a2aa
fix typing issues
SkafteNicki Jul 31, 2023
5238d83
change requirement and checks on import
SkafteNicki Jul 31, 2023
5edafa6
fix
SkafteNicki Jul 31, 2023
9cc8604
Merge branch 'master' into newmetric/clip_iqa
Borda Aug 1, 2023
0b5e455
Merge branch 'master' into newmetric/clip_iqa
SkafteNicki Aug 1, 2023
e93a694
add link ref
SkafteNicki Aug 1, 2023
3e1d746
Merge branch 'master' into newmetric/clip_iqa
SkafteNicki Aug 1, 2023
ca4ecb4
skip on missing
SkafteNicki Aug 1, 2023
518c5bb
F401
SkafteNicki Aug 1, 2023
0dbdd2e
skip on older versions
SkafteNicki Aug 1, 2023
67ad04e
Merge branch 'master' into newmetric/clip_iqa
SkafteNicki Aug 3, 2023
4a48b8f
Merge branch 'master' into newmetric/clip_iqa
SkafteNicki Aug 7, 2023
6e3957b
Merge branch 'master' into newmetric/clip_iqa
Borda Aug 7, 2023
f0c85a3
Merge branch 'master' into newmetric/clip_iqa
Borda Aug 7, 2023
8255256
Merge branch 'master' into newmetric/clip_iqa
Borda Aug 8, 2023
65665d8
Merge branch 'master' into newmetric/clip_iqa
SkafteNicki Aug 8, 2023
9fea319
fix mistake
SkafteNicki Aug 8, 2023
39654ce
Merge branch 'master' into newmetric/clip_iqa
SkafteNicki Aug 8, 2023
c0dfc97
Merge branch 'master' into newmetric/clip_iqa
SkafteNicki Aug 9, 2023
3b2d10e
another skip
SkafteNicki Aug 9, 2023
d68944f
Merge branch 'master' into newmetric/clip_iqa
SkafteNicki Aug 9, 2023
3ebdc23
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2023
5e0797c
Merge branch 'master' into newmetric/clip_iqa
mergify[bot] Aug 18, 2023
9dc4327
move files
SkafteNicki Aug 18, 2023
df6fbd2
small corrections after moving files
SkafteNicki Aug 18, 2023
68d6fd3
Uodate cmd
Borda Aug 19, 2023
6396b67
Apply suggestions from code review
Borda Aug 19, 2023
04fe382
Merge branch 'master' into newmetric/clip_iqa
SkafteNicki Aug 21, 2023
5ebe5e9
improve based on suggestions
SkafteNicki Aug 21, 2023
71c5dac
change import in doctests
SkafteNicki Aug 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961))


- Added `CLIPImageQualityAssessment` to image package ([#1931](https://github.com/Lightning-AI/torchmetrics/pull/1931))


- Added warning to `MeanAveragePrecision` if too many detections are observed ([#1978](https://github.com/Lightning-AI/torchmetrics/pull/1978))


Expand Down
24 changes: 24 additions & 0 deletions docs/source/image/clip_iqa.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
.. customcarditem::
:header: CLIP IQA
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

########################################
CLIP Image Quality Assessment (CLIP-IQA)
########################################

Module Interface
________________

.. autoclass:: torchmetrics.image.CLIPImageQualityAssessment
:noindex:
:exclude-members: update, compute


Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.clip_image_quality_assessment
:noindex:
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,5 @@
.. _Seamless Scene Segmentation paper: https://arxiv.org/abs/1905.01220
.. _Fleiss kappa: https://en.wikipedia.org/wiki/Fleiss%27_kappa
.. _VIF: https://ieeexplore.ieee.org/abstract/document/1576816
.. _CLIP-IQA: https://arxiv.org/abs/2207.12396
.. _CLIP: https://arxiv.org/abs/2103.00020
1 change: 1 addition & 0 deletions requirements/image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ scipy >1.0.0, <1.11.0
torchvision >=0.8, <=0.15.2
torch-fidelity <=0.3.0
lpips <=0.1.4
piq <=0.8.0
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.image.clip_iqa import clip_image_quality_assessment
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
Expand All @@ -28,6 +29,7 @@
from torchmetrics.functional.image.vif import visual_information_fidelity

__all__ = [
"clip_image_quality_assessment",
"spectral_distortion_index",
"error_relative_global_dimensionless_synthesis",
"image_gradients",
Expand Down
326 changes: 326 additions & 0 deletions src/torchmetrics/functional/image/clip_iqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Literal, Tuple, Union

import torch
from torch import Tensor

from torchmetrics.functional.multimodal.clip_score import _get_clip_model_and_processor
from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10

if _TRANSFORMERS_GREATER_EQUAL_4_10:
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor

def _download_clip() -> None:
_CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
_CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip):
__doctest_skip__ = ["clip_score"]

else:
__doctest_skip__ = ["clip_image_quality_assessment"]
_CLIPModel = None
_CLIPProcessor = None

if not _PIQ_GREATER_EQUAL_0_8:
__doctest_skip__ = ["clip_image_quality_assessment"]

_PROMPTS: Dict[str, Tuple[str, str]] = {
"quality": ("Good photo.", "Bad photo."),
"brightness": ("Bright photo.", "Dark photo."),
"noisiness": ("Clean photo.", "Noisy photo."),
"colorfullness": ("Colorful photo.", "Dull photo."),
"sharpness": ("Sharp photo.", "Blurry photo."),
"contrast": ("High contrast photo.", "Low contrast photo."),
"complexity": ("Complex photo.", "Simple photo."),
"natural": ("Natural photo.", "Synthetic photo."),
"happy": ("Happy photo.", "Sad photo."),
"scary": ("Scary photo.", "Peaceful photo."),
"new": ("New photo.", "Old photo."),
"warm": ("Warm photo.", "Cold photo."),
"real": ("Real photo.", "Abstract photo."),
"beutiful": ("Beautiful photo.", "Ugly photo."),
"lonely": ("Lonely photo.", "Sociable photo."),
"relaxing": ("Relaxing photo.", "Stressful photo."),
}


def _get_clip_iqa_model_and_processor(
model_name_or_path: Literal[
"clip_iqa",
"openai/clip-vit-base-patch16",
"openai/clip-vit-base-patch32",
"openai/clip-vit-large-patch14-336",
"openai/clip-vit-large-patch14",
]
) -> Tuple[_CLIPModel, _CLIPProcessor]:
"""Extract the CLIP model and processor from the model name or path."""
if model_name_or_path == "clip_iqa":
if not _PIQ_GREATER_EQUAL_0_8:
raise ValueError(
"For metric `clip_iqa` to work with argument `model_name_or_path` set to default value `'clip_iqa'`"
", package `piq` version v0.8.0 or later must be installed. Either install with `pip install piq` or"
"`pip install torchmetrics[image]`"
)

import piq

model = piq.clip_iqa.clip.load().eval()
# any model checkpoint can be used here because the tokenizer is the same for all
processor = _CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
return model, processor
return _get_clip_model_and_processor(model_name_or_path)


def _clip_iqa_format_prompts(prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",)) -> Tuple[List[str], List[str]]:
"""Converts the provided keywords into a list of prompts for the model to calculate the achor vectors.

Args:
prompts: A string, list of strings or tuple of strings. If a string is provided, it must be one of the
availble prompts. If a list of strings is provided, all strings must be one of the availble prompts.
If a tuple of strings is provided, it must be of length 2 and the first string must be a positive prompt
and the second string must be a negative prompt.

Returns:
Tuple containing a list of prompts and a list of the names of the prompts. The first list is double the length
of the second list.

Examples::

>>> # single prompt
>>> _clip_iqa_format_prompts(("quality",))
(['Good photo.', 'Bad photo.'], ['quality'])
>>> # multiple prompts
>>> _clip_iqa_format_prompts(("quality", "brightness"))
(['Good photo.', 'Bad photo.', 'Bright photo.', 'Dark photo.'], ['quality', 'brightness'])
>>> # Custom prompts
>>> _clip_iqa_format_prompts(("quality", ("Super good photo.", "Super bad photo.")))
(['Good photo.', 'Bad photo.', 'Super good photo.', 'Super bad photo.'], ['quality', 'user_defined_0'])

"""
if not isinstance(prompts, tuple):
raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings")

prompts_names: List[str] = []
prompts_list: List[str] = []
count = 0
for p in prompts:
if not isinstance(p, (str, tuple)):
raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings")
if isinstance(p, str):
if p not in _PROMPTS:
raise ValueError(
f"All elements of `prompts` must be one of {_PROMPTS.keys()} if not custom tuple promts, got {p}."
)
prompts_names.append(p)
prompts_list.extend(_PROMPTS[p])
if isinstance(p, tuple) and len(p) != 2:
raise ValueError("If a tuple is provided in argument `prompts`, it must be of length 2")
if isinstance(p, tuple):
prompts_names.append(f"user_defined_{count}")
prompts_list.extend(p)
count += 1

return prompts_list, prompts_names


def _clip_iqa_get_anchor_vectors(
model_name_or_path: str,
model: _CLIPModel,
processor: _CLIPProcessor,
prompts_list: List[str],
device: Union[str, torch.device],
) -> Tensor:
"""Calculates the anchor vectors for the CLIP IQA metric.

Args:
model_name_or_path: string indicating the version of the CLIP model to use.
model: The CLIP model
processor: The CLIP processor
prompts_list: A list of prompts
device: The device to use for the calculation

"""
if model_name_or_path == "clip_iqa":
text_processed = processor(text=prompts_list)
anchors_text = torch.zeros(
len(prompts_list), processor.tokenizer.model_max_length, dtype=torch.long, device=device
)
for i, tp in enumerate(text_processed["input_ids"]):
anchors_text[i, : len(tp)] = torch.tensor(tp, dtype=torch.long, device=device)

anchors = model.encode_text(anchors_text).float()
else:
text_processed = processor(text=prompts_list, return_tensors="pt", padding=True)
anchors = model.get_text_features(
text_processed["input_ids"].to(device), text_processed["attention_mask"].to(device)
)
return anchors / anchors.norm(p=2, dim=-1, keepdim=True)


def _clip_iqa_update(
model_name_or_path: str,
images: Tensor,
model: _CLIPModel,
processor: _CLIPProcessor,
data_range: Union[int, float],
device: Union[str, torch.device],
) -> Tensor:
images = images / float(data_range)
"""Update function for CLIP IQA."""
if model_name_or_path == "clip_iqa":
default_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1, 3, 1, 1)
default_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1, 3, 1, 1)
images = (images - default_mean) / default_std
img_features = model.encode_image(images.float(), pos_embedding=False).float()
else:
processed_input = processor(images=[i.cpu() for i in images], return_tensors="pt", padding=True)
img_features = model.get_image_features(processed_input["pixel_values"].to(device))
return img_features / img_features.norm(p=2, dim=-1, keepdim=True)


def _clip_iqa_compute(
img_features: Tensor,
anchors: Tensor,
prompts_names: List[str],
format_as_dict: bool = True,
) -> Union[Tensor, Dict[str, Tensor]]:
"""Final computation of CLIP IQA."""
logits_per_image = 100 * img_features @ anchors.t()
probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(-1)[:, :, 0]
if len(prompts_names) == 1:
return probs.squeeze()
if format_as_dict:
return {p: probs[:, i] for i, p in enumerate(prompts_names)}
return probs


def clip_image_quality_assessment(
images: Tensor,
model_name_or_path: Literal[
"clip_iqa",
"openai/clip-vit-base-patch16",
"openai/clip-vit-base-patch32",
"openai/clip-vit-large-patch14-336",
"openai/clip-vit-large-patch14",
] = "clip_iqa",
data_range: Union[int, float] = 1.0,
prompts: Tuple[Union[str, Tuple[str, str]]] = ("quality",),
) -> Union[Tensor, Dict[str, Tensor]]:
"""Calculates `CLIP-IQA`_, that can be used to measure the visual content of images.

The metric is based on the `CLIP`_ model, which is a neural network trained on a variety of (image, text) pairs to
be able to generate a vector representation of the image and the text that is similar if the image and text are
semantically similar.

The metric works by calculating the cosine similarity between user provided images and pre-defined promts. The
promts always comes in pairs of "positive" and "negative" such as "Good photo." and "Bad photo.". By calculating
the similartity between image embeddings and both the "positive" and "negative" prompt, the metric can determine
which prompt the image is more similar to. The metric then returns the probability that the image is more similar
to the first prompt than the second prompt.

Build in promts are:
* quality: "Good photo." vs "Bad photo."
* brightness: "Bright photo." vs "Dark photo."
* noisiness: "Clean photo." vs "Noisy photo."
* colorfullness: "Colorful photo." vs "Dull photo."
* sharpness: "Sharp photo." vs "Blurry photo."
* contrast: "High contrast photo." vs "Low contrast photo."
* complexity: "Complex photo." vs "Simple photo."
* natural: "Natural photo." vs "Synthetic photo."
* happy: "Happy photo." vs "Sad photo."
* scary: "Scary photo." vs "Peaceful photo."
* new: "New photo." vs "Old photo."
* warm: "Warm photo." vs "Cold photo."
* real: "Real photo." vs "Abstract photo."
* beutiful: "Beautiful photo." vs "Ugly photo."
* lonely: "Lonely photo." vs "Sociable photo."
* relaxing: "Relaxing photo." vs "Stressful photo."

Args:
images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors
model_name_or_path: string indicating the version of the CLIP model to use. By default this argument is set to
``clip_iqa`` which corresponds to the model used in the original paper. Other availble models are
`"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"`
and `"openai/clip-vit-large-patch14"`
data_range: The maximum value of the input tensor. For example, if the input images are in range [0, 255],
data_range should be 255. The images are normalized by this value.
prompts: A string, tuple of strings or nested tuple of strings. If a string is provided, it must be one of the
availble prompts. If a tuple of strings is provided, all strings must be one of the availble prompts.
If a nested tuple of strings is provided, it must be of length 2 and the first string must be a positive
prompt and the second string must be a negative prompt.

.. note:: If using the default `clip_iqa` model, the package `piq` must be installed. Either install with
`pip install piq` or `pip install torchmetrics[image]`.

Returns:
A tensor of shape ``(N,)`` if a single promts is provided. If a list of promts is provided, a dictionary of
with the promts as keys and tensors of shape ``(N,)`` as values.

Raises:
ModuleNotFoundError:
If transformers package is not installed or version is lower than 4.10.0
ValueError:
If not all images have format [C, H, W]
ValueError:
If promts is a tuple and it is not of length 2
ValueError:
If promts is a string and it is not one of the available promts
ValueError:
If promts is a list of strings and not all strings are one of the available promts

Example::
Single promt:

>>> from torchmetrics.functional.image import clip_image_quality_assessment
>>> import torch
>>> _ = torch.manual_seed(42)
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
>>> clip_image_quality_assessment(imgs, prompts=("quality",))
tensor([0.8894, 0.8902])

Example::
Multiple promts:

>>> from torchmetrics.functional.image import clip_image_quality_assessment
>>> import torch
>>> _ = torch.manual_seed(42)
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
>>> clip_image_quality_assessment(imgs, prompts=("quality", "brightness"))
{'quality': tensor([0.8894, 0.8902]), 'brightness': tensor([0.5507, 0.5208])}

Example::
Custom promts. Must always be a tuple of length 2, with a positive and negative prompt.

>>> from torchmetrics.functional.image import clip_image_quality_assessment
>>> import torch
>>> _ = torch.manual_seed(42)
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
>>> clip_image_quality_assessment(imgs, prompts=(("Super good photo.", "Super bad photo."), "brightness"))
{'user_defined_0': tensor([0.9652, 0.9629]), 'brightness': tensor([0.5507, 0.5208])}

"""
prompts_list, prompts_names = _clip_iqa_format_prompts(prompts)

model, processor = _get_clip_iqa_model_and_processor(model_name_or_path)
device = images.device
model = model.to(device)

with torch.inference_mode():
anchors = _clip_iqa_get_anchor_vectors(model_name_or_path, model, processor, prompts_list, device)
img_features = _clip_iqa_update(model_name_or_path, images, model, processor, data_range, device)
return _clip_iqa_compute(img_features, anchors, prompts_names)
Loading