From 46b56fbfec622969770597d87367fb118283be4f Mon Sep 17 00:00:00 2001 From: almogbaku Date: Wed, 22 May 2024 13:28:22 +0300 Subject: [PATCH] fix: fix yaml settings not loaded closes #6 and #5 --- server/main.py | 11 +++++------ server/src/protocol.py | 21 +++++++++++++++++++-- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/server/main.py b/server/main.py index eb3439d..6c1fef8 100644 --- a/server/main.py +++ b/server/main.py @@ -32,22 +32,21 @@ async def lifespan(app: FastAPI): organization=settings.openai_organization ) - if ((settings.openai_base_url is None) - and settings.openai_api_key - and (settings.models.oai_urls is None or (oai_default_base_url not in settings.models.oai_urls))): + if settings.openai_base_url is None and settings.openai_api_key and settings.models.oai_urls is None: if settings.models.oai_urls is None: settings.models.oai_urls = [] settings.models.oai_urls.append(oai_default_base_url) for url in settings.models.oai_urls or []: + base_url = url.rsplit("/models")[0] if url.startswith( + settings.openai_base_url or oai_default_base_url) else None + vendor = "OpenAI" if "openai" in url else urlparse(url).hostname.rsplit(".", 1)[0].rsplit(".", 1)[-1] + cli = AsyncOpenAI( api_key=settings.openai_api_key, base_url=url, ) resp = await cli.models.list() - base_url = url.rsplit("/models")[0] if url.startswith( - settings.openai_base_url or "https://api.openai.com/v1") else None - vendor = "OpenAI" if "openai" in url else urlparse(url).hostname.rsplit(".", 1)[0].rsplit(".", 1)[-1] models += [ Model(name=model.id, system_prompt=True, type='chat', vendor=vendor, base_url=base_url) diff --git a/server/src/protocol.py b/server/src/protocol.py index 5fdc234..d1f279c 100644 --- a/server/src/protocol.py +++ b/server/src/protocol.py @@ -1,8 +1,13 @@ -from typing import Literal, Optional, List, Union +from typing import Literal, Optional, List, Union, Type, Tuple from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel, Field, SecretStr -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, + YamlConfigSettingsSource, +) class Model(BaseModel): @@ -33,6 +38,18 @@ class Settings(BaseSettings): dist_dir: Optional[str] = None + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + + return YamlConfigSettingsSource(settings_cls), env_settings, dotenv_settings, file_secret_settings + class ChatCompletionsRequest(BaseModel): model: str