From 1668d0a8b14c8fea3eda8767ed85fda65efb58fd Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 10 Dec 2024 16:37:51 +0800 Subject: [PATCH] Support iframe chatbot. --- agent/canvas.py | 5 +- agent/component/base.py | 17 +++- agent/component/generate.py | 8 ++ api/apps/canvas_app.py | 20 ++++ api/apps/conversation_app.py | 52 ++++++---- api/apps/sdk/session.py | 10 ++ api/db/services/canvas_service.py | 44 ++++----- api/db/services/conversation_service.py | 38 ++++---- api/db/services/dialog_service.py | 122 +++++++++--------------- 9 files changed, 174 insertions(+), 142 deletions(-) diff --git a/agent/canvas.py b/agent/canvas.py index d8b04a9835..2f545f4ec3 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -330,4 +330,7 @@ def set_global_param(self, **kwargs): q["value"] = v def get_preset_param(self): - return self.components["begin"]["obj"]._param.query \ No newline at end of file + return self.components["begin"]["obj"]._param.query + + def get_component_input_elements(self, cpnnm): + return self.components["begin"]["obj"].get_input_elements() \ No newline at end of file diff --git a/agent/component/base.py b/agent/component/base.py index 2dc0cd49be..3b6a9fce32 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -476,7 +476,7 @@ def get_input(self): self._param.inputs.append({"component_id": q["component_id"], "content": "\n".join( [str(d["content"]) for d in outs[-1].to_dict('records')])}) - elif q["value"]: + elif q.get("value"): self._param.inputs.append({"component_id": None, "content": q["value"]}) outs.append(pd.DataFrame([{"content": q["value"]}])) if outs: @@ -526,6 +526,21 @@ def get_input(self): return df + def get_input_elements(self): + assert self._param.query, "Please identify input parameters firstly." + eles = [] + for q in self._param.query: + if q.get("component_id"): + if q["component_id"].split("@")[0].lower().find("begin") >= 0: + cpn_id, key = q["component_id"].split("@") + eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query) + continue + + eles.append({"key": q["key"], "component_id": q["component_id"]}) + else: + eles.append({"key": q["key"]}) + return eles + def get_stream_input(self): reversed_cpnts = [] if len(self._canvas.path) > 1: diff --git a/agent/component/generate.py b/agent/component/generate.py index 27f1ce2fdf..fc5672d3c1 100644 --- a/agent/component/generate.py +++ b/agent/component/generate.py @@ -17,6 +17,7 @@ from functools import partial import pandas as pd from api.db import LLMType +from api.db.services.conversation_service import structure_answer from api.db.services.dialog_service import message_fit_in from api.db.services.llm_service import LLMBundle from api import settings @@ -104,9 +105,16 @@ def set_cite(self, retrieval_res, answer): if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" res = {"content": answer, "reference": reference} + res = structure_answer(None, res, "", "") return res + def get_input_elements(self): + if self._param.parameters: + return self._param.parameters + + return [{"key": "input"}] + def _run(self, history, **kwargs): chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) prompt = self._param.prompt diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 679947476e..a36789bed3 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -186,6 +186,26 @@ def reset(): return server_error_response(e) +@manager.route('/input_elements', methods=['GET']) # noqa: F821 +@validate_request("id", "component_id") +@login_required +def input_elements(): + req = request.json + try: + e, user_canvas = UserCanvasService.get_by_id(req["id"]) + if not e: + return get_data_error_result(message="canvas not found.") + if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): + return get_json_result( + data=False, message='Only owner of canvas authorized for this operation.', + code=RetCode.OPERATING_ERROR) + + canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) + return get_json_result(data=canvas.get_component_input_elements(req["component_id"])) + except Exception as e: + return server_error_response(e) + + @manager.route('/test_db_connect', methods=['POST']) # noqa: F821 @validate_request("db_type", "database", "username", "host", "port", "password") @login_required diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index d787e5f1f6..345f71f74d 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -18,7 +18,7 @@ import traceback from copy import deepcopy -from api.db.services.conversation_service import ConversationService +from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.user_service import UserTenantService from flask import request, Response from flask_login import login_required, current_user @@ -90,6 +90,21 @@ def get(): return get_json_result( data=False, message='Only owner of conversation authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) + + def get_value(d, k1, k2): + return d.get(k1, d.get(k2)) + + for ref in conv.reference: + ref["chunks"] = [{ + "id": get_value(ck, "chunk_id", "id"), + "content": get_value(ck, "content", "content_with_weight"), + "document_id": get_value(ck, "doc_id", "document_id"), + "document_name": get_value(ck, "docnm_kwd", "document_name"), + "dataset_id": get_value(ck, "kb_id", "dataset_id"), + "image_id": get_value(ck, "image_id", "img_id"), + "positions": get_value(ck, "positions", "position_int"), + } for ck in ref.get("chunks", [])] + conv = conv.to_dict() return get_json_result(data=conv) except Exception as e: @@ -132,6 +147,7 @@ def list_convsersation(): dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True) + convs = [d.to_dict() for d in convs] return get_json_result(data=convs) except Exception as e: @@ -164,24 +180,29 @@ def completion(): if not conv.reference: conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) - - def fillin_conv(ans): - nonlocal conv, message_id - if not conv.reference: - conv.reference.append(ans["reference"]) - else: - conv.reference[-1] = ans["reference"] - conv.message[-1] = {"role": "assistant", "content": ans["answer"], - "id": message_id, "prompt": ans.get("prompt", "")} - ans["id"] = message_id + else: + def get_value(d, k1, k2): + return d.get(k1, d.get(k2)) + + for ref in conv.reference: + ref["chunks"] = [{ + "id": get_value(ck, "chunk_id", "id"), + "content": get_value(ck, "content", "content_with_weight"), + "document_id": get_value(ck, "doc_id", "document_id"), + "document_name": get_value(ck, "docnm_kwd", "document_name"), + "dataset_id": get_value(ck, "kb_id", "dataset_id"), + "image_id": get_value(ck, "image_id", "img_id"), + "positions": get_value(ck, "positions", "position_int"), + } for ck in ref.get("chunks", [])] + if not conv.reference: + conv.reference = [] + conv.reference.append({"chunks": [], "doc_aggs": []}) def stream(): nonlocal dia, msg, req, conv try: for ans in chat(dia, msg, True, **req): - fillin_conv(ans) + ans = structure_answer(conv, ans, message_id, conv.id) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" ConversationService.update_by_id(conv.id, conv.to_dict()) except Exception as e: @@ -202,8 +223,7 @@ def stream(): else: answer = None for ans in chat(dia, msg, **req): - answer = ans - fillin_conv(ans) + answer = structure_answer(conv, ans, message_id, req["conversation_id"]) ConversationService.update_by_id(conv.id, conv.to_dict()) break return get_json_result(data=answer) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index a779fe9c70..5307a2a676 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -112,6 +112,11 @@ def update(tenant_id, chat_id, session_id): @token_required def chat_completion(tenant_id, chat_id): req = request.json + if not DialogService.query(tenant_id=tenant_id,id=chat_id,status=StatusEnum.VALID.value): + return get_error_data_result(f"You don't own the chat {chat_id}") + if req.get("session_id"): + if not ConversationService.query(id=req["session_id"],dialog_id=chat_id): + return get_error_data_result(f"You don't own the session {req['session_id']}") if req.get("stream", True): resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") @@ -133,6 +138,11 @@ def chat_completion(tenant_id, chat_id): @token_required def agent_completions(tenant_id, agent_id): req = request.json + if not UserCanvasService.query(user_id=tenant_id,id=agent_id): + return get_error_data_result(f"You don't own the agent {agent_id}") + if req.get("session_id"): + if not API4ConversationService.query(id=req["session_id"],dialog_id=agent_id): + return get_error_data_result(f"You don't own the session {req['session_id']}") if req.get("stream", True): resp = Response(agent_completion(tenant_id, agent_id, **req), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index c51ffeee57..ac86733460 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -14,6 +14,7 @@ # limitations under the License. # import json +import traceback from uuid import uuid4 from agent.canvas import Canvas from api.db.db_models import DB, CanvasTemplate, UserCanvas, API4Conversation @@ -58,6 +59,8 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw if not isinstance(cvs.dsl, str): cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) canvas = Canvas(cvs.dsl, tenant_id) + canvas.reset() + message_id = str(uuid4()) if not session_id: session_id = get_uuid() @@ -84,40 +87,24 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw return conv = API4Conversation(**conv) else: - session_id = session_id e, conv = API4ConversationService.get_by_id(session_id) assert e, "Session not found!" canvas = Canvas(json.dumps(conv.dsl), tenant_id) - - if not conv.message: - conv.message = [] - messages = conv.message - question = { - "role": "user", - "content": question, - "id": str(uuid4()) - } - messages.append(question) - msg = [] - for m in messages: - if m["role"] == "system": - continue - if m["role"] == "assistant" and not msg: - continue - msg.append(m) - if not msg[-1].get("id"): - msg[-1]["id"] = get_uuid() - message_id = msg[-1]["id"] - - if not conv.reference: - conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) + canvas.messages.append({"role": "user", "content": question, "id": message_id}) + canvas.add_user_input(question) + if not conv.message: + conv.message = [] + conv.message.append({ + "role": "user", + "content": question, + "id": message_id + }) + if not conv.reference: + conv.reference = [] + conv.reference.append({"chunks": [], "doc_aggs": []}) final_ans = {"reference": [], "content": ""} - canvas.add_user_input(msg[-1]["content"]) - if stream: try: for ans in canvas.run(stream=stream): @@ -141,6 +128,7 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw conv.dsl = json.loads(str(canvas)) API4ConversationService.append_message(conv.id, conv.to_dict()) except Exception as e: + traceback.print_exc() conv.dsl = json.loads(str(canvas)) API4ConversationService.append_message(conv.id, conv.to_dict()) yield "data:" + json.dumps({"code": 500, "message": str(e), diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 3cfa42a6b8..7518561782 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -49,30 +49,35 @@ def structure_answer(conv, ans, message_id, session_id): reference = ans["reference"] if not isinstance(reference, dict): reference = {} - temp_reference = deepcopy(ans["reference"]) - if not conv.reference: - conv.reference.append(temp_reference) - else: - conv.reference[-1] = temp_reference - conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id} + ans["reference"] = {} + def get_value(d, k1, k2): + return d.get(k1, d.get(k2)) chunk_list = [{ - "id": chunk["chunk_id"], - "content": chunk.get("content") if chunk.get("content") else chunk.get("content_with_content"), - "document_id": chunk["doc_id"], - "document_name": chunk["docnm_kwd"], - "dataset_id": chunk["kb_id"], - "image_id": chunk["image_id"], - "similarity": chunk["similarity"], - "vector_similarity": chunk["vector_similarity"], - "term_similarity": chunk["term_similarity"], - "positions": chunk["positions"], + "id": get_value(chunk, "chunk_id", "id"), + "content": get_value(chunk, "content", "content_with_weight"), + "document_id": get_value(chunk, "doc_id", "document_id"), + "document_name": get_value(chunk, "docnm_kwd", "document_name"), + "dataset_id": get_value(chunk, "kb_id", "dataset_id"), + "image_id": get_value(chunk, "image_id", "img_id"), + "positions": get_value(chunk, "positions", "position_int"), } for chunk in reference.get("chunks", [])] reference["chunks"] = chunk_list ans["id"] = message_id ans["session_id"] = session_id + if not conv: + return ans + + if not conv.message: + conv.message = [] + if not conv.message or conv.message[-1].get("role", "") != "assistant": + conv.message.append({"role": "assistant", "content": ans["answer"], "id": message_id}) + else: + conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id} + if conv.reference: + conv.reference[-1] = reference return ans @@ -199,7 +204,6 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg if not conv.reference: conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) conv.reference.append({"chunks": [], "doc_aggs": []}) if stream: diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 36daadddac..b68b8e65fe 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -18,6 +18,7 @@ import os import json import re +from collections import defaultdict from copy import deepcopy from timeit import default_timer as timer import datetime @@ -108,6 +109,32 @@ def llm_id2llm_type(llm_id): return llm["model_type"].strip(",")[-1] +def kb_prompt(kbinfos, max_tokens): + knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] + used_token_count = 0 + chunks_num = 0 + for i, c in enumerate(knowledges): + used_token_count += num_tokens_from_string(c) + chunks_num += 1 + if max_tokens * 0.97 < used_token_count: + knowledges = knowledges[:i] + break + + doc2chunks = defaultdict(list) + for i, ck in enumerate(kbinfos["chunks"]): + if i >= chunks_num: + break + doc2chunks["docnm_kwd"].append(ck["content_with_weight"]) + + knowledges = [] + for nm, chunks in doc2chunks.items(): + txt = f"Document: {nm} \nContains the following relevant fragments:\n" + for i, chunk in enumerate(chunks, 1): + txt += f"{i}. {chunk}\n" + knowledges.append(txt) + return knowledges + + def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." st = timer() @@ -195,32 +222,7 @@ def chat(dialog, messages, stream=True, **kwargs): dialog.vector_similarity_weight, doc_ids=attachments, top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) - - # Group chunks by document ID - doc_chunks = {} - for ck in kbinfos["chunks"]: - doc_id = ck["doc_id"] - if doc_id not in doc_chunks: - doc_chunks[doc_id] = [] - doc_chunks[doc_id].append(ck["content_with_weight"]) - - # Create knowledges list with grouped chunks - knowledges = [] - for doc_id, chunks in doc_chunks.items(): - # Find the corresponding document name - doc_name = next((d["doc_name"] for d in kbinfos.get("doc_aggs", []) if d["doc_id"] == doc_id), doc_id) - - # Create a header for the document - doc_knowledge = f"Document: {doc_name} \nContains the following relevant fragments:\n" - - # Add numbered fragments - for i, chunk in enumerate(chunks, 1): - doc_knowledge += f"{i}. {chunk}\n" - - knowledges.append(doc_knowledge) - - - + knowledges = kb_prompt(kbinfos, max_tokens) logging.debug( "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) retrieval_tm = timer() @@ -603,7 +605,6 @@ def tts(tts_mdl, text): def ask(question, kb_ids, tenant_id): kbs = KnowledgebaseService.get_by_ids(kb_ids) - tenant_ids = [kb.tenant_id for kb in kbs] embd_nms = list(set([kb.embd_id for kb in kbs])) is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) @@ -612,45 +613,9 @@ def ask(question, kb_ids, tenant_id): embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) max_tokens = chat_mdl.max_length - + tenant_ids = list(set([kb.tenant_id for kb in kbs])) kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False) - knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] - - used_token_count = 0 - chunks_num = 0 - for i, c in enumerate(knowledges): - used_token_count += num_tokens_from_string(c) - if max_tokens * 0.97 < used_token_count: - knowledges = knowledges[:i] - chunks_num = chunks_num + 1 - break - - # Group chunks by document ID - doc_chunks = {} - counter_chunks = 0 - for ck in kbinfos["chunks"]: - if counter_chunks < chunks_num: - counter_chunks = counter_chunks + 1 - doc_id = ck["doc_id"] - if doc_id not in doc_chunks: - doc_chunks[doc_id] = [] - doc_chunks[doc_id].append(ck["content_with_weight"]) - - # Create knowledges list with grouped chunks - knowledges = [] - for doc_id, chunks in doc_chunks.items(): - # Find the corresponding document name - doc_name = next((d["doc_name"] for d in kbinfos.get("doc_aggs", []) if d["doc_id"] == doc_id), doc_id) - - # Create a header for the document - doc_knowledge = f"Document: {doc_name} \nContains the following relevant fragments:\n" - - # Add numbered fragments - for i, chunk in enumerate(chunks, 1): - doc_knowledge += f"{i}. {chunk}\n" - - knowledges.append(doc_knowledge) - + knowledges = kb_prompt(kbinfos, max_tokens) prompt = """ Role: You're a smart assistant. Your name is Miss R. Task: Summarize the information from knowledge bases and answer user's question. @@ -660,30 +625,29 @@ def ask(question, kb_ids, tenant_id): - Answer with markdown format text. - Answer in language of user's question. - DO NOT make things up, especially for numbers. - + ### Information from knowledge bases %s - + The above is information from knowledge bases. - - """%"\n".join(knowledges) + + """ % "\n".join(knowledges) msg = [{"role": "user", "content": question}] def decorate_answer(answer): nonlocal knowledges, kbinfos, prompt answer, idx = retr.insert_citations(answer, - [ck["content_ltks"] - for ck in kbinfos["chunks"]], - [ck["vector"] - for ck in kbinfos["chunks"]], - embd_mdl, - tkweight=0.7, - vtweight=0.3) + [ck["content_ltks"] + for ck in kbinfos["chunks"]], + [ck["vector"] + for ck in kbinfos["chunks"]], + embd_mdl, + tkweight=0.7, + vtweight=0.3) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [ d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] - if not recall_docs: - recall_docs = kbinfos["doc_aggs"] + if not recall_docs: recall_docs = kbinfos["doc_aggs"] kbinfos["doc_aggs"] = recall_docs refs = deepcopy(kbinfos) for c in refs["chunks"]: @@ -691,7 +655,7 @@ def decorate_answer(answer): del c["vector"] if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: - answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" + answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" return {"answer": answer, "reference": refs} answer = ""