Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinHuSh authored Dec 12, 2024
2 parents 1d907f4 + d6c74ff commit cec26a6
Show file tree
Hide file tree
Showing 16 changed files with 459 additions and 353 deletions.
12 changes: 10 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,16 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
apt purge -y nodejs npm && \
apt autoremove && \
apt update && \
apt install -y nodejs cargo
apt install -y nodejs cargo

# Add msssql17
RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
curl https://packages.microsoft.com/keys/microsoft.asc | tee /etc/apt/trusted.gpg.d/microsoft.asc && \
curl https://packages.microsoft.com/config/ubuntu/22.04/prod.list | tee /etc/apt/sources.list.d/mssql-release.list && \
apt update && \
ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql17



# Add dependencies of selenium
RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/chrome-linux64-121-0-6167-85,target=/chrome-linux64.zip \
Expand Down Expand Up @@ -170,5 +179,4 @@ RUN chmod +x ./entrypoint.sh
COPY --from=builder /ragflow/web/dist /ragflow/web/dist

COPY --from=builder /ragflow/VERSION /ragflow/VERSION

ENTRYPOINT ["./entrypoint.sh"]
13 changes: 11 additions & 2 deletions agent/component/exesql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pymysql
import psycopg2
from agent.component.base import ComponentBase, ComponentParamBase
import pyodbc


class ExeSQLParam(ComponentParamBase):
Expand All @@ -38,7 +39,7 @@ def __init__(self):
self.top_n = 30

def check(self):
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb'])
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb', 'mssql'])
self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address")
Expand Down Expand Up @@ -77,7 +78,15 @@ def _run(self, history, **kwargs):
elif self._param.db_type == 'postgresql':
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)

elif self._param.db_type == 'mssql':
conn_str = (
r'DRIVER={ODBC Driver 17 for SQL Server};'
r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';'
r'DATABASE=' + self._param.database + ';'
r'UID=' + self._param.username + ';'
r'PWD=' + self._param.password
)
db = pyodbc.connect(conn_str)
try:
cursor = db.cursor()
except Exception as e:
Expand Down
20 changes: 19 additions & 1 deletion api/apps/canvas_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,26 @@ def test_db_connect():
elif req["db_type"] == 'postgresql':
db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"])
db.connect()
elif req["db_type"] == 'mssql':
import pyodbc
connection_string = (
f"DRIVER={{ODBC Driver 17 for SQL Server}};"
f"SERVER={req['host']},{req['port']};"
f"DATABASE={req['database']};"
f"UID={req['username']};"
f"PWD={req['password']};"
)
db = pyodbc.connect(connection_string)
cursor = db.cursor()
cursor.execute("SELECT 1")
cursor.close()
else:
return server_error_response("Unsupported database type.")
if req["db_type"] != 'mssql':
db.connect()
db.close()

return get_json_result(data="Database Connection Successful!")
except Exception as e:
return server_error_response(e)

6 changes: 2 additions & 4 deletions api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from api.db.services.document_service import DocumentService
from api import settings
from api.utils.api_utils import get_json_result
import hashlib
import xxhash
import re


Expand Down Expand Up @@ -208,9 +208,7 @@ def rm():
@validate_request("doc_id", "content_with_weight")
def create():
req = request.json
md5 = hashlib.md5()
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
chunck_id = md5.hexdigest()
chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]}
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
Expand Down
2 changes: 2 additions & 0 deletions api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def get_value(d, k1, k2):
return d.get(k1, d.get(k2))

for ref in conv.reference:
if isinstance(ref, list):
continue
ref["chunks"] = [{
"id": get_value(ck, "chunk_id", "id"),
"content": get_value(ck, "content", "content_with_weight"),
Expand Down
7 changes: 2 additions & 5 deletions api/apps/sdk/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from api.db import LLMType, ParserType
from api.db.services.llm_service import TenantLLMService
from api import settings
import hashlib
import xxhash
import re
from api.utils.api_utils import token_required
from api.db.db_models import Task
Expand Down Expand Up @@ -984,10 +984,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
return get_error_data_result(
"`questions` is required to be a list"
)
md5 = hashlib.md5()
md5.update((req["content"] + document_id).encode("utf-8"))

chunk_id = md5.hexdigest()
chunk_id = xxhash.xxh64((req["content"] + document_id).encode("utf-8")).hexdigest()
d = {
"id": chunk_id,
"content_ltks": rag_tokenizer.tokenize(req["content"]),
Expand Down
7 changes: 2 additions & 5 deletions api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#
import logging
import hashlib
import xxhash
import json
import random
import re
Expand Down Expand Up @@ -508,10 +508,7 @@ def dummy(prog=None, msg=""):
for ck in th.result():
d = deepcopy(doc)
d.update(ck)
md5 = hashlib.md5()
md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8"))
d["id"] = md5.hexdigest()
d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"):
Expand Down
20 changes: 8 additions & 12 deletions api/db/services/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,13 @@
from rag.nlp import search

def trim_header_by_lines(text: str, max_length) -> str:
if len(text) <= max_length:
len_text = len(text)
if len_text <= max_length:
return text
lines = text.split("\n")
total = 0
idx = len(lines) - 1
for i in range(len(lines)-1, -1, -1):
if total + len(lines[i]) > max_length:
break
idx = i
text2 = "\n".join(lines[idx:])
return text2
for i in range(len_text):
if text[i] == '\n' and len_text - i <= max_length:
return text[i+1:]
return text

class TaskService(CommonService):
model = Task
Expand Down Expand Up @@ -183,7 +179,7 @@ def update_progress(cls, id, info):
if os.environ.get("MACOS"):
if info["progress_msg"]:
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 1000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
Expand All @@ -194,7 +190,7 @@ def update_progress(cls, id, info):
with DB.lock("update_progress", -1):
if info["progress_msg"]:
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 1000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
Expand Down
Loading

0 comments on commit cec26a6

Please sign in to comment.