Skip to content

Commit

Permalink
fix user login issue (infiniflow#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinHuSh authored Feb 29, 2024
1 parent 4873964 commit f666f56
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 81 deletions.
107 changes: 47 additions & 60 deletions api/apps/user_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,49 +33,14 @@

@manager.route('/login', methods=['POST', 'GET'])
def login():
userinfo = None
login_channel = "password"
if session.get("access_token"):
login_channel = session["access_token_from"]
if session["access_token_from"] == "github":
userinfo = user_info_from_github(session["access_token"])
elif not request.json:
if not request.json:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg='Unautherized!')

email = request.json.get('email') if not userinfo else userinfo["email"]
email = request.json.get('email', "")
users = UserService.query(email=email)
if not users:
if request.json is not None:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
avatar = ""
try:
avatar = download_img(userinfo["avatar_url"])
except Exception as e:
stat_logger.exception(e)
user_id = get_uuid()
try:
users = user_register(user_id, {
"access_token": session["access_token"],
"email": userinfo["email"],
"avatar": avatar,
"nickname": userinfo["login"],
"login_channel": login_channel,
"last_login_time": get_format_time(),
"is_superuser": False,
})
if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e)
return server_error_response(e)
elif not request.json:
login_user(users[0])
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')

password = request.json.get('password')
try:
Expand All @@ -97,28 +62,50 @@ def login():

@manager.route('/github_callback', methods=['GET'])
def github_callback():
try:
import requests
res = requests.post(GITHUB_OAUTH.get("url"), data={
"client_id": GITHUB_OAUTH.get("client_id"),
"client_secret": GITHUB_OAUTH.get("secret_key"),
"code": request.args.get('code')
},headers={"Accept": "application/json"})
res = res.json()
if "error" in res:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg=res["error_description"])

if "user:email" not in res["scope"].split(","):
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')

session["access_token"] = res["access_token"]
session["access_token_from"] = "github"
return redirect(url_for("user.login"), code=307)
import requests
res = requests.post(GITHUB_OAUTH.get("url"), data={
"client_id": GITHUB_OAUTH.get("client_id"),
"client_secret": GITHUB_OAUTH.get("secret_key"),
"code": request.args.get('code')
}, headers={"Accept": "application/json"})
res = res.json()
if "error" in res:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg=res["error_description"])

except Exception as e:
stat_logger.exception(e)
return server_error_response(e)
if "user:email" not in res["scope"].split(","):
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')

session["access_token"] = res["access_token"]
session["access_token_from"] = "github"
userinfo = user_info_from_github(session["access_token"])
users = UserService.query(email=userinfo["email"])
user_id = get_uuid()
if not users:
try:
try:
avatar = download_img(userinfo["avatar_url"])
except Exception as e:
stat_logger.exception(e)
avatar = ""
users = user_register(user_id, {
"access_token": session["access_token"],
"email": userinfo["email"],
"avatar": avatar,
"nickname": userinfo["login"],
"login_channel": "github",
"last_login_time": get_format_time(),
"is_superuser": False,
})
if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e)

return redirect("/knowledge")


def user_info_from_github(access_token):
Expand Down Expand Up @@ -208,7 +195,7 @@ def user_register(user_id, user):
for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})

if not UserService.insert(**user):return
if not UserService.save(**user):return
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
Expand Down
1 change: 0 additions & 1 deletion api/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class TaskStatus(StrEnum):


class ParserType(StrEnum):
GENERAL = "general"
PRESENTATION = "presentation"
LAWS = "laws"
MANUAL = "manual"
Expand Down
2 changes: 1 addition & 1 deletion api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ class Knowledgebase(DataBaseModel):
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)

parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")

Expand Down
4 changes: 2 additions & 2 deletions api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def init_superuser():
"password": "admin",
"nickname": "admin",
"is_superuser": True,
"email": "[email protected]",
"email": "[email protected]",
"creator": "system",
"status": "1",
}
Expand Down Expand Up @@ -61,7 +61,7 @@ def init_superuser():
TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
print("【INFO】Super user initialized. \033[93muser name: admin, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")

chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
Expand Down
13 changes: 10 additions & 3 deletions api/db/services/user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime

import peewee
from werkzeug.security import generate_password_hash, check_password_hash

from api.db import UserTenantRole
from api.db.db_models import DB, UserTenant
from api.db.db_models import User, Tenant
from api.db.services.common_service import CommonService
from api.utils import get_uuid, get_format_time
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
from api.db import StatusEnum


Expand Down Expand Up @@ -53,6 +55,11 @@ def save(cls, **kwargs):
kwargs["id"] = get_uuid()
if "password" in kwargs:
kwargs["password"] = generate_password_hash(str(kwargs["password"]))

kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
obj = cls.model(**kwargs).save(force_insert=True)
return obj

Expand All @@ -66,10 +73,10 @@ def delete_user(cls, user_ids, update_user_dict):
@classmethod
@DB.connection_context()
def update_user(cls, user_id, user_dict):
date_time = get_format_time()
with DB.atomic():
if user_dict:
user_dict["update_time"] = date_time
user_dict["update_time"] = current_timestamp()
user_dict["update_date"] = datetime_format(datetime.now())
cls.model.update(user_dict).where(cls.model.id == user_id).execute()


Expand Down
2 changes: 1 addition & 1 deletion api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]

API_KEY = LLM.get("api_key", "infiniflow API Key")
PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")

# distribution
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
Expand Down
2 changes: 1 addition & 1 deletion deepdoc/parser/pdf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class HuParser:
def __init__(self):
self.ocr = OCR()
if not hasattr(self, "model_speciess"):
self.model_speciess = ParserType.GENERAL.value
self.model_speciess = ParserType.NAIVE.value
self.layouter = LayoutRecognizer("layout."+self.model_speciess)
self.tbl_det = TableStructureRecognizer()

Expand Down
3 changes: 1 addition & 2 deletions deepdoc/vision/layout_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class LayoutRecognizer(Recognizer):
"Equation",
]
def __init__(self, domain):
super().__init__(self.labels, domain,
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
super().__init__(self.labels, domain) #, os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))

def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
def __is_garbage(b):
Expand Down
3 changes: 1 addition & 2 deletions deepdoc/vision/table_structure_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class TableStructureRecognizer(Recognizer):
]

def __init__(self):
super().__init__(self.labels, "tsr",
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
super().__init__(self.labels, "tsr")#,os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))

def __call__(self, images, thr=0.2):
tbls = super().__call__(images, thr)
Expand Down
6 changes: 6 additions & 0 deletions rag/app/manual.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import copy
import re

from api.db import ParserType
from rag.nlp import huqie, tokenize
from deepdoc.parser import PdfParser
from rag.utils import num_tokens_from_string


class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.MANUAL.value
super().__init__()

def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
self.__images__(
Expand Down
35 changes: 29 additions & 6 deletions rag/app/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,21 @@ def __call__(self, filename, binary=None, from_page=0,

from timeit import default_timer as timer
start = timer()
start = timer()
self._layouts_rec(zoomin)
callback(0.77, "Layout analysis finished")
callback(0.5, "Layout analysis finished.")
print("paddle layouts:", timer() - start)
self._table_transformer_job(zoomin)
callback(0.7, "Table analysis finished.")
self._text_merge()
self._concat_downward(concat_between_pages=False)
self._filter_forpages()
callback(0.77, "Text merging finished")
tbls = self._extract_table_figure(True, zoomin, False)

cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
self._naive_vertical_merge()
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
#self._naive_vertical_merge()
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls


def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
Expand All @@ -44,11 +54,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
Successive text will be sliced into pieces using 'delimiter'.
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
"""

eng = lang.lower() == "english"#is_english(cks)
doc = {
"docnm_kwd": filename,
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
res = []
pdf_parser = None
sections = []
if re.search(r"\.docx?$", filename, re.IGNORECASE):
Expand All @@ -58,8 +71,19 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
sections = pdf_parser(filename if not binary else binary,
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
# add tables
for img, rows in tbls:
bs = 10
de = ";" if eng else ";"
for i in range(0, len(rows), bs):
d = copy.deepcopy(doc)
r = de.join(rows[i:i + bs])
r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r)
tokenize(d, r, eng)
d["image"] = img
res.append(d)
elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = ""
Expand All @@ -79,8 +103,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca

parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
eng = lang.lower() == "english"#is_english(cks)
res = []

# wrap up to es documents
for ck in cks:
print("--", ck)
Expand Down
4 changes: 2 additions & 2 deletions rag/svr/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from io import BytesIO
import pandas as pd

from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive

from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
Expand All @@ -48,7 +48,7 @@
BATCH_SIZE = 64

FACTORY = {
ParserType.GENERAL.value: laws,
ParserType.NAIVE.value: naive,
ParserType.PAPER.value: paper,
ParserType.BOOK.value: book,
ParserType.PRESENTATION.value: presentation,
Expand Down

0 comments on commit f666f56

Please sign in to comment.