Skip to content

Commit

Permalink
add support for XunFei Spark (infiniflow#2017)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#1853  add support for XunFei Spark

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Zhedong Cen <[email protected]>
  • Loading branch information
hangters and aopstudio authored Aug 20, 2024
1 parent ce04352 commit 7743922
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 6 deletions.
5 changes: 4 additions & 1 deletion api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def add_llm():
elif factory == "OpenAI-API-Compatible":
llm_name = req["llm_name"]+"___OpenAI-API"
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
elif factory =="XunFei Spark":
llm_name = req["llm_name"]
api_key = req.get("spark_api_password","")
else:
llm_name = req["llm_name"]
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
Expand Down Expand Up @@ -165,7 +168,7 @@ def add_llm():
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
elif llm["model_type"] == LLMType.CHAT.value:
mdl = ChatModel[factory](
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate","XunFei Spark"] else None,
model_name=llm["llm_name"],
base_url=llm["api_base"]
)
Expand Down
7 changes: 7 additions & 0 deletions conf/llm_factories.json
Original file line number Diff line number Diff line change
Expand Up @@ -3194,6 +3194,13 @@
"model_type": "image2text"
}
]
},
{
"name": "XunFei Spark",
"logo": "",
"tags": "LLM",
"status": "1",
"llm": []
}
]
}
3 changes: 2 additions & 1 deletion rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@
"SILICONFLOW": SILICONFLOWChat,
"01.AI": YiChat,
"Replicate": ReplicateChat,
"Tencent Hunyuan": HunyuanChat
"Tencent Hunyuan": HunyuanChat,
"XunFei Spark": SparkChat
}


Expand Down
21 changes: 19 additions & 2 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,12 +1133,12 @@ def chat_streamly(self, system, history, gen_conf):
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)

_gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
if system:
_history.insert(0, {"Role": "system", "Content": system})

if "temperature" in gen_conf:
_gen_conf["Temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf:
Expand Down Expand Up @@ -1168,3 +1168,20 @@ def chat_streamly(self, system, history, gen_conf):
yield ans + "\n**ERROR**: " + str(e)

yield total_tokens


class SparkChat(Base):
def __init__(
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
):
if not base_url:
base_url = "https://spark-api-open.xf-yun.com/v1"
model2version = {
"Spark-Max": "generalv3.5",
"Spark-Lite": "general",
"Spark-Pro": "generalv3",
"Spark-Pro-128K": "pro-128k",
"Spark-4.0-Ultra": "4.0Ultra",
}
model_version = model2version[model_name]
super().__init__(key, model_version, base_url)
1 change: 1 addition & 0 deletions web/src/assets/svg/llm/spark.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions web/src/locales/en.ts
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,9 @@ The above is the content you need to summarize.`,
HunyuanSIDMessage: 'Please input your Secret ID',
addHunyuanSK: 'Hunyuan Secret Key',
HunyuanSKMessage: 'Please input your Secret Key',
SparkModelNameMessage: 'Please select Spark model',
addSparkAPIPassword: 'Spark APIPassword',
SparkAPIPasswordMessage: 'please input your APIPassword',
},
message: {
registered: 'Registered!',
Expand Down
3 changes: 3 additions & 0 deletions web/src/locales/zh-traditional.ts
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,9 @@ export default {
HunyuanSIDMessage: '請輸入 Secret ID',
addHunyuanSK: '混元 Secret Key',
HunyuanSKMessage: '請輸入 Secret Key',
SparkModelNameMessage: '請選擇星火模型!',
addSparkAPIPassword: '星火 APIPassword',
SparkAPIPasswordMessage: '請輸入 APIPassword',
},
message: {
registered: '註冊成功',
Expand Down
3 changes: 3 additions & 0 deletions web/src/locales/zh.ts
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,9 @@ export default {
HunyuanSIDMessage: '请输入 Secret ID',
addHunyuanSK: '混元 Secret Key',
HunyuanSKMessage: '请输入 Secret Key',
SparkModelNameMessage: '请选择星火模型!',
addSparkAPIPassword: '星火 APIPassword',
SparkAPIPasswordMessage: '请输入 APIPassword',
},
message: {
registered: '注册成功',
Expand Down
1 change: 1 addition & 0 deletions web/src/pages/user-setting/setting-model/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export const IconMap = {
'01.AI': 'yi',
Replicate: 'replicate',
'Tencent Hunyuan': 'hunyuan',
'XunFei Spark': 'spark',
};

export const BedrockRegionList = [
Expand Down
27 changes: 27 additions & 0 deletions web/src/pages/user-setting/setting-model/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,33 @@ export const useSubmitHunyuan = () => {
};
};

export const useSubmitSpark = () => {
const { addLlm, loading } = useAddLlm();
const {
visible: SparkAddingVisible,
hideModal: hideSparkAddingModal,
showModal: showSparkAddingModal,
} = useSetModalState();

const onSparkAddingOk = useCallback(
async (payload: IAddLlmRequestBody) => {
const ret = await addLlm(payload);
if (ret === 0) {
hideSparkAddingModal();
}
},
[hideSparkAddingModal, addLlm],
);

return {
SparkAddingLoading: loading,
onSparkAddingOk,
SparkAddingVisible,
hideSparkAddingModal,
showSparkAddingModal,
};
};

export const useSubmitBedrock = () => {
const { addLlm, loading } = useAddLlm();
const {
Expand Down
28 changes: 26 additions & 2 deletions web/src/pages/user-setting/setting-model/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ import {
useSubmitBedrock,
useSubmitHunyuan,
useSubmitOllama,
useSubmitSpark,
useSubmitSystemModelSetting,
useSubmitVolcEngine,
} from './hooks';
import HunyuanModal from './hunyuan-modal';
import styles from './index.less';
import OllamaModal from './ollama-modal';
import SparkModal from './spark-modal';
import SystemModelSettingModal from './system-model-setting-modal';
import VolcEngineModal from './volcengine-modal';

Expand Down Expand Up @@ -92,7 +94,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => {
<Button onClick={handleApiKeyClick}>
{isLocalLlmFactory(item.name) ||
item.name === 'VolcEngine' ||
item.name === 'Tencent Hunyuan'
item.name === 'Tencent Hunyuan' ||
item.name === 'XunFei Spark'
? t('addTheModel')
: 'API-Key'}
<SettingOutlined />
Expand Down Expand Up @@ -174,6 +177,14 @@ const UserSettingModel = () => {
HunyuanAddingLoading,
} = useSubmitHunyuan();

const {
SparkAddingVisible,
hideSparkAddingModal,
showSparkAddingModal,
onSparkAddingOk,
SparkAddingLoading,
} = useSubmitSpark();

const {
bedrockAddingLoading,
onBedrockAddingOk,
Expand All @@ -187,8 +198,14 @@ const UserSettingModel = () => {
Bedrock: showBedrockAddingModal,
VolcEngine: showVolcAddingModal,
'Tencent Hunyuan': showHunyuanAddingModal,
'XunFei Spark': showSparkAddingModal,
}),
[showBedrockAddingModal, showVolcAddingModal, showHunyuanAddingModal],
[
showBedrockAddingModal,
showVolcAddingModal,
showHunyuanAddingModal,
showSparkAddingModal,
],
);

const handleAddModel = useCallback(
Expand Down Expand Up @@ -306,6 +323,13 @@ const UserSettingModel = () => {
loading={HunyuanAddingLoading}
llmFactory={'Tencent Hunyuan'}
></HunyuanModal>
<SparkModal
visible={SparkAddingVisible}
hideModal={hideSparkAddingModal}
onOk={onSparkAddingOk}
loading={SparkAddingLoading}
llmFactory={'XunFei Spark'}
></SparkModal>
<BedrockModal
visible={bedrockAddingVisible}
hideModal={hideBedrockAddingModal}
Expand Down
94 changes: 94 additions & 0 deletions web/src/pages/user-setting/setting-model/spark-modal/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import { useTranslate } from '@/hooks/common-hooks';
import { IModalProps } from '@/interfaces/common';
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
import { Form, Input, Modal, Select } from 'antd';
import omit from 'lodash/omit';

type FieldType = IAddLlmRequestBody & {
vision: boolean;
spark_api_password: string;
};

const { Option } = Select;

const SparkModal = ({
visible,
hideModal,
onOk,
loading,
llmFactory,
}: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
const [form] = Form.useForm<FieldType>();

const { t } = useTranslate('setting');

const handleOk = async () => {
const values = await form.validateFields();
const modelType =
values.model_type === 'chat' && values.vision
? 'image2text'
: values.model_type;

const data = {
...omit(values, ['vision']),
model_type: modelType,
llm_factory: llmFactory,
};
console.info(data);

onOk?.(data);
};

return (
<Modal
title={t('addLlmTitle', { name: llmFactory })}
open={visible}
onOk={handleOk}
onCancel={hideModal}
okButtonProps={{ loading }}
confirmLoading={loading}
>
<Form
name="basic"
style={{ maxWidth: 600 }}
autoComplete="off"
layout={'vertical'}
form={form}
>
<Form.Item<FieldType>
label={t('modelType')}
name="model_type"
initialValue={'chat'}
rules={[{ required: true, message: t('modelTypeMessage') }]}
>
<Select placeholder={t('modelTypeMessage')}>
<Option value="chat">chat</Option>
</Select>
</Form.Item>
<Form.Item<FieldType>
label={t('modelName')}
name="llm_name"
initialValue={'Spark-Max'}
rules={[{ required: true, message: t('SparkModelNameMessage') }]}
>
<Select placeholder={t('modelTypeMessage')}>
<Option value="Spark-Max">Spark-Max</Option>
<Option value="Spark-Lite">Spark-Lite</Option>
<Option value="Spark-Pro">Spark-Pro</Option>
<Option value="Spark-Pro-128K">Spark-Pro-128K</Option>
<Option value="Spark-4.0-Ultra">Spark-4.0-Ultra</Option>
</Select>
</Form.Item>
<Form.Item<FieldType>
label={t('addSparkAPIPassword')}
name="spark_api_password"
rules={[{ required: true, message: t('SparkPasswordMessage') }]}
>
<Input placeholder={t('SparkSIDMessage')} />
</Form.Item>
</Form>
</Modal>
);
};

export default SparkModal;

0 comments on commit 7743922

Please sign in to comment.