diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py
index 0ddfcc38ccf..100a893bf15 100644
--- a/api/apps/llm_app.py
+++ b/api/apps/llm_app.py
@@ -58,7 +58,7 @@ def set_api_key():
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
msg = ""
- for llm in LLMService.query(fid=factory)[:3]:
+ for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
@@ -77,10 +77,10 @@ def set_api_key():
{"temperature": 0.9,'max_tokens':50})
if m.find("**ERROR**") >=0:
raise Exception(m)
+ chat_passed = True
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
- chat_passed = True
elif not rerank_passed and llm.model_type == LLMType.RERANK:
mdl = RerankModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
@@ -88,10 +88,14 @@ def set_api_key():
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0:
raise Exception("Fail")
+ rerank_passed = True
+ print(f'passed model rerank{llm.llm_name}',flush=True)
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
- rerank_passed = True
+ if any([embd_passed, chat_passed, rerank_passed]):
+ msg = ''
+ break
if msg:
return get_data_error_result(retmsg=msg)
@@ -183,6 +187,10 @@ def apikey_json(keys):
llm_name = req["llm_name"]
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
+ elif factory == "Azure-OpenAI":
+ llm_name = req["llm_name"]
+ api_key = apikey_json(["api_key", "api_version"])
+
else:
llm_name = req["llm_name"]
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
diff --git a/conf/llm_factories.json b/conf/llm_factories.json
index 4daa014b02f..c7c52ddd460 100644
--- a/conf/llm_factories.json
+++ b/conf/llm_factories.json
@@ -619,13 +619,13 @@
"model_type": "chat,image2text"
},
{
- "llm_name": "gpt-35-turbo",
+ "llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K",
"max_tokens": 4096,
"model_type": "chat"
},
{
- "llm_name": "gpt-35-turbo-16k",
+ "llm_name": "gpt-3.5-turbo-16k",
"tags": "LLM,CHAT,16k",
"max_tokens": 16385,
"model_type": "chat"
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index fb9c6e22443..d18fc02e97f 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -114,7 +114,9 @@ def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepse
class AzureChat(Base):
def __init__(self, key, model_name, **kwargs):
- self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
+ api_key = json.loads(key).get('api_key', '')
+ api_version = json.loads(key).get('api_version', '2024-02-01')
+ self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py
index aeaeefffad7..97e02911ffd 100644
--- a/rag/llm/cv_model.py
+++ b/rag/llm/cv_model.py
@@ -160,7 +160,9 @@ def describe(self, image, max_tokens=300):
class AzureGptV4(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
- self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
+ api_key = json.loads(key).get('api_key', '')
+ api_version = json.loads(key).get('api_version', '2024-02-01')
+ self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
self.lang = lang
diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py
index ba73cdfba89..c7af5c5069b 100644
--- a/rag/llm/embedding_model.py
+++ b/rag/llm/embedding_model.py
@@ -137,7 +137,9 @@ def encode_queries(self, text):
class AzureEmbed(OpenAIEmbed):
def __init__(self, key, model_name, **kwargs):
from openai.lib.azure import AzureOpenAI
- self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
+ api_key = json.loads(key).get('api_key', '')
+ api_version = json.loads(key).get('api_version', '2024-02-01')
+ self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts
index 8f292a95377..48e40fb79fb 100644
--- a/web/src/locales/en.ts
+++ b/web/src/locales/en.ts
@@ -581,6 +581,8 @@ The above is the content you need to summarize.`,
GoogleRegionMessage: 'Please input Google Cloud Region',
modelProvidersWarn:
'Please add both embedding model and LLM in Settings > Model providers firstly.',
+ apiVersion: 'API-Version',
+ apiVersionMessage: 'Please input API version',
},
message: {
registered: 'Registered!',
diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts
index c2126a04226..f72d8512d41 100644
--- a/web/src/locales/zh.ts
+++ b/web/src/locales/zh.ts
@@ -557,6 +557,8 @@ export default {
GoogleRegionMessage: '请输入 Google Cloud 区域',
modelProvidersWarn:
'请首先在 设置 > 模型提供商 中添加嵌入模型和 LLM。',
+ apiVersion: 'API版本',
+ apiVersionMessage: '请输入API版本!',
},
message: {
registered: '注册成功',
diff --git a/web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx b/web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx
new file mode 100644
index 00000000000..f9fab8ab2c4
--- /dev/null
+++ b/web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx
@@ -0,0 +1,128 @@
+import { useTranslate } from '@/hooks/common-hooks';
+import { IModalProps } from '@/interfaces/common';
+import { IAddLlmRequestBody } from '@/interfaces/request/llm';
+import { Form, Input, Modal, Select, Switch } from 'antd';
+import omit from 'lodash/omit';
+
+type FieldType = IAddLlmRequestBody & {
+ api_version: string;
+ vision: boolean;
+};
+
+const { Option } = Select;
+
+const AzureOpenAIModal = ({
+ visible,
+ hideModal,
+ onOk,
+ loading,
+ llmFactory,
+}: IModalProps & { llmFactory: string }) => {
+ const [form] = Form.useForm();
+
+ const { t } = useTranslate('setting');
+
+ const handleOk = async () => {
+ const values = await form.validateFields();
+ const modelType =
+ values.model_type === 'chat' && values.vision
+ ? 'image2text'
+ : values.model_type;
+
+ const data = {
+ ...omit(values, ['vision']),
+ model_type: modelType,
+ llm_factory: llmFactory,
+ };
+ console.info(data);
+
+ onOk?.(data);
+ };
+ const optionsMap = {
+ Default: [
+ { value: 'chat', label: 'chat' },
+ { value: 'embedding', label: 'embedding' },
+ { value: 'image2text', label: 'image2text' },
+ ],
+ };
+ const getOptions = (factory: string) => {
+ return optionsMap.Default;
+ };
+ return (
+
+
+ label={t('modelType')}
+ name="model_type"
+ initialValue={'embedding'}
+ rules={[{ required: true, message: t('modelTypeMessage') }]}
+ >
+
+
+
+ label={t('addLlmBaseUrl')}
+ name="api_base"
+ rules={[{ required: true, message: t('baseUrlNameMessage') }]}
+ >
+
+
+
+ label={t('apiKey')}
+ name="api_key"
+ rules={[{ required: false, message: t('apiKeyMessage') }]}
+ >
+
+
+
+ label={t('modelName')}
+ name="llm_name"
+ initialValue="gpt-3.5-turbo"
+ rules={[{ required: true, message: t('modelNameMessage') }]}
+ >
+
+
+
+ label={t('apiVersion')}
+ name="api_version"
+ initialValue="2024-02-01"
+ rules={[{ required: false, message: t('apiVersionMessage') }]}
+ >
+
+
+
+ {({ getFieldValue }) =>
+ getFieldValue('model_type') === 'chat' && (
+
+
+
+ )
+ }
+
+
+
+ );
+};
+
+export default AzureOpenAIModal;
diff --git a/web/src/pages/user-setting/setting-model/hooks.ts b/web/src/pages/user-setting/setting-model/hooks.ts
index a53159f15e6..29cc76f9176 100644
--- a/web/src/pages/user-setting/setting-model/hooks.ts
+++ b/web/src/pages/user-setting/setting-model/hooks.ts
@@ -353,6 +353,33 @@ export const useSubmitBedrock = () => {
};
};
+export const useSubmitAzure = () => {
+ const { addLlm, loading } = useAddLlm();
+ const {
+ visible: AzureAddingVisible,
+ hideModal: hideAzureAddingModal,
+ showModal: showAzureAddingModal,
+ } = useSetModalState();
+
+ const onAzureAddingOk = useCallback(
+ async (payload: IAddLlmRequestBody) => {
+ const ret = await addLlm(payload);
+ if (ret === 0) {
+ hideAzureAddingModal();
+ }
+ },
+ [hideAzureAddingModal, addLlm],
+ );
+
+ return {
+ AzureAddingLoading: loading,
+ onAzureAddingOk,
+ AzureAddingVisible,
+ hideAzureAddingModal,
+ showAzureAddingModal,
+ };
+};
+
export const useHandleDeleteLlm = (llmFactory: string) => {
const { deleteLlm } = useDeleteLlm();
const showDeleteConfirm = useShowDeleteConfirm();
diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx
index 14287d78561..9199c7aa989 100644
--- a/web/src/pages/user-setting/setting-model/index.tsx
+++ b/web/src/pages/user-setting/setting-model/index.tsx
@@ -29,6 +29,7 @@ import SettingTitle from '../components/setting-title';
import { isLocalLlmFactory } from '../utils';
import TencentCloudModal from './Tencent-modal';
import ApiKeyModal from './api-key-modal';
+import AzureOpenAIModal from './azure-openai-modal';
import BedrockModal from './bedrock-modal';
import { IconMap } from './constant';
import FishAudioModal from './fish-audio-modal';
@@ -37,6 +38,7 @@ import {
useHandleDeleteFactory,
useHandleDeleteLlm,
useSubmitApiKey,
+ useSubmitAzure,
useSubmitBedrock,
useSubmitFishAudio,
useSubmitGoogle,
@@ -109,7 +111,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => {
item.name === 'BaiduYiyan' ||
item.name === 'Fish Audio' ||
item.name === 'Tencent Cloud' ||
- item.name === 'Google Cloud'
+ item.name === 'Google Cloud' ||
+ item.name === 'Azure OpenAI'
? t('addTheModel')
: 'API-Key'}
@@ -242,6 +245,14 @@ const UserSettingModel = () => {
showBedrockAddingModal,
} = useSubmitBedrock();
+ const {
+ AzureAddingVisible,
+ hideAzureAddingModal,
+ showAzureAddingModal,
+ onAzureAddingOk,
+ AzureAddingLoading,
+ } = useSubmitAzure();
+
const ModalMap = useMemo(
() => ({
Bedrock: showBedrockAddingModal,
@@ -252,6 +263,7 @@ const UserSettingModel = () => {
'Fish Audio': showFishAudioAddingModal,
'Tencent Cloud': showTencentCloudAddingModal,
'Google Cloud': showGoogleAddingModal,
+ 'Azure-OpenAI': showAzureAddingModal,
}),
[
showBedrockAddingModal,
@@ -262,6 +274,7 @@ const UserSettingModel = () => {
showyiyanAddingModal,
showFishAudioAddingModal,
showGoogleAddingModal,
+ showAzureAddingModal,
],
);
@@ -435,6 +448,13 @@ const UserSettingModel = () => {
loading={bedrockAddingLoading}
llmFactory={'Bedrock'}
>
+
);
};
diff --git a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx
index c880ec254bc..c372b2f3fa6 100644
--- a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx
+++ b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx
@@ -101,7 +101,7 @@ const OllamaModal = ({
label={t('modelType')}
name="model_type"
- initialValue={'chat'}
+ initialValue={'embedding'}
rules={[{ required: true, message: t('modelTypeMessage') }]}
>