Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

search between multiple indiices for team function #3079

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agent/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .tushare import TuShare, TuShareParam
from .akshare import AkShare, AkShareParam
from .crawler import Crawler, CrawlerParam
from .invoke import Invoke, InvokeParam


def component_class(class_name):
Expand Down
14 changes: 9 additions & 5 deletions agent/component/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from functools import partial
import pandas as pd
from api.db import LLMType
from api.db.services.dialog_service import message_fit_in
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from agent.component.base import ComponentBase, ComponentParamBase
Expand Down Expand Up @@ -112,7 +113,7 @@ def _run(self, history, **kwargs):

kwargs["input"] = input
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v), prompt)
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)

downstreams = self._canvas.get_component(self._id)["downstream"]
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
Expand All @@ -124,8 +125,10 @@ def _run(self, history, **kwargs):
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
return pd.DataFrame([res])

ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
self._param.gen_conf())
msg = self._canvas.get_history(self._param.message_history_window_size)
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())

if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
res = self.set_cite(retrieval_res, ans)
return pd.DataFrame([res])
Expand All @@ -141,9 +144,10 @@ def stream_output(self, chat_mdl, prompt, retrieval_res):
self.set_output(res)
return

msg = self._canvas.get_history(self._param.message_history_window_size)
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
answer = ""
for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size),
self._param.gen_conf()):
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []}
answer = ans
yield res
Expand Down
25 changes: 22 additions & 3 deletions agent/component/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.
#
import json
import re
from abc import ABC

import requests

from deepdoc.parser import HtmlParser
from agent.component.base import ComponentBase, ComponentParamBase


Expand All @@ -34,11 +34,13 @@ def __init__(self):
self.variables = []
self.url = ""
self.timeout = 60
self.clean_html = False

def check(self):
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
self.check_empty(self.url, "End point URL")
self.check_positive_integer(self.timeout, "Timeout time in second")
self.check_boolean(self.clean_html, "Clean HTML")


class Invoke(ComponentBase, ABC):
Expand All @@ -63,7 +65,7 @@ def _run(self, history, **kwargs):
if self._param.headers:
headers = json.loads(self._param.headers)
proxies = None
if self._param.proxy:
if re.sub(r"https?:?/?/?", "", self._param.proxy):
proxies = {"http": self._param.proxy, "https": self._param.proxy}

if method == 'get':
Expand All @@ -72,6 +74,10 @@ def _run(self, history, **kwargs):
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))

return Invoke.be_output(response.text)

if method == 'put':
Expand All @@ -80,5 +86,18 @@ def _run(self, history, **kwargs):
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
return Invoke.be_output(response.text)

if method == 'post':
response = requests.post(url=url,
json=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
return Invoke.be_output(response.text)
4 changes: 3 additions & 1 deletion api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ def chat(dialog, messages, stream=True, **kwargs):
else:
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,

tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
Expand Down
2 changes: 2 additions & 0 deletions deepdoc/parser/html_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import html_text
import chardet


def get_encoding(file):
with open(file,'rb') as f:
tmp = chardet.detect(f.read())
return tmp['encoding']


class RAGFlowHtmlParser:
def __call__(self, fnm, binary=None):
txt = ""
Expand Down
18 changes: 12 additions & 6 deletions rag/nlp/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _add_filters(self, bqry, req):
Q("bool", must_not=Q("range", available_int={"lt": 1})))
return bqry

def search(self, req, idxnm, emb_mdl=None, highlight=False):
def search(self, req, idxnms, emb_mdl=None, highlight=False):
qst = req.get("question", "")
bqry, keywords = self.qryr.question(qst, min_match="30%")
bqry = self._add_filters(bqry, req)
Expand Down Expand Up @@ -134,7 +134,7 @@ def search(self, req, idxnm, emb_mdl=None, highlight=False):
del s["highlight"]
q_vec = s["knn"]["query_vector"]
es_logger.info("【Q】: {}".format(json.dumps(s)))
res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%")
Expand All @@ -144,7 +144,7 @@ def search(self, req, idxnm, emb_mdl=None, highlight=False):
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.17
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
es_logger.info("【Q】: {}".format(json.dumps(s)))

kwds = set([])
Expand Down Expand Up @@ -358,20 +358,26 @@ def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
rag_tokenizer.tokenize(ans).split(" "),
rag_tokenizer.tokenize(inst).split(" "))

def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
return ranks

RERANK_PAGE_LIMIT = 3
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
"question": question, "vector": True, "topk": top,
"similarity": similarity_threshold,
"available_int": 1}

if page > RERANK_PAGE_LIMIT:
req["page"] = page
req["size"] = page_size
sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)

if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")

sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
ranks["total"] = sres.total

if page <= RERANK_PAGE_LIMIT:
Expand Down Expand Up @@ -467,7 +473,7 @@ def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "co
s = Search()
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
s = s.to_dict()
es_res = self.es.search(s, idxnm=index_name(tenant_id), timeout="600s", src=fields)
es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
res = []
for index, chunk in enumerate(es_res['hits']['hits']):
res.append({fld: chunk['_source'].get(fld) for fld in fields})
Expand Down
6 changes: 4 additions & 2 deletions rag/utils/es_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,14 @@ def rm(self, d):

return False

def search(self, q, idxnm=None, src=False, timeout="2s"):
def search(self, q, idxnms=None, src=False, timeout="2s"):
if not isinstance(q, dict):
q = Search().query(q).to_dict()
if isinstance(idxnms, str):
idxnms = idxnms.split(",")
for i in range(3):
try:
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
body=q,
timeout=timeout,
# search_type="dfs_query_then_fetch",
Expand Down