Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK for session #2312

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions api/apps/sdk/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from flask import request

from api.db import StatusEnum
from api.db.db_models import TenantLLM
from api.db.services.dialog_service import DialogService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
Expand All @@ -30,7 +31,6 @@
@token_required
def save(tenant_id):
req = request.json
id = req.get("id")
# dataset
if req.get("knowledgebases") == []:
return get_data_error_result(retmsg="knowledgebases can not be empty list")
Expand All @@ -41,8 +41,8 @@ def save(tenant_id):
return get_data_error_result(retmsg="knowledgebase needs id")
if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id):
return get_data_error_result(retmsg="you do not own the knowledgebase")
if not DocumentService.query(kb_id=kb["id"]):
return get_data_error_result(retmsg="There is a invalid knowledgebase")
# if not DocumentService.query(kb_id=kb["id"]):
# return get_data_error_result(retmsg="There is a invalid knowledgebase")
kb_list.append(kb["id"])
req["kb_ids"] = kb_list
# llm
Expand Down Expand Up @@ -72,18 +72,22 @@ def save(tenant_id):
req[key] = prompt.pop(key)
req["prompt_config"] = req.pop("prompt")
# create
if not id:
if "id" not in req:
# dataset
if not kb_list:
return get_data_error_result(retmsg="knowledgebase is required!")
return get_data_error_result(retmsg="knowledgebases are required!")
# init
req["id"] = get_uuid()
req["description"] = req.get("description", "A helpful Assistant")
req["icon"] = req.get("avatar", "")
req["top_n"] = req.get("top_n", 6)
req["top_k"] = req.get("top_k", 1024)
req["rerank_id"] = req.get("rerank_id", "")
req["llm_id"] = req.get("llm_id", tenant.llm_id)
if req.get("llm_id"):
if not TenantLLMService.query(llm_name=req["llm_id"]):
return get_data_error_result(retmsg="the model_name does not exist.")
else:
req["llm_id"] = tenant.llm_id
if not req.get("name"):
return get_data_error_result(retmsg="name is required.")
if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
Expand Down Expand Up @@ -149,14 +153,20 @@ def save(tenant_id):
if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value):
return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR)
# prompt
if not req["id"]:
return get_data_error_result(retmsg="id can not be empty")
e, res = DialogService.get_by_id(req["id"])
res = res.to_json()
if "llm_id" in req:
if not TenantLLMService.query(llm_name=req["llm_id"]):
return get_data_error_result(retmsg="the model_name does not exist.")
if "name" in req:
if not req.get("name"):
return get_data_error_result(retmsg="name is not empty.")
if req["name"].lower() != res["name"].lower() \
and len(DialogService.query(name=req["name"], tenant_id=tenant_id,status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(retmsg="Duplicated knowledgebase name in updating dataset.")
and len(
DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
return get_data_error_result(retmsg="Duplicated assistant name in updating dataset.")
if "prompt_config" in req:
res["prompt_config"].update(req["prompt_config"])
for p in res["prompt_config"]["parameters"]:
Expand Down Expand Up @@ -186,7 +196,7 @@ def delete(tenant_id):
if "id" not in req:
return get_data_error_result(retmsg="id is required")
id = req['id']
if not DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value):
if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value):
return get_json_result(data=False, retmsg='you do not own the assistant.', retcode=RetCode.OPERATING_ERROR)

temp_dict = {"status": StatusEnum.INVALID.value}
Expand All @@ -200,21 +210,22 @@ def get(tenant_id):
req = request.args
if "id" in req:
id = req["id"]
ass = DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value)
ass = DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value)
if not ass:
return get_json_result(data=False, retmsg='You do not own the assistant.', retcode=RetCode.OPERATING_ERROR)
if "name" in req:
name = req["name"]
if ass[0].name != name:
return get_json_result(data=False, retmsg='name does not match id.', retcode=RetCode.OPERATING_ERROR)
res=ass[0].to_json()
res = ass[0].to_json()
else:
if "name" in req:
name = req["name"]
ass = DialogService.query(name=name, tenant_id=tenant_id,status=StatusEnum.VALID.value)
ass = DialogService.query(name=name, tenant_id=tenant_id, status=StatusEnum.VALID.value)
if not ass:
return get_json_result(data=False, retmsg='You do not own the dataset.',retcode=RetCode.OPERATING_ERROR)
res=ass[0].to_json()
return get_json_result(data=False, retmsg='You do not own the assistant.',
retcode=RetCode.OPERATING_ERROR)
res = ass[0].to_json()
else:
return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.")
renamed_dict = {}
Expand Down Expand Up @@ -258,7 +269,7 @@ def list_assistants(tenant_id):
reverse=True,
order_by=DialogService.model.create_time)
assts = [d.to_dict() for d in assts]
list_assts=[]
list_assts = []
renamed_dict = {}
key_mapping = {"parameters": "variables",
"prologue": "opener",
Expand Down
6 changes: 5 additions & 1 deletion api/apps/sdk/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def save(tenant_id):
req.update(mapped_keys)
if not KnowledgebaseService.save(**req):
return get_data_error_result(retmsg="Create dataset error.(Database error)")
renamed_data={}
renamed_data = {}
e, k = KnowledgebaseService.get_by_id(req["id"])
for key, value in k.to_dict().items():
new_key = key_mapping.get(key, key)
Expand Down Expand Up @@ -88,6 +88,9 @@ def save(tenant_id):
data=False, retmsg='You do not own the dataset.',
retcode=RetCode.OPERATING_ERROR)

if not req["id"]:
return get_data_error_result(
retmsg="id can not be empty.")
e, kb = KnowledgebaseService.get_by_id(req["id"])

if "chunk_count" in req:
Expand All @@ -108,6 +111,7 @@ def save(tenant_id):
retmsg="If chunk count is not 0, parse method is not changable.")
req['parser_id'] = req.pop('parse_method')
if "name" in req:
req["name"] = req["name"].strip()
if req["name"].lower() != kb.name.lower() \
and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id,
status=StatusEnum.VALID.value)) > 0:
Expand Down
168 changes: 168 additions & 0 deletions api/apps/sdk/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from copy import deepcopy
from uuid import uuid4

from flask import request, Response

from api.db import StatusEnum
from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result
from api.utils.api_utils import get_json_result, token_required


@manager.route('/save', methods=['POST'])
@token_required
def set_conversation(tenant_id):
req = request.json
conv_id = req.get("id")
if "messages" in req:
req["message"] = req.pop("messages")
if req["message"]:
for message in req["message"]:
if "reference" in message:
req["reference"] = message.pop("reference")
if "assistant_id" in req:
req["dialog_id"] = req.pop("assistant_id")
if "id" in req:
del req["id"]
conv = ConversationService.query(id=conv_id)
if not conv:
return get_data_error_result(retmsg="Session does not exist")
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(retmsg="You do not own the session")
if req.get("dialog_id"):
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
if not dia:
return get_data_error_result(retmsg="You do not own the assistant")
if "dialog_id" in req and not req.get("dialog_id"):
return get_data_error_result(retmsg="assistant_id can not be empty.")
if "name" in req and not req.get("name"):
return get_data_error_result(retmsg="name can not be empty.")
if "message" in req and not req.get("message"):
return get_data_error_result(retmsg="messages can not be empty")
if not ConversationService.update_by_id(conv_id, req):
return get_data_error_result(retmsg="Session updates error")
return get_json_result(data=True)

if not req.get("dialog_id"):
return get_data_error_result(retmsg="assistant_id is required.")
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
if not dia:
return get_data_error_result(retmsg="You do not own the assistant")
conv = {
"id": get_uuid(),
"dialog_id": req["dialog_id"],
"name": req.get("name", "New session"),
"message": req.get("message", [{"role": "assistant", "content": dia[0].prompt_config["prologue"]}]),
"reference": req.get("reference", [])
}
if not conv.get("name"):
return get_data_error_result(retmsg="name can not be empty.")
if not conv.get("message"):
return get_data_error_result(retmsg="messages can not be empty")
ConversationService.save(**conv)
e, conv = ConversationService.get_by_id(conv["id"])
if not e:
return get_data_error_result(retmsg="Fail to new session!")
conv = conv.to_dict()
conv["messages"] = conv.pop("message")
conv["assistant_id"] = conv.pop("dialog_id")
for message in conv["messages"]:
message["reference"] = conv.get("reference")
del conv["reference"]
return get_json_result(data=conv)


@manager.route('/completion', methods=['POST'])
@token_required
def completion(tenant_id):
req = request.json
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# {"role": "user", "content": "上海有吗?"}
# ]}
msg = []
question = {
"content": req.get("question"),
"role": "user",
"id": str(uuid4())
}
req["messages"].append(question)
for m in req["messages"]:
if m["role"] == "system": continue
if m["role"] == "assistant" and not msg: continue
m["id"] = m.get("id", str(uuid4()))
msg.append(m)
message_id = msg[-1].get("id")
conv = ConversationService.query(id=req["id"])
conv = conv[0]
if not conv:
return get_data_error_result(retmsg="Session does not exist")
if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_data_error_result(retmsg="You do not own the session")
conv.message = deepcopy(req["messages"])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(retmsg="Dialog not found!")
del req["id"]
del req["messages"]

if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
conv.reference.append({"chunks": [], "doc_aggs": []})

def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
"id": message_id, "prompt": ans.get("prompt", "")}
ans["id"] = message_id

def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, **req):
fillin_conv(ans)
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"

if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp

else:
answer = None
for ans in chat(dia, msg, **req):
answer = ans
fillin_conv(ans)
ConversationService.update_by_id(conv.id, conv.to_dict())
break
return get_json_result(data=answer)
21 changes: 18 additions & 3 deletions sdk/python/ragflow/modules/chat_assistant.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import List

from .base import Base
from .session import Session, Message


class Assistant(Base):
def __init__(self, rag, res_dict):
self.id=""
self.id = ""
self.name = "assistant"
self.avatar = "path/to/avatar"
self.knowledgebases = ["kb1"]
Expand Down Expand Up @@ -41,8 +44,8 @@ def __init__(self, rag, res_dict):

def save(self) -> bool:
res = self.post('/assistant/save',
{"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases":self.knowledgebases,
"llm":self.llm.to_json(),"prompt":self.prompt.to_json()
{"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases": self.knowledgebases,
"llm": self.llm.to_json(), "prompt": self.prompt.to_json()
})
res = res.json()
if res.get("retmsg") == "success": return True
Expand All @@ -54,3 +57,15 @@ def delete(self) -> bool:
res = res.json()
if res.get("retmsg") == "success": return True
raise Exception(res["retmsg"])

def create_session(self, name: str = "New session", messages: List[Message] = [
{"role": "assistant", "reference": [],
"content": "您好,我是您的助手小樱,长得可爱又善良,can I help you?"}]) -> Session:
res = self.post("/session/save", {"name": name, "messages": messages, "assistant_id": self.id, })
res = res.json()
if res.get("retmsg") == "success":
return Session(self.rag, res['data'])
raise Exception(res["retmsg"])

def get_prologue(self):
return self.prompt.opener
Loading