diff --git a/llmlib/llmlib/gemini/media_description.py b/llmlib/llmlib/gemini/media_description.py index 8a65d2a..f7c8542 100644 --- a/llmlib/llmlib/gemini/media_description.py +++ b/llmlib/llmlib/gemini/media_description.py @@ -176,7 +176,9 @@ class ResponseRefusedException(Exception): class GeminiAPI(LLM): model_id: str = Models.gemini_pro max_output_tokens: int = 1000 - requires_gpu_exclusively: bool = False + + requires_gpu_exclusively = False + model_ids = [Models.gemini_pro, Models.gemini_flash] def complete_msgs2(self, msgs: list[Message]) -> str: if len(msgs) != 1: @@ -198,7 +200,7 @@ def complete_msgs2(self, msgs: list[Message]) -> str: @singledispatchmethod def video_prompt(self, video, prompt: str) -> str: - raise NotImplementedError + raise NotImplementedError(f"Unsupported video type: {type(video)}") @video_prompt.register def _(self, video: Path, prompt: str) -> str: @@ -208,5 +210,12 @@ def _(self, video: Path, prompt: str) -> str: @video_prompt.register def _(self, video: BytesIO, prompt: str) -> str: path = tempfile.mktemp(suffix=".mp4") - video.save(path) - return self.video_prompt(path, prompt) + with open(path, "wb") as f: + f.write(video.getvalue()) + return self.video_prompt(Path(path), prompt) + + @classmethod + def get_warnings(cls) -> list[str]: + return [ + "While Gemini supports multi-turn, and multi-file chat, we have only implemented single-file and single-turn prompts atm." + ] diff --git a/llmlib/llmlib/runtime.py b/llmlib/llmlib/runtime.py index 08c14ff..27e24e2 100644 --- a/llmlib/llmlib/runtime.py +++ b/llmlib/llmlib/runtime.py @@ -1,3 +1,5 @@ +from .base_llm import LLM +from .gemini.media_description import GeminiAPI from .gemma import PaliGemma2 from .minicpm import MiniCPM from .llama3 import LLama3Vision8B @@ -13,11 +15,21 @@ def filled_model_registry() -> ModelRegistry: ModelEntry.from_cls_with_id(MiniCPM), ModelEntry.from_cls_with_id(LLama3Vision8B), ModelEntry.from_cls_with_id(PaliGemma2), - *[ - ModelEntry( - model_id=id_, clazz=OpenAIModel, ctor=lambda: OpenAIModel(model=id_) - ) - for id_ in OpenAIModel.model_ids - ], + *model_entries_from_mult_ids(OpenAIModel), + *model_entries_from_mult_ids(GeminiAPI), ] ) + + +def model_entries_from_mult_ids(cls: type[LLM]) -> list[ModelEntry]: + assert hasattr(cls, "model_ids") + entries = [ + ModelEntry( + model_id=id_, + clazz=cls, + ctor=lambda: cls(model_id=id_), + warnings=cls.get_warnings(), + ) + for id_ in cls.model_ids + ] + return entries diff --git a/st_app.py b/st_app.py index 9cfcdc8..cff49f1 100644 --- a/st_app.py +++ b/st_app.py @@ -1,4 +1,5 @@ from io import BytesIO +import logging from PIL import Image import streamlit as st from llmlib.runtime import filled_model_registry @@ -9,6 +10,9 @@ from st_helpers import is_image, is_video from login_mask_simple import check_password +fmt = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +logging.basicConfig(level=logging.INFO, format=fmt) + if not check_password(): st.stop()