From db3949a823fbab7b0ed06e842db7348b9d405141 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 23 Dec 2024 07:01:02 -0500 Subject: [PATCH] inject pending tasks to context --- agixt/Agent.py | 49 +++++++++++++++++++++++++++++++++++++++++++ agixt/Interactions.py | 3 +++ agixt/Task.py | 20 ------------------ 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index e17999604535..1c248ea8103e 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -17,6 +17,7 @@ get_session, UserOAuth, OAuthProvider, + TaskItem, ) from Providers import Providers from Extensions import Extensions @@ -750,3 +751,51 @@ def get_agent_id(self): return None session.close() return agent.id + + def get_conversation_tasks(self, conversation_id: str) -> str: + """Get all tasks assigned to an agent""" + try: + session = get_session() + tasks = ( + session.query(TaskItem) + .filter( + TaskItem.agent_id == self.agent_id, + TaskItem.user_id == self.user_id, + TaskItem.completed == False, + TaskItem.memory_collection == conversation_id, + ) + .all() + ) + + markdown_tasks = "## The Assistant's Scheduled Tasks\n**The assistant currently has the following tasks scheduled:**\n" + for task in tasks: + markdown_tasks += ( + f"### Task: {task.title}\n" + f"**Description:** {task.description}\n" + f"**Will be completed at:** {task.due_date}\n" + f"**Status:** {task.status}\n" + ) + session.close() + except Exception as e: + logging.error(f"Error getting tasks by agent: {str(e)}") + session.close() + return "" + + def get_all_pending_tasks(self) -> list: + """Get all tasks assigned to an agent""" + try: + session = get_session() + tasks = ( + session.query(TaskItem) + .filter( + TaskItem.agent_id == self.agent_id, + TaskItem.user_id == self.user_id, + TaskItem.completed == False, + ) + .all() + ) + session.close() + return tasks + except Exception as e: + logging.error(f"Error getting tasks by agent: {str(e)}") + return [] diff --git a/agixt/Interactions.py b/agixt/Interactions.py index de423c8fad46..404cda2e44b1 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -252,6 +252,9 @@ async def format_prompt( conversation_results = int(top_results) if top_results > 0 else 5 except: conversation_results = 5 + agent_tasks = self.agent.get_conversation_tasks(conversation_id=conversation_id) + if agent_tasks != "": + context.append(agent_tasks) conversation_history = "" conversation = c.get_conversation() if "interactions" in conversation: diff --git a/agixt/Task.py b/agixt/Task.py index 3df644d1b5a7..f796c5376d5b 100644 --- a/agixt/Task.py +++ b/agixt/Task.py @@ -183,26 +183,6 @@ async def get_tasks_by_category(self, category_name: str) -> list: session.close() return tasks - async def get_tasks_by_agent(self, agent_name: str) -> list: - """Get all tasks assigned to an agent""" - session = get_session() - agent = ( - session.query(Agent) - .filter(Agent.name == agent_name, Agent.user_id == self.user_id) - .first() - ) - if not agent: - session.close() - return [] - - tasks = ( - session.query(TaskItem) - .filter(TaskItem.agent_id == agent.id, TaskItem.user_id == self.user_id) - .all() - ) - session.close() - return tasks - async def update_task( self, task_id: str,