From 405fd6e53a972bd4c899067018bd0e114b60ce25 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Tue, 8 Oct 2024 14:51:00 -0700 Subject: [PATCH 1/5] fix azure --- ...33ba_add_deployment_name_to_llmprovider.py | 26 +++++++ backend/danswer/chat/process_message.py | 28 +++++++- backend/danswer/configs/app_configs.py | 8 ++- backend/danswer/db/llm.py | 1 + backend/danswer/db/models.py | 2 + backend/danswer/llm/chat_llm.py | 5 +- backend/danswer/llm/factory.py | 2 + backend/danswer/llm/interfaces.py | 2 +- backend/danswer/llm/llm_provider_options.py | 4 ++ backend/danswer/server/features/tool/api.py | 9 ++- backend/danswer/server/manage/llm/api.py | 15 ++++ backend/danswer/server/manage/llm/models.py | 2 + backend/danswer/tools/utils.py | 21 ++++++ .../docker_compose/docker-compose.dev.yml | 2 +- .../docker_compose/docker-compose.gpu-dev.yml | 2 +- .../docker-compose.search-testing.yml | 2 +- .../app/admin/assistants/AssistantEditor.tsx | 10 +-- .../llm/ConfiguredLLMProviderDisplay.tsx | 1 + .../llm/CustomLLMProviderUpdateForm.tsx | 71 +++++++++++-------- .../llm/LLMProviderUpdateForm.tsx | 68 +++++++++++------- .../app/admin/configuration/llm/interfaces.ts | 5 +- .../initialSetup/welcome/WelcomeModal.tsx | 2 +- web/src/lib/chat/fetchChatData.ts | 9 ++- 23 files changed, 222 insertions(+), 75 deletions(-) create mode 100644 backend/alembic/versions/e4334d5b33ba_add_deployment_name_to_llmprovider.py diff --git a/backend/alembic/versions/e4334d5b33ba_add_deployment_name_to_llmprovider.py b/backend/alembic/versions/e4334d5b33ba_add_deployment_name_to_llmprovider.py new file mode 100644 index 00000000000..19b14427e34 --- /dev/null +++ b/backend/alembic/versions/e4334d5b33ba_add_deployment_name_to_llmprovider.py @@ -0,0 +1,26 @@ +"""add_deployment_name_to_llmprovider + +Revision ID: e4334d5b33ba +Revises: 46b7a812670f +Create Date: 2024-10-04 09:52:34.896867 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "e4334d5b33ba" +down_revision = "46b7a812670f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "llm_provider", sa.Column("deployment_name", sa.String(), nullable=True) + ) + + +def downgrade() -> None: + op.drop_column("llm_provider", "deployment_name") diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index f09ac18f32a..ad7a38bf796 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -18,6 +18,10 @@ from danswer.chat.models import MessageSpecificCitations from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError +from danswer.configs.app_configs import AZURE_DALLE_API_BASE +from danswer.configs.app_configs import AZURE_DALLE_API_KEY +from danswer.configs.app_configs import AZURE_DALLE_API_VERSION +from danswer.configs.app_configs import AZURE_DEPLOYMENT_NAME from danswer.configs.chat_configs import BING_API_KEY from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH @@ -560,7 +564,26 @@ def stream_chat_message_objects( and llm.config.api_key and llm.config.model_provider == "openai" ): - img_generation_llm_config = llm.config + img_generation_llm_config = LLMConfig( + model_provider=llm.config.model_provider, + model_name="dall-e-3", + temperature=GEN_AI_TEMPERATURE, + api_key=llm.config.api_key, + api_base=llm.config.api_base, + api_version=llm.config.api_version, + ) + elif ( + llm.config.model_provider == "azure" + and AZURE_DALLE_API_KEY is not None + ): + img_generation_llm_config = LLMConfig( + model_provider="azure", + model_name=f"azure/{AZURE_DEPLOYMENT_NAME}", + temperature=GEN_AI_TEMPERATURE, + api_key=AZURE_DALLE_API_KEY, + api_base=AZURE_DALLE_API_BASE, + api_version=AZURE_DALLE_API_VERSION, + ) else: llm_providers = fetch_existing_llm_providers(db_session) openai_provider = next( @@ -579,7 +602,7 @@ def stream_chat_message_objects( ) img_generation_llm_config = LLMConfig( model_provider=openai_provider.provider, - model_name=openai_provider.default_model_name, + model_name="dall-e-3", temperature=GEN_AI_TEMPERATURE, api_key=openai_provider.api_key, api_base=openai_provider.api_base, @@ -591,6 +614,7 @@ def stream_chat_message_objects( api_base=img_generation_llm_config.api_base, api_version=img_generation_llm_config.api_version, additional_headers=litellm_additional_headers, + model=img_generation_llm_config.model_name, ) ] elif tool_cls.__name__ == InternetSearchTool.__name__: diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 4559fed6b87..fc0f05435b0 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -135,7 +135,7 @@ os.environ.get("POSTGRES_PASSWORD") or "password" ) POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost" -POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432" +POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433" POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres" POSTGRES_API_SERVER_POOL_SIZE = int( @@ -413,6 +413,12 @@ os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true" ) +# Azure DALL-E Configurations +AZURE_DALLE_API_VERSION = os.environ.get("AZURE_DALLE_API_VERSION") +AZURE_DALLE_API_KEY = os.environ.get("AZURE_DALLE_API_KEY") +AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE") +AZURE_DEPLOYMENT_NAME = os.environ.get("AZURE_DEPLOYMENT_NAME") + MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true" SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "") diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index af2ded9562a..c03ed99e412 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -83,6 +83,7 @@ def upsert_llm_provider( existing_llm_provider.model_names = llm_provider.model_names existing_llm_provider.is_public = llm_provider.is_public existing_llm_provider.display_model_names = llm_provider.display_model_names + existing_llm_provider.deployment_name = llm_provider.deployment_name if not existing_llm_provider.id: # If its not already in the db, we need to generate an ID by flushing diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 4777577d0fd..392c7a28b2e 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1143,6 +1143,8 @@ class LLMProvider(Base): postgresql.ARRAY(String), nullable=True ) + deployment_name: Mapped[str | None] = mapped_column(String, nullable=True) + # should only be set for a single provider is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True) # EE only diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index 1021f82abc6..90136a76bd8 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -204,6 +204,7 @@ def __init__( model_name: str, api_base: str | None = None, api_version: str | None = None, + deployment_name: str | None = None, max_output_tokens: int | None = None, custom_llm_provider: str | None = None, temperature: float = GEN_AI_TEMPERATURE, @@ -215,6 +216,7 @@ def __init__( self._model_version = model_name self._temperature = temperature self._api_key = api_key + self._deployment_name = deployment_name self._api_base = api_base self._api_version = api_version self._custom_llm_provider = custom_llm_provider @@ -290,7 +292,7 @@ def _completion( try: return litellm.completion( # model choice - model=f"{self.config.model_provider}/{self.config.model_name}", + model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}", # NOTE: have to pass in None instead of empty string for these # otherwise litellm can have some issues with bedrock api_key=self._api_key or None, @@ -325,6 +327,7 @@ def config(self) -> LLMConfig: api_key=self._api_key, api_base=self._api_base, api_version=self._api_version, + deployment_name=self._deployment_name, ) def _invoke_implementation( diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index f57bfb524b9..904735d5ffe 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -88,6 +88,7 @@ def _create_llm(model: str) -> LLM: return get_llm( provider=llm_provider.provider, model=model, + deployment_name=llm_provider.deployment_name, api_key=llm_provider.api_key, api_base=llm_provider.api_base, api_version=llm_provider.api_version, @@ -103,6 +104,7 @@ def _create_llm(model: str) -> LLM: def get_llm( provider: str, model: str, + deployment_name: str | None = None, api_key: str | None = None, api_base: str | None = None, api_version: str | None = None, diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py index 5e39792c393..6cb58e46c6b 100644 --- a/backend/danswer/llm/interfaces.py +++ b/backend/danswer/llm/interfaces.py @@ -24,7 +24,7 @@ class LLMConfig(BaseModel): api_key: str | None = None api_base: str | None = None api_version: str | None = None - + deployment_name: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} diff --git a/backend/danswer/llm/llm_provider_options.py b/backend/danswer/llm/llm_provider_options.py index 8fc1de73955..9fb55365f6c 100644 --- a/backend/danswer/llm/llm_provider_options.py +++ b/backend/danswer/llm/llm_provider_options.py @@ -16,10 +16,12 @@ class WellKnownLLMProviderDescriptor(BaseModel): api_base_required: bool api_version_required: bool custom_config_keys: list[CustomConfigKey] | None = None + single_model_supported: bool = False llm_names: list[str] default_model: str | None = None default_fast_model: str | None = None + deployment_name_required: bool = False OPENAI_PROVIDER_NAME = "openai" @@ -108,6 +110,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]: api_version_required=True, custom_config_keys=[], llm_names=fetch_models_for_provider(AZURE_PROVIDER_NAME), + deployment_name_required=True, + single_model_supported=True, ), WellKnownLLMProviderDescriptor( name=BEDROCK_PROVIDER_NAME, diff --git a/backend/danswer/server/features/tool/api.py b/backend/danswer/server/features/tool/api.py index 1d441593784..7e15c048826 100644 --- a/backend/danswer/server/features/tool/api.py +++ b/backend/danswer/server/features/tool/api.py @@ -21,6 +21,8 @@ from danswer.tools.custom.openapi_parsing import MethodSpec from danswer.tools.custom.openapi_parsing import openapi_to_method_specs from danswer.tools.custom.openapi_parsing import validate_openapi_schema +from danswer.tools.images.image_generation_tool import ImageGenerationTool +from danswer.tools.utils import is_image_generation_available router = APIRouter(prefix="/tool") admin_router = APIRouter(prefix="/admin/tool") @@ -127,4 +129,9 @@ def list_tools( _: User | None = Depends(current_user), ) -> list[ToolSnapshot]: tools = get_tools(db_session) - return [ToolSnapshot.from_model(tool) for tool in tools] + return [ + ToolSnapshot.from_model(tool) + for tool in tools + if tool.in_code_tool_id != ImageGenerationTool.name + or is_image_generation_available(db_session=db_session) + ] diff --git a/backend/danswer/server/manage/llm/api.py b/backend/danswer/server/manage/llm/api.py index 23f16047e91..06501d6834c 100644 --- a/backend/danswer/server/manage/llm/api.py +++ b/backend/danswer/server/manage/llm/api.py @@ -55,6 +55,7 @@ def test_llm_configuration( api_version=test_llm_request.api_version, custom_config=test_llm_request.custom_config, ) + functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))] if ( @@ -141,6 +142,20 @@ def put_llm_provider( detail=f"LLM Provider with name {llm_provider.name} already exists", ) + # Ensure default_model_name and fast_default_model_name are in display_model_names + # This is necessary for custom models and Bedrock/Azure models + if llm_provider.display_model_names is None: + llm_provider.display_model_names = [] + + if llm_provider.default_model_name not in llm_provider.display_model_names: + llm_provider.display_model_names.append(llm_provider.default_model_name) + + if ( + llm_provider.fast_default_model_name + and llm_provider.fast_default_model_name not in llm_provider.display_model_names + ): + llm_provider.display_model_names.append(llm_provider.fast_default_model_name) + try: return upsert_llm_provider( llm_provider=llm_provider, diff --git a/backend/danswer/server/manage/llm/models.py b/backend/danswer/server/manage/llm/models.py index 3ef66971003..2e3b3844807 100644 --- a/backend/danswer/server/manage/llm/models.py +++ b/backend/danswer/server/manage/llm/models.py @@ -66,6 +66,7 @@ class LLMProvider(BaseModel): is_public: bool = True groups: list[int] = Field(default_factory=list) display_model_names: list[str] | None = None + deployment_name: str | None = None class LLMProviderUpsertRequest(LLMProvider): @@ -100,4 +101,5 @@ def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider" ), is_public=llm_provider_model.is_public, groups=[group.id for group in llm_provider_model.groups], + deployment_name=llm_provider_model.deployment_name, ) diff --git a/backend/danswer/tools/utils.py b/backend/danswer/tools/utils.py index 9e20105edef..157d4bb6ec9 100644 --- a/backend/danswer/tools/utils.py +++ b/backend/danswer/tools/utils.py @@ -1,5 +1,11 @@ import json +from sqlalchemy.orm import Session + +from danswer.configs.app_configs import AZURE_DALLE_API_KEY +from danswer.db.connector import check_connectors_exist +from danswer.db.document import check_docs_exist +from danswer.db.models import LLMProvider from danswer.natural_language_processing.utils import BaseTokenizer from danswer.tools.tool import Tool @@ -26,3 +32,18 @@ def compute_tool_tokens(tool: Tool, llm_tokenizer: BaseTokenizer) -> int: def compute_all_tool_tokens(tools: list[Tool], llm_tokenizer: BaseTokenizer) -> int: return sum(compute_tool_tokens(tool, llm_tokenizer) for tool in tools) + + +def is_image_generation_available(db_session: Session) -> bool: + providers = db_session.query(LLMProvider).all() + for provider in providers: + if provider.name == "OpenAI": + return True + + return bool(AZURE_DALLE_API_KEY) + + +def is_document_search_available(db_session: Session) -> bool: + docs_exist = check_docs_exist(db_session) + connectors_exist = check_connectors_exist(db_session) + return docs_exist or connectors_exist diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 86d988e7d90..52dda002b2b 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -298,7 +298,7 @@ services: - POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} ports: - - "5432:5432" + - "5433:5432" volumes: - db_volume:/var/lib/postgresql/data diff --git a/deployment/docker_compose/docker-compose.gpu-dev.yml b/deployment/docker_compose/docker-compose.gpu-dev.yml index ebce01eadb2..55038dd9eb2 100644 --- a/deployment/docker_compose/docker-compose.gpu-dev.yml +++ b/deployment/docker_compose/docker-compose.gpu-dev.yml @@ -308,7 +308,7 @@ services: - POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} ports: - - "5432:5432" + - "5433:5432" volumes: - db_volume:/var/lib/postgresql/data diff --git a/deployment/docker_compose/docker-compose.search-testing.yml b/deployment/docker_compose/docker-compose.search-testing.yml index fab950c064e..2afd54e029c 100644 --- a/deployment/docker_compose/docker-compose.search-testing.yml +++ b/deployment/docker_compose/docker-compose.search-testing.yml @@ -157,7 +157,7 @@ services: - POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} ports: - - "5432" + - "5433" volumes: - db_volume:/var/lib/postgresql/data diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 8c295b31a48..7ca1d087641 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -192,15 +192,11 @@ export function AssistantEditor({ modelOptionsByProvider.set(llmProvider.name, providerOptions); }); - const providerSupportingImageGenerationExists = - providersContainImageGeneratingSupport(llmProviders); - const personaCurrentToolIds = existingPersona?.tools.map((tool) => tool.id) || []; + const searchTool = findSearchTool(tools); - const imageGenerationTool = providerSupportingImageGenerationExists - ? findImageGenerationTool(tools) - : undefined; + const imageGenerationTool = findImageGenerationTool(tools); const internetSearchTool = findInternetSearchTool(tools); const customTools = tools.filter( @@ -997,7 +993,7 @@ export function AssistantEditor({ alignTop={tool.description != null} key={tool.id} name={`enabled_tools_map.${tool.id}`} - label={tool.name} + label={tool.display_name} subtext={tool.description} onChange={() => { toggleToolInValues(tool.id); diff --git a/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx b/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx index aa8c0f9725d..850eb9690b7 100644 --- a/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx +++ b/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx @@ -133,6 +133,7 @@ function LLMProviderDisplay({ + {formIsVisible && ( + {existingLlmProvider?.deployment_name && ( + + )} + - - List the individual models that you want to make available as - a part of this provider. At least one must be specified. For - the best experience your [Provider Name]/[Model Name] should - match one of the pairs listed{" "} - - here - - . - - } - /> + {!existingLlmProvider?.deployment_name && ( + + List the individual models that you want to make available + as a part of this provider. At least one must be specified. + For the best experience your [Provider Name]/[Model Name] + should match one of the pairs listed{" "} + + here + + . + + } + /> + )} @@ -395,14 +408,16 @@ export function CustomLLMProviderUpdateForm({ placeholder="E.g. gpt-4" /> - + label="[Optional] Fast Model" + placeholder="E.g. gpt-4" + /> + )} diff --git a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx index 70a3ce7ff99..2857920520e 100644 --- a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx @@ -68,6 +68,7 @@ export function LLMProviderUpdateForm({ existingLlmProvider?.display_model_names || defaultModelsByProvider[llmProviderDescriptor.name] || [], + deployment_name: existingLlmProvider?.deployment_name, }; // Setup validation schema if required @@ -99,6 +100,9 @@ export function LLMProviderUpdateForm({ ), } : {}), + deployment_name: llmProviderDescriptor.deployment_name_required + ? Yup.string().required("Deployment Name is required") + : Yup.string(), default_model_name: Yup.string().required("Model name is required"), fast_default_model_name: Yup.string().nullable(), // EE Only @@ -289,38 +293,50 @@ export function LLMProviderUpdateForm({ /> )} - {llmProviderDescriptor.llm_names.length > 0 ? ( - ({ - name: getDisplayNameForModel(name), - value: name, - }))} - includeDefault - maxHeight="max-h-56" - /> - ) : ( + {llmProviderDescriptor.deployment_name_required && ( )} - + {!llmProviderDescriptor.single_model_supported && + (llmProviderDescriptor.llm_names.length > 0 ? ( + ({ + name: getDisplayNameForModel(name), + value: name, + }))} + includeDefault + maxHeight="max-h-56" + /> + ) : ( + + ))} {llmProviderDescriptor.name != "azure" && ( - + <> + + + + )} {showAdvancedOptions && ( diff --git a/web/src/app/admin/configuration/llm/interfaces.ts b/web/src/app/admin/configuration/llm/interfaces.ts index 33fa94d7f15..61f81311ecf 100644 --- a/web/src/app/admin/configuration/llm/interfaces.ts +++ b/web/src/app/admin/configuration/llm/interfaces.ts @@ -19,11 +19,13 @@ export interface WellKnownLLMProviderDescriptor { name: string; display_name: string; + deployment_name_required: boolean; api_key_required: boolean; api_base_required: boolean; api_version_required: boolean; - custom_config_keys: CustomConfigKey[] | null; + single_model_supported: boolean; + custom_config_keys: CustomConfigKey[] | null; llm_names: string[]; default_model: string | null; default_fast_model: string | null; @@ -43,6 +45,7 @@ export interface LLMProvider { is_public: boolean; groups: number[]; display_model_names: string[] | null; + deployment_name: string | null; } export interface FullLLMProvider extends LLMProvider { diff --git a/web/src/components/initialSetup/welcome/WelcomeModal.tsx b/web/src/components/initialSetup/welcome/WelcomeModal.tsx index 1c94ae22961..e2689d2f7ab 100644 --- a/web/src/components/initialSetup/welcome/WelcomeModal.tsx +++ b/web/src/components/initialSetup/welcome/WelcomeModal.tsx @@ -92,4 +92,4 @@ export function _WelcomeModal({ user }: { user: User | null }) { ); -} +} \ No newline at end of file diff --git a/web/src/lib/chat/fetchChatData.ts b/web/src/lib/chat/fetchChatData.ts index d17c3da01b8..12181dc31cd 100644 --- a/web/src/lib/chat/fetchChatData.ts +++ b/web/src/lib/chat/fetchChatData.ts @@ -28,6 +28,7 @@ import { import { hasCompletedWelcomeFlowSS } from "@/components/initialSetup/welcome/WelcomeModalWrapper"; import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS"; import { NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN } from "../constants"; +import { checkLLMSupportsImageInput } from "../llm/utils"; interface FetchChatDataResult { user: User | null; @@ -195,10 +196,12 @@ export async function fetchChatData(searchParams: { assistants = assistants.filter((assistant) => assistant.num_chunks === 0); } - const hasOpenAIProvider = llmProviders.some( - (provider) => provider.provider === "openai" + const hasImageCompatibleModel = llmProviders.some( + (provider) => + provider.provider === "openai" || + provider.model_names.some((model) => checkLLMSupportsImageInput(model)) ); - if (!hasOpenAIProvider) { + if (!hasImageCompatibleModel) { assistants = assistants.filter( (assistant) => !assistant.tools.some( From 1c939f10b5333b527442f542f6c1c0b9f2791a24 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Tue, 8 Oct 2024 15:15:57 -0700 Subject: [PATCH 2/5] nit --- ...4334d5b33ba_add_deployment_name_to_llmprovider.py | 4 ++-- backend/danswer/configs/app_configs.py | 2 +- backend/danswer/llm/llm_provider_options.py | 5 +++-- backend/danswer/server/features/persona/api.py | 7 ++++++- deployment/docker_compose/docker-compose.dev.yml | 2 +- deployment/docker_compose/docker-compose.gpu-dev.yml | 2 +- .../docker_compose/docker-compose.search-testing.yml | 2 +- .../configuration/llm/LLMProviderUpdateForm.tsx | 2 +- web/src/lib/chat/fetchChatData.ts | 1 + web/src/lib/chat/fetchSomeChatData.ts | 12 ++++++++---- 10 files changed, 25 insertions(+), 14 deletions(-) diff --git a/backend/alembic/versions/e4334d5b33ba_add_deployment_name_to_llmprovider.py b/backend/alembic/versions/e4334d5b33ba_add_deployment_name_to_llmprovider.py index 19b14427e34..e837b87e3e0 100644 --- a/backend/alembic/versions/e4334d5b33ba_add_deployment_name_to_llmprovider.py +++ b/backend/alembic/versions/e4334d5b33ba_add_deployment_name_to_llmprovider.py @@ -1,7 +1,7 @@ """add_deployment_name_to_llmprovider Revision ID: e4334d5b33ba -Revises: 46b7a812670f +Revises: ac5eaac849f9 Create Date: 2024-10-04 09:52:34.896867 """ @@ -11,7 +11,7 @@ # revision identifiers, used by Alembic. revision = "e4334d5b33ba" -down_revision = "46b7a812670f" +down_revision = "ac5eaac849f9" branch_labels = None depends_on = None diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index fc0f05435b0..f3e398c0eb7 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -135,7 +135,7 @@ os.environ.get("POSTGRES_PASSWORD") or "password" ) POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost" -POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433" +POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432" POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres" POSTGRES_API_SERVER_POOL_SIZE = int( diff --git a/backend/danswer/llm/llm_provider_options.py b/backend/danswer/llm/llm_provider_options.py index 9fb55365f6c..3cb6157d6da 100644 --- a/backend/danswer/llm/llm_provider_options.py +++ b/backend/danswer/llm/llm_provider_options.py @@ -16,12 +16,13 @@ class WellKnownLLMProviderDescriptor(BaseModel): api_base_required: bool api_version_required: bool custom_config_keys: list[CustomConfigKey] | None = None - single_model_supported: bool = False - llm_names: list[str] default_model: str | None = None default_fast_model: str | None = None + # set for providers like Azure, which require a deployment name. deployment_name_required: bool = False + # set for providers like Azure, which support a single model per deployment. + single_model_supported: bool = False OPENAI_PROVIDER_NAME = "openai" diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index bcc4800b860..8b4305755dc 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -31,9 +31,9 @@ from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.features.persona.models import PromptTemplateResponse from danswer.server.models import DisplayPriorityRequest +from danswer.tools.utils import is_image_generation_available from danswer.utils.logger import setup_logger - logger = setup_logger() @@ -226,6 +226,11 @@ def list_personas( get_editable=False, joinedload_all=True, ) + # If the persona has an image generation tool and it's not available, don't include it + if not ( + any(tool.in_code_tool_id == "ImageGenerationTool" for tool in persona.tools) + and not is_image_generation_available(db_session=db_session) + ) ] diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 52dda002b2b..86d988e7d90 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -298,7 +298,7 @@ services: - POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} ports: - - "5433:5432" + - "5432:5432" volumes: - db_volume:/var/lib/postgresql/data diff --git a/deployment/docker_compose/docker-compose.gpu-dev.yml b/deployment/docker_compose/docker-compose.gpu-dev.yml index 55038dd9eb2..ebce01eadb2 100644 --- a/deployment/docker_compose/docker-compose.gpu-dev.yml +++ b/deployment/docker_compose/docker-compose.gpu-dev.yml @@ -308,7 +308,7 @@ services: - POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} ports: - - "5433:5432" + - "5432:5432" volumes: - db_volume:/var/lib/postgresql/data diff --git a/deployment/docker_compose/docker-compose.search-testing.yml b/deployment/docker_compose/docker-compose.search-testing.yml index 2afd54e029c..fab950c064e 100644 --- a/deployment/docker_compose/docker-compose.search-testing.yml +++ b/deployment/docker_compose/docker-compose.search-testing.yml @@ -157,7 +157,7 @@ services: - POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} ports: - - "5433" + - "5432" volumes: - db_volume:/var/lib/postgresql/data diff --git a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx index 2857920520e..b072083662f 100644 --- a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx @@ -298,7 +298,7 @@ export function LLMProviderUpdateForm({ small={hideAdvanced} name="deployment_name" label="Deployment Name" - placeholder="E.g. gpt-4-mycompanyname-1" + placeholder="Deployment Name" /> )} diff --git a/web/src/lib/chat/fetchChatData.ts b/web/src/lib/chat/fetchChatData.ts index 12181dc31cd..1416f787cc7 100644 --- a/web/src/lib/chat/fetchChatData.ts +++ b/web/src/lib/chat/fetchChatData.ts @@ -201,6 +201,7 @@ export async function fetchChatData(searchParams: { provider.provider === "openai" || provider.model_names.some((model) => checkLLMSupportsImageInput(model)) ); + if (!hasImageCompatibleModel) { assistants = assistants.filter( (assistant) => diff --git a/web/src/lib/chat/fetchSomeChatData.ts b/web/src/lib/chat/fetchSomeChatData.ts index 827cf0c21fc..fdfa55b5ea1 100644 --- a/web/src/lib/chat/fetchSomeChatData.ts +++ b/web/src/lib/chat/fetchSomeChatData.ts @@ -28,6 +28,7 @@ import { import { hasCompletedWelcomeFlowSS } from "@/components/initialSetup/welcome/WelcomeModalWrapper"; import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS"; import { NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN } from "../constants"; +import { checkLLMSupportsImageInput } from "../llm/utils"; interface FetchChatDataResult { user?: User | null; @@ -178,10 +179,13 @@ export async function fetchSomeChatData( ); } - const hasOpenAIProvider = - result.llmProviders && - result.llmProviders.some((provider) => provider.provider === "openai"); - if (!hasOpenAIProvider) { + const hasImageCompatibleModel = result.llmProviders?.some( + (provider) => + provider.provider === "openai" || + provider.model_names.some((model) => checkLLMSupportsImageInput(model)) + ); + + if (!hasImageCompatibleModel) { result.assistants = result.assistants.filter( (assistant) => !assistant.tools.some( From 6e66d7d7f274dda378c09274db8062219c026dc9 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Tue, 8 Oct 2024 15:17:42 -0700 Subject: [PATCH 3/5] nit --- backend/danswer/chat/process_message.py | 4 ++-- backend/danswer/configs/app_configs.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index ad7a38bf796..19787545e4a 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -21,7 +21,7 @@ from danswer.configs.app_configs import AZURE_DALLE_API_BASE from danswer.configs.app_configs import AZURE_DALLE_API_KEY from danswer.configs.app_configs import AZURE_DALLE_API_VERSION -from danswer.configs.app_configs import AZURE_DEPLOYMENT_NAME +from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME from danswer.configs.chat_configs import BING_API_KEY from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH @@ -578,7 +578,7 @@ def stream_chat_message_objects( ): img_generation_llm_config = LLMConfig( model_provider="azure", - model_name=f"azure/{AZURE_DEPLOYMENT_NAME}", + model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}", temperature=GEN_AI_TEMPERATURE, api_key=AZURE_DALLE_API_KEY, api_base=AZURE_DALLE_API_BASE, diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index f3e398c0eb7..4c4ee44535f 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -417,6 +417,8 @@ AZURE_DALLE_API_VERSION = os.environ.get("AZURE_DALLE_API_VERSION") AZURE_DALLE_API_KEY = os.environ.get("AZURE_DALLE_API_KEY") AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE") +AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME") + AZURE_DEPLOYMENT_NAME = os.environ.get("AZURE_DEPLOYMENT_NAME") From cfd19dd7caf68288aeb76f799b47e00b8c914ed5 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Tue, 8 Oct 2024 15:18:12 -0700 Subject: [PATCH 4/5] nit --- backend/danswer/configs/app_configs.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 4c4ee44535f..1174c8d060f 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -419,8 +419,6 @@ AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE") AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME") -AZURE_DEPLOYMENT_NAME = os.environ.get("AZURE_DEPLOYMENT_NAME") - MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true" SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "") From 52f83081938e688756b4de24c6b0f3a7e9280da9 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Tue, 8 Oct 2024 15:26:28 -0700 Subject: [PATCH 5/5] nit pretty --- web/src/components/initialSetup/welcome/WelcomeModal.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/initialSetup/welcome/WelcomeModal.tsx b/web/src/components/initialSetup/welcome/WelcomeModal.tsx index e2689d2f7ab..1c94ae22961 100644 --- a/web/src/components/initialSetup/welcome/WelcomeModal.tsx +++ b/web/src/components/initialSetup/welcome/WelcomeModal.tsx @@ -92,4 +92,4 @@ export function _WelcomeModal({ user }: { user: User | null }) { ); -} \ No newline at end of file +}