Skip to content

Commit

Permalink
apply pep8 formalize (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinHuSh authored Mar 27, 2024
1 parent a02e836 commit fd7fcb5
Show file tree
Hide file tree
Showing 55 changed files with 1,573 additions and 758 deletions.
19 changes: 14 additions & 5 deletions api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def get():
"important_kwd")
def set():
req = request.json
d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
d = {
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
d["content_ltks"] = huqie.qie(req["content_with_weight"])
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req["important_kwd"]
Expand All @@ -140,10 +142,16 @@ def set():
return get_data_error_result(retmsg="Document not found!")

if doc.parser_id == ParserType.QA:
arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t) > 1]
if len(arr) != 2: return get_data_error_result(retmsg="Q&A must be separated by TAB/ENTER key.")
arr = [
t for t in re.split(
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
if len(arr) != 2:
return get_data_error_result(
retmsg="Q&A must be separated by TAB/ENTER key.")
q, a = rmPrefix(arr[0]), rmPrefix[arr[1]]
d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q + a]))
d = beAdoc(d, arr[0], arr[1], not any(
[huqie.is_chinese(t) for t in q + a]))

v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
Expand Down Expand Up @@ -177,7 +185,8 @@ def switch():
def rm():
req = request.json
try:
if not ELASTICSEARCH.deleteByQuery(Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
if not ELASTICSEARCH.deleteByQuery(
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
return get_data_error_result(retmsg="Index updating failure")
return get_json_result(data=True)
except Exception as e:
Expand Down
124 changes: 81 additions & 43 deletions api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def rm():
def list_convsersation():
dialog_id = request.args["dialog_id"]
try:
convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
convs = ConversationService.query(
dialog_id=dialog_id,
order_by=ConversationService.model.create_time,
reverse=True)
convs = [d.to_dict() for d in convs]
return get_json_result(data=convs)
except Exception as e:
Expand All @@ -111,19 +114,24 @@ def message_fit_in(msg, max_length=4000):
def count():
nonlocal msg
tks_cnts = []
for m in msg: tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
for m in msg:
tks_cnts.append(
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
total = 0
for m in tks_cnts: total += m["count"]
for m in tks_cnts:
total += m["count"]
return total

c = count()
if c < max_length: return c, msg
if c < max_length:
return c, msg

msg_ = [m for m in msg[:-1] if m.role == "system"]
msg_.append(msg[-1])
msg = msg_
c = count()
if c < max_length: return c, msg
if c < max_length:
return c, msg

ll = num_tokens_from_string(msg_[0].content)
l = num_tokens_from_string(msg_[-1].content)
Expand All @@ -146,8 +154,10 @@ def completion():
req = request.json
msg = []
for m in req["messages"]:
if m["role"] == "system": continue
if m["role"] == "assistant" and not msg: continue
if m["role"] == "system":
continue
if m["role"] == "assistant" and not msg:
continue
msg.append({"role": m["role"], "content": m["content"]})
try:
e, conv = ConversationService.get_by_id(req["conversation_id"])
Expand All @@ -160,7 +170,8 @@ def completion():
del req["conversation_id"]
del req["messages"]
ans = chat(dia, msg, **req)
if not conv.reference: conv.reference = []
if not conv.reference:
conv.reference = []
conv.reference.append(ans["reference"])
conv.message.append({"role": "assistant", "content": ans["answer"]})
ConversationService.update_by_id(conv.id, conv.to_dict())
Expand All @@ -180,52 +191,67 @@ def chat(dialog, messages, **kwargs):
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)

field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
## try to use sql if field mapping is good to go
# try to use sql if field mapping is good to go
if field_map:
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)

prompt_config = dialog.prompt_config
for p in prompt_config["parameters"]:
if p["key"] == "knowledge": continue
if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"])
if p["key"] == "knowledge":
continue
if p["key"] not in kwargs and not p["optional"]:
raise KeyError("Miss parameter: " + p["key"])
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
prompt_config["system"] = prompt_config["system"].replace(
"{%s}" % p["key"], " ")

for _ in range(len(questions)//2):
for _ in range(len(questions) // 2):
questions.append(questions[-1])
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
kbinfos = {"total":0, "chunks":[],"doc_aggs":[]}
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
else:
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
chat_logger.info(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))

if not knowledges and prompt_config.get("empty_response"):
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
return {
"answer": prompt_config["empty_response"], "reference": kbinfos}

kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting
msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
msg = [{"role": m["role"], "content": m["content"]}
for m in messages if m["role"] != "system"]
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
gen_conf["max_tokens"] = min(
gen_conf["max_tokens"],
llm.max_tokens - used_token_count)
answer = chat_mdl.chat(
prompt_config["system"].format(
**kwargs), msg, gen_conf)
chat_logger.info("User: {}|Assistant: {}".format(
msg[-1]["content"], answer))

if knowledges:
answer, idx = retrievaler.insert_citations(answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],
[ck["vector"] for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
kbinfos["doc_aggs"] = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
kbinfos["doc_aggs"] = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
for c in kbinfos["chunks"]:
if c.get("vector"): del c["vector"]
if c.get("vector"):
del c["vector"]
return {"answer": answer, "reference": kbinfos}


Expand All @@ -245,9 +271,11 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
question
)
tried_times = 0

def get_table():
nonlocal sys_prompt, user_promt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
"temperature": 0.06})
print(user_promt, sql)
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
Expand All @@ -262,8 +290,10 @@ def get_table():
else:
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:continue
if len(flds) > 11:break
if k in forbidden_select_fields4resume:
continue
if len(flds) > 11:
break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]

Expand All @@ -284,13 +314,13 @@ def get_table():
问题如下:
{}
你上一次给出的错误SQL如下:
{}
后台报错如下:
{}
请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
""".format(
index_name(tenant_id),
Expand All @@ -302,16 +332,24 @@ def get_table():

chat_logger.info("GET table: {}".format(tbl))
print(tbl)
if tbl.get("error") or len(tbl["rows"]) == 0: return None, None
if tbl.get("error") or len(tbl["rows"]) == 0:
return None, None

docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
docid_idx = set([ii for ii, c in enumerate(
tbl["columns"]) if c["name"] == "doc_id"])
docnm_idx = set([ii for ii, c in enumerate(
tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(
len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]

# compose markdown table
clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
("|------|" if docid_idx and docid_idx else "")
rows = ["|" +
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
"|" for r in tbl["rows"]]
if not docid_idx or not docnm_idx:
chat_logger.warning("SQL missing field: " + sql)
return "\n".join([clmns, line, "\n".join(rows)]), []
Expand All @@ -328,5 +366,5 @@ def get_table():
return {
"answer": "\n".join([clmns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
"doc_aggs":[{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
}
43 changes: 30 additions & 13 deletions api/apps/dialog_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,30 @@ def set_dialog():
}
prompt_config = req.get("prompt_config", default_prompt)

if not prompt_config["system"]: prompt_config["system"] = default_prompt["system"]
if not prompt_config["system"]:
prompt_config["system"] = default_prompt["system"]
# if len(prompt_config["parameters"]) < 1:
# prompt_config["parameters"] = default_prompt["parameters"]
# for p in prompt_config["parameters"]:
# if p["key"] == "knowledge":break
# else: prompt_config["parameters"].append(default_prompt["parameters"][0])

for p in prompt_config["parameters"]:
if p["optional"]: continue
if p["optional"]:
continue
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"]))
return get_data_error_result(
retmsg="Parameter '{}' is not used".format(p["key"]))

try:
e, tenant = TenantService.get_by_id(current_user.id)
if not e: return get_data_error_result(retmsg="Tenant not found!")
if not e:
return get_data_error_result(retmsg="Tenant not found!")
llm_id = req.get("llm_id", tenant.llm_id)
if not dialog_id:
if not req.get("kb_ids"):return get_data_error_result(retmsg="Fail! Please select knowledgebase!")
if not req.get("kb_ids"):
return get_data_error_result(
retmsg="Fail! Please select knowledgebase!")
dia = {
"id": get_uuid(),
"tenant_id": current_user.id,
Expand All @@ -86,17 +92,21 @@ def set_dialog():
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight
}
if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!")
if not DialogService.save(**dia):
return get_data_error_result(retmsg="Fail to new a dialog!")
e, dia = DialogService.get_by_id(dia["id"])
if not e: return get_data_error_result(retmsg="Fail to new a dialog!")
if not e:
return get_data_error_result(retmsg="Fail to new a dialog!")
return get_json_result(data=dia.to_json())
else:
del req["dialog_id"]
if "kb_names" in req: del req["kb_names"]
if "kb_names" in req:
del req["kb_names"]
if not DialogService.update_by_id(dialog_id, req):
return get_data_error_result(retmsg="Dialog not found!")
e, dia = DialogService.get_by_id(dialog_id)
if not e: return get_data_error_result(retmsg="Fail to update a dialog!")
if not e:
return get_data_error_result(retmsg="Fail to update a dialog!")
dia = dia.to_dict()
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
return get_json_result(data=dia)
Expand All @@ -110,7 +120,8 @@ def get():
dialog_id = request.args["dialog_id"]
try:
e, dia = DialogService.get_by_id(dialog_id)
if not e: return get_data_error_result(retmsg="Dialog not found!")
if not e:
return get_data_error_result(retmsg="Dialog not found!")
dia = dia.to_dict()
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
return get_json_result(data=dia)
Expand All @@ -122,7 +133,8 @@ def get_kb_names(kb_ids):
ids, nms = [], []
for kid in kb_ids:
e, kb = KnowledgebaseService.get_by_id(kid)
if not e or kb.status != StatusEnum.VALID.value: continue
if not e or kb.status != StatusEnum.VALID.value:
continue
ids.append(kid)
nms.append(kb.name)
return ids, nms
Expand All @@ -132,7 +144,11 @@ def get_kb_names(kb_ids):
@login_required
def list():
try:
diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time)
diags = DialogService.query(
tenant_id=current_user.id,
status=StatusEnum.VALID.value,
reverse=True,
order_by=DialogService.model.create_time)
diags = [d.to_dict() for d in diags]
for d in diags:
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
Expand All @@ -147,7 +163,8 @@ def list():
def rm():
req = request.json
try:
DialogService.update_many_by_id([{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
DialogService.update_many_by_id(
[{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
Loading

0 comments on commit fd7fcb5

Please sign in to comment.