diff --git a/agixt/Agent.py b/agixt/Agent.py index d55685d8f525..0c7b7dd553c2 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -16,6 +16,9 @@ from Providers import Providers from Extensions import Extensions from Globals import getenv, DEFAULT_SETTINGS, DEFAULT_USER +from MagicalAuth import get_user_id, is_agixt_admin +from agixtsdk import AGiXTSDK +from fastapi import HTTPException from datetime import datetime, timezone, timedelta import logging import json @@ -29,9 +32,9 @@ def add_agent(agent_name, provider_settings=None, commands=None, user=DEFAULT_USER): - session = get_session() if not agent_name: return {"message": "Agent name cannot be empty."} + session = get_session() # Check if agent already exists agent = ( session.query(AgentModel) @@ -39,6 +42,7 @@ def add_agent(agent_name, provider_settings=None, commands=None, user=DEFAULT_US .first() ) if agent: + session.close() return {"message": f"Agent {agent_name} already exists."} agent = ( session.query(AgentModel) @@ -46,6 +50,7 @@ def add_agent(agent_name, provider_settings=None, commands=None, user=DEFAULT_US .first() ) if agent: + session.close() return {"message": f"Agent {agent_name} already exists."} user_data = session.query(User).filter(User.email == user).first() user_id = user_data.id @@ -80,7 +85,7 @@ def add_agent(agent_name, provider_settings=None, commands=None, user=DEFAULT_US ) session.add(agent_command) session.commit() - + session.close() return {"message": f"Agent {agent_name} created."} @@ -94,6 +99,7 @@ def delete_agent(agent_name, user=DEFAULT_USER): .first() ) if not agent: + session.close() return {"message": f"Agent {agent_name} not found."}, 404 # Delete associated chain steps @@ -125,7 +131,7 @@ def delete_agent(agent_name, user=DEFAULT_USER): # Delete the agent session.delete(agent) session.commit() - + session.close() return {"message": f"Agent {agent_name} deleted."}, 200 @@ -139,11 +145,11 @@ def rename_agent(agent_name, new_name, user=DEFAULT_USER): .first() ) if not agent: + session.close() return {"message": f"Agent {agent_name} not found."}, 404 - agent.name = new_name session.commit() - + session.close() return {"message": f"Agent {agent_name} renamed to {new_name}."}, 200 @@ -162,21 +168,16 @@ def get_agents(user=DEFAULT_USER): if agent.name in [a["name"] for a in output]: continue output.append({"name": agent.name, "id": agent.id, "status": False}) + session.close() return output class Agent: - def __init__(self, agent_name=None, user=DEFAULT_USER, ApiClient=None): + def __init__(self, agent_name=None, user=DEFAULT_USER, ApiClient: AGiXTSDK = None): self.agent_name = agent_name if agent_name is not None else "AGiXT" - self.session = get_session() user = user if user is not None else DEFAULT_USER self.user = user.lower() - try: - user_data = self.session.query(User).filter(User.email == self.user).first() - self.user_id = user_data.id - except Exception as e: - logging.error(f"User {self.user} not found.") - raise + self.user_id = get_user_id(user=self.user) self.AGENT_CONFIG = self.get_agent_config() self.load_config_keys() if "settings" not in self.AGENT_CONFIG: @@ -286,8 +287,9 @@ def load_config_keys(self): setattr(self, key, self.AGENT_CONFIG[key]) def get_agent_config(self): + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id ) @@ -295,26 +297,23 @@ def get_agent_config(self): ) if not agent: # Check if it is a global agent - global_user = ( - self.session.query(User).filter(User.email == DEFAULT_USER).first() - ) + global_user = session.query(User).filter(User.email == DEFAULT_USER).first() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == global_user.id, ) .first() ) - config = {"settings": {}, "commands": {}} if agent: - all_commands = self.session.query(Command).all() + all_commands = session.query(Command).all() agent_settings = ( - self.session.query(AgentSettingModel).filter_by(agent_id=agent.id).all() + session.query(AgentSettingModel).filter_by(agent_id=agent.id).all() ) agent_commands = ( - self.session.query(AgentCommand) + session.query(AgentCommand) .join(Command) .filter( AgentCommand.agent_id == agent.id, @@ -331,7 +330,10 @@ def get_agent_config(self): ) for setting in agent_settings: config["settings"][setting.name] = setting.value + session.commit() + session.close() return config + session.close() return {"settings": DEFAULT_SETTINGS, "commands": {}} async def inference(self, prompt: str, tokens: int = 0, images: list = []): @@ -387,24 +389,68 @@ def get_commands_string(self): return verbose_commands def update_agent_config(self, new_config, config_key): + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id ) .first() ) if not agent: - logging.error(f"Agent '{self.agent_name}' not found in the database.") - return + if self.user == DEFAULT_USER: + return f"Agent {self.agent_name} not found." + # Check if it is a global agent. + global_user = session.query(User).filter(User.email == DEFAULT_USER).first() + global_agent = ( + session.query(AgentModel) + .filter( + AgentModel.name == self.agent_name, + AgentModel.user_id == global_user.id, + ) + .first() + ) + # if it is a global agent, copy it to the user's agents. + if global_agent: + agent = AgentModel( + name=self.agent_name, + user_id=self.user_id, + provider_id=global_agent.provider_id, + ) + session.add(agent) + agent_settings = ( + session.query(AgentSettingModel) + .filter_by(agent_id=global_agent.id) + .all() + ) + for setting in agent_settings: + agent_setting = AgentSettingModel( + agent_id=agent.id, + name=setting.name, + value=setting.value, + ) + session.add(agent_setting) + agent_commands = ( + session.query(AgentCommand) + .filter_by(agent_id=global_agent.id) + .all() + ) + for agent_command in agent_commands: + agent_command = AgentCommand( + agent_id=agent.id, + command_id=agent_command.command_id, + state=agent_command.state, + ) + session.add(agent_command) + session.commit() + session.close() + return f"Agent {self.agent_name} configuration updated successfully." if config_key == "commands": for command_name, enabled in new_config.items(): - command = ( - self.session.query(Command).filter_by(name=command_name).first() - ) + command = session.query(Command).filter_by(name=command_name).first() if command: agent_command = ( - self.session.query(AgentCommand) + session.query(AgentCommand) .filter_by(agent_id=agent.id, command_id=command.id) .first() ) @@ -414,12 +460,12 @@ def update_agent_config(self, new_config, config_key): agent_command = AgentCommand( agent_id=agent.id, command_id=command.id, state=enabled ) - self.session.add(agent_command) + session.add(agent_command) else: for setting_name, setting_value in new_config.items(): logging.info(f"Setting {setting_name} to {setting_value}.") agent_setting = ( - self.session.query(AgentSettingModel) + session.query(AgentSettingModel) .filter_by(agent_id=agent.id, name=setting_name) .first() ) @@ -429,15 +475,18 @@ def update_agent_config(self, new_config, config_key): agent_setting = AgentSettingModel( agent_id=agent.id, name=setting_name, value=str(setting_value) ) - self.session.add(agent_setting) + session.add(agent_setting) try: - self.session.commit() + session.commit() + session.close() logging.info(f"Agent {self.agent_name} configuration updated successfully.") except Exception as e: - self.session.rollback() + session.rollback() + session.close() logging.error(f"Error updating agent configuration: {str(e)}") - raise - + raise HTTPException( + status_code=500, detail=f"Error updating agent configuration: {str(e)}" + ) return f"Agent {self.agent_name} configuration updated." def get_browsed_links(self, conversation_id=None): @@ -447,21 +496,24 @@ def get_browsed_links(self, conversation_id=None): Returns: list: The list of URLs that have been browsed by the agent. """ + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id ) .first() ) if not agent: + session.close() return [] browsed_links = ( - self.session.query(AgentBrowsedLink) + session.query(AgentBrowsedLink) .filter_by(agent_id=agent.id, conversation_id=conversation_id) .order_by(AgentBrowsedLink.id.desc()) .all() ) + session.close() if not browsed_links: return [] return browsed_links @@ -495,8 +547,9 @@ def add_browsed_link(self, url, conversation_id=None): Returns: str: The response message. """ + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id ) @@ -507,8 +560,9 @@ def add_browsed_link(self, url, conversation_id=None): browsed_link = AgentBrowsedLink( agent_id=agent.id, url=url, conversation_id=conversation_id ) - self.session.add(browsed_link) - self.session.commit() + session.add(browsed_link) + session.commit() + session.close() return f"Link {url} added to browsed links." def delete_browsed_link(self, url, conversation_id=None): @@ -521,8 +575,9 @@ def delete_browsed_link(self, url, conversation_id=None): Returns: str: The response message. """ + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id, @@ -532,19 +587,21 @@ def delete_browsed_link(self, url, conversation_id=None): if not agent: return f"Agent {self.agent_name} not found." browsed_link = ( - self.session.query(AgentBrowsedLink) + session.query(AgentBrowsedLink) .filter_by(agent_id=agent.id, url=url, conversation_id=conversation_id) .first() ) if not browsed_link: return f"Link {url} not found." - self.session.delete(browsed_link) - self.session.commit() + session.delete(browsed_link) + session.commit() + session.close() return f"Link {url} deleted from browsed links." def get_agent_id(self): + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id ) @@ -552,13 +609,15 @@ def get_agent_id(self): ) if not agent: agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user.has(email=DEFAULT_USER), ) .first() ) + session.close() if not agent: return None + session.close() return agent.id diff --git a/agixt/Chain.py b/agixt/Chain.py index 174571c8bffe..d62f81bbaa81 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -16,6 +16,7 @@ from Globals import getenv, DEFAULT_USER from Prompts import Prompts from Extensions import Extensions +from MagicalAuth import get_user_id import logging import asyncio @@ -27,29 +28,22 @@ class Chain: def __init__(self, user=DEFAULT_USER, ApiClient=None): - self.session = get_session() self.user = user self.ApiClient = ApiClient - try: - user_data = self.session.query(User).filter(User.email == self.user).first() - self.user_id = user_data.id - except: - user_data = ( - self.session.query(User).filter(User.email == DEFAULT_USER).first() - ) - self.user_id = user_data.id + self.user_id = get_user_id(self.user) def get_chain(self, chain_name): + session = get_session() chain_name = chain_name.replace("%20", " ") - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() chain_db = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.user_id == user_data.id, ChainDB.name == chain_name) .first() ) if chain_db is None: chain_db = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter( ChainDB.name == chain_name, ChainDB.user_id == self.user_id, @@ -57,9 +51,10 @@ def get_chain(self, chain_name): .first() ) if chain_db is None: + session.close() return [] chain_steps = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter(ChainStep.chain_id == chain_db.id) .order_by(ChainStep.step_number) .all() @@ -67,24 +62,24 @@ def get_chain(self, chain_name): steps = [] for step in chain_steps: - agent_name = self.session.query(Agent).get(step.agent_id).name + agent_name = session.query(Agent).get(step.agent_id).name prompt = {} if step.target_chain_id: prompt["chain_name"] = ( - self.session.query(ChainDB).get(step.target_chain_id).name + session.query(ChainDB).get(step.target_chain_id).name ) elif step.target_command_id: prompt["command_name"] = ( - self.session.query(Command).get(step.target_command_id).name + session.query(Command).get(step.target_command_id).name ) elif step.target_prompt_id: prompt["prompt_name"] = ( - self.session.query(Prompt).get(step.target_prompt_id).name + session.query(Prompt).get(step.target_prompt_id).name ) # Retrieve argument data for the step arguments = ( - self.session.query(Argument, ChainStepArgument) + session.query(Argument, ChainStepArgument) .join(ChainStepArgument, ChainStepArgument.argument_id == Argument.id) .filter(ChainStepArgument.chain_step_id == step.id) .all() @@ -109,38 +104,42 @@ def get_chain(self, chain_name): "chain_name": chain_db.name, "steps": steps, } - + session.close() return chain_data def get_chains(self): - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + session = get_session() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() global_chains = ( - self.session.query(ChainDB).filter(ChainDB.user_id == user_data.id).all() - ) - chains = ( - self.session.query(ChainDB).filter(ChainDB.user_id == self.user_id).all() + session.query(ChainDB).filter(ChainDB.user_id == user_data.id).all() ) + chains = session.query(ChainDB).filter(ChainDB.user_id == self.user_id).all() chain_list = [] for chain in chains: chain_list.append(chain.name) for chain in global_chains: chain_list.append(chain.name) + session.close() return chain_list def add_chain(self, chain_name): + session = get_session() chain = ChainDB(name=chain_name, user_id=self.user_id) - self.session.add(chain) - self.session.commit() + session.add(chain) + session.commit() + session.close() def rename_chain(self, chain_name, new_name): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) if chain: chain.name = new_name - self.session.commit() + session.commit() + session.close() def add_chain_step( self, @@ -150,13 +149,14 @@ def add_chain_step( prompt_type: str, prompt: dict, ): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) agent = ( - self.session.query(Agent) + session.query(Agent) .filter(Agent.name == agent_name, Agent.user_id == self.user_id) .first() ) @@ -168,7 +168,7 @@ def add_chain_step( if prompt_type.lower() == "prompt": argument_key = "prompt_name" target_id = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt["prompt_name"], Prompt.user_id == self.user_id, @@ -183,7 +183,7 @@ def add_chain_step( if argument_key not in prompt: argument_key = "chain" target_id = ( - self.session.query(Chain) + session.query(Chain) .filter( Chain.name == prompt["chain_name"], Chain.user_id == self.user_id ) @@ -194,7 +194,7 @@ def add_chain_step( elif prompt_type.lower() == "command": argument_key = "command_name" target_id = ( - self.session.query(Command) + session.query(Command) .filter(Command.name == prompt["command_name"]) .first() .id @@ -216,7 +216,7 @@ def add_chain_step( del prompt["input"] argument_key = "prompt_name" target_id = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt["prompt_name"], Prompt.user_id == self.user_id, @@ -229,7 +229,6 @@ def add_chain_step( argument_value = prompt[argument_key] prompt_arguments = prompt.copy() del prompt_arguments[argument_key] - chain_step = ChainStep( chain_id=chain.id, step_number=step_number, @@ -240,14 +239,12 @@ def add_chain_step( target_command_id=target_id if target_type == "command" else None, target_prompt_id=target_id if target_type == "prompt" else None, ) - self.session.add(chain_step) - self.session.commit() + session.add(chain_step) + session.commit() for argument_name, argument_value in prompt_arguments.items(): argument = ( - self.session.query(Argument) - .filter(Argument.name == argument_name) - .first() + session.query(Argument).filter(Argument.name == argument_name).first() ) if not argument: # Handle the case where argument not found based on argument_name @@ -259,40 +256,39 @@ def add_chain_step( argument_id=argument.id, value=argument_value, ) - self.session.add(chain_step_argument) - self.session.commit() + session.add(chain_step_argument) + session.commit() + session.close() def update_step(self, chain_name, step_number, agent_name, prompt_type, prompt): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) chain_step = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter( ChainStep.chain_id == chain.id, ChainStep.step_number == step_number ) .first() ) - agent = ( - self.session.query(Agent) + session.query(Agent) .filter(Agent.name == agent_name, Agent.user_id == self.user_id) .first() ) agent_id = agent.id if agent else None - target_chain_id = None target_command_id = None target_prompt_id = None - if prompt_type == "Command": command_name = prompt.get("command_name") command_args = prompt.copy() del command_args["command_name"] command = ( - self.session.query(Command).filter(Command.name == command_name).first() + session.query(Command).filter(Command.name == command_name).first() ) if command: target_command_id = command.id @@ -302,7 +298,7 @@ def update_step(self, chain_name, step_number, agent_name, prompt_type, prompt): prompt_args = prompt.copy() del prompt_args["prompt_name"] prompt_obj = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.prompt_category.has(name=prompt_category), @@ -317,32 +313,26 @@ def update_step(self, chain_name, step_number, agent_name, prompt_type, prompt): chain_args = prompt.copy() del chain_args["chain_name"] chain_obj = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) if chain_obj: target_chain_id = chain_obj.id - chain_step.agent_id = agent_id chain_step.prompt_type = prompt_type chain_step.prompt = prompt.get("prompt_name", None) chain_step.target_chain_id = target_chain_id chain_step.target_command_id = target_command_id chain_step.target_prompt_id = target_prompt_id - - self.session.commit() - + session.commit() # Update the arguments for the step - self.session.query(ChainStepArgument).filter( + session.query(ChainStepArgument).filter( ChainStepArgument.chain_step_id == chain_step.id ).delete() - for argument_name, argument_value in prompt_args.items(): argument = ( - self.session.query(Argument) - .filter(Argument.name == argument_name) - .first() + session.query(Argument).filter(Argument.name == argument_name).first() ) if argument: chain_step_argument = ChainStepArgument( @@ -350,56 +340,59 @@ def update_step(self, chain_name, step_number, agent_name, prompt_type, prompt): argument_id=argument.id, value=argument_value, ) - self.session.add(chain_step_argument) - self.session.commit() + session.add(chain_step_argument) + session.commit() + session.close() def delete_step(self, chain_name, step_number): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) - if chain: chain_step = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter( ChainStep.chain_id == chain.id, ChainStep.step_number == step_number ) .first() ) if chain_step: - self.session.delete( - chain_step - ) # Remove the chain step from the session - self.session.commit() + session.delete(chain_step) # Remove the chain step from the session + session.commit() else: logging.info( f"No step found with number {step_number} in chain '{chain_name}'" ) else: logging.info(f"No chain found with name '{chain_name}'") + session.close() def delete_chain(self, chain_name): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) - self.session.delete(chain) - self.session.commit() + session.delete(chain) + session.commit() + session.close() def get_steps(self, chain_name): + session = get_session() chain_name = chain_name.replace("%20", " ") - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() chain_db = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.user_id == user_data.id, ChainDB.name == chain_name) .first() ) if chain_db is None: chain_db = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter( ChainDB.name == chain_name, ChainDB.user_id == self.user_id, @@ -407,13 +400,15 @@ def get_steps(self, chain_name): .first() ) if chain_db is None: + session.close() return [] chain_steps = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter(ChainStep.chain_id == chain_db.id) .order_by(ChainStep.step_number) .all() ) + session.close() return chain_steps def get_step(self, chain_name, step_number): @@ -426,13 +421,14 @@ def get_step(self, chain_name, step_number): return chain_step def move_step(self, chain_name, current_step_number, new_step_number): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) chain_step = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter( ChainStep.chain_id == chain.id, ChainStep.step_number == current_step_number, @@ -441,7 +437,7 @@ def move_step(self, chain_name, current_step_number, new_step_number): ) chain_step.step_number = new_step_number if new_step_number < current_step_number: - self.session.query(ChainStep).filter( + session.query(ChainStep).filter( ChainStep.chain_id == chain.id, ChainStep.step_number >= new_step_number, ChainStep.step_number < current_step_number, @@ -449,22 +445,24 @@ def move_step(self, chain_name, current_step_number, new_step_number): {"step_number": ChainStep.step_number + 1}, synchronize_session=False ) else: - self.session.query(ChainStep).filter( + session.query(ChainStep).filter( ChainStep.chain_id == chain.id, ChainStep.step_number > current_step_number, ChainStep.step_number <= new_step_number, ).update( {"step_number": ChainStep.step_number - 1}, synchronize_session=False ) - self.session.commit() + session.commit() + session.close() def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): if chain_run_id is None: chain_run_id = self.get_last_chain_run_id(chain_name=chain_name) chain_data = self.get_chain(chain_name=chain_name) + session = get_session() if step_number == "all": chain_steps = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter(ChainStep.chain_id == chain_data["id"]) .order_by(ChainStep.step_number) .all() @@ -473,7 +471,7 @@ def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): responses = {} for step in chain_steps: chain_step_responses = ( - self.session.query(ChainStepResponse) + session.query(ChainStepResponse) .filter( ChainStepResponse.chain_step_id == step.id, ChainStepResponse.chain_run_id == chain_run_id, @@ -483,12 +481,12 @@ def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): ) step_responses = [response.content for response in chain_step_responses] responses[str(step.step_number)] = step_responses - + session.close() return responses else: step_number = int(step_number) chain_step = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter( ChainStep.chain_id == chain_data["id"], ChainStep.step_number == step_number, @@ -498,7 +496,7 @@ def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): if chain_step: chain_step_responses = ( - self.session.query(ChainStepResponse) + session.query(ChainStepResponse) .filter( ChainStepResponse.chain_step_id == chain_step.id, ChainStepResponse.chain_run_id == chain_run_id, @@ -507,49 +505,53 @@ def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): .all() ) step_responses = [response.content for response in chain_step_responses] + session.close() return step_responses else: + session.close() return None def get_chain_responses(self, chain_name): chain_steps = self.get_steps(chain_name=chain_name) responses = {} + session = get_session() for step in chain_steps: chain_step_responses = ( - self.session.query(ChainStepResponse) + session.query(ChainStepResponse) .filter(ChainStepResponse.chain_step_id == step.id) .order_by(ChainStepResponse.timestamp) .all() ) step_responses = [response.content for response in chain_step_responses] responses[str(step.step_number)] = step_responses + session.close() return responses def import_chain(self, chain_name: str, steps: dict): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) if chain: + session.close() return None chain = ChainDB(name=chain_name, user_id=self.user_id) - self.session.add(chain) - self.session.commit() + session.add(chain) + session.commit() steps = steps["steps"] if "steps" in steps else steps for step_data in steps: agent_name = step_data["agent_name"] agent = ( - self.session.query(Agent) + session.query(Agent) .filter(Agent.name == agent_name, Agent.user_id == self.user_id) .first() ) if not agent: # Use the first agent in the database agent = ( - self.session.query(Agent) - .filter(Agent.user_id == self.user_id) - .first() + session.query(Agent).filter(Agent.user_id == self.user_id).first() ) prompt = step_data["prompt"] if "prompt_type" not in step_data: @@ -559,7 +561,7 @@ def import_chain(self, chain_name: str, steps: dict): argument_key = "prompt_name" prompt_category = prompt.get("prompt_category", "Default") target_id = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt[argument_key], Prompt.user_id == self.user_id, @@ -574,7 +576,7 @@ def import_chain(self, chain_name: str, steps: dict): if "chain" in prompt: argument_key = "chain" target_id = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter( ChainDB.name == prompt[argument_key], ChainDB.user_id == self.user_id, @@ -586,7 +588,7 @@ def import_chain(self, chain_name: str, steps: dict): elif prompt_type == "command": argument_key = "command_name" target_id = ( - self.session.query(Command) + session.query(Command) .filter(Command.name == prompt[argument_key]) .first() .id @@ -609,11 +611,11 @@ def import_chain(self, chain_name: str, steps: dict): target_command_id=target_id if target_type == "command" else None, target_prompt_id=target_id if target_type == "prompt" else None, ) - self.session.add(chain_step) - self.session.commit() + session.add(chain_step) + session.commit() for argument_name, argument_value in prompt_arguments.items(): argument = ( - self.session.query(Argument) + session.query(Argument) .filter(Argument.name == argument_name) .first() ) @@ -627,8 +629,9 @@ def import_chain(self, chain_name: str, steps: dict): argument_id=argument.id, value=argument_value, ) - self.session.add(chain_step_argument) - self.session.commit() + session.add(chain_step_argument) + session.commit() + session.close() return f"Imported chain: {chain_name}" def get_chain_step_dependencies(self, chain_name): @@ -771,8 +774,9 @@ async def update_step_response( ): chain_step = self.get_step(chain_name=chain_name, step_number=step_number) if chain_step: + session = get_session() existing_response = ( - self.session.query(ChainStepResponse) + session.query(ChainStepResponse) .filter( ChainStepResponse.chain_step_id == chain_step.id, ChainStepResponse.chain_run_id == chain_run_id, @@ -785,48 +789,55 @@ async def update_step_response( response, dict ): existing_response.content.update(response) - self.session.commit() + session.commit() elif isinstance(existing_response.content, list) and isinstance( response, list ): existing_response.content.extend(response) - self.session.commit() + session.commit() else: chain_step_response = ChainStepResponse( chain_step_id=chain_step.id, chain_run_id=chain_run_id, content=response, ) - self.session.add(chain_step_response) - self.session.commit() + session.add(chain_step_response) + session.commit() else: chain_step_response = ChainStepResponse( chain_step_id=chain_step.id, chain_run_id=chain_run_id, content=response, ) - self.session.add(chain_step_response) - self.session.commit() + session.add(chain_step_response) + session.commit() + session.close() async def get_chain_run_id(self, chain_name): + session = get_session() chain_run = ChainRun( chain_id=self.get_chain(chain_name=chain_name)["id"], user_id=self.user_id, ) - self.session.add(chain_run) - self.session.commit() - return chain_run.id + session.add(chain_run) + session.commit() + chain_id = chain_run.id + session.close() + return chain_id async def get_last_chain_run_id(self, chain_name): chain_data = self.get_chain(chain_name=chain_name) + session = get_session() chain_run = ( - self.session.query(ChainRun) + session.query(ChainRun) .filter(ChainRun.chain_id == chain_data["id"]) .order_by(ChainRun.timestamp.desc()) .first() ) if chain_run: - return chain_run.id + chain_run_id = chain_run.id + session.close() + return chain_run_id else: return await self.get_chain_run_id(chain_name=chain_name) @@ -876,8 +887,9 @@ def new_task( task_description, estimated_hours, ): + session = get_session() task_category = ( - self.session.query(TaskCategory) + session.query(TaskCategory) .filter( TaskCategory.name == task_category, TaskCategory.user_id == self.user_id ) @@ -885,8 +897,8 @@ def new_task( ) if not task_category: task_category = TaskCategory(name=task_category, user_id=self.user_id) - self.session.add(task_category) - self.session.commit() + session.add(task_category) + session.commit() task = TaskItem( user_id=self.user_id, category_id=task_category.id, @@ -895,6 +907,8 @@ def new_task( estimated_hours=estimated_hours, memory_collection=str(conversation_id), ) - self.session.add(task) - self.session.commit() - return task.id + session.add(task) + session.commit() + task_id = task.id + session.close() + return task_id diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py index b32ca85176eb..a5e01891766e 100644 --- a/agixt/MagicalAuth.py +++ b/agixt/MagicalAuth.py @@ -11,8 +11,6 @@ from fastapi import Header, HTTPException from Globals import getenv from datetime import datetime, timedelta -from Agent import add_agent -from agixtsdk import AGiXTSDK from fastapi import HTTPException from sendgrid import SendGridAPIClient from sendgrid.helpers.mail import ( @@ -58,51 +56,6 @@ def is_agixt_admin(email: str = "", api_key: str = ""): return False -def webhook_create_user( - api_key: str, - email: str, - role: str = "user", - agent_name: str = "", - settings: dict = {}, - commands: dict = {}, - training_urls: list = [], - github_repos: list = [], - ApiClient: AGiXTSDK = AGiXTSDK(), -): - if not is_agixt_admin(email=email, api_key=api_key): - return {"error": "Access Denied"}, 403 - session = get_session() - email = email.lower() - user_exists = session.query(User).filter_by(email=email).first() - if user_exists: - session.close() - return {"error": "User already exists"}, 400 - admin = True if role.lower() == "admin" else False - user = User( - email=email, - admin=admin, - first_name="", - last_name="", - ) - session.add(user) - session.commit() - session.close() - if agent_name != "" and agent_name is not None: - add_agent( - agent_name=agent_name, - provider_settings=settings, - commands=commands, - user=email, - ) - if training_urls != []: - for url in training_urls: - ApiClient.learn_url(agent_name=agent_name, url=url) - if github_repos != []: - for repo in github_repos: - ApiClient.learn_github_repo(agent_name=agent_name, github_repo=repo) - return {"status": "Success"}, 200 - - def verify_api_key(authorization: str = Header(None)): AGIXT_API_KEY = getenv("AGIXT_API_KEY") if getenv("AUTH_PROVIDER") == "magicalauth": @@ -133,6 +86,21 @@ def verify_api_key(authorization: str = Header(None)): return authorization +def get_user_id(user: str): + session = get_session() + user_data = session.query(User).filter(User.email == user).first() + if user_data is None: + session.close() + raise HTTPException(status_code=404, detail=f"User {user} not found.") + try: + user_id = user_data.id + except Exception as e: + session.close() + raise HTTPException(status_code=404, detail=f"User {user} not found.") + session.close() + return user_id + + def send_email( email: str, subject: str, diff --git a/agixt/Models.py b/agixt/Models.py index 224295d7c5bd..7b8a2dd2a2b6 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -316,6 +316,7 @@ class WebhookUser(BaseModel): commands: Optional[Dict[str, Any]] = {} training_urls: Optional[List[str]] = [] github_repos: Optional[List[str]] = [] + zip_file_content: Optional[str] = "" # Auth user models diff --git a/agixt/Prompts.py b/agixt/Prompts.py index 85782efdc26e..eb97a9f93a31 100644 --- a/agixt/Prompts.py +++ b/agixt/Prompts.py @@ -1,21 +1,20 @@ from DB import Prompt, PromptCategory, Argument, User, get_session from Globals import DEFAULT_USER +from MagicalAuth import get_user_id import os class Prompts: def __init__(self, user=DEFAULT_USER): - self.session = get_session() self.user = user - user_data = self.session.query(User).filter(User.email == self.user).first() - self.user_id = user_data.id + self.user_id = get_user_id(user) def add_prompt(self, prompt_name, prompt, prompt_category="Default"): + session = get_session() if not prompt_category: prompt_category = "Default" - prompt_category = ( - self.session.query(PromptCategory) + session.query(PromptCategory) .filter( PromptCategory.name == prompt_category, PromptCategory.user_id == self.user_id, @@ -28,8 +27,8 @@ def add_prompt(self, prompt_name, prompt, prompt_category="Default"): description=f"{prompt_category} category", user_id=self.user_id, ) - self.session.add(prompt_category) - self.session.commit() + session.add(prompt_category) + session.commit() prompt_obj = Prompt( name=prompt_name, @@ -38,8 +37,8 @@ def add_prompt(self, prompt_name, prompt, prompt_category="Default"): prompt_category=prompt_category, user_id=self.user_id, ) - self.session.add(prompt_obj) - self.session.commit() + session.add(prompt_obj) + session.commit() # Populate prompt arguments prompt_args = self.get_prompt_args(prompt) @@ -48,13 +47,15 @@ def add_prompt(self, prompt_name, prompt, prompt_category="Default"): prompt_id=prompt_obj.id, name=arg, ) - self.session.add(argument) - self.session.commit() + session.add(argument) + session.commit() + session.close() def get_prompt(self, prompt_name: str, prompt_category: str = "Default"): - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + session = get_session() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == user_data.id, @@ -66,7 +67,7 @@ def get_prompt(self, prompt_name: str, prompt_category: str = "Default"): ) if not prompt: prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == self.user_id, @@ -82,7 +83,7 @@ def get_prompt(self, prompt_name: str, prompt_category: str = "Default"): if not prompt and prompt_category != "Default": # Prompt not found in specified category, try the default category prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == self.user_id, @@ -110,7 +111,7 @@ def get_prompt(self, prompt_name: str, prompt_category: str = "Default"): prompt_category="Default", ) prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == self.user_id, @@ -124,13 +125,17 @@ def get_prompt(self, prompt_name: str, prompt_category: str = "Default"): .first() ) if prompt: - return prompt.content + prompt_content = prompt.content + session.close() + return prompt_content + session.close() return None def get_prompts(self, prompt_category="Default"): - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + session = get_session() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() global_prompts = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.user_id == user_data.id, Prompt.prompt_category.has(name=prompt_category), @@ -142,7 +147,7 @@ def get_prompts(self, prompt_category="Default"): .all() ) user_prompts = ( - self.session.query(Prompt) + session.query(Prompt) .join(PromptCategory) .filter( PromptCategory.name == prompt_category, Prompt.user_id == self.user_id @@ -154,6 +159,7 @@ def get_prompts(self, prompt_category="Default"): prompts.append(prompt.name) for prompt in user_prompts: prompts.append(prompt.name) + session.close() return prompts def get_prompt_args(self, prompt_text): @@ -169,8 +175,9 @@ def get_prompt_args(self, prompt_text): return prompt_args def delete_prompt(self, prompt_name, prompt_category="Default"): + session = get_session() prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter_by(name=prompt_name) .join(PromptCategory) .filter( @@ -179,12 +186,14 @@ def delete_prompt(self, prompt_name, prompt_category="Default"): .first() ) if prompt: - self.session.delete(prompt) - self.session.commit() + session.delete(prompt) + session.commit() + session.close() def update_prompt(self, prompt_name, prompt, prompt_category="Default"): + session = get_session() prompt_obj = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == self.user_id, @@ -195,7 +204,7 @@ def update_prompt(self, prompt_name, prompt, prompt_category="Default"): if prompt_obj: if prompt_category: prompt_category = ( - self.session.query(PromptCategory) + session.query(PromptCategory) .filter( PromptCategory.name == prompt_category, PromptCategory.user_id == self.user_id, @@ -208,25 +217,21 @@ def update_prompt(self, prompt_name, prompt, prompt_category="Default"): description=f"{prompt_category} category", user_id=self.user_id, ) - self.session.add(prompt_category) - self.session.commit() + session.add(prompt_category) + session.commit() prompt_obj.prompt_category = prompt_category - prompt_obj.content = prompt - self.session.commit() - + session.commit() # Update prompt arguments prompt_args = self.get_prompt_args(prompt) existing_args = ( - self.session.query(Argument).filter_by(prompt_id=prompt_obj.id).all() + session.query(Argument).filter_by(prompt_id=prompt_obj.id).all() ) existing_arg_names = {arg.name for arg in existing_args} - # Delete removed arguments for arg in existing_args: if arg.name not in prompt_args: - self.session.delete(arg) - + session.delete(arg) # Add new arguments for arg in prompt_args: if arg not in existing_arg_names: @@ -234,13 +239,14 @@ def update_prompt(self, prompt_name, prompt, prompt_category="Default"): prompt_id=prompt_obj.id, name=arg, ) - self.session.add(argument) - - self.session.commit() + session.add(argument) + session.commit() + session.close() def rename_prompt(self, prompt_name, new_prompt_name, prompt_category="Default"): + session = get_session() prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == self.user_id, @@ -254,17 +260,19 @@ def rename_prompt(self, prompt_name, new_prompt_name, prompt_category="Default") ) if prompt: prompt.name = new_prompt_name - self.session.commit() + session.commit() + session.close() def get_prompt_categories(self): - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + session = get_session() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() global_prompt_categories = ( - self.session.query(PromptCategory) + session.query(PromptCategory) .filter(PromptCategory.user_id == user_data.id) .all() ) user_prompt_categories = ( - self.session.query(PromptCategory) + session.query(PromptCategory) .filter(PromptCategory.user_id == self.user_id) .all() ) @@ -273,4 +281,5 @@ def get_prompt_categories(self): prompt_categories.append(prompt_category.name) for prompt_category in user_prompt_categories: prompt_categories.append(prompt_category.name) + session.close() return prompt_categories diff --git a/agixt/XT.py b/agixt/XT.py index 3f2ef71ac2e9..33401ba0c16d 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -630,10 +630,14 @@ async def learn_from_file( stderr=subprocess.PIPE, ) file_path = pdf_file_path - if conversation_name != "" and conversation_name != None: + if ( + conversation_name != "" + and conversation_name != None + and file_type not in ["jpg", "jpeg", "png", "gif"] + ): c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Reading file `{file_name}` into memory.", + message=f"[ACTIVITY] Reading `{file_name}` into memory.", ) if user_input == "": user_input = "Describe each stage of this image." @@ -785,7 +789,7 @@ async def learn_from_file( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Viewing image at {file_url} .", + message=f"[ACTIVITY] [Uploaded {file_name}]({file_url}) .", ) try: vision_prompt = f"The assistant has an image in context\nThe user's last message was: {user_input}\nThe uploaded image is `{file_name}`.\n\nAnswer anything relevant to the image that the user is questioning if anything, additionally, describe the image in detail." @@ -1347,7 +1351,9 @@ async def chat_completions(self, prompt: ChatCompletions): self.agent_workspace, audio_file_info["file_name"], ) - if url.startswith(self.agent_workspace): + if os.path.normpath(audio_file_path).startswith( + self.agent_workspace + ): wav_file = os.path.join( self.agent_workspace, f"{uuid.uuid4().hex}.wav", diff --git a/agixt/endpoints/Auth.py b/agixt/endpoints/Auth.py index 224e71b510aa..91aa3059be01 100644 --- a/agixt/endpoints/Auth.py +++ b/agixt/endpoints/Auth.py @@ -1,6 +1,8 @@ from fastapi import APIRouter, Request, Header, Depends, HTTPException from Models import Detail, Login, UserInfo, Register -from MagicalAuth import MagicalAuth, verify_api_key, webhook_create_user +from MagicalAuth import MagicalAuth, verify_api_key, is_agixt_admin +from DB import get_session, User +from Agent import add_agent from ApiClient import get_api_client, is_admin from Models import WebhookUser from Globals import getenv @@ -99,18 +101,50 @@ async def createuser( account: WebhookUser, authorization: str = Header(None), ): + if not is_agixt_admin(email=email, api_key=authorization): + raise HTTPException(status_code=403, detail="Unauthorized") ApiClient = get_api_client(authorization=authorization) - return webhook_create_user( - api_key=authorization, - email=account.email, - role="user", - agent_name=account.agent_name, - settings=account.settings, - commands=account.commands, - training_urls=account.training_urls, - github_repos=account.github_repos, - ApiClient=ApiClient, + session = get_session() + email = account.email.lower() + agent_name = account.agent_name + settings = account.settings + commands = account.commands + training_urls = account.training_urls + github_repos = account.github_repos + zip_file_content = account.zip_file_content + user_exists = session.query(User).filter_by(email=email).first() + if user_exists: + session.close() + return {"status": "User already exists"}, 200 + user = User( + email=email, + admin=False, + first_name="", + last_name="", ) + session.add(user) + session.commit() + session.close() + if agent_name != "" and agent_name is not None: + add_agent( + agent_name=agent_name, + provider_settings=settings, + commands=commands, + user=email, + ) + if training_urls != []: + for url in training_urls: + ApiClient.learn_url(agent_name=agent_name, url=url) + if github_repos != []: + for repo in github_repos: + ApiClient.learn_github_repo(agent_name=agent_name, github_repo=repo) + if zip_file_content != "": + ApiClient.learn_file( + agent_name=agent_name, + file_name="training_data.zip", + file_content=zip_file_content, + ) + return {"status": "Success"}, 200 @app.post(