Skip to content

Commit

Permalink
修改openai接口,让openai可以使用function的方式获取表结构
Browse files Browse the repository at this point in the history
  • Loading branch information
hejianjun committed Feb 4, 2024
1 parent e707d64 commit 0b05cc9
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,17 @@ const ChatInput = (props: IProps) => {
};

const renderSelectTable = () => {
const { tables, onSelectTableSyncModel, selectedTables, onSelectTables } = props;
const { tables, onSelectTableSyncModel, selectedTables, onSelectTables,syncTableModel } = props;
const options = (tables || []).map((t) => ({ value: t, label: t }));
return (
<div className={styles.aiSelectedTable}>
<Radio.Group
onChange={(v) => onSelectTableSyncModel(v.target.value)}
// value={syncTableModel}
value={SyncModelType.MANUAL}
value={syncTableModel}
style={{ marginBottom: '8px' }}
>
<Space direction="horizontal">
{/* <Radio value={SyncModelType.AUTO}>自动</Radio> */}
<Radio value={SyncModelType.AUTO}>自动</Radio>
<Radio value={SyncModelType.MANUAL}>手动</Radio>
</Space>
</Radio.Group>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ const SelectBoundInfo = memo((props: IProps) => {
boundInfo.databaseName,
boundInfo.schemaName,
);
setSelectedTables(tableNameListTemp.slice(0, 1));
//setSelectedTables(tableNameListTemp.slice(0, 1));
}
}, [allTableList, isActive]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,20 @@
import ai.chat2db.server.web.api.http.response.EsTableSchemaResponse;
import ai.chat2db.server.web.api.http.response.TableSchemaResponse;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import ai.chat2db.spi.MetaData;
import ai.chat2db.spi.model.Table;
import ai.chat2db.spi.sql.Chat2DBContext;
import ai.chat2db.spi.sql.ConnectInfo;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSON;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.entity.chat.Parameters;
import com.unfbx.chatgpt.entity.chat.tool.Tools;
import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
Expand Down Expand Up @@ -171,7 +180,7 @@ public SseEmitter customChat(@RequestBody ChatRequest queryRequest) throws IOExc
/**
* 自定义模型非流式输出接口DEMO
* <p>
* Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致
* Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致
* </p>
*
* @param queryRequest
Expand Down Expand Up @@ -276,11 +285,11 @@ private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter
* @throws IOException
*/
private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
throws IOException {
String prompt = buildPrompt(queryRequest);
throws IOException {
String prompt = buildPrompt2(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}

Expand All @@ -290,9 +299,28 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);

OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
OpenAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener);
ConnectInfo connectInfo = Chat2DBContext.getConnectInfo();
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter, messages, connectInfo, queryRequest);
ToolsFunction function = ToolsFunction.builder()
.name("get_table_columns")
.description("获取指定表的字段名,类型")
.parameters(Parameters.builder()
.type("object")
.properties(ImmutableMap.builder()
.put("table_name", ImmutableMap.builder()
.put("type", "string")
.put("description", "表名,例如```User```")
.build())
.build())
.required(List.of("table_name"))
.build())
.build();
ChatCompletion chatCompletion = ChatCompletion.builder()
.model("gpt-3.5-turbo-1106")
.tools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function)))
.toolChoice("auto")
.messages(messages).stream(true).build();
OpenAIClient.getInstance().streamChatCompletion(chatCompletion, openAIEventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
}
Expand Down Expand Up @@ -630,6 +658,47 @@ private String buildPrompt(ChatQueryRequest queryRequest) {
return cleanedInput;
}

/**
* 构建prompt
*
* @param queryRequest
* @return
*/
private String buildPrompt2(ChatQueryRequest queryRequest) {
if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) {
return queryRequest.getMessage();
}

// 查询schema信息
String dataSourceType = queryDatabaseType(queryRequest);
String properties = "";
if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) {
properties = queryRequest.getTableNames().stream().collect(Collectors.joining(","));
} else {
properties = queryDatabaseSchema2(queryRequest);
}
String prompt = queryRequest.getMessage();
String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode()
: queryRequest.getPromptType();
PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType);
String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : "";
String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format(
"### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables:\n#\n# "
+ "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType,
properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s",
pType.getDescription(), ext, prompt);
switch (pType) {
case SQL_2_SQL:
schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format(
"%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format(
"%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType);
default:
break;
}
String cleanedInput = schemaProperty.replaceAll("[\r\t]", "");
return cleanedInput;
}

/**
* query chat2db apikey
*
Expand Down Expand Up @@ -727,6 +796,28 @@ public String queryDatabaseSchema(ChatQueryRequest queryRequest) {
}
}


/**
* query database schema
*
* @param queryRequest
* @return
* @throws IOException
*/
public String queryDatabaseSchema2(ChatQueryRequest queryRequest) {
MetaData metaSchema = Chat2DBContext.getMetaData();
try {
List<Table> tables = metaSchema.tables(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), null);
return tables.stream()
.map(table -> StringUtils.isBlank(table.getComment()) ? table.getName()
: table.getName() + "(" + table.getComment() + ")")
.collect(Collectors.joining(","));
} catch (Exception e) {
log.error("query table error:{}, do nothing", e.getMessage());
return "";
}
}

/**
* query database schema
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
package ai.chat2db.server.web.api.controller.ai.chat2db.client;

import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum;
import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.controller.ai.chat2db.interceptor.Chat2dbHeaderAuthorizationInterceptor;
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatOpenAiApi;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbedding;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.interceptor.HeaderAuthorizationInterceptor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.service.ConfigService;
Expand Down Expand Up @@ -93,7 +94,17 @@ public static void refresh() {
log.info("refresh openai apikey:{}", maskApiKey(apikey));
if (Objects.nonNull(host) && Objects.nonNull(port)) {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(host, port));
OkHttpClient okHttpClient = new OkHttpClient.Builder().proxy(proxy).build();
OkHttpClient okHttpClient = new OkHttpClient.Builder()
// 设置连接超时为10秒
.connectTimeout(10, TimeUnit.SECONDS)
// 设置读取超时为30秒
.readTimeout(30, TimeUnit.SECONDS)
// 设置写入超时为15秒
.writeTimeout(15, TimeUnit.SECONDS)
// 设置整个调用的超时为1分钟
.callTimeout(1, TimeUnit.MINUTES)
.proxy(proxy)
.build();
OPEN_AI_STREAM_CLIENT = OpenAiStreamClient.builder().apiHost(apiHost).apiKey(
Lists.newArrayList(apikey)).okHttpClient(okHttpClient).build();
} else {
Expand Down
Loading

0 comments on commit 0b05cc9

Please sign in to comment.