Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ”¨ Allow setting API keys from env #2353

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ core = [
"open-clip-torch>=2.23.0,<2.26.1",
]
openvino = ["openvino>=2024.0", "nncf>=2.10.0", "onnx>=1.16.0"]
vlm = ["ollama", "openai", "transformers"]
vlm = ["ollama", "openai", "python-dotenv","transformers"]
loggers = [
"comet-ml>=3.31.7",
"gradio>=4",
Expand Down
20 changes: 18 additions & 2 deletions src/anomalib/models/image/vlm_ad/backends/chat_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

import base64
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING

from dotenv import load_dotenv

from anomalib.models.image.vlm_ad.utils import Prompt
from anomalib.utils.exceptions import try_import

Expand All @@ -27,12 +30,12 @@
class ChatGPT(Backend):
"""ChatGPT backend."""

def __init__(self, api_key: str, model_name: str) -> None:
def __init__(self, model_name: str, api_key: str | None = None) -> None:
"""Initialize the ChatGPT backend."""
self.api_key = api_key
self._ref_images_encoded: list[str] = []
self.model_name: str = model_name
self._client: OpenAI | None = None
self.api_key = self._get_api_key(api_key)

@property
def client(self) -> OpenAI:
Expand Down Expand Up @@ -86,3 +89,16 @@ def _encode_image_to_base_64(image: str | Path) -> str:
"""Encode the image to base64."""
image = Path(image)
return base64.b64encode(image.read_bytes()).decode("utf-8")

def _get_api_key(self, api_key: str | None = None) -> str:
if api_key is None:
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
msg = (
f"OpenAI API key must be provided to use {self.model_name}."
" Please provide the API key in the constructor, or set the OPENAI_API_KEY environment variable"
" or in a `.env` file."
)
raise ValueError(msg)
return api_key
3 changes: 0 additions & 3 deletions src/anomalib/models/image/vlm_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def _setup_vlm(model: VLMModel, api_key: str | None) -> Backend:
if model == VLMModel.LLAMA_OLLAMA:
return Ollama(model_name=model.value)
if model == VLMModel.GPT_4O_MINI:
if api_key is None:
msg = f"ChatGPT API key is required to use {model.value} model."
raise ValueError(msg)
return ChatGPT(api_key=api_key, model_name=model.value)
if model in {VLMModel.VICUNA_7B_HF, VLMModel.VICUNA_13B_HF, VLMModel.MISTRAL_7B_HF}:
return Huggingface(model_name=model.value)
Expand Down
Loading