From 6462a7e508831666aed1a14cca5cc9dd5f59fc4f Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Mon, 7 Oct 2024 14:29:40 +0200 Subject: [PATCH] Allow setting API keys from env Signed-off-by: Ashwin Vaidya --- pyproject.toml | 2 +- .../models/image/vlm_ad/backends/chat_gpt.py | 20 +++++++++++++++++-- .../models/image/vlm_ad/lightning_model.py | 3 --- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 453a41be89..268544ad2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ core = [ "open-clip-torch>=2.23.0,<2.26.1", ] openvino = ["openvino>=2024.0", "nncf>=2.10.0", "onnx>=1.16.0"] -vlm = ["ollama", "openai", "transformers"] +vlm = ["ollama", "openai", "python-dotenv","transformers"] loggers = [ "comet-ml>=3.31.7", "gradio>=4", diff --git a/src/anomalib/models/image/vlm_ad/backends/chat_gpt.py b/src/anomalib/models/image/vlm_ad/backends/chat_gpt.py index def43674ae..a6a0414012 100644 --- a/src/anomalib/models/image/vlm_ad/backends/chat_gpt.py +++ b/src/anomalib/models/image/vlm_ad/backends/chat_gpt.py @@ -5,9 +5,12 @@ import base64 import logging +import os from pathlib import Path from typing import TYPE_CHECKING +from dotenv import load_dotenv + from anomalib.models.image.vlm_ad.utils import Prompt from anomalib.utils.exceptions import try_import @@ -27,12 +30,12 @@ class ChatGPT(Backend): """ChatGPT backend.""" - def __init__(self, api_key: str, model_name: str) -> None: + def __init__(self, model_name: str, api_key: str | None = None) -> None: """Initialize the ChatGPT backend.""" - self.api_key = api_key self._ref_images_encoded: list[str] = [] self.model_name: str = model_name self._client: OpenAI | None = None + self.api_key = self._get_api_key(api_key) @property def client(self) -> OpenAI: @@ -86,3 +89,16 @@ def _encode_image_to_base_64(image: str | Path) -> str: """Encode the image to base64.""" image = Path(image) return base64.b64encode(image.read_bytes()).decode("utf-8") + + def _get_api_key(self, api_key: str | None = None) -> str: + if api_key is None: + load_dotenv() + api_key = os.getenv("OPENAI_API_KEY") + if api_key is None: + msg = ( + f"OpenAI API key must be provided to use {self.model_name}." + " Please provide the API key in the constructor, or set the OPENAI_API_KEY environment variable" + " or in a `.env` file." + ) + raise ValueError(msg) + return api_key diff --git a/src/anomalib/models/image/vlm_ad/lightning_model.py b/src/anomalib/models/image/vlm_ad/lightning_model.py index 92797c948b..929ac0e591 100644 --- a/src/anomalib/models/image/vlm_ad/lightning_model.py +++ b/src/anomalib/models/image/vlm_ad/lightning_model.py @@ -36,9 +36,6 @@ def _setup_vlm(model: VLMModel, api_key: str | None) -> Backend: if model == VLMModel.LLAMA_OLLAMA: return Ollama(model_name=model.value) if model == VLMModel.GPT_4O_MINI: - if api_key is None: - msg = f"ChatGPT API key is required to use {model.value} model." - raise ValueError(msg) return ChatGPT(api_key=api_key, model_name=model.value) if model in {VLMModel.VICUNA_7B_HF, VLMModel.VICUNA_13B_HF, VLMModel.MISTRAL_7B_HF}: return Huggingface(model_name=model.value)