diff --git a/agent/component/__init__.py b/agent/component/__init__.py index 2e85a3e093c..a977f03f441 100644 --- a/agent/component/__init__.py +++ b/agent/component/__init__.py @@ -29,6 +29,7 @@ from .tushare import TuShare, TuShareParam from .akshare import AkShare, AkShareParam from .crawler import Crawler, CrawlerParam +from .invoke import Invoke, InvokeParam def component_class(class_name): diff --git a/agent/component/generate.py b/agent/component/generate.py index 9297cdca0bb..588613e582b 100644 --- a/agent/component/generate.py +++ b/agent/component/generate.py @@ -17,6 +17,7 @@ from functools import partial import pandas as pd from api.db import LLMType +from api.db.services.dialog_service import message_fit_in from api.db.services.llm_service import LLMBundle from api.settings import retrievaler from agent.component.base import ComponentBase, ComponentParamBase @@ -112,7 +113,7 @@ def _run(self, history, **kwargs): kwargs["input"] = input for n, v in kwargs.items(): - prompt = re.sub(r"\{%s\}" % re.escape(n), str(v), prompt) + prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt) downstreams = self._canvas.get_component(self._id)["downstream"] if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[ @@ -124,8 +125,10 @@ def _run(self, history, **kwargs): retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []} return pd.DataFrame([res]) - ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size), - self._param.gen_conf()) + msg = self._canvas.get_history(self._param.message_history_window_size) + _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97)) + ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf()) + if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: res = self.set_cite(retrieval_res, ans) return pd.DataFrame([res]) @@ -141,9 +144,10 @@ def stream_output(self, chat_mdl, prompt, retrieval_res): self.set_output(res) return + msg = self._canvas.get_history(self._param.message_history_window_size) + _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97)) answer = "" - for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size), - self._param.gen_conf()): + for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()): res = {"content": ans, "reference": []} answer = ans yield res diff --git a/agent/component/invoke.py b/agent/component/invoke.py index 1078d35a978..d497be93f8a 100644 --- a/agent/component/invoke.py +++ b/agent/component/invoke.py @@ -14,10 +14,10 @@ # limitations under the License. # import json +import re from abc import ABC - import requests - +from deepdoc.parser import HtmlParser from agent.component.base import ComponentBase, ComponentParamBase @@ -34,11 +34,13 @@ def __init__(self): self.variables = [] self.url = "" self.timeout = 60 + self.clean_html = False def check(self): self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put']) self.check_empty(self.url, "End point URL") self.check_positive_integer(self.timeout, "Timeout time in second") + self.check_boolean(self.clean_html, "Clean HTML") class Invoke(ComponentBase, ABC): @@ -63,7 +65,7 @@ def _run(self, history, **kwargs): if self._param.headers: headers = json.loads(self._param.headers) proxies = None - if self._param.proxy: + if re.sub(r"https?:?/?/?", "", self._param.proxy): proxies = {"http": self._param.proxy, "https": self._param.proxy} if method == 'get': @@ -72,6 +74,10 @@ def _run(self, history, **kwargs): headers=headers, proxies=proxies, timeout=self._param.timeout) + if self._param.clean_html: + sections = HtmlParser()(None, response.content) + return Invoke.be_output("\n".join(sections)) + return Invoke.be_output(response.text) if method == 'put': @@ -80,5 +86,18 @@ def _run(self, history, **kwargs): headers=headers, proxies=proxies, timeout=self._param.timeout) + if self._param.clean_html: + sections = HtmlParser()(None, response.content) + return Invoke.be_output("\n".join(sections)) + return Invoke.be_output(response.text) + if method == 'post': + response = requests.post(url=url, + json=args, + headers=headers, + proxies=proxies, + timeout=self._param.timeout) + if self._param.clean_html: + sections = HtmlParser()(None, response.content) + return Invoke.be_output("\n".join(sections)) return Invoke.be_output(response.text) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 13a395f85d7..4245a284f51 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -205,7 +205,9 @@ def chat(dialog, messages, stream=True, **kwargs): else: if prompt_config.get("keyword", False): questions[-1] += keyword_extraction(chat_mdl, questions[-1]) - kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, + + tenant_ids = list(set([kb.tenant_id for kb in kbs])) + kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, dialog.vector_similarity_weight, doc_ids=attachments, diff --git a/deepdoc/parser/html_parser.py b/deepdoc/parser/html_parser.py index d5cde78c2ea..e02aaa1f3e5 100644 --- a/deepdoc/parser/html_parser.py +++ b/deepdoc/parser/html_parser.py @@ -16,11 +16,13 @@ import html_text import chardet + def get_encoding(file): with open(file,'rb') as f: tmp = chardet.detect(f.read()) return tmp['encoding'] + class RAGFlowHtmlParser: def __call__(self, fnm, binary=None): txt = "" diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 89a2592ece0..a7740fafca1 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -79,7 +79,7 @@ def _add_filters(self, bqry, req): Q("bool", must_not=Q("range", available_int={"lt": 1}))) return bqry - def search(self, req, idxnm, emb_mdl=None, highlight=False): + def search(self, req, idxnms, emb_mdl=None, highlight=False): qst = req.get("question", "") bqry, keywords = self.qryr.question(qst, min_match="30%") bqry = self._add_filters(bqry, req) @@ -134,7 +134,7 @@ def search(self, req, idxnm, emb_mdl=None, highlight=False): del s["highlight"] q_vec = s["knn"]["query_vector"] es_logger.info("【Q】: {}".format(json.dumps(s))) - res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) + res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src) es_logger.info("TOTAL: {}".format(self.es.getTotal(res))) if self.es.getTotal(res) == 0 and "knn" in s: bqry, _ = self.qryr.question(qst, min_match="10%") @@ -144,7 +144,7 @@ def search(self, req, idxnm, emb_mdl=None, highlight=False): s["query"] = bqry.to_dict() s["knn"]["filter"] = bqry.to_dict() s["knn"]["similarity"] = 0.17 - res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src) + res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src) es_logger.info("【Q】: {}".format(json.dumps(s))) kwds = set([]) @@ -358,20 +358,26 @@ def hybrid_similarity(self, ans_embd, ins_embd, ans, inst): rag_tokenizer.tokenize(ans).split(" "), rag_tokenizer.tokenize(inst).split(" ")) - def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, + def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2, vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False): ranks = {"total": 0, "chunks": [], "doc_aggs": {}} if not question: return ranks + RERANK_PAGE_LIMIT = 3 req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128), "question": question, "vector": True, "topk": top, "similarity": similarity_threshold, "available_int": 1} + if page > RERANK_PAGE_LIMIT: req["page"] = page req["size"] = page_size - sres = self.search(req, index_name(tenant_id), embd_mdl, highlight) + + if isinstance(tenant_ids, str): + tenant_ids = tenant_ids.split(",") + + sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight) ranks["total"] = sres.total if page <= RERANK_PAGE_LIMIT: @@ -467,7 +473,7 @@ def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "co s = Search() s = s.query(Q("match", doc_id=doc_id))[0:max_count] s = s.to_dict() - es_res = self.es.search(s, idxnm=index_name(tenant_id), timeout="600s", src=fields) + es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields) res = [] for index, chunk in enumerate(es_res['hits']['hits']): res.append({fld: chunk['_source'].get(fld) for fld in fields}) diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 8b07be312c3..d39e263f7f8 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -221,12 +221,14 @@ def rm(self, d): return False - def search(self, q, idxnm=None, src=False, timeout="2s"): + def search(self, q, idxnms=None, src=False, timeout="2s"): if not isinstance(q, dict): q = Search().query(q).to_dict() + if isinstance(idxnms, str): + idxnms = idxnms.split(",") for i in range(3): try: - res = self.es.search(index=(self.idxnm if not idxnm else idxnm), + res = self.es.search(index=(self.idxnm if not idxnms else idxnms), body=q, timeout=timeout, # search_type="dfs_query_then_fetch",