diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 0ddfcc38ccf..100a893bf15 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -58,7 +58,7 @@ def set_api_key(): chat_passed, embd_passed, rerank_passed = False, False, False factory = req["llm_factory"] msg = "" - for llm in LLMService.query(fid=factory)[:3]: + for llm in LLMService.query(fid=factory): if not embd_passed and llm.model_type == LLMType.EMBEDDING.value: mdl = EmbeddingModel[factory]( req["api_key"], llm.llm_name, base_url=req.get("base_url")) @@ -77,10 +77,10 @@ def set_api_key(): {"temperature": 0.9,'max_tokens':50}) if m.find("**ERROR**") >=0: raise Exception(m) + chat_passed = True except Exception as e: msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( e) - chat_passed = True elif not rerank_passed and llm.model_type == LLMType.RERANK: mdl = RerankModel[factory]( req["api_key"], llm.llm_name, base_url=req.get("base_url")) @@ -88,10 +88,14 @@ def set_api_key(): arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) if len(arr) == 0 or tc == 0: raise Exception("Fail") + rerank_passed = True + print(f'passed model rerank{llm.llm_name}',flush=True) except Exception as e: msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( e) - rerank_passed = True + if any([embd_passed, chat_passed, rerank_passed]): + msg = '' + break if msg: return get_data_error_result(retmsg=msg) @@ -183,6 +187,10 @@ def apikey_json(keys): llm_name = req["llm_name"] api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"]) + elif factory == "Azure-OpenAI": + llm_name = req["llm_name"] + api_key = apikey_json(["api_key", "api_version"]) + else: llm_name = req["llm_name"] api_key = req.get("api_key", "xxxxxxxxxxxxxxx") diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 4daa014b02f..c7c52ddd460 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -619,13 +619,13 @@ "model_type": "chat,image2text" }, { - "llm_name": "gpt-35-turbo", + "llm_name": "gpt-3.5-turbo", "tags": "LLM,CHAT,4K", "max_tokens": 4096, "model_type": "chat" }, { - "llm_name": "gpt-35-turbo-16k", + "llm_name": "gpt-3.5-turbo-16k", "tags": "LLM,CHAT,16k", "max_tokens": 16385, "model_type": "chat" diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index fb9c6e22443..d18fc02e97f 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -114,7 +114,9 @@ def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepse class AzureChat(Base): def __init__(self, key, model_name, **kwargs): - self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") + api_key = json.loads(key).get('api_key', '') + api_version = json.loads(key).get('api_version', '2024-02-01') + self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.model_name = model_name diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index aeaeefffad7..97e02911ffd 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -160,7 +160,9 @@ def describe(self, image, max_tokens=300): class AzureGptV4(Base): def __init__(self, key, model_name, lang="Chinese", **kwargs): - self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") + api_key = json.loads(key).get('api_key', '') + api_version = json.loads(key).get('api_version', '2024-02-01') + self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.model_name = model_name self.lang = lang diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index ba73cdfba89..c7af5c5069b 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -137,7 +137,9 @@ def encode_queries(self, text): class AzureEmbed(OpenAIEmbed): def __init__(self, key, model_name, **kwargs): from openai.lib.azure import AzureOpenAI - self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") + api_key = json.loads(key).get('api_key', '') + api_version = json.loads(key).get('api_version', '2024-02-01') + self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.model_name = model_name diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 8f292a95377..48e40fb79fb 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -581,6 +581,8 @@ The above is the content you need to summarize.`, GoogleRegionMessage: 'Please input Google Cloud Region', modelProvidersWarn: 'Please add both embedding model and LLM in Settings > Model providers firstly.', + apiVersion: 'API-Version', + apiVersionMessage: 'Please input API version', }, message: { registered: 'Registered!', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index c2126a04226..f72d8512d41 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -557,6 +557,8 @@ export default { GoogleRegionMessage: '请输入 Google Cloud 区域', modelProvidersWarn: '请首先在 设置 > 模型提供商 中添加嵌入模型和 LLM。', + apiVersion: 'API版本', + apiVersionMessage: '请输入API版本!', }, message: { registered: '注册成功', diff --git a/web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx b/web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx new file mode 100644 index 00000000000..f9fab8ab2c4 --- /dev/null +++ b/web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx @@ -0,0 +1,128 @@ +import { useTranslate } from '@/hooks/common-hooks'; +import { IModalProps } from '@/interfaces/common'; +import { IAddLlmRequestBody } from '@/interfaces/request/llm'; +import { Form, Input, Modal, Select, Switch } from 'antd'; +import omit from 'lodash/omit'; + +type FieldType = IAddLlmRequestBody & { + api_version: string; + vision: boolean; +}; + +const { Option } = Select; + +const AzureOpenAIModal = ({ + visible, + hideModal, + onOk, + loading, + llmFactory, +}: IModalProps & { llmFactory: string }) => { + const [form] = Form.useForm(); + + const { t } = useTranslate('setting'); + + const handleOk = async () => { + const values = await form.validateFields(); + const modelType = + values.model_type === 'chat' && values.vision + ? 'image2text' + : values.model_type; + + const data = { + ...omit(values, ['vision']), + model_type: modelType, + llm_factory: llmFactory, + }; + console.info(data); + + onOk?.(data); + }; + const optionsMap = { + Default: [ + { value: 'chat', label: 'chat' }, + { value: 'embedding', label: 'embedding' }, + { value: 'image2text', label: 'image2text' }, + ], + }; + const getOptions = (factory: string) => { + return optionsMap.Default; + }; + return ( + +
+ + label={t('modelType')} + name="model_type" + initialValue={'embedding'} + rules={[{ required: true, message: t('modelTypeMessage') }]} + > + + + + label={t('addLlmBaseUrl')} + name="api_base" + rules={[{ required: true, message: t('baseUrlNameMessage') }]} + > + + + + label={t('apiKey')} + name="api_key" + rules={[{ required: false, message: t('apiKeyMessage') }]} + > + + + + label={t('modelName')} + name="llm_name" + initialValue="gpt-3.5-turbo" + rules={[{ required: true, message: t('modelNameMessage') }]} + > + + + + label={t('apiVersion')} + name="api_version" + initialValue="2024-02-01" + rules={[{ required: false, message: t('apiVersionMessage') }]} + > + + + + {({ getFieldValue }) => + getFieldValue('model_type') === 'chat' && ( + + + + ) + } + + +
+ ); +}; + +export default AzureOpenAIModal; diff --git a/web/src/pages/user-setting/setting-model/hooks.ts b/web/src/pages/user-setting/setting-model/hooks.ts index a53159f15e6..29cc76f9176 100644 --- a/web/src/pages/user-setting/setting-model/hooks.ts +++ b/web/src/pages/user-setting/setting-model/hooks.ts @@ -353,6 +353,33 @@ export const useSubmitBedrock = () => { }; }; +export const useSubmitAzure = () => { + const { addLlm, loading } = useAddLlm(); + const { + visible: AzureAddingVisible, + hideModal: hideAzureAddingModal, + showModal: showAzureAddingModal, + } = useSetModalState(); + + const onAzureAddingOk = useCallback( + async (payload: IAddLlmRequestBody) => { + const ret = await addLlm(payload); + if (ret === 0) { + hideAzureAddingModal(); + } + }, + [hideAzureAddingModal, addLlm], + ); + + return { + AzureAddingLoading: loading, + onAzureAddingOk, + AzureAddingVisible, + hideAzureAddingModal, + showAzureAddingModal, + }; +}; + export const useHandleDeleteLlm = (llmFactory: string) => { const { deleteLlm } = useDeleteLlm(); const showDeleteConfirm = useShowDeleteConfirm(); diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index 14287d78561..9199c7aa989 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -29,6 +29,7 @@ import SettingTitle from '../components/setting-title'; import { isLocalLlmFactory } from '../utils'; import TencentCloudModal from './Tencent-modal'; import ApiKeyModal from './api-key-modal'; +import AzureOpenAIModal from './azure-openai-modal'; import BedrockModal from './bedrock-modal'; import { IconMap } from './constant'; import FishAudioModal from './fish-audio-modal'; @@ -37,6 +38,7 @@ import { useHandleDeleteFactory, useHandleDeleteLlm, useSubmitApiKey, + useSubmitAzure, useSubmitBedrock, useSubmitFishAudio, useSubmitGoogle, @@ -109,7 +111,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => { item.name === 'BaiduYiyan' || item.name === 'Fish Audio' || item.name === 'Tencent Cloud' || - item.name === 'Google Cloud' + item.name === 'Google Cloud' || + item.name === 'Azure OpenAI' ? t('addTheModel') : 'API-Key'} @@ -242,6 +245,14 @@ const UserSettingModel = () => { showBedrockAddingModal, } = useSubmitBedrock(); + const { + AzureAddingVisible, + hideAzureAddingModal, + showAzureAddingModal, + onAzureAddingOk, + AzureAddingLoading, + } = useSubmitAzure(); + const ModalMap = useMemo( () => ({ Bedrock: showBedrockAddingModal, @@ -252,6 +263,7 @@ const UserSettingModel = () => { 'Fish Audio': showFishAudioAddingModal, 'Tencent Cloud': showTencentCloudAddingModal, 'Google Cloud': showGoogleAddingModal, + 'Azure-OpenAI': showAzureAddingModal, }), [ showBedrockAddingModal, @@ -262,6 +274,7 @@ const UserSettingModel = () => { showyiyanAddingModal, showFishAudioAddingModal, showGoogleAddingModal, + showAzureAddingModal, ], ); @@ -435,6 +448,13 @@ const UserSettingModel = () => { loading={bedrockAddingLoading} llmFactory={'Bedrock'} > + ); }; diff --git a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx index c880ec254bc..c372b2f3fa6 100644 --- a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx @@ -101,7 +101,7 @@ const OllamaModal = ({ label={t('modelType')} name="model_type" - initialValue={'chat'} + initialValue={'embedding'} rules={[{ required: true, message: t('modelTypeMessage') }]} >