From ff3e3540d94dfccd62f3e772e256fd9857f6c67f Mon Sep 17 00:00:00 2001 From: H <43509927+guoyuhao2330@users.noreply.github.com> Date: Fri, 9 Aug 2024 16:54:29 +0800 Subject: [PATCH] Add agent api (#1888) ### What problem does this PR solve? #1842 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/api_app.py | 208 ++++++++++++++++++++++++++++++++------------ api/db/db_models.py | 16 ++++ 2 files changed, 168 insertions(+), 56 deletions(-) diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 8d7f00e8bc..eb9bd6c520 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -39,6 +39,10 @@ from api.utils.file_utils import filename_type, thumbnail from rag.utils.minio_conn import MINIO +from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService +from agent.canvas import Canvas +from functools import partial + def generate_confirmation_token(tenent_id): serializer = URLSafeTimedSerializer(tenent_id) @@ -46,7 +50,6 @@ def generate_confirmation_token(tenent_id): @manager.route('/new_token', methods=['POST']) -@validate_request("dialog_id") @login_required def new_token(): req = request.json @@ -57,12 +60,17 @@ def new_token(): tenant_id = tenants[0].tenant_id obj = {"tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id), - "dialog_id": req["dialog_id"], "create_time": current_timestamp(), "create_date": datetime_format(datetime.now()), "update_time": None, "update_date": None } + if req.get("canvas_id"): + obj["dialog_id"] = req["canvas_id"] + obj["source"] = "agent" + else: + obj["dialog_id"] = req["dialog_id"] + if not APITokenService.save(**obj): return get_data_error_result(retmsg="Fail to new a dialog!") @@ -112,15 +120,15 @@ def stats(): "from_date", (datetime.now() - timedelta( - days=7)).strftime("%Y-%m-%d 24:00:00")), + days=7)).strftime("%Y-%m-%d 24:00:00")), request.args.get( "to_date", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) res = { "pv": [(o["dt"], o["pv"]) for o in objs], "uv": [(o["dt"], o["uv"]) for o in objs], - "speed": [(o["dt"], float(o["tokens"])/(float(o["duration"]+0.1))) for o in objs], - "tokens": [(o["dt"], float(o["tokens"])/1000.) for o in objs], + "speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs], + "tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs], "round": [(o["dt"], o["round"]) for o in objs], "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs] } @@ -138,21 +146,31 @@ def set_conversation(): data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) req = request.json try: - e, dia = DialogService.get_by_id(objs[0].dialog_id) - if not e: - return get_data_error_result(retmsg="Dialog not found") - conv = { - "id": get_uuid(), - "dialog_id": dia.id, - "user_id": request.args.get("user_id", ""), - "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] - } - API4ConversationService.save(**conv) - e, conv = API4ConversationService.get_by_id(conv["id"]) - if not e: - return get_data_error_result(retmsg="Fail to new a conversation!") - conv = conv.to_dict() - return get_json_result(data=conv) + if objs[0].source == "agent": + e, c = UserCanvasService.get_by_id(objs[0].dialog_id) + if not e: + return server_error_response("canvas not found.") + conv = { + "id": get_uuid(), + "dialog_id": c.id, + "user_id": request.args.get("user_id", ""), + "message": [{"role": "assistant", "content": "Hi there!"}], + "source": "agent" + } + API4ConversationService.save(**conv) + return get_json_result(data=conv) + else: + e, dia = DialogService.get_by_id(objs[0].dialog_id) + if not e: + return get_data_error_result(retmsg="Dialog not found") + conv = { + "id": get_uuid(), + "dialog_id": dia.id, + "user_id": request.args.get("user_id", ""), + "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] + } + API4ConversationService.save(**conv) + return get_json_result(data=conv) except Exception as e: return server_error_response(e) @@ -161,7 +179,8 @@ def set_conversation(): @validate_request("conversation_id", "messages") def completion(): token = request.headers.get('Authorization').split()[1] - if not APIToken.query(token=token): + objs = APIToken.query(token=token) + if not objs: return get_json_result( data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) req = request.json @@ -178,7 +197,100 @@ def completion(): continue msg.append({"role": m["role"], "content": m["content"]}) + def fillin_conv(ans): + nonlocal conv + if not conv.reference: + conv.reference.append(ans["reference"]) + else: + conv.reference[-1] = ans["reference"] + conv.message[-1] = {"role": "assistant", "content": ans["answer"]} + + def rename_field(ans): + reference = ans['reference'] + if not isinstance(reference, dict): + return + for chunk_i in reference.get('chunks', []): + if 'docnm_kwd' in chunk_i: + chunk_i['doc_name'] = chunk_i['docnm_kwd'] + chunk_i.pop('docnm_kwd') + try: + if conv.source == "agent": + stream = req.get("stream", True) + conv.message.append(msg[-1]) + e, cvs = UserCanvasService.get_by_id(conv.dialog_id) + if not e: + return server_error_response("canvas not found.") + del req["conversation_id"] + del req["messages"] + + if not isinstance(cvs.dsl, str): + cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) + + if not conv.reference: + conv.reference = [] + conv.message.append({"role": "assistant", "content": ""}) + conv.reference.append({"chunks": [], "doc_aggs": []}) + + final_ans = {"reference": [], "content": ""} + canvas = Canvas(cvs.dsl, objs[0].tenant_id) + + canvas.messages.append(msg[-1]) + canvas.add_user_input(msg[-1]["content"]) + answer = canvas.run(stream=stream) + + assert answer is not None, "Nothing. Is it over?" + + if stream: + assert isinstance(answer, partial), "Nothing. Is it over?" + + def sse(): + nonlocal answer, cvs, conv + try: + for ans in answer(): + for k in ans.keys(): + final_ans[k] = ans[k] + ans = {"answer": ans["content"], "reference": ans.get("reference", [])} + fillin_conv(ans) + rename_field(ans) + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, + ensure_ascii=False) + "\n\n" + + canvas.messages.append({"role": "assistant", "content": final_ans["content"]}) + if final_ans.get("reference"): + canvas.reference.append(final_ans["reference"]) + cvs.dsl = json.loads(str(canvas)) + API4ConversationService.append_message(conv.id, conv.to_dict()) + except Exception as e: + yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), + "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" + + resp = Response(sse(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" + canvas.messages.append({"role": "assistant", "content": final_ans["content"]}) + if final_ans.get("reference"): + canvas.reference.append(final_ans["reference"]) + cvs.dsl = json.loads(str(canvas)) + + result = None + for ans in answer(): + ans = {"answer": ans["content"], "reference": ans.get("reference", [])} + result = ans + fillin_conv(ans) + API4ConversationService.append_message(conv.id, conv.to_dict()) + break + rename_field(result) + return get_json_result(data=result) + + #******************For dialog****************** conv.message.append(msg[-1]) e, dia = DialogService.get_by_id(conv.dialog_id) if not e: @@ -191,35 +303,20 @@ def completion(): conv.message.append({"role": "assistant", "content": ""}) conv.reference.append({"chunks": [], "doc_aggs": []}) - def fillin_conv(ans): - nonlocal conv - if not conv.reference: - conv.reference.append(ans["reference"]) - else: conv.reference[-1] = ans["reference"] - conv.message[-1] = {"role": "assistant", "content": ans["answer"]} - - def rename_field(ans): - reference = ans['reference'] - if not isinstance(reference, dict): - return - for chunk_i in reference.get('chunks', []): - if 'docnm_kwd' in chunk_i: - chunk_i['doc_name'] = chunk_i['docnm_kwd'] - chunk_i.pop('docnm_kwd') - def stream(): nonlocal dia, msg, req, conv try: for ans in chat(dia, msg, True, **req): fillin_conv(ans) rename_field(ans) - yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, + ensure_ascii=False) + "\n\n" API4ConversationService.append_message(conv.id, conv.to_dict()) except Exception as e: yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), - "data": {"answer": "**ERROR**: "+str(e), "reference": []}}, + "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" - yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" if req.get("stream", True): resp = Response(stream(), mimetype="text/event-stream") @@ -228,16 +325,15 @@ def stream(): resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - else: - answer = None - for ans in chat(dia, msg, **req): - answer = ans - fillin_conv(ans) - API4ConversationService.append_message(conv.id, conv.to_dict()) - break - - rename_field(answer) - return get_json_result(data=answer) + + answer = None + for ans in chat(dia, msg, **req): + answer = ans + fillin_conv(ans) + API4ConversationService.append_message(conv.id, conv.to_dict()) + break + rename_field(answer) + return get_json_result(data=answer) except Exception as e: return server_error_response(e) @@ -332,7 +428,7 @@ def upload(): "thumbnail": thumbnail(filename, blob) } - form_data=request.form + form_data = request.form if "parser_id" in form_data.keys(): if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]: doc["parser_id"] = request.form.get("parser_id").strip() @@ -361,7 +457,7 @@ def upload(): if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") - #e, doc = DocumentService.get_by_id(doc["id"]) + # e, doc = DocumentService.get_by_id(doc["id"]) TaskService.filter_delete([Task.doc_id == doc["id"]]) e, doc = DocumentService.get_by_id(doc["id"]) doc = doc.to_dict() @@ -369,7 +465,7 @@ def upload(): bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"]) queue_tasks(doc, bucket, name) except Exception as e: - return server_error_response(e) + return server_error_response(e) return get_json_result(data=doc_result.to_json()) @@ -448,7 +544,7 @@ def list_kb_docs(): docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs] return get_json_result(data={"total": tol, "docs": docs}) - + except Exception as e: return server_error_response(e) @@ -549,7 +645,8 @@ def fillin_conv(ans): nonlocal conv if not conv.reference: conv.reference.append(ans["reference"]) - else: conv.reference[-1] = ans["reference"] + else: + conv.reference[-1] = ans["reference"] conv.message[-1] = {"role": "assistant", "content": ans["answer"]} data_type_picture = { @@ -638,4 +735,3 @@ def retrieval(): return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!', retcode=RetCode.DATA_ERROR) return server_error_response(e) - diff --git a/api/db/db_models.py b/api/db/db_models.py index 267c34b7be..9df1580d35 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -858,6 +858,7 @@ class APIToken(DataBaseModel): tenant_id = CharField(max_length=32, null=False, index=True) token = CharField(max_length=255, null=False, index=True) dialog_id = CharField(max_length=32, null=False, index=True) + source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True) class Meta: db_table = "api_token" @@ -871,6 +872,7 @@ class API4Conversation(DataBaseModel): message = JSONField(null=True) reference = JSONField(null=True, default=[]) tokens = IntegerField(default=0) + source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True) duration = FloatField(default=0, index=True) round = IntegerField(default=0, index=True) @@ -949,3 +951,17 @@ def migrate_db(): ) except Exception as e: pass + try: + migrate( + migrator.add_column('api_token', 'source', + CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)) + ) + except Exception as e: + pass + try: + migrate( + migrator.add_column('api_4_conversation', 'source', + CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)) + ) + except Exception as e: + pass