Skip to content

Commit

Permalink
add dockerfile for cuda envirement. Refine table search strategy, (in…
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinHuSh authored Mar 14, 2024
1 parent f1ced48 commit 9fe9fc4
Show file tree
Hide file tree
Showing 18 changed files with 260 additions and 85 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ADD ./rag ./rag
ENV PYTHONPATH=/ragflow/
ENV HF_ENDPOINT=https://hf-mirror.com

/root/miniconda3/envs/py11/bin/pip install peewee==3.17.1
ADD docker/entrypoint.sh ./entrypoint.sh
RUN chmod +x ./entrypoint.sh

Expand Down
26 changes: 26 additions & 0 deletions Dockerfile.cuda
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
FROM swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow-base:v1.0
USER root

WORKDIR /ragflow

## for cuda > 12.0
RUN /root/miniconda3/envs/py11/bin/pip uninstall -y onnxruntime-gpu
RUN /root/miniconda3/envs/py11/bin/pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/


ADD ./web ./web
RUN cd ./web && npm i && npm run build

ADD ./api ./api
ADD ./conf ./conf
ADD ./deepdoc ./deepdoc
ADD ./rag ./rag

ENV PYTHONPATH=/ragflow/
ENV HF_ENDPOINT=https://hf-mirror.com

/root/miniconda3/envs/py11/bin/pip install peewee==3.17.1
ADD docker/entrypoint.sh ./entrypoint.sh
RUN chmod +x ./entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]
101 changes: 69 additions & 32 deletions api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle
from api.settings import access_logger, stat_logger, retrievaler
from api.settings import access_logger, stat_logger, retrievaler, chat_logger
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
Expand Down Expand Up @@ -183,10 +183,10 @@ def chat(dialog, messages, **kwargs):
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
## try to use sql if field mapping is good to go
if field_map:
stat_logger.info("Use SQL to retrieval.")
markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl)
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
markdown_tbl, chunks = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
if markdown_tbl:
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
return {"answer": markdown_tbl, "reference": {"chunks": chunks, "doc_aggs": []}}

prompt_config = dialog.prompt_config
for p in prompt_config["parameters"]:
Expand All @@ -201,6 +201,7 @@ def chat(dialog, messages, **kwargs):
dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))

if not knowledges and prompt_config.get("empty_response"):
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
Expand All @@ -212,7 +213,7 @@ def chat(dialog, messages, **kwargs):
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
stat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))

if knowledges:
answer, idx = retrievaler.insert_citations(answer,
Expand All @@ -237,47 +238,83 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
问题如下:
{}
请写出SQL且只要SQL,不要有其他说明及文字。
请写出SQL, 且只要SQL,不要有其他说明及文字。
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question
)
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
stat_logger.info(f"“{question}” get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*?select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;;]|```).*", "", sql)
if sql[:len("select ")] != "select ":
return None, None
if sql[:len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:continue
if len(flds) > 11:break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]

stat_logger.info(f"“{question}” get SQL(refined): {sql}")
tbl = retrievaler.sql_retrieval(sql, format="json")
if not tbl or len(tbl["rows"]) == 0: return None, None
tried_times = 0
def get_table():
nonlocal sys_prompt, user_promt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
print(user_promt, sql)
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"([;;]|```).*", "", sql)
if sql[:len("select ")] != "select ":
return None, None
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
if sql[:len("select *")] != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
else:
flds = []
for k in field_map.keys():
if k in forbidden_select_fields4resume:continue
if len(flds) > 11:break
flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]

print(f"“{question}” get SQL(refined): {sql}")

chat_logger.info(f"“{question}” get SQL(refined): {sql}")
tried_times += 1
return retrievaler.sql_retrieval(sql, format="json"), sql

tbl, sql = get_table()
if tbl.get("error") and tried_times <= 2:
user_promt = """
表名:{};
数据库表字段说明如下:
{}
问题如下:
{}
你上一次给出的错误SQL如下:
{}
后台报错如下:
{}
请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
question, sql, tbl["error"]
)
tbl, sql = get_table()
chat_logger.info("TRY it again: {}".format(sql))

chat_logger.info("GET table: {}".format(tbl))
print(tbl)
if tbl.get("error") or len(tbl["rows"]) == 0: return None, None

docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]

# compose markdown table
clmns = "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
rows = ["|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|原文|" if docid_idx and docid_idx else "|")
line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
if not docid_idx or not docnm_idx:
access_logger.error("SQL missing field: " + sql)
chat_logger.warning("SQL missing field: " + sql)
return "\n".join([clmns, line, "\n".join(rows)]), []

rows = "\n".join([r + f"##{ii}$$" for ii, r in enumerate(rows)])
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
docid_idx = list(docid_idx)[0]
docnm_idx = list(docnm_idx)[0]
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
4 changes: 2 additions & 2 deletions api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class Document(DataBaseModel):
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
progress = FloatField(default=0)
progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
progress_msg = TextField(null=True, help_text="process message", default="")
process_begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
Expand All @@ -520,7 +520,7 @@ class Task(DataBaseModel):
begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
progress = FloatField(default=0)
progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="")
progress_msg = TextField(null=True, help_text="process message", default="")


class Dialog(DataBaseModel):
Expand Down
57 changes: 57 additions & 0 deletions api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ def init_llm_factory():
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
},
{
"name": "Local",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "0",
},{
"name": "Moonshot",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
}
# {
# "name": "文心一言",
# "logo": "",
Expand Down Expand Up @@ -155,6 +166,12 @@ def init_llm_factory():
"tags": "LLM,CHAT,32K",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "qwen-max-1201",
"tags": "LLM,CHAT,6K",
"max_tokens": 5899,
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[1]["name"],
"llm_name": "text-embedding-v2",
Expand Down Expand Up @@ -201,6 +218,46 @@ def init_llm_factory():
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
# ---------------------- 本地 ----------------------
{
"fid": factory_infos[3]["name"],
"llm_name": "qwen-14B-chat",
"tags": "LLM,CHAT,",
"max_tokens": 8191,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[3]["name"],
"llm_name": "flag-enbedding",
"tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ Moonshot -----------------------
{
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-8k",
"tags": "LLM,CHAT,",
"max_tokens": 7900,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[4]["name"],
"llm_name": "flag-enbedding",
"tags": "TEXT EMBEDDING,",
"max_tokens": 128 * 1000,
"model_type": LLMType.EMBEDDING.value
},{
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-32k",
"tags": "LLM,CHAT,",
"max_tokens": 32768,
"model_type": LLMType.CHAT.value
},{
"fid": factory_infos[4]["name"],
"llm_name": "moonshot-v1-128k",
"tags": "LLM,CHAT",
"max_tokens": 128 * 1000,
"model_type": LLMType.CHAT.value
},
]
for info in factory_infos:
LLMFactoriesService.save(**info)
Expand Down
15 changes: 11 additions & 4 deletions api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
stat_logger = getLogger("stat")
access_logger = getLogger("access")
database_logger = getLogger("database")
chat_logger = getLogger("chat")

API_VERSION = "v1"
RAG_FLOW_SERVICE_NAME = "ragflow"
Expand Down Expand Up @@ -69,9 +70,15 @@
"image2text_model": "glm-4v",
"asr_model": "",
},
"local": {
"chat_model": "",
"embedding_model": "",
"Local": {
"chat_model": "qwen-14B-chat",
"embedding_model": "flag-enbedding",
"image2text_model": "",
"asr_model": "",
},
"Moonshot": {
"chat_model": "moonshot-v1-8k",
"embedding_model": "flag-enbedding",
"image2text_model": "",
"asr_model": "",
}
Expand All @@ -86,7 +93,7 @@
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]

API_KEY = LLM.get("api_key", "infiniflow API Key")
API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")

# distribution
Expand Down
2 changes: 1 addition & 1 deletion deepdoc/parser/excel_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def row_number(fnm, binary):
total = 0
for sheetname in wb.sheetnames:
ws = wb[sheetname]
total += len(ws.rows)
total += len(list(ws.rows))
return total

if fnm.split(".")[-1].lower() in ["csv", "txt"]:
Expand Down
4 changes: 2 additions & 2 deletions deepdoc/parser/pdf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,14 +655,14 @@ def nearest(tbls):
#if min(tv, fv) > 2000:
# i += 1
# continue
if tv < fv:
if tv < fv and tk:
tables[tk].insert(0, c)
logging.debug(
"TABLE:" +
self.boxes[i]["text"] +
"; Cap: " +
tk)
else:
elif fk:
figures[fk].insert(0, c)
logging.debug(
"FIGURE:" +
Expand Down
4 changes: 2 additions & 2 deletions deepdoc/parser/ppt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __extract(self, shape):

if shape.shape_type == 6:
texts = []
for p in shape.shapes:
for p in sorted(shape.shapes, key=lambda x: (x.top//10, x.left)):
t = self.__extract(p)
if t: texts.append(t)
return "\n".join(texts)
Expand All @@ -46,7 +46,7 @@ def __call__(self, fnm, from_page, to_page, callback=None):
if i < from_page: continue
if i >= to_page:break
texts = []
for shape in slide.shapes:
for shape in sorted(slide.shapes, key=lambda x: (x.top//10, x.left)):
txt = self.__extract(shape)
if txt: texts.append(txt)
txts.append("\n".join(texts))
Expand Down
Loading

0 comments on commit 9fe9fc4

Please sign in to comment.