Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ImageToText (caption generator) #3859

Merged
merged 15 commits into from
Jan 23, 2023
2 changes: 2 additions & 0 deletions haystack/nodes/image_to_text/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from haystack.nodes.image_to_text.base import BaseImageToText
from haystack.nodes.image_to_text.transformers import TransformersImageToText
53 changes: 53 additions & 0 deletions haystack/nodes/image_to_text/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import List, Optional

from abc import abstractmethod

from haystack.schema import Document
from haystack.nodes.base import BaseComponent


class BaseImageToText(BaseComponent):
"""
Abstract class for ImageToText
"""

outgoing_edges = 1

@abstractmethod
def generate_captions(
self, image_file_paths: List[str], generation_kwargs: Optional[dict] = None, batch_size: Optional[int] = None
) -> List[Document]:
"""
Abstract method for generating captions.

:param image_file_paths: Paths of the images
:param generation_kwargs: Dictionary containing arguments for the generate method of the Hugging Face model.
See https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
:param batch_size: Number of images to process at a time.
:return: List of Documents. Document.content is the caption. Document.meta["image_file_path"] contains the image file path.
"""
pass

def run(self, file_paths: Optional[List[str]] = None, documents: Optional[List[Document]] = None): # type: ignore

if file_paths is None and documents is None:
raise ValueError("You must either specify documents or image file_paths to process.")

image_file_paths = []
if file_paths is not None:
image_file_paths.extend(file_paths)
if documents is not None:
if any((doc.content_type != "image" for doc in documents)):
raise ValueError("The ImageToText node only supports image documents.")
image_file_paths.extend([doc.content for doc in documents])

results: dict = {}
results["documents"] = self.generate_captions(image_file_paths=image_file_paths)

return results, "output_1"

def run_batch( # type: ignore
self, file_paths: Optional[List[str]] = None, documents: Optional[List[Document]] = None
):

return self.run(file_paths=file_paths, documents=documents)
141 changes: 141 additions & 0 deletions haystack/nodes/image_to_text/transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from typing import List, Optional, Union

import logging

import torch
from tqdm.auto import tqdm
from transformers import pipeline

from haystack.schema import Document
from haystack.nodes.image_to_text.base import BaseImageToText
from haystack.modeling.utils import initialize_device_settings
from haystack.utils.torch_utils import ListDataset

logger = logging.getLogger(__name__)


class TransformersImageToText(BaseImageToText):
"""
Transformer based model to generate captions for images using the HuggingFace's transformers framework

See the up-to-date list of available models on
`huggingface.co/models <https://huggingface.co/models?pipeline_tag=image-to-text>`__

**Example**

```python
image_file_paths = ["/path/to/images/apple.jpg",
"/path/to/images/cat.jpg", ]

# Generate captions
documents = image_to_text.generate_captions(image_file_paths=image_file_paths)

# Show results (List of Documents, containing caption and image file_path)
print(documents)

[
{
"content": "a red apple is sitting on a pile of hay",
...
"meta": {
"image_path": "/path/to/images/apple.jpg",
...
},
...
},
...
]
```
"""

def __init__(
self,
model_name_or_path: str = "nlpconnect/vit-gpt2-image-captioning",
model_version: Optional[str] = None,
generation_kwargs: Optional[dict] = None,
use_gpu: bool = True,
batch_size: int = 16,
progress_bar: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
devices: Optional[List[Union[str, torch.device]]] = None,
):
"""
Load an Image To Text model from Transformers.
See the up-to-date list of available models at
https://huggingface.co/models?pipeline_tag=image-to-text

:param model_name_or_path: Directory of a saved model or the name of a public model.
See https://huggingface.co/models?pipeline_tag=image-to-text for full list of available models.
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param generation_kwargs: Dictionary containing arguments for the generate method of the Hugging Face model.
See https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
:param use_gpu: Whether to use GPU (if available).
:param batch_size: Number of documents to process at a time.
:param progress_bar: Whether to show a progress bar.
:param use_auth_token: The API token used to download private models from Huggingface.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
Additional information can be found here
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
:param devices: List of torch devices (e.g. cuda, cpu, mps) to limit inference to specific devices.
A list containing torch device objects and/or strings is supported (For example
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
parameter is not used and a single cpu device is used for inference.
"""
super().__init__()

self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
if len(self.devices) > 1:
logger.warning(
"Multiple devices are not supported in %s inference, using the first device %s.",
self.__class__.__name__,
self.devices[0],
)

self.model = pipeline(
task="image-to-text",
model=model_name_or_path,
revision=model_version,
device=self.devices[0],
use_auth_token=use_auth_token,
)
self.generation_kwargs = generation_kwargs
self.batch_size = batch_size
self.progress_bar = progress_bar

def generate_captions(
self, image_file_paths: List[str], generation_kwargs: Optional[dict] = None, batch_size: Optional[int] = None
) -> List[Document]:
"""
Generate captions for provided image files

:param image_file_paths: Paths of the images
:param generation_kwargs: Dictionary containing arguments for the generate method of the Hugging Face model.
See https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate
:param batch_size: Number of images to process at a time.
:return: List of Documents. Document.content is the caption. Document.meta["image_file_path"] contains the image file path.
"""
generation_kwargs = generation_kwargs or self.generation_kwargs
batch_size = batch_size or self.batch_size

if len(image_file_paths) == 0:
raise AttributeError("ImageToText needs at least one filepath to produce a caption.")
anakin87 marked this conversation as resolved.
Show resolved Hide resolved

images_dataset = ListDataset(image_file_paths)

captions: List[str] = []

for captions_batch in tqdm(
self.model(images_dataset, generate_kwargs=generation_kwargs, batch_size=batch_size),
disable=not self.progress_bar,
total=len(images_dataset),
desc="Generating captions",
):
captions.append("".join([el["generated_text"] for el in captions_batch]).strip())

result: List[Document] = []
for caption, image_file_path in zip(captions, image_file_paths):
document = Document(content=caption, content_type="text", meta={"image_path": image_file_path})
result.append(document)

return result
62 changes: 62 additions & 0 deletions test/nodes/test_image_to_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import pytest

from PIL import UnidentifiedImageError

from haystack import Document
from haystack.nodes.image_to_text.transformers import TransformersImageToText
from haystack.nodes.image_to_text.base import BaseImageToText

from ..conftest import SAMPLES_PATH


IMAGE_FILE_NAMES = ["apple.jpg", "car.jpg", "cat.jpg", "galaxy.jpg", "paris.jpg"]
IMAGE_FILE_PATHS = [os.path.join(SAMPLES_PATH, "images", file_name) for file_name in IMAGE_FILE_NAMES]
IMAGE_DOCS = [Document(content=image_path, content_type="image") for image_path in IMAGE_FILE_PATHS]
INVALID_IMAGE_FILE_PATH = str(SAMPLES_PATH / "markdown" / "sample.md")

EXPECTED_CAPTIONS = [
"a red apple is sitting on a pile of hay",
"a white car parked in a parking lot",
"a cat laying in the grass",
"a blurry photo of a blurry shot of a black object",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this model has no concept of galaxies 😆

"a city with a large building and a clock tower",
]


@pytest.fixture
def image_to_text():
return TransformersImageToText(
model_name_or_path="nlpconnect/vit-gpt2-image-captioning",
devices=["cpu"],
generation_kwargs={"max_new_tokens": 50},
)


@pytest.mark.integration
def test_image_to_text(image_to_text):
assert isinstance(image_to_text, BaseImageToText)

results_0 = image_to_text.run(file_paths=IMAGE_FILE_PATHS)
image_paths_0 = [doc.meta["image_path"] for doc in results_0[0]["documents"]]
assert image_paths_0 == IMAGE_FILE_PATHS
generated_captions_0 = [doc.content for doc in results_0[0]["documents"]]
assert generated_captions_0 == EXPECTED_CAPTIONS

results_1 = image_to_text.run(documents=IMAGE_DOCS)
image_paths_1 = [doc.meta["image_path"] for doc in results_1[0]["documents"]]
assert image_paths_1 == IMAGE_FILE_PATHS
generated_captions_1 = [doc.content for doc in results_1[0]["documents"]]
assert generated_captions_1 == EXPECTED_CAPTIONS

results_2 = image_to_text.run(file_paths=IMAGE_FILE_PATHS[:3], documents=IMAGE_DOCS[3:])
image_paths_2 = [doc.meta["image_path"] for doc in results_2[0]["documents"]]
assert image_paths_2 == IMAGE_FILE_PATHS
generated_captions_2 = [doc.content for doc in results_2[0]["documents"]]
assert generated_captions_2 == EXPECTED_CAPTIONS
anakin87 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.integration
def test_image_to_text_invalid_image(image_to_text):
with pytest.raises(UnidentifiedImageError, match="cannot identify image file"):
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
image_to_text.run(file_paths=[INVALID_IMAGE_FILE_PATH])