From e512f1cf05dbf4a26508a8cd0c868b767f049eb8 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Fri, 27 Sep 2024 10:22:13 +0800 Subject: [PATCH] fix generate bug (#2614) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- agent/component/generate.py | 6 +++--- agent/component/switch.py | 33 +++++++-------------------------- 2 files changed, 10 insertions(+), 29 deletions(-) diff --git a/agent/component/generate.py b/agent/component/generate.py index b5b21bb9314..63ac7dd3e1a 100644 --- a/agent/component/generate.py +++ b/agent/component/generate.py @@ -122,13 +122,13 @@ def _run(self, history, **kwargs): if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]): res = {"content": "\n- ".join(retrieval_res["empty_response"]) if "\n- ".join( retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []} - return Generate.be_output(res) + return pd.DataFrame([res]) ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size), self._param.gen_conf()) if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: - df = self.set_cite(retrieval_res, ans) - return pd.DataFrame(df) + res = self.set_cite(retrieval_res, ans) + return pd.DataFrame([res]) return Generate.be_output(ans) diff --git a/agent/component/switch.py b/agent/component/switch.py index fe2e2452ffb..bb3a15b52c9 100644 --- a/agent/component/switch.py +++ b/agent/component/switch.py @@ -49,34 +49,15 @@ class Switch(ComponentBase, ABC): def _run(self, history, **kwargs): for cond in self._param.conditions: - - if len(cond["items"]) == 1: - out = self._canvas.get_component(cond["items"][0]["cpn_id"])["obj"].output()[1] - cpn_input = "" if "content" not in out.columns else " ".join(out["content"]) - if self.process_operator(cpn_input, cond["items"][0]["operator"], cond["items"][0]["value"]): - return Switch.be_output(cond["to"]) - continue - - if cond["logical_operator"] == "and": - res = True - for item in cond["items"]: - out = self._canvas.get_component(item["cpn_id"])["obj"].output()[1] - cpn_input = "" if "content" not in out.columns else " ".join(out["content"]) - if not self.process_operator(cpn_input, item["operator"], item["value"]): - res = False - break - if res: - return Switch.be_output(cond["to"]) - continue - - res = False + res = [] for item in cond["items"]: out = self._canvas.get_component(item["cpn_id"])["obj"].output()[1] cpn_input = "" if "content" not in out.columns else " ".join(out["content"]) - if self.process_operator(cpn_input, item["operator"], item["value"]): - res = True - break - if res: + res.append(self.process_operator(cpn_input, item["operator"], item["value"])) + if cond["logical_operator"] != "and" and any(res): + return Switch.be_output(cond["to"]) + + if all(res): return Switch.be_output(cond["to"]) return Switch.be_output(self._param.end_cpn_id) @@ -122,4 +103,4 @@ def process_operator(self, input: str, operator: str, value: str) -> bool: except Exception as e: return True if input <= value else False - raise ValueError('Not supported operator' + operator) + raise ValueError('Not supported operator' + operator) \ No newline at end of file