Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support api-version and change default-model in adding azure-openai and openai #2799

Merged
merged 2 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def set_api_key():
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
msg = ""
for llm in LLMService.query(fid=factory)[:3]:
for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
Expand All @@ -77,21 +77,25 @@ def set_api_key():
{"temperature": 0.9,'max_tokens':50})
if m.find("**ERROR**") >=0:
raise Exception(m)
chat_passed = True
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
chat_passed = True
elif not rerank_passed and llm.model_type == LLMType.RERANK:
mdl = RerankModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try:
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr) == 0 or tc == 0:
raise Exception("Fail")
rerank_passed = True
print(f'passed model rerank{llm.llm_name}',flush=True)
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
rerank_passed = True
if any([embd_passed, chat_passed, rerank_passed]):
msg = ''
break

if msg:
return get_data_error_result(retmsg=msg)
Expand Down Expand Up @@ -183,6 +187,9 @@ def apikey_json(keys):
llm_name = req["llm_name"]
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])

elif factory == "Azure-OpenAI":
llm_name = req["llm_name"]
api_key = apikey_json(["api_key", "api_version"])
KevinHuSh marked this conversation as resolved.
Show resolved Hide resolved
else:
llm_name = req["llm_name"]
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
Expand Down
4 changes: 2 additions & 2 deletions conf/llm_factories.json
Original file line number Diff line number Diff line change
Expand Up @@ -619,13 +619,13 @@
"model_type": "chat,image2text"
},
{
"llm_name": "gpt-35-turbo",
"llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K",
"max_tokens": 4096,
"model_type": "chat"
},
{
"llm_name": "gpt-35-turbo-16k",
"llm_name": "gpt-3.5-turbo-16k",
"tags": "LLM,CHAT,16k",
"max_tokens": 16385,
"model_type": "chat"
Expand Down
4 changes: 3 additions & 1 deletion rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepse

class AzureChat(Base):
def __init__(self, key, model_name, **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01')
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name


Expand Down
4 changes: 3 additions & 1 deletion rag/llm/cv_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def describe(self, image, max_tokens=300):

class AzureGptV4(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01')
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
self.lang = lang

Expand Down
4 changes: 3 additions & 1 deletion rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def encode_queries(self, text):
class AzureEmbed(OpenAIEmbed):
def __init__(self, key, model_name, **kwargs):
from openai.lib.azure import AzureOpenAI
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01')
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name


Expand Down
2 changes: 2 additions & 0 deletions web/src/locales/en.ts
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,8 @@ The above is the content you need to summarize.`,
GoogleRegionMessage: 'Please input Google Cloud Region',
modelProvidersWarn:
'Please add both embedding model and LLM in <b>Settings > Model providers</b> firstly.',
apiVersion: 'API-Version',
apiVersionMessage: 'Please input API version',
},
message: {
registered: 'Registered!',
Expand Down
2 changes: 2 additions & 0 deletions web/src/locales/zh.ts
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ export default {
GoogleRegionMessage: '请输入 Google Cloud 区域',
modelProvidersWarn:
'请首先在 <b>设置 > 模型提供商</b> 中添加嵌入模型和 LLM。',
apiVersion: 'API版本',
apiVersionMessage: '请输入API版本!',
},
message: {
registered: '注册成功',
Expand Down
128 changes: 128 additions & 0 deletions web/src/pages/user-setting/setting-model/azure-openai-modal/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Form, Input, Modal, Select, Switch } from 'antd';
import omit from 'lodash/omit';

type FieldType = IAddLlmRequestBody & {
api_version: string;
vision: boolean;
};

const { Option } = Select;

const AzureOpenAIModal = ({
visible,
hideModal,
onOk,
loading,
llmFactory,
}: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
const [form] = Form.useForm<FieldType>();

const { t } = useTranslate('setting');

const handleOk = async () => {
const values = await form.validateFields();
const modelType =
values.model_type === 'chat' && values.vision
? 'image2text'
: values.model_type;

const data = {
...omit(values, ['vision']),
model_type: modelType,
llm_factory: llmFactory,
};
console.info(data);

onOk?.(data);
};
const optionsMap = {
Default: [
{ value: 'chat', label: 'chat' },
{ value: 'embedding', label: 'embedding' },
{ value: 'image2text', label: 'image2text' },
],
};
const getOptions = (factory: string) => {
return optionsMap.Default;
};
return (
<Modal
title={t('addLlmTitle', { name: llmFactory })}
open={visible}
onOk={handleOk}
onCancel={hideModal}
okButtonProps={{ loading }}
>
<Form
name="basic"
style={{ maxWidth: 600 }}
autoComplete="off"
layout={'vertical'}
form={form}
>
<Form.Item<FieldType>
label={t('modelType')}
name="model_type"
initialValue={'embedding'}
rules={[{ required: true, message: t('modelTypeMessage') }]}
>
<Select placeholder={t('modelTypeMessage')}>
{getOptions(llmFactory).map((option) => (
<Option key={option.value} value={option.value}>
{option.label}
</Option>
))}
</Select>
</Form.Item>
<Form.Item<FieldType>
label={t('addLlmBaseUrl')}
name="api_base"
rules={[{ required: true, message: t('baseUrlNameMessage') }]}
>
<Input placeholder={t('baseUrlNameMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('apiKey')}
name="api_key"
rules={[{ required: false, message: t('apiKeyMessage') }]}
>
<Input placeholder={t('apiKeyMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('modelName')}
name="llm_name"
initialValue="gpt-3.5-turbo"
rules={[{ required: true, message: t('modelNameMessage') }]}
>
<Input placeholder={t('modelNameMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('apiVersion')}
name="api_version"
initialValue="2024-02-01"
rules={[{ required: false, message: t('apiVersionMessage') }]}
>
<Input placeholder={t('apiVersionMessage')} />
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'chat' && (
<Form.Item
label={t('vision')}
valuePropName="checked"
name={'vision'}
>
<Switch />
</Form.Item>
)
}
</Form.Item>
</Form>
</Modal>
);
};

export default AzureOpenAIModal;
27 changes: 27 additions & 0 deletions web/src/pages/user-setting/setting-model/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,33 @@ export const useSubmitBedrock = () => {
};
};

export const useSubmitAzure = () => {
const { addLlm, loading } = useAddLlm();
const {
visible: AzureAddingVisible,
hideModal: hideAzureAddingModal,
showModal: showAzureAddingModal,
} = useSetModalState();

const onAzureAddingOk = useCallback(
async (payload: IAddLlmRequestBody) => {
const ret = await addLlm(payload);
if (ret === 0) {
hideAzureAddingModal();
}
},
[hideAzureAddingModal, addLlm],
);

return {
AzureAddingLoading: loading,
onAzureAddingOk,
AzureAddingVisible,
hideAzureAddingModal,
showAzureAddingModal,
};
};

export const useHandleDeleteLlm = (llmFactory: string) => {
const { deleteLlm } = useDeleteLlm();
const showDeleteConfirm = useShowDeleteConfirm();
Expand Down
22 changes: 21 additions & 1 deletion web/src/pages/user-setting/setting-model/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import SettingTitle from '../components/setting-title';
import { isLocalLlmFactory } from '../utils';
import TencentCloudModal from './Tencent-modal';
import ApiKeyModal from './api-key-modal';
import AzureOpenAIModal from './azure-openai-modal';
import BedrockModal from './bedrock-modal';
import { IconMap } from './constant';
import FishAudioModal from './fish-audio-modal';
Expand All @@ -37,6 +38,7 @@ import {
useHandleDeleteFactory,
useHandleDeleteLlm,
useSubmitApiKey,
useSubmitAzure,
useSubmitBedrock,
useSubmitFishAudio,
useSubmitGoogle,
Expand Down Expand Up @@ -109,7 +111,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => {
item.name === 'BaiduYiyan' ||
item.name === 'Fish Audio' ||
item.name === 'Tencent Cloud' ||
item.name === 'Google Cloud'
item.name === 'Google Cloud' ||
item.name === 'Azure OpenAI'
? t('addTheModel')
: 'API-Key'}
<SettingOutlined />
Expand Down Expand Up @@ -242,6 +245,14 @@ const UserSettingModel = () => {
showBedrockAddingModal,
} = useSubmitBedrock();

const {
AzureAddingVisible,
hideAzureAddingModal,
showAzureAddingModal,
onAzureAddingOk,
AzureAddingLoading,
} = useSubmitAzure();

const ModalMap = useMemo(
() => ({
Bedrock: showBedrockAddingModal,
Expand All @@ -252,6 +263,7 @@ const UserSettingModel = () => {
'Fish Audio': showFishAudioAddingModal,
'Tencent Cloud': showTencentCloudAddingModal,
'Google Cloud': showGoogleAddingModal,
'Azure-OpenAI': showAzureAddingModal,
}),
[
showBedrockAddingModal,
Expand All @@ -262,6 +274,7 @@ const UserSettingModel = () => {
showyiyanAddingModal,
showFishAudioAddingModal,
showGoogleAddingModal,
showAzureAddingModal,
],
);

Expand Down Expand Up @@ -435,6 +448,13 @@ const UserSettingModel = () => {
loading={bedrockAddingLoading}
llmFactory={'Bedrock'}
></BedrockModal>
<AzureOpenAIModal
visible={AzureAddingVisible}
hideModal={hideAzureAddingModal}
onOk={onAzureAddingOk}
loading={AzureAddingLoading}
llmFactory={'Azure-OpenAI'}
></AzureOpenAIModal>
</section>
);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ const OllamaModal = ({
<Form.Item<FieldType>
label={t('modelType')}
name="model_type"
initialValue={'chat'}
initialValue={'embedding'}
rules={[{ required: true, message: t('modelTypeMessage') }]}
>
<Select placeholder={t('modelTypeMessage')}>
Expand Down