-
Notifications
You must be signed in to change notification settings - Fork 692
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🚀 Add VLM based Anomaly Model (#2344)
* [Draft] Llm on (#2165) * Add TaskType Explanation Signed-off-by: Bepitic <[email protected]> * Add llm model Signed-off-by: Bepitic <[email protected]> * add ollama Signed-off-by: Bepitic <[email protected]> * better description for descr in title Signed-off-by: Bepitic <[email protected]> * add text of llm into imageResult visualization * add text of llm into imageResult visualization Signed-off-by: Bepitic <[email protected]> * latest changes Signed-off-by: Bepitic <[email protected]> * add wip llava/llava_next Signed-off-by: Bepitic <[email protected]> * add init Signed-off-by: Bepitic <[email protected]> * add text of llm into imageResult visualization Signed-off-by: Bepitic <[email protected]> * latest changes Signed-off-by: Bepitic <[email protected]> * upd Lint Signed-off-by: Bepitic <[email protected]> * fix visualization with description Signed-off-by: Bepitic <[email protected]> * show the images every batch Signed-off-by: Bepitic <[email protected]> * fix docstring and error management Signed-off-by: Bepitic <[email protected]> * Add compatibility for TaskType.EXPLANATION. Signed-off-by: Bepitic <[email protected]> * Remove, show in the engine-Visualization. * fix visualization and llm openai multishot. * fix Circular import problem * Add HugginFace To LLavaNext Signed-off-by: Bepitic <[email protected]> --------- Signed-off-by: Bepitic <[email protected]> * 🔨 Scaffold for refactor (#2340) * initial scafold Signed-off-by: Ashwin Vaidya <[email protected]> * Apply PR comments Signed-off-by: Ashwin Vaidya <[email protected]> * rename dir Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]> * Add ChatGPT (#2341) * initial scafold Signed-off-by: Ashwin Vaidya <[email protected]> * Apply PR comments Signed-off-by: Ashwin Vaidya <[email protected]> * rename dir Signed-off-by: Ashwin Vaidya <[email protected]> * delete llm_ollama Signed-off-by: Ashwin Vaidya <[email protected]> * Add ChatGPT Signed-off-by: Ashwin Vaidya <[email protected]> * Add ChatGPT Signed-off-by: Ashwin Vaidya <[email protected]> * Remove LLM model Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]> * Add Huggingface (#2343) * initial scafold Signed-off-by: Ashwin Vaidya <[email protected]> * Apply PR comments Signed-off-by: Ashwin Vaidya <[email protected]> * rename dir Signed-off-by: Ashwin Vaidya <[email protected]> * delete llm_ollama Signed-off-by: Ashwin Vaidya <[email protected]> * Add ChatGPT Signed-off-by: Ashwin Vaidya <[email protected]> * Add ChatGPT Signed-off-by: Ashwin Vaidya <[email protected]> * Remove LLM model Signed-off-by: Ashwin Vaidya <[email protected]> * Add transformers Signed-off-by: Ashwin Vaidya <[email protected]> * Remove llava Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]> * 🔨 Minor Refactor (#2345) Refactor Signed-off-by: Ashwin Vaidya <[email protected]> * undo changes Signed-off-by: Ashwin Vaidya <[email protected]> * undo changes Signed-off-by: Ashwin Vaidya <[email protected]> * undo changes to image.py Signed-off-by: Ashwin Vaidya <[email protected]> * Add explanation visualizer (#2351) * Add explanation visualizer Signed-off-by: Ashwin Vaidya <[email protected]> * bug-fix Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]> * 🔨 Allow setting API keys from env (#2353) Allow setting API keys from env Signed-off-by: Ashwin Vaidya <[email protected]> * 🧪 Add tests (#2355) * Add tests Signed-off-by: Ashwin Vaidya <[email protected]> * remove explanation task type Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]> * minor fixes Signed-off-by: Ashwin Vaidya <[email protected]> * Update changelog Signed-off-by: Ashwin Vaidya <[email protected]> * Fix tests Signed-off-by: Ashwin Vaidya <[email protected]> * Address PR comments Signed-off-by: Ashwin Vaidya <[email protected]> * update name Signed-off-by: Ashwin Vaidya <[email protected]> * Update src/anomalib/models/image/vlm_ad/lightning_model.py Co-authored-by: Samet Akcay <[email protected]> * update name Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Bepitic <[email protected]> Signed-off-by: Ashwin Vaidya <[email protected]> Co-authored-by: Paco <[email protected]> Co-authored-by: Samet Akcay <[email protected]>
- Loading branch information
1 parent
6eeb7f6
commit 3a403ae
Showing
17 changed files
with
603 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Visual Anomaly Model.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .lightning_model import VlmAd | ||
|
||
__all__ = ["VlmAd"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""VLM backends.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .base import Backend | ||
from .chat_gpt import ChatGPT | ||
from .huggingface import Huggingface | ||
from .ollama import Ollama | ||
|
||
__all__ = ["Backend", "ChatGPT", "Huggingface", "Ollama"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
"""Base backend.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
|
||
from anomalib.models.image.vlm_ad.utils import Prompt | ||
|
||
|
||
class Backend(ABC): | ||
"""Base backend.""" | ||
|
||
@abstractmethod | ||
def __init__(self, model_name: str) -> None: | ||
"""Initialize the backend.""" | ||
|
||
@abstractmethod | ||
def add_reference_images(self, image: str | Path) -> None: | ||
"""Add reference images for k-shot.""" | ||
|
||
@abstractmethod | ||
def predict(self, image: str | Path, prompt: Prompt) -> str: | ||
"""Predict the anomaly label.""" | ||
|
||
@property | ||
@abstractmethod | ||
def num_reference_images(self) -> int: | ||
"""Get the number of reference images.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
"""ChatGPT backend.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import base64 | ||
import logging | ||
import os | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
|
||
from dotenv import load_dotenv | ||
from lightning_utilities.core.imports import package_available | ||
|
||
from anomalib.models.image.vlm_ad.utils import Prompt | ||
|
||
from .base import Backend | ||
|
||
if package_available("openai"): | ||
from openai import OpenAI | ||
else: | ||
OpenAI = None | ||
|
||
if TYPE_CHECKING: | ||
from openai.types.chat import ChatCompletion | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ChatGPT(Backend): | ||
"""ChatGPT backend.""" | ||
|
||
def __init__(self, model_name: str, api_key: str | None = None) -> None: | ||
"""Initialize the ChatGPT backend.""" | ||
self._ref_images_encoded: list[str] = [] | ||
self.model_name: str = model_name | ||
self._client: OpenAI | None = None | ||
self.api_key = self._get_api_key(api_key) | ||
|
||
@property | ||
def client(self) -> OpenAI: | ||
"""Get the OpenAI client.""" | ||
if OpenAI is None: | ||
msg = "OpenAI is not installed. Please install it to use ChatGPT backend." | ||
raise ImportError(msg) | ||
if self._client is None: | ||
self._client = OpenAI(api_key=self.api_key) | ||
return self._client | ||
|
||
def add_reference_images(self, image: str | Path) -> None: | ||
"""Add reference images for k-shot.""" | ||
self._ref_images_encoded.append(self._encode_image_to_url(image)) | ||
|
||
@property | ||
def num_reference_images(self) -> int: | ||
"""Get the number of reference images.""" | ||
return len(self._ref_images_encoded) | ||
|
||
def predict(self, image: str | Path, prompt: Prompt) -> str: | ||
"""Predict the anomaly label.""" | ||
image_encoded = self._encode_image_to_url(image) | ||
messages = [] | ||
|
||
# few-shot | ||
if len(self._ref_images_encoded) > 0: | ||
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images_encoded)) | ||
|
||
messages.append(self._generate_message(content=prompt.predict, images=[image_encoded])) | ||
|
||
response: ChatCompletion = self.client.chat.completions.create(messages=messages, model=self.model_name) | ||
return response.choices[0].message.content | ||
|
||
@staticmethod | ||
def _generate_message(content: str, images: list[str] | None) -> dict: | ||
"""Generate a message.""" | ||
message: dict[str, list[dict] | str] = {"role": "user"} | ||
if images is not None: | ||
_content: list[dict[str, str | dict]] = [{"type": "text", "text": content}] | ||
_content.extend([{"type": "image_url", "image_url": {"url": image}} for image in images]) | ||
message["content"] = _content | ||
else: | ||
message["content"] = content | ||
return message | ||
|
||
def _encode_image_to_url(self, image: str | Path) -> str: | ||
"""Encode the image to base64 and embed in url string.""" | ||
image_path = Path(image) | ||
extension = image_path.suffix | ||
base64_encoded = self._encode_image_to_base_64(image_path) | ||
return f"data:image/{extension};base64,{base64_encoded}" | ||
|
||
@staticmethod | ||
def _encode_image_to_base_64(image: str | Path) -> str: | ||
"""Encode the image to base64.""" | ||
image = Path(image) | ||
return base64.b64encode(image.read_bytes()).decode("utf-8") | ||
|
||
def _get_api_key(self, api_key: str | None = None) -> str: | ||
if api_key is None: | ||
load_dotenv() | ||
api_key = os.getenv("OPENAI_API_KEY") | ||
if api_key is None: | ||
msg = ( | ||
f"OpenAI API key must be provided to use {self.model_name}." | ||
" Please provide the API key in the constructor, or set the OPENAI_API_KEY environment variable" | ||
" or in a `.env` file." | ||
) | ||
raise ValueError(msg) | ||
return api_key |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
"""Huggingface backend.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import logging | ||
from pathlib import Path | ||
|
||
from lightning_utilities.core.imports import package_available | ||
from PIL import Image | ||
from transformers.modeling_utils import PreTrainedModel | ||
|
||
from anomalib.models.image.vlm_ad.utils import Prompt | ||
|
||
from .base import Backend | ||
|
||
if package_available("transformers"): | ||
import transformers | ||
from transformers.modeling_utils import PreTrainedModel | ||
from transformers.processing_utils import ProcessorMixin | ||
else: | ||
transformers = None | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Huggingface(Backend): | ||
"""Huggingface backend.""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
) -> None: | ||
"""Initialize the Huggingface backend.""" | ||
self.model_name: str = model_name | ||
self._ref_images: list[str] = [] | ||
self._processor: ProcessorMixin | None = None | ||
self._model: PreTrainedModel | None = None | ||
|
||
@property | ||
def processor(self) -> ProcessorMixin: | ||
"""Get the Huggingface processor.""" | ||
if self._processor is None: | ||
if transformers is None: | ||
msg = "transformers is not installed." | ||
raise ValueError(msg) | ||
self._processor = transformers.LlavaNextProcessor.from_pretrained(self.model_name) | ||
return self._processor | ||
|
||
@property | ||
def model(self) -> PreTrainedModel: | ||
"""Get the Huggingface model.""" | ||
if self._model is None: | ||
if transformers is None: | ||
msg = "transformers is not installed." | ||
raise ValueError(msg) | ||
self._model = transformers.LlavaNextForConditionalGeneration.from_pretrained(self.model_name) | ||
return self._model | ||
|
||
@staticmethod | ||
def _generate_message(content: str, images: list[str] | None) -> dict: | ||
"""Generate a message.""" | ||
message: dict[str, str | list[dict]] = {"role": "user"} | ||
_content: list[dict[str, str]] = [{"type": "text", "text": content}] | ||
if images is not None: | ||
_content.extend([{"type": "image"} for _ in images]) | ||
message["content"] = _content | ||
return message | ||
|
||
def add_reference_images(self, image: str | Path) -> None: | ||
"""Add reference images for k-shot.""" | ||
self._ref_images.append(Image.open(image)) | ||
|
||
@property | ||
def num_reference_images(self) -> int: | ||
"""Get the number of reference images.""" | ||
return len(self._ref_images) | ||
|
||
def predict(self, image_path: str | Path, prompt: Prompt) -> str: | ||
"""Predict the anomaly label.""" | ||
image = Image.open(image_path) | ||
messages: list[dict] = [] | ||
|
||
if len(self._ref_images) > 0: | ||
messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images)) | ||
|
||
messages.append(self._generate_message(content=prompt.predict, images=[image])) | ||
processed_prompt = [self.processor.apply_chat_template(messages, add_generation_prompt=True)] | ||
|
||
images = [*self._ref_images, image] | ||
inputs = self.processor(images, processed_prompt, return_tensors="pt", padding=True).to(self.model.device) | ||
outputs = self.model.generate(**inputs, max_new_tokens=100) | ||
result = self.processor.decode(outputs[0], skip_special_tokens=True) | ||
print(result) | ||
return result |
Oops, something went wrong.