Skip to content

Commit

Permalink
search between multiple indiices for team function (#3079)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

#2834 
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
KevinHuSh authored Oct 29, 2024
1 parent c5a3146 commit 2d1fbef
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 17 deletions.
1 change: 1 addition & 0 deletions agent/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .tushare import TuShare, TuShareParam
from .akshare import AkShare, AkShareParam
from .crawler import Crawler, CrawlerParam
from .invoke import Invoke, InvokeParam


def component_class(class_name):
Expand Down
14 changes: 9 additions & 5 deletions agent/component/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from functools import partial
import pandas as pd
from api.db import LLMType
from api.db.services.dialog_service import message_fit_in
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from agent.component.base import ComponentBase, ComponentParamBase
Expand Down Expand Up @@ -112,7 +113,7 @@ def _run(self, history, **kwargs):

kwargs["input"] = input
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v), prompt)
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)

downstreams = self._canvas.get_component(self._id)["downstream"]
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
Expand All @@ -124,8 +125,10 @@ def _run(self, history, **kwargs):
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
return pd.DataFrame([res])

ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
self._param.gen_conf())
msg = self._canvas.get_history(self._param.message_history_window_size)
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())

if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
res = self.set_cite(retrieval_res, ans)
return pd.DataFrame([res])
Expand All @@ -141,9 +144,10 @@ def stream_output(self, chat_mdl, prompt, retrieval_res):
self.set_output(res)
return

msg = self._canvas.get_history(self._param.message_history_window_size)
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
answer = ""
for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size),
self._param.gen_conf()):
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []}
answer = ans
yield res
Expand Down
25 changes: 22 additions & 3 deletions agent/component/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.
#
import json
import re
from abc import ABC

import requests

from deepdoc.parser import HtmlParser
from agent.component.base import ComponentBase, ComponentParamBase


Expand All @@ -34,11 +34,13 @@ def __init__(self):
self.variables = []
self.url = ""
self.timeout = 60
self.clean_html = False

def check(self):
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
self.check_empty(self.url, "End point URL")
self.check_positive_integer(self.timeout, "Timeout time in second")
self.check_boolean(self.clean_html, "Clean HTML")


class Invoke(ComponentBase, ABC):
Expand All @@ -63,7 +65,7 @@ def _run(self, history, **kwargs):
if self._param.headers:
headers = json.loads(self._param.headers)
proxies = None
if self._param.proxy:
if re.sub(r"https?:?/?/?", "", self._param.proxy):
proxies = {"http": self._param.proxy, "https": self._param.proxy}

if method == 'get':
Expand All @@ -72,6 +74,10 @@ def _run(self, history, **kwargs):
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))

return Invoke.be_output(response.text)

if method == 'put':
Expand All @@ -80,5 +86,18 @@ def _run(self, history, **kwargs):
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
return Invoke.be_output(response.text)

if method == 'post':
response = requests.post(url=url,
json=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
return Invoke.be_output(response.text)
4 changes: 3 additions & 1 deletion api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ def chat(dialog, messages, stream=True, **kwargs):
else:
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,

tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
Expand Down
2 changes: 2 additions & 0 deletions deepdoc/parser/html_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import html_text
import chardet


def get_encoding(file):
with open(file,'rb') as f:
tmp = chardet.detect(f.read())
return tmp['encoding']


class RAGFlowHtmlParser:
def __call__(self, fnm, binary=None):
txt = ""
Expand Down
18 changes: 12 additions & 6 deletions rag/nlp/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _add_filters(self, bqry, req):
Q("bool", must_not=Q("range", available_int={"lt": 1})))
return bqry

def search(self, req, idxnm, emb_mdl=None, highlight=False):
def search(self, req, idxnms, emb_mdl=None, highlight=False):
qst = req.get("question", "")
bqry, keywords = self.qryr.question(qst, min_match="30%")
bqry = self._add_filters(bqry, req)
Expand Down Expand Up @@ -134,7 +134,7 @@ def search(self, req, idxnm, emb_mdl=None, highlight=False):
del s["highlight"]
q_vec = s["knn"]["query_vector"]
es_logger.info("【Q】: {}".format(json.dumps(s)))
res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%")
Expand All @@ -144,7 +144,7 @@ def search(self, req, idxnm, emb_mdl=None, highlight=False):
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.17
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
es_logger.info("【Q】: {}".format(json.dumps(s)))

kwds = set([])
Expand Down Expand Up @@ -358,20 +358,26 @@ def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
rag_tokenizer.tokenize(ans).split(" "),
rag_tokenizer.tokenize(inst).split(" "))

def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
return ranks

RERANK_PAGE_LIMIT = 3
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
"question": question, "vector": True, "topk": top,
"similarity": similarity_threshold,
"available_int": 1}

if page > RERANK_PAGE_LIMIT:
req["page"] = page
req["size"] = page_size
sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)

if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")

sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
ranks["total"] = sres.total

if page <= RERANK_PAGE_LIMIT:
Expand Down Expand Up @@ -467,7 +473,7 @@ def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "co
s = Search()
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
s = s.to_dict()
es_res = self.es.search(s, idxnm=index_name(tenant_id), timeout="600s", src=fields)
es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
res = []
for index, chunk in enumerate(es_res['hits']['hits']):
res.append({fld: chunk['_source'].get(fld) for fld in fields})
Expand Down
6 changes: 4 additions & 2 deletions rag/utils/es_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,14 @@ def rm(self, d):

return False

def search(self, q, idxnm=None, src=False, timeout="2s"):
def search(self, q, idxnms=None, src=False, timeout="2s"):
if not isinstance(q, dict):
q = Search().query(q).to_dict()
if isinstance(idxnms, str):
idxnms = idxnms.split(",")
for i in range(3):
try:
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
body=q,
timeout=timeout,
# search_type="dfs_query_then_fetch",
Expand Down

0 comments on commit 2d1fbef

Please sign in to comment.