Skip to content

Commit

Permalink
Add Gemini chat to UI
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasruizt committed Dec 23, 2024
1 parent 8e7b5ab commit e7d8e46
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
17 changes: 13 additions & 4 deletions llmlib/llmlib/gemini/media_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ class ResponseRefusedException(Exception):
class GeminiAPI(LLM):
model_id: str = Models.gemini_pro
max_output_tokens: int = 1000
requires_gpu_exclusively: bool = False

requires_gpu_exclusively = False
model_ids = [Models.gemini_pro, Models.gemini_flash]

def complete_msgs2(self, msgs: list[Message]) -> str:
if len(msgs) != 1:
Expand All @@ -198,7 +200,7 @@ def complete_msgs2(self, msgs: list[Message]) -> str:

@singledispatchmethod
def video_prompt(self, video, prompt: str) -> str:
raise NotImplementedError
raise NotImplementedError(f"Unsupported video type: {type(video)}")

@video_prompt.register
def _(self, video: Path, prompt: str) -> str:
Expand All @@ -208,5 +210,12 @@ def _(self, video: Path, prompt: str) -> str:
@video_prompt.register
def _(self, video: BytesIO, prompt: str) -> str:
path = tempfile.mktemp(suffix=".mp4")
video.save(path)
return self.video_prompt(path, prompt)
with open(path, "wb") as f:
f.write(video.getvalue())
return self.video_prompt(Path(path), prompt)

@classmethod
def get_warnings(cls) -> list[str]:
return [
"While Gemini supports multi-turn, and multi-file chat, we have only implemented single-file and single-turn prompts atm."
]
24 changes: 18 additions & 6 deletions llmlib/llmlib/runtime.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .base_llm import LLM
from .gemini.media_description import GeminiAPI
from .gemma import PaliGemma2
from .minicpm import MiniCPM
from .llama3 import LLama3Vision8B
Expand All @@ -13,11 +15,21 @@ def filled_model_registry() -> ModelRegistry:
ModelEntry.from_cls_with_id(MiniCPM),
ModelEntry.from_cls_with_id(LLama3Vision8B),
ModelEntry.from_cls_with_id(PaliGemma2),
*[
ModelEntry(
model_id=id_, clazz=OpenAIModel, ctor=lambda: OpenAIModel(model=id_)
)
for id_ in OpenAIModel.model_ids
],
*model_entries_from_mult_ids(OpenAIModel),
*model_entries_from_mult_ids(GeminiAPI),
]
)


def model_entries_from_mult_ids(cls: type[LLM]) -> list[ModelEntry]:
assert hasattr(cls, "model_ids")
entries = [
ModelEntry(
model_id=id_,
clazz=cls,
ctor=lambda: cls(model_id=id_),
warnings=cls.get_warnings(),
)
for id_ in cls.model_ids
]
return entries
4 changes: 4 additions & 0 deletions st_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from io import BytesIO
import logging
from PIL import Image
import streamlit as st
from llmlib.runtime import filled_model_registry
Expand All @@ -9,6 +10,9 @@
from st_helpers import is_image, is_video
from login_mask_simple import check_password

fmt = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=fmt)

if not check_password():
st.stop()

Expand Down

0 comments on commit e7d8e46

Please sign in to comment.