Skip to content

Commit

Permalink
Fix errors detected by Ruff (infiniflow#3918)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Fix errors detected by Ruff

### Type of change

- [x] Refactoring
  • Loading branch information
yuzhichang authored Dec 8, 2024
1 parent e267a02 commit 0d68a6c
Show file tree
Hide file tree
Showing 97 changed files with 2,560 additions and 1,978 deletions.
36 changes: 23 additions & 13 deletions agent/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def __str__(self):
"components": {}
}
for k in self.dsl.keys():
if k in ["components"]:continue
if k in ["components"]:
continue
dsl[k] = deepcopy(self.dsl[k])

for k, cpn in self.components.items():
Expand All @@ -158,7 +159,8 @@ def reset(self):

def get_compnent_name(self, cid):
for n in self.dsl["graph"]["nodes"]:
if cid == n["id"]: return n["data"]["name"]
if cid == n["id"]:
return n["data"]["name"]
return ""

def run(self, **kwargs):
Expand All @@ -173,7 +175,8 @@ def run(self, **kwargs):
if kwargs.get("stream"):
for an in ans():
yield an
else: yield ans
else:
yield ans
return

if not self.path:
Expand All @@ -188,7 +191,8 @@ def run(self, **kwargs):
def prepare2run(cpns):
nonlocal ran, ans
for c in cpns:
if self.path[-1] and c == self.path[-1][-1]: continue
if self.path[-1] and c == self.path[-1][-1]:
continue
cpn = self.components[c]["obj"]
if cpn.component_name == "Answer":
self.answer.append(c)
Expand All @@ -197,7 +201,8 @@ def prepare2run(cpns):
if c not in without_dependent_checking:
cpids = cpn.get_dependent_components()
if any([cc not in self.path[-1] for cc in cpids]):
if c not in waiting: waiting.append(c)
if c not in waiting:
waiting.append(c)
continue
yield "*'{}'* is running...🕞".format(self.get_compnent_name(c))
ans = cpn.run(self.history, **kwargs)
Expand All @@ -211,10 +216,12 @@ def prepare2run(cpns):
logging.debug(f"Canvas.run: {ran} {self.path}")
cpn_id = self.path[-1][ran]
cpn = self.get_component(cpn_id)
if not cpn["downstream"]: break
if not cpn["downstream"]:
break

loop = self._find_loop()
if loop: raise OverflowError(f"Too much loops: {loop}")
if loop:
raise OverflowError(f"Too much loops: {loop}")

if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
switch_out = cpn["obj"].output()[1].iloc[0, 0]
Expand Down Expand Up @@ -283,27 +290,30 @@ def get_embedding_model(self):

def _find_loop(self, max_loops=6):
path = self.path[-1][::-1]
if len(path) < 2: return False
if len(path) < 2:
return False

for i in range(len(path)):
if path[i].lower().find("answer") >= 0:
path = path[:i]
break

if len(path) < 2: return False
if len(path) < 2:
return False

for l in range(2, len(path) // 2):
pat = ",".join(path[0:l])
for loc in range(2, len(path) // 2):
pat = ",".join(path[0:loc])
path_str = ",".join(path)
if len(pat) >= len(path_str): return False
if len(pat) >= len(path_str):
return False
loop = max_loops
while path_str.find(pat) == 0 and loop >= 0:
loop -= 1
if len(pat)+1 >= len(path_str):
return False
path_str = path_str[len(pat)+1:]
if loop < 0:
pat = " => ".join([p.split(":")[0] for p in path[0:l]])
pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
return pat + " => " + pat

return False
Expand Down
70 changes: 70 additions & 0 deletions agent/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,73 @@ def component_class(class_name):
m = importlib.import_module("agent.component")
c = getattr(m, class_name)
return c

__all__ = [
"Begin",
"BeginParam",
"Generate",
"GenerateParam",
"Retrieval",
"RetrievalParam",
"Answer",
"AnswerParam",
"Categorize",
"CategorizeParam",
"Switch",
"SwitchParam",
"Relevant",
"RelevantParam",
"Message",
"MessageParam",
"RewriteQuestion",
"RewriteQuestionParam",
"KeywordExtract",
"KeywordExtractParam",
"Concentrator",
"ConcentratorParam",
"Baidu",
"BaiduParam",
"DuckDuckGo",
"DuckDuckGoParam",
"Wikipedia",
"WikipediaParam",
"PubMed",
"PubMedParam",
"ArXiv",
"ArXivParam",
"Google",
"GoogleParam",
"Bing",
"BingParam",
"GoogleScholar",
"GoogleScholarParam",
"DeepL",
"DeepLParam",
"GitHub",
"GitHubParam",
"BaiduFanyi",
"BaiduFanyiParam",
"QWeather",
"QWeatherParam",
"ExeSQL",
"ExeSQLParam",
"YahooFinance",
"YahooFinanceParam",
"WenCai",
"WenCaiParam",
"Jin10",
"Jin10Param",
"TuShare",
"TuShareParam",
"AkShare",
"AkShareParam",
"Crawler",
"CrawlerParam",
"Invoke",
"InvokeParam",
"Template",
"TemplateParam",
"Email",
"EmailParam",
"component_class"
]
15 changes: 10 additions & 5 deletions agent/component/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ def _run(self, history, **kwargs):
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
if not isinstance(o, list): o = [o]
if not isinstance(o, list):
o = [o]
o = pd.DataFrame(o)

if allow_partial or not isinstance(o, partial):
Expand All @@ -440,7 +441,8 @@ def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]
for oo in o():
if not isinstance(oo, pd.DataFrame):
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
else: outs = oo
else:
outs = oo
return self._param.output_var_name, outs

def reset(self):
Expand Down Expand Up @@ -482,13 +484,15 @@ def get_input(self):
outs.append(pd.DataFrame([{"content": q["value"]}]))
if outs:
df = pd.concat(outs, ignore_index=True)
if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
if "content" in df:
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
return df

upstream_outs = []

for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.get_component_name(u) in ["switch", "concentrator"]:
continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
if o is not None:
Expand Down Expand Up @@ -532,7 +536,8 @@ def get_stream_input(self):
reversed_cpnts.extend(self._canvas.path[-1])

for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "answer"]: continue
if self.get_component_name(u) in ["switch", "answer"]:
continue
return self._canvas.get_component(u)["obj"].output()[1]

@staticmethod
Expand Down
13 changes: 8 additions & 5 deletions agent/component/categorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,18 @@ def check(self):
super().check()
self.check_empty(self.category_description, "[Categorize] Category examples")
for k, v in self.category_description.items():
if not k: raise ValueError("[Categorize] Category name can not be empty!")
if not v.get("to"): raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
if not k:
raise ValueError("[Categorize] Category name can not be empty!")
if not v.get("to"):
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")

def get_prompt(self):
cate_lines = []
for c, desc in self.category_description.items():
for l in desc.get("examples", "").split("\n"):
if not l: continue
cate_lines.append("Question: {}\tCategory: {}".format(l, c))
for line in desc.get("examples", "").split("\n"):
if not line:
continue
cate_lines.append("Question: {}\tCategory: {}".format(line, c))
descriptions = []
for c, desc in self.category_description.items():
if desc.get("description"):
Expand Down
1 change: 0 additions & 1 deletion agent/component/deepl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
#
from abc import ABC
import re
from agent.component.base import ComponentBase, ComponentParamBase
import deepl

Expand Down
6 changes: 4 additions & 2 deletions agent/component/exesql.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ def check(self):
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.top_n, "Number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.")
if self.host == "ragflow-mysql":
raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow":
raise ValueError("The host is not accessible.")


class ExeSQL(ComponentBase, ABC):
Expand Down
36 changes: 24 additions & 12 deletions agent/component/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,16 @@ def check(self):

def gen_conf(self):
conf = {}
if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens
if self.temperature > 0: conf["temperature"] = self.temperature
if self.top_p > 0: conf["top_p"] = self.top_p
if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty
if self.max_tokens > 0:
conf["max_tokens"] = self.max_tokens
if self.temperature > 0:
conf["temperature"] = self.temperature
if self.top_p > 0:
conf["top_p"] = self.top_p
if self.presence_penalty > 0:
conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0:
conf["frequency_penalty"] = self.frequency_penalty
return conf


Expand Down Expand Up @@ -83,7 +88,8 @@ def set_cite(self, retrieval_res, answer):
recall_docs = []
for i in idx:
did = retrieval_res.loc[int(i), "doc_id"]
if did in doc_ids: continue
if did in doc_ids:
continue
doc_ids.add(did)
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})

Expand All @@ -108,7 +114,8 @@ def _run(self, history, **kwargs):
retrieval_res = []
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"): continue
if not para.get("component_id"):
continue
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
Expand Down Expand Up @@ -142,7 +149,8 @@ def _run(self, history, **kwargs):

if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else: retrieval_res = pd.DataFrame([])
else:
retrieval_res = pd.DataFrame([])

for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
Expand All @@ -164,9 +172,11 @@ def _run(self, history, **kwargs):
return pd.DataFrame([res])

msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())

if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
Expand All @@ -185,9 +195,11 @@ def stream_output(self, chat_mdl, prompt, retrieval_res):
return

msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
answer = ""
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []}
Expand Down
3 changes: 2 additions & 1 deletion agent/component/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def _run(self, history, **kwargs):
hist = self._canvas.get_history(4)
conv = []
for m in hist:
if m["role"] not in ["user", "assistant"]: continue
if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conv = "\n".join(conv)

Expand Down
Loading

0 comments on commit 0d68a6c

Please sign in to comment.