Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: googleai llm addition #152

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions backend/.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ SQLITE_DB_PATH=
# LLM Integrations
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
GOOGLEAI_API_KEY=

# Tools
REPLICATE_API_TOKEN=
Expand Down
11 changes: 11 additions & 0 deletions backend/director/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ def get_parameters(self):
raise Exception(
"Failed to infere parameters, please define JSON instead of using this automated util."
)

parameters["properties"].pop("args", None)
parameters["properties"].pop("kwargs", None)

if "required" in parameters:
parameters["required"] = [
param
for param in parameters["required"]
if param not in ["args", "kwargs"]
]

return parameters

def to_llm_format(self):
Expand Down
9 changes: 8 additions & 1 deletion backend/director/agents/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@ def __init__(self, session: Session, **kwargs):
self.parameters = self.get_parameters()
super().__init__(session=session, **kwargs)

def run(self, stream_link: str, name: str = None, *args, **kwargs) -> AgentResponse:
def run(
self,
stream_link: str,
name: str = None,
stream_name: str = None,
*args,
**kwargs,
) -> AgentResponse:
"""
Downloads the video from the given stream link.

Expand Down
6 changes: 0 additions & 6 deletions backend/director/agents/dubbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@
"description": "The dubbing engine to use. Default is 'elevenlabs'. Possible values include 'elevenlabs'.",
"default": "elevenlabs",
},
"engine_params": {
"type": "object",
"description": "Optional parameters for the dubbing engine.",
},
},
"required": [
"video_id",
Expand Down Expand Up @@ -66,7 +62,6 @@ def run(
target_language_code: str,
collection_id: str,
engine: str,
engine_params: dict = {},
*args,
**kwargs,
) -> AgentResponse:
Expand All @@ -77,7 +72,6 @@ def run(
:param str target_language_code: The target language code for dubbing (e.g. es).
:param str collection_id: The ID of the collection to process.
:param str engine: The dubbing engine to use. Default is 'elevenlabs'.
:param dict engine_params: Optional parameters for the dubbing engine.
:param args: Additional positional arguments.
:param kwargs: Additional keyword arguments.
:return: The response containing information about the dubbing operation.
Expand Down
2 changes: 1 addition & 1 deletion backend/director/agents/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"description": "Collection ID to upload the content",
},
},
"required": ["url", "media_type", "collection_id"],
"required": ["source", "media_type", "collection_id"],
}


Expand Down
2 changes: 2 additions & 0 deletions backend/director/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class LLMType(str, Enum):

OPENAI = "openai"
ANTHROPIC = "anthropic"
GOOGLEAI = "googleai"
VIDEODB_PROXY = "videodb_proxy"


Expand All @@ -27,5 +28,6 @@ class EnvPrefix(str, Enum):

OPENAI_ = "OPENAI_"
ANTHROPIC_ = "ANTHROPIC_"
GOOGLEAI_ = "GOOGLEAI_"

DOWNLOADS_PATH="director/downloads"
4 changes: 4 additions & 0 deletions backend/director/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from director.llm.openai import OpenAI
from director.llm.anthropic import AnthropicAI
from director.llm.googleai import GoogleAI
from director.llm.videodb_proxy import VideoDBProxy


Expand All @@ -12,12 +13,15 @@ def get_default_llm():

openai = True if os.getenv("OPENAI_API_KEY") else False
anthropic = True if os.getenv("ANTHROPIC_API_KEY") else False
googleai = True if os.getenv("GOOGLEAI_API_KEY") else False

default_llm = os.getenv("DEFAULT_LLM")

if openai or default_llm == LLMType.OPENAI:
return OpenAI()
elif anthropic or default_llm == LLMType.ANTHROPIC:
return AnthropicAI()
elif googleai or default_llm == LLMType.GOOGLEAI:
return GoogleAI()
else:
return VideoDBProxy()
191 changes: 191 additions & 0 deletions backend/director/llm/googleai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import json
from enum import Enum

from pydantic import Field, field_validator, FieldValidationInfo
from pydantic_settings import SettingsConfigDict


from director.llm.base import BaseLLM, BaseLLMConfig, LLMResponse, LLMResponseStatus
from director.constants import (
LLMType,
EnvPrefix,
)


class GoogleChatModel(str, Enum):
"""Enum for Google Gemini Chat models"""

GEMINI_1_5_FLASH = "gemini-1.5-flash"
GEMINI_1_5_FLASH_0_0_2 = "gemini-1.5-flash-002"
GEMINI_1_5_PRO = "gemini-1.5-pro"
GEMINI_1_5_PRO_0_0_2 = "gemini-1.5-pro-002"
GEMINI_2_0_FLASH = "gemini-2.0-flash"
GEMINI_2_0_FLASH_0_0_1 = "gemini-2.0-flash-001"
GEMINI_2_0_PRO = "gemini-2.0-pro-exp"


class GoogleAIConfig(BaseLLMConfig):
"""GoogleAI Config"""

model_config = SettingsConfigDict(
env_prefix=EnvPrefix.GOOGLEAI_,
extra="ignore",
)

llm_type: str = LLMType.GOOGLEAI
api_key: str = ""
api_base: str = "https://generativelanguage.googleapis.com/v1beta/openai/"
chat_model: str = Field(default=GoogleChatModel.GEMINI_2_0_FLASH)
max_tokens: int = 4096

@field_validator("api_key")
@classmethod
def validate_non_empty(cls, v, info: FieldValidationInfo):
if not v:
raise ValueError(
f"{info.field_name} must not be empty. Please set {EnvPrefix.GOOGLEAI_.value}{info.field_name.upper()} environment variable."
)
return v


class GoogleAI(BaseLLM):
def __init__(self, config: GoogleAIConfig = None):
"""
:param config: GoogleAI Config
"""
if config is None:
config = GoogleAIConfig()
super().__init__(config=config)
try:
import openai
except ImportError:
raise ImportError("Please install OpenAI python library.")

Comment on lines +59 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve exception chaining and error message clarity.

The error handling for the missing OpenAI library should use proper exception chaining with from and provide a more specific error message about Google AI's reliance on the OpenAI client.

try:
    import openai
except ImportError as err:
-   raise ImportError("Please install OpenAI python library.")
+   raise ImportError("Please install OpenAI python library which is required for Google AI integration.") from err
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
import openai
except ImportError:
raise ImportError("Please install OpenAI python library.")
try:
import openai
except ImportError as err:
raise ImportError("Please install OpenAI python library which is required for Google AI integration.") from err
🧰 Tools
🪛 Ruff (0.8.2)

62-62: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

self.client = openai.OpenAI(
api_key=self.api_key, base_url=self.api_base
)

def _format_messages(self, messages: list):
"""Format the messages to the format that Google Gemini expects."""
formatted_messages = []

for message in messages:
if message["role"] == "assistant" and message.get("tool_calls"):
formatted_messages.append(
{
"role": message["role"],
"content": message["content"]
if message["content"]
else "[Processing request...]",
"tool_calls": [
{
"id": tool_call["id"],
"function": {
"name": tool_call["tool"]["name"],
"arguments": json.dumps(
tool_call["tool"]["arguments"]
),
},
"type": tool_call["type"],
}
for tool_call in message["tool_calls"]
],
}
)
else:
formatted_messages.append(message)

return formatted_messages

def _format_tools(self, tools: list):
"""Format the tools to the format that Gemini expects.

**Example**::

[
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. Chicago, IL"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}
]
"""
return [
{
"type": "function",
"function": {
"name": tool.get("name", ""),
"description": tool.get("description", ""),
"parameters": tool.get("parameters", {}),
},
}
for tool in tools
if tool.get("name")
]

def chat_completions(
self, messages: list, tools: list = [], stop=None, response_format=None
):
"""Get chat completions using Gemini.

docs: https://ai.google.dev/gemini-api/docs/openai
"""
params = {
"model": self.chat_model,
"messages": self._format_messages(messages),
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"timeout": self.timeout,
}

if tools:
params["tools"] = self._format_tools(tools)
params["tool_choice"] = "auto"

if response_format:
params["response_format"] = response_format

try:
response = self.client.chat.completions.create(**params)
except Exception as e:
print(f"Error: {e}")
return LLMResponse(content=f"Error: {e}")

return LLMResponse(
content=response.choices[0].message.content or "",
tool_calls=[
{
"id": tool_call.id,
"tool": {
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
},
"type": tool_call.type,
}
for tool_call in response.choices[0].message.tool_calls
]
if response.choices[0].message.tool_calls
else [],
finish_reason=response.choices[0].finish_reason,
send_tokens=response.usage.prompt_tokens,
recv_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
status=LLMResponseStatus.SUCCESS,
Comment on lines +172 to +190
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for accessing response properties.

The code assumes that response.choices[0] and other properties always exist. Add error handling to prevent potential exceptions.

+       if not response.choices:
+           return LLMResponse(
+               content="Error: No choices in response",
+               status=LLMResponseStatus.ERROR
+           )
+
        return LLMResponse(
            content=response.choices[0].message.content or "",
            tool_calls=[
                {
                    "id": tool_call.id,
                    "tool": {
                        "name": tool_call.function.name,
                        "arguments": json.loads(tool_call.function.arguments),
                    },
                    "type": tool_call.type,
                }
                for tool_call in response.choices[0].message.tool_calls
            ]
            if response.choices[0].message.tool_calls
            else [],
            finish_reason=response.choices[0].finish_reason,
-           send_tokens=response.usage.prompt_tokens,
-           recv_tokens=response.usage.completion_tokens,
-           total_tokens=response.usage.total_tokens,
+           send_tokens=getattr(response.usage, 'prompt_tokens', 0),
+           recv_tokens=getattr(response.usage, 'completion_tokens', 0),
+           total_tokens=getattr(response.usage, 'total_tokens', 0),
            status=LLMResponseStatus.SUCCESS,
        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
content=response.choices[0].message.content or "",
tool_calls=[
{
"id": tool_call.id,
"tool": {
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
},
"type": tool_call.type,
}
for tool_call in response.choices[0].message.tool_calls
]
if response.choices[0].message.tool_calls
else [],
finish_reason=response.choices[0].finish_reason,
send_tokens=response.usage.prompt_tokens,
recv_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
status=LLMResponseStatus.SUCCESS,
if not response.choices:
return LLMResponse(
content="Error: No choices in response",
status=LLMResponseStatus.ERROR
)
return LLMResponse(
content=response.choices[0].message.content or "",
tool_calls=[
{
"id": tool_call.id,
"tool": {
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
},
"type": tool_call.type,
}
for tool_call in response.choices[0].message.tool_calls
] if response.choices[0].message.tool_calls else [],
finish_reason=response.choices[0].finish_reason,
send_tokens=getattr(response.usage, 'prompt_tokens', 0),
recv_tokens=getattr(response.usage, 'completion_tokens', 0),
total_tokens=getattr(response.usage, 'total_tokens', 0),
status=LLMResponseStatus.SUCCESS,
)

)
15 changes: 15 additions & 0 deletions docs/llm/googleai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
## GoogleAI

GoogleAI extends the Base LLM and implements the Google Gemini API.

### GoogleAI Config

GoogleAI Config is the configuration object for Google Gemini. It is used to configure Google Gemini and is passed to GoogleAI when it is created.

::: director.llm.googleai.GoogleAIConfig

### GoogleAI Interface

GoogleAI is the LLM used by the agents and tools. It is used to generate responses to messages.

::: director.llm.googleai.GoogleAI
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ nav:
- Integrations:
- 'OpenAI': 'llm/openai.md'
- 'AnthropicAI': 'llm/anthropic.md'
- 'GoogleAI': 'llm/googleai.md'
- 'Database':
- 'Interface': 'database/interface.md'
- Integrations:
Expand Down