Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix azure #2665

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add_deployment_name_to_llmprovider

Revision ID: e4334d5b33ba
Revises: 46b7a812670f
Create Date: 2024-10-04 09:52:34.896867

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "e4334d5b33ba"
down_revision = "46b7a812670f"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"llm_provider", sa.Column("deployment_name", sa.String(), nullable=True)
)


def downgrade() -> None:
op.drop_column("llm_provider", "deployment_name")
28 changes: 26 additions & 2 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
Expand Down Expand Up @@ -560,7 +564,26 @@ def stream_chat_message_objects(
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = llm.config
img_generation_llm_config = LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
elif (
llm.config.model_provider == "azure"
and AZURE_DALLE_API_KEY is not None
):
img_generation_llm_config = LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
Expand All @@ -579,7 +602,7 @@ def stream_chat_message_objects(
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name=openai_provider.default_model_name,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
Expand All @@ -591,6 +614,7 @@ def stream_chat_message_objects(
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
model=img_generation_llm_config.model_name,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
Expand Down
8 changes: 7 additions & 1 deletion backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"

POSTGRES_API_SERVER_POOL_SIZE = int(
Expand Down Expand Up @@ -413,6 +413,12 @@
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)

# Azure DALL-E Configurations
AZURE_DALLE_API_VERSION = os.environ.get("AZURE_DALLE_API_VERSION")
AZURE_DALLE_API_KEY = os.environ.get("AZURE_DALLE_API_KEY")
AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE")
AZURE_DEPLOYMENT_NAME = os.environ.get("AZURE_DEPLOYMENT_NAME")


MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/db/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def upsert_llm_provider(
existing_llm_provider.model_names = llm_provider.model_names
existing_llm_provider.is_public = llm_provider.is_public
existing_llm_provider.display_model_names = llm_provider.display_model_names
existing_llm_provider.deployment_name = llm_provider.deployment_name

if not existing_llm_provider.id:
# If its not already in the db, we need to generate an ID by flushing
Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,8 @@ class LLMProvider(Base):
postgresql.ARRAY(String), nullable=True
)

deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)

# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
# EE only
Expand Down
5 changes: 4 additions & 1 deletion backend/danswer/llm/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __init__(
model_name: str,
api_base: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
max_output_tokens: int | None = None,
custom_llm_provider: str | None = None,
temperature: float = GEN_AI_TEMPERATURE,
Expand All @@ -215,6 +216,7 @@ def __init__(
self._model_version = model_name
self._temperature = temperature
self._api_key = api_key
self._deployment_name = deployment_name
self._api_base = api_base
self._api_version = api_version
self._custom_llm_provider = custom_llm_provider
Expand Down Expand Up @@ -290,7 +292,7 @@ def _completion(
try:
return litellm.completion(
# model choice
model=f"{self.config.model_provider}/{self.config.model_name}",
model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
# NOTE: have to pass in None instead of empty string for these
# otherwise litellm can have some issues with bedrock
api_key=self._api_key or None,
Expand Down Expand Up @@ -325,6 +327,7 @@ def config(self) -> LLMConfig:
api_key=self._api_key,
api_base=self._api_base,
api_version=self._api_version,
deployment_name=self._deployment_name,
)

def _invoke_implementation(
Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/llm/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
Expand All @@ -103,6 +104,7 @@ def _create_llm(model: str) -> LLM:
def get_llm(
provider: str,
model: str,
deployment_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/llm/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class LLMConfig(BaseModel):
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None

deployment_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/llm/llm_provider_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ class WellKnownLLMProviderDescriptor(BaseModel):
api_base_required: bool
api_version_required: bool
custom_config_keys: list[CustomConfigKey] | None = None
single_model_supported: bool = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Consider adding a comment explaining the purpose of this field


llm_names: list[str]
default_model: str | None = None
default_fast_model: str | None = None
deployment_name_required: bool = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Add a comment explaining when this field is used



OPENAI_PROVIDER_NAME = "openai"
Expand Down Expand Up @@ -108,6 +110,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
api_version_required=True,
custom_config_keys=[],
llm_names=fetch_models_for_provider(AZURE_PROVIDER_NAME),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Azure provider not included in _PROVIDER_TO_MODELS_MAP, this may return an empty list

deployment_name_required=True,
single_model_supported=True,
),
WellKnownLLMProviderDescriptor(
name=BEDROCK_PROVIDER_NAME,
Expand Down
9 changes: 8 additions & 1 deletion backend/danswer/server/features/tool/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from danswer.tools.custom.openapi_parsing import MethodSpec
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.utils import is_image_generation_available

router = APIRouter(prefix="/tool")
admin_router = APIRouter(prefix="/admin/tool")
Expand Down Expand Up @@ -127,4 +129,9 @@ def list_tools(
_: User | None = Depends(current_user),
) -> list[ToolSnapshot]:
tools = get_tools(db_session)
return [ToolSnapshot.from_model(tool) for tool in tools]
return [
ToolSnapshot.from_model(tool)
for tool in tools
if tool.in_code_tool_id != ImageGenerationTool.name
or is_image_generation_available(db_session=db_session)
]
15 changes: 15 additions & 0 deletions backend/danswer/server/manage/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_llm_configuration(
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
)

functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]

if (
Expand Down Expand Up @@ -141,6 +142,20 @@ def put_llm_provider(
detail=f"LLM Provider with name {llm_provider.name} already exists",
)

# Ensure default_model_name and fast_default_model_name are in display_model_names
# This is necessary for custom models and Bedrock/Azure models
if llm_provider.display_model_names is None:
llm_provider.display_model_names = []

if llm_provider.default_model_name not in llm_provider.display_model_names:
llm_provider.display_model_names.append(llm_provider.default_model_name)

if (
llm_provider.fast_default_model_name
and llm_provider.fast_default_model_name not in llm_provider.display_model_names
):
llm_provider.display_model_names.append(llm_provider.fast_default_model_name)
Comment on lines +150 to +157
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: This logic might be better placed in the LLMProviderUpsertRequest model's validation or in a separate function for cleaner code organization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to clean upsert as clean and minimal as possible


try:
return upsert_llm_provider(
llm_provider=llm_provider,
Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/server/manage/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class LLMProvider(BaseModel):
is_public: bool = True
groups: list[int] = Field(default_factory=list)
display_model_names: list[str] | None = None
deployment_name: str | None = None


class LLMProviderUpsertRequest(LLMProvider):
Expand Down Expand Up @@ -100,4 +101,5 @@ def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider"
),
is_public=llm_provider_model.is_public,
groups=[group.id for group in llm_provider_model.groups],
deployment_name=llm_provider_model.deployment_name,
)
21 changes: 21 additions & 0 deletions backend/danswer/tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import json

from sqlalchemy.orm import Session

from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.db.connector import check_connectors_exist
from danswer.db.document import check_docs_exist
from danswer.db.models import LLMProvider
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.tools.tool import Tool

Expand All @@ -26,3 +32,18 @@ def compute_tool_tokens(tool: Tool, llm_tokenizer: BaseTokenizer) -> int:

def compute_all_tool_tokens(tools: list[Tool], llm_tokenizer: BaseTokenizer) -> int:
return sum(compute_tool_tokens(tool, llm_tokenizer) for tool in tools)


def is_image_generation_available(db_session: Session) -> bool:
providers = db_session.query(LLMProvider).all()
for provider in providers:
if provider.name == "OpenAI":
return True

return bool(AZURE_DALLE_API_KEY)
Comment on lines +37 to +43
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: This function assumes OpenAI always supports image generation. Consider checking for specific OpenAI models that support DALL-E.



def is_document_search_available(db_session: Session) -> bool:
docs_exist = check_docs_exist(db_session)
connectors_exist = check_connectors_exist(db_session)
return docs_exist or connectors_exist
2 changes: 1 addition & 1 deletion deployment/docker_compose/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

Expand Down
2 changes: 1 addition & 1 deletion deployment/docker_compose/docker-compose.gpu-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432"
- "5433"
volumes:
- db_volume:/var/lib/postgresql/data

Expand Down
10 changes: 3 additions & 7 deletions web/src/app/admin/assistants/AssistantEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,11 @@ export function AssistantEditor({
modelOptionsByProvider.set(llmProvider.name, providerOptions);
});

const providerSupportingImageGenerationExists =
providersContainImageGeneratingSupport(llmProviders);

const personaCurrentToolIds =
existingPersona?.tools.map((tool) => tool.id) || [];

const searchTool = findSearchTool(tools);
const imageGenerationTool = providerSupportingImageGenerationExists
? findImageGenerationTool(tools)
: undefined;
const imageGenerationTool = findImageGenerationTool(tools);
const internetSearchTool = findInternetSearchTool(tools);

const customTools = tools.filter(
Expand Down Expand Up @@ -997,7 +993,7 @@ export function AssistantEditor({
alignTop={tool.description != null}
key={tool.id}
name={`enabled_tools_map.${tool.id}`}
label={tool.name}
label={tool.display_name}
subtext={tool.description}
onChange={() => {
toggleToolInValues(tool.id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ function LLMProviderDisplay({
</Button>
</div>
</div>

{formIsVisible && (
<LLMProviderUpdateModal
llmProviderDescriptor={llmProviderDescriptor}
Expand Down
Loading
Loading