diff --git a/agixt/Chain.py b/agixt/Chain.py index 054a402b3b95..a327cf3596ee 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -3,6 +3,7 @@ Chain as ChainDB, ChainStep, ChainStepResponse, + ChainRun, Agent, Argument, ChainStepArgument, @@ -11,6 +12,9 @@ User, ) from Globals import getenv, DEFAULT_USER +from Prompts import Prompts +from Conversations import Conversations +from Extensions import Extensions import logging logging.basicConfig( @@ -20,9 +24,10 @@ class Chain: - def __init__(self, user=DEFAULT_USER): + def __init__(self, user=DEFAULT_USER, ApiClient=None): self.session = get_session() self.user = user + self.ApiClient = ApiClient try: user_data = self.session.query(User).filter(User.email == self.user).first() self.user_id = user_data.id @@ -430,12 +435,14 @@ def move_step(self, chain_name, current_step_number, new_step_number): ) self.session.commit() - def get_step_response(self, chain_name, step_number="all"): - chain = self.get_chain(chain_name=chain_name) + def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): + if chain_run_id is None: + chain_run_id = self.get_last_chain_run_id(chain_name=chain_name) + chain_data = self.get_chain(chain_name=chain_name) if step_number == "all": chain_steps = ( self.session.query(ChainStep) - .filter(ChainStep.chain_id == chain["id"]) + .filter(ChainStep.chain_id == chain_data["id"]) .order_by(ChainStep.step_number) .all() ) @@ -444,7 +451,10 @@ def get_step_response(self, chain_name, step_number="all"): for step in chain_steps: chain_step_responses = ( self.session.query(ChainStepResponse) - .filter(ChainStepResponse.chain_step_id == step.id) + .filter( + ChainStepResponse.chain_step_id == step.id, + ChainStepResponse.chain_run_id == chain_run_id, + ) .order_by(ChainStepResponse.timestamp) .all() ) @@ -456,7 +466,7 @@ def get_step_response(self, chain_name, step_number="all"): chain_step = ( self.session.query(ChainStep) .filter( - ChainStep.chain_id == chain["id"], + ChainStep.chain_id == chain_data["id"], ChainStep.step_number == step_number, ) .first() @@ -465,7 +475,10 @@ def get_step_response(self, chain_name, step_number="all"): if chain_step: chain_step_responses = ( self.session.query(ChainStepResponse) - .filter(ChainStepResponse.chain_step_id == chain_step.id) + .filter( + ChainStepResponse.chain_step_id == chain_step.id, + ChainStepResponse.chain_run_id == chain_run_id, + ) .order_by(ChainStepResponse.timestamp) .all() ) @@ -585,7 +598,9 @@ def import_chain(self, chain_name: str, steps: dict): return f"Imported chain: {chain_name}" - def get_step_content(self, chain_name, prompt_content, user_input, agent_name): + def get_step_content( + self, chain_run_id, chain_name, prompt_content, user_input, agent_name + ): if isinstance(prompt_content, dict): new_prompt_content = {} for arg, value in prompt_content.items(): @@ -599,7 +614,9 @@ def get_step_content(self, chain_name, prompt_content, user_input, agent_name): for i in range(step_count): new_step_number = int(value.split("{STEP")[1].split("}")[0]) step_response = self.get_step_response( - chain_name=chain_name, step_number=new_step_number + chain_run_id=chain_run_id, + chain_name=chain_name, + step_number=new_step_number, ) if step_response: resp = ( @@ -629,7 +646,9 @@ def get_step_content(self, chain_name, prompt_content, user_input, agent_name): prompt_content.split("{STEP")[1].split("}")[0] ) step_response = self.get_step_response( - chain_name=chain_name, step_number=new_step_number + chain_run_id=chain_run_id, + chain_name=chain_name, + step_number=new_step_number, ) if step_response: resp = ( @@ -644,13 +663,17 @@ def get_step_content(self, chain_name, prompt_content, user_input, agent_name): else: return prompt_content - async def update_step_response(self, chain_name, step_number, response): - chain = self.get_chain(chain_name=chain_name) + async def update_step_response( + self, chain_run_id, chain_name, step_number, response + ): chain_step = self.get_step(chain_name=chain_name, step_number=step_number) if chain_step: existing_response = ( self.session.query(ChainStepResponse) - .filter(ChainStepResponse.chain_step_id == chain_step.id) + .filter( + ChainStepResponse.chain_step_id == chain_step.id, + ChainStepResponse.chain_run_id == chain_run_id, + ) .order_by(ChainStepResponse.timestamp.desc()) .first() ) @@ -667,13 +690,77 @@ async def update_step_response(self, chain_name, step_number, response): self.session.commit() else: chain_step_response = ChainStepResponse( - chain_step_id=chain_step.id, content=response + chain_step_id=chain_step.id, + chain_run_id=chain_run_id, + content=response, ) self.session.add(chain_step_response) self.session.commit() else: chain_step_response = ChainStepResponse( - chain_step_id=chain_step.id, content=response + chain_step_id=chain_step.id, + chain_run_id=chain_run_id, + content=response, ) self.session.add(chain_step_response) self.session.commit() + + async def get_chain_run_id(self, chain_name): + chain_run = ChainRun( + chain_id=self.get_chain(chain_name=chain_name)["id"], + user_id=self.user_id, + ) + self.session.add(chain_run) + self.session.commit() + return chain_run.id + + async def get_last_chain_run_id(self, chain_name): + chain_data = self.get_chain(chain_name=chain_name) + chain_run = ( + self.session.query(ChainRun) + .filter(ChainRun.chain_id == chain_data["id"]) + .order_by(ChainRun.timestamp.desc()) + .first() + ) + if chain_run: + return chain_run.id + else: + return await self.get_chain_run_id(chain_name=chain_name) + + def get_chain_args(self, chain_name): + skip_args = [ + "command_list", + "context", + "COMMANDS", + "date", + "conversation_history", + "agent_name", + "working_directory", + "helper_agent_name", + ] + chain_data = self.get_chain(chain_name=chain_name) + steps = chain_data["steps"] + prompt_args = [] + args = [] + for step in steps: + try: + prompt = step["prompt"] + if "prompt_name" in prompt: + prompt_text = Prompts(user=self.user).get_prompt( + prompt_name=prompt["prompt_name"] + ) + args = Prompts(user=self.user).get_prompt_args( + prompt_text=prompt_text + ) + elif "command_name" in prompt: + args = Extensions().get_command_args( + command_name=prompt["command_name"] + ) + elif "chain_name" in prompt: + args = self.get_chain_args(chain_name=prompt["chain_name"]) + for arg in args: + if arg not in prompt_args and arg not in skip_args: + prompt_args.append(arg) + except Exception as e: + logging.error(f"Error getting chain args: {e}") + return prompt_args diff --git a/agixt/Chains.py b/agixt/Chains.py deleted file mode 100644 index 6951a5055c31..000000000000 --- a/agixt/Chains.py +++ /dev/null @@ -1,212 +0,0 @@ -import logging -from Globals import getenv -from ApiClient import Chain, Prompts, Conversations -from Extensions import Extensions - -logging.basicConfig( - level=getenv("LOG_LEVEL"), - format=getenv("LOG_FORMAT"), -) - - -class Chains: - def __init__(self, user="USER", ApiClient=None): - self.user = user - self.chain = Chain(user=user) - self.ApiClient = ApiClient - - async def run_chain_step( - self, - step: dict = {}, - chain_name="", - user_input="", - agent_override="", - chain_args={}, - ): - if step: - if "prompt_type" in step: - if agent_override != "": - agent_name = agent_override - else: - agent_name = step["agent_name"] - - prompt_type = step["prompt_type"] - step_number = step["step"] - if "prompt_name" in step["prompt"]: - prompt_name = step["prompt"]["prompt_name"] - else: - prompt_name = "" - args = self.chain.get_step_content( - chain_name=chain_name, - prompt_content=step["prompt"], - user_input=user_input, - agent_name=step["agent_name"], - ) - if chain_args != {}: - for arg, value in chain_args.items(): - args[arg] = value - if "chain_name" in args: - args["chain"] = args["chain_name"] - if "chain" not in args: - args["chain"] = chain_name - if "conversation_name" not in args: - args["conversation_name"] = f"Chain Execution History: {chain_name}" - if "conversation" in args: - args["conversation_name"] = args["conversation"] - if prompt_type == "Command": - return self.ApiClient.execute_command( - agent_name=agent_name, - command_name=step["prompt"]["command_name"], - command_args=args, - conversation_name=args["conversation_name"], - ) - elif prompt_type == "Prompt": - result = self.ApiClient.prompt_agent( - agent_name=agent_name, - prompt_name=prompt_name, - prompt_args={ - "chain_name": chain_name, - "step_number": step_number, - "user_input": user_input, - **args, - }, - ) - elif prompt_type == "Chain": - result = self.ApiClient.run_chain( - chain_name=args["chain"], - user_input=args["input"], - agent_name=agent_name, - all_responses=( - args["all_responses"] if "all_responses" in args else False - ), - from_step=args["from_step"] if "from_step" in args else 1, - chain_args=( - args["chain_args"] - if "chain_args" in args - else {"conversation_name": args["conversation_name"]} - ), - ) - if result: - if isinstance(result, dict) and "response" in result: - result = result["response"] - if result == "Unable to retrieve data.": - result = None - if not isinstance(result, str): - result = str(result) - return result - else: - return None - - async def run_chain( - self, - chain_name, - user_input=None, - all_responses=True, - agent_override="", - from_step=1, - chain_args={}, - ): - chain_data = self.ApiClient.get_chain(chain_name=chain_name) - if chain_data == {}: - return f"Chain `{chain_name}` not found." - c = Conversations( - conversation_name=( - f"Chain Execution History: {chain_name}" - if "conversation_name" not in chain_args - else chain_args["conversation_name"] - ), - user=self.user, - ) - c.log_interaction( - role="USER", - message=user_input, - ) - logging.info(f"Running chain '{chain_name}'") - responses = {} # Create a dictionary to hold responses. - last_response = "" - for step_data in chain_data["steps"]: - if int(step_data["step"]) >= int(from_step): - if "prompt" in step_data and "step" in step_data: - step = {} - step["agent_name"] = ( - agent_override - if agent_override != "" - else step_data["agent_name"] - ) - step["prompt_type"] = step_data["prompt_type"] - step["prompt"] = step_data["prompt"] - step["step"] = step_data["step"] - logging.info( - f"Running step {step_data['step']} with agent {step['agent_name']}." - ) - # try: - step_response = await self.run_chain_step( - step=step, - chain_name=chain_name, - user_input=user_input, - agent_override=agent_override, - chain_args=chain_args, - ) # Get the response of the current step. - # except Exception as e: - # logging.error(f"Error running chain step: {e}") - # step_response = None - if step_response == None: - return f"Chain failed to complete, it failed on step {step_data['step']}. You can resume by starting the chain from the step that failed." - step["response"] = step_response - last_response = step_response - logging.info(f"Last response: {last_response}") - responses[step_data["step"]] = step # Store the response. - logging.info(f"Step {step_data['step']} response: {step_response}") - # Write the response to the chain responses file. - await self.chain.update_step_response( - chain_name=chain_name, - step_number=step_data["step"], - response=step_response, - ) - if all_responses: - return responses - else: - # Return only the last response in the chain. - c.log_interaction( - role=agent_override if agent_override != "" else "AGiXT", - message=last_response, - ) - return last_response - - def get_chain_args(self, chain_name): - skip_args = [ - "command_list", - "context", - "COMMANDS", - "date", - "conversation_history", - "agent_name", - "working_directory", - "helper_agent_name", - ] - chain_data = self.chain.get_chain(chain_name=chain_name) - steps = chain_data["steps"] - prompt_args = [] - args = [] - for step in steps: - try: - prompt = step["prompt"] - if "prompt_name" in prompt: - prompt_text = Prompts(user=self.user).get_prompt( - prompt_name=prompt["prompt_name"] - ) - args = Prompts(user=self.user).get_prompt_args( - prompt_text=prompt_text - ) - elif "command_name" in prompt: - args = Extensions().get_command_args( - command_name=prompt["command_name"] - ) - elif "chain_name" in prompt: - args = self.get_chain_args(chain_name=prompt["chain_name"]) - for arg in args: - if arg not in prompt_args and arg not in skip_args: - prompt_args.append(arg) - except Exception as e: - logging.error(f"Error getting chain args: {e}") - return prompt_args diff --git a/agixt/DB.py b/agixt/DB.py index 6d78c2879516..5dea1e5e0880 100644 --- a/agixt/DB.py +++ b/agixt/DB.py @@ -314,6 +314,7 @@ class Chain(Base): "ChainStep", backref="target_chain", foreign_keys="ChainStep.target_chain_id" ) user = relationship("User", backref="chain") + runs = relationship("ChainRun", backref="chain", cascade="all, delete-orphan") class ChainStep(Base): @@ -379,6 +380,29 @@ class ChainStepArgument(Base): value = Column(Text, nullable=True) +class ChainRun(Base): + __tablename__ = "chain_run" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + chain_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain.id", ondelete="CASCADE"), + nullable=False, + ) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + timestamp = Column(DateTime, server_default=text("now()")) + chain_step_responses = relationship( + "ChainStepResponse", backref="chain_run", cascade="all, delete-orphan" + ) + + class ChainStepResponse(Base): __tablename__ = "chain_step_response" id = Column( @@ -391,6 +415,11 @@ class ChainStepResponse(Base): ForeignKey("chain_step.id", ondelete="CASCADE"), nullable=False, # Add the ondelete option ) + chain_run_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain_run.id", ondelete="CASCADE"), + nullable=True, + ) timestamp = Column(DateTime, server_default=text("now()")) content = Column(Text, nullable=False) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index b58eef729934..47e550c175c6 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -99,8 +99,6 @@ async def format_prompt( user_input: str = "", top_results: int = 5, prompt="", - chain_name="", - step_number=0, conversation_name="", vision_response: str = "", **kwargs, @@ -195,36 +193,6 @@ async def format_prompt( context = f"The user's input causes you remember these things:\n{context}\n" else: context = "" - if chain_name != "": - try: - for arg, value in kwargs.items(): - if "{STEP" in value: - # get the response from the step number - step_response = self.chain.get_step_response( - chain_name=chain_name, step_number=step_number - ) - # replace the {STEPx} with the response - value = value.replace( - f"{{STEP{step_number}}}", - step_response if step_response else "", - ) - kwargs[arg] = value - except: - logging.info("No args to replace.") - if "{STEP" in prompt: - step_response = self.chain.get_step_response( - chain_name=chain_name, step_number=step_number - ) - prompt = prompt.replace( - f"{{STEP{step_number}}}", step_response if step_response else "" - ) - if "{STEP" in user_input: - step_response = self.chain.get_step_response( - chain_name=chain_name, step_number=step_number - ) - user_input = user_input.replace( - f"{{STEP{step_number}}}", step_response if step_response else "" - ) try: working_directory = self.agent.AGENT_CONFIG["settings"]["WORKING_DIRECTORY"] except: @@ -385,8 +353,6 @@ async def run( self, user_input: str = "", context_results: int = 5, - chain_name: str = "", - step_number: int = 0, shots: int = 1, disable_memory: bool = True, conversation_name: str = "", @@ -501,8 +467,6 @@ async def run( top_results=int(context_results), prompt=prompt, prompt_category=prompt_category, - chain_name=chain_name, - step_number=step_number, conversation_name=conversation_name, websearch=websearch, vision_response=vision_response, @@ -541,8 +505,6 @@ async def run( if context_results > 0: context_results = context_results - 1 prompt_args = { - "chain_name": chain_name, - "step_number": step_number, "shots": shots, "disable_memory": disable_memory, "user_input": user_input, @@ -648,8 +610,6 @@ async def run( responses = [self.response] for shot in range(shots - 1): prompt_args = { - "chain_name": chain_name, - "step_number": step_number, "user_input": user_input, "context_results": context_results, "conversation_name": conversation_name, diff --git a/agixt/Models.py b/agixt/Models.py index 2fb75888dfa3..d25b69a0066f 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -117,6 +117,7 @@ class RunChainStep(BaseModel): prompt: str agent_override: Optional[str] = "" chain_args: Optional[dict] = {} + chain_run_id: Optional[str] = "" class StepInfo(BaseModel): @@ -285,15 +286,11 @@ class Register(BaseModel): email: str first_name: str last_name: str - company_name: str - job_title: str class UserInfo(BaseModel): first_name: str last_name: str - company_name: str - job_title: str class Detail(BaseModel): diff --git a/agixt/AGiXT.py b/agixt/XT.py similarity index 77% rename from agixt/AGiXT.py rename to agixt/XT.py index 494ee8b0a126..60bb75b48575 100644 --- a/agixt/AGiXT.py +++ b/agixt/XT.py @@ -2,7 +2,6 @@ from ApiClient import get_api_client, Conversations, Prompts, Chain from readers.file import FileReader from Extensions import Extensions -from Chains import Chains from pydub import AudioSegment from Globals import getenv, get_tokens, DEFAULT_SETTINGS from Models import ChatCompletions @@ -31,6 +30,7 @@ def __init__(self, user: str, agent_name: str, api_key: str): if "settings" in self.agent.AGENT_CONFIG else DEFAULT_SETTINGS ) + self.chain = Chain(user=self.user_email) async def prompts(self, prompt_category: str = "Default"): """ @@ -53,7 +53,7 @@ async def chains(self): Returns: list: List of available chains """ - return Chain(user=self.user_email).get_chains() + return self.chain.get_chains() async def settings(self): """ @@ -264,8 +264,8 @@ async def execute_command( if conversation_name != "" and conversation_name != None: c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( - role=self.agent, - message=f"[ACTIVITY_START] Execute command: {command_name} with args: {command_args} [ACTIVITY_END]", + role=self.agent_name, + message=f"[ACTIVITY_START] Executing command: {command_name} with args: {command_args} [ACTIVITY_END]", ) response = await Extensions( agent_name=self.agent_name, @@ -293,41 +293,181 @@ async def execute_command( ) return response - async def execute_chain( + async def run_chain_step( self, - chain_name: str, - user_input: str, - chain_args: dict = {}, - use_current_agent: bool = True, - conversation_name: str = "", - voice_response: bool = False, + chain_run_id=None, + step: dict = {}, + chain_name="", + user_input="", + agent_override="", + chain_args={}, + conversation_name="", ): - """ - Execute a chain with arguments - - Args: - chain_name (str): Name of the chain to execute - user_input (str): Message to add to conversation log pre-execution - chain_args (dict): Arguments for the chain - use_current_agent (bool): Whether to use the current agent - conversation_name (str): Name of the conversation - voice_response (bool): Whether to generate a voice response + if not chain_run_id: + chain_run_id = await self.chain.get_chain_run_id(chain_name=chain_name) + if step: + if "prompt_type" in step: + c = None + if conversation_name != "": + c = Conversations( + conversation_name=conversation_name, + user=self.user_email, + ) + if agent_override != "": + agent_name = agent_override + else: + agent_name = step["agent_name"] + prompt_type = step["prompt_type"] + step_number = step["step"] + if "prompt_name" in step["prompt"]: + prompt_name = step["prompt"]["prompt_name"] + else: + prompt_name = "" + args = self.chain.get_step_content( + chain_run_id=chain_run_id, + chain_name=chain_name, + prompt_content=step["prompt"], + user_input=user_input, + agent_name=agent_name, + ) + if chain_args != {}: + for arg, value in chain_args.items(): + args[arg] = value + if "chain_name" in args: + args["chain"] = args["chain_name"] + if "chain" not in args: + args["chain"] = chain_name + if "conversation_name" not in args: + args["conversation_name"] = f"Chain Execution History: {chain_name}" + if "conversation" in args: + args["conversation_name"] = args["conversation"] + if prompt_type == "Command": + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Executing command: {step['prompt']['command_name']} with args: {args} [ACTIVITY_END]", + ) + result = await self.execute_command( + command_name=step["prompt"]["command_name"], + command_args=args, + conversation_name=args["conversation_name"], + voice_response=False, + ) + elif prompt_type == "Prompt": + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Running prompt: {prompt_name} with args: {args} [ACTIVITY_END]", + ) + if "prompt_name" not in args: + args["prompt_name"] = prompt_name + result = await self.inference( + agent_name=agent_name, + user_input=user_input, + log_user_input=False, + **args, + ) + elif prompt_type == "Chain": + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Running chain: {step['prompt']['chain_name']} with args: {args} [ACTIVITY_END]", + ) + result = await self.execute_chain( + chain_name=args["chain"], + user_input=args["input"], + agent_override=agent_name, + from_step=args["from_step"] if "from_step" in args else 1, + chain_args=( + args["chain_args"] + if "chain_args" in args + else {"conversation_name": args["conversation_name"]} + ), + conversation_name=args["conversation_name"], + log_user_input=False, + voice_response=False, + ) + if result: + if isinstance(result, dict) and "response" in result: + result = result["response"] + if result == "Unable to retrieve data.": + result = None + if isinstance(result, dict): + result = json.dumps(result) + if not isinstance(result, str): + result = str(result) + await self.chain.update_step_response( + chain_run_id=chain_run_id, + chain_name=chain_name, + step_number=step_number, + response=result, + ) + return result + else: + return None - Returns: - str: Response from the chain - """ - c = Conversations(conversation_name=conversation_name, user=self.user_email) - c.log_interaction(role="USER", message=user_input) - response = await Chains( - user=self.user_email, ApiClient=self.ApiClient - ).run_chain( - chain_name=chain_name, - user_input=user_input, - agent_override=self.agent_name if use_current_agent else None, - all_responses=False, - chain_args=chain_args, - from_step=1, + async def execute_chain( + self, + chain_name, + chain_run_id=None, + user_input=None, + agent_override="", + from_step=1, + chain_args={}, + log_user_input=False, + conversation_name="", + voice_response=False, + ): + chain_data = self.chain.get_chain(chain_name=chain_name) + if not chain_run_id: + chain_run_id = await self.chain.get_chain_run_id(chain_name=chain_name) + if chain_data == {}: + return f"Chain `{chain_name}` not found." + c = Conversations( + conversation_name=conversation_name, + user=self.user_email, ) + if log_user_input: + c.log_interaction( + role="USER", + message=user_input, + ) + agent_name = agent_override if agent_override != "" else "AGiXT" + if conversation_name != "": + c.log_interaction( + role=agent_name, + message=f"[ACTIVITY_START] Running chain `{chain_name}`... [ACTIVITY_END]", + ) + response = "" + for step_data in chain_data["steps"]: + if int(step_data["step"]) >= int(from_step): + if "prompt" in step_data and "step" in step_data: + step = {} + step["agent_name"] = ( + agent_override + if agent_override != "" + else step_data["agent_name"] + ) + step["prompt_type"] = step_data["prompt_type"] + step["prompt"] = step_data["prompt"] + step["step"] = step_data["step"] + step_response = await self.run_chain_step( + chain_run_id=chain_run_id, + step=step, + chain_name=chain_name, + user_input=user_input, + agent_override=agent_override, + chain_args=chain_args, + conversation_name=conversation_name, + ) + if step_response == None: + return f"Chain failed to complete, it failed on step {step_data['step']}. You can resume by starting the chain from the step that failed with chain ID {chain_run_id}." + response = step_response + if conversation_name != "": + c.log_interaction( + role=agent_name, + message=response, + ) if "tts_provider" in self.agent_settings and voice_response: if ( self.agent_settings["tts_provider"] != "None" @@ -667,8 +807,9 @@ async def chat_completions(self, prompt: ChatCompletions): response = await self.execute_chain( chain_name=chain_name, user_input=new_prompt, + agent_override=self.agent_name, chain_args=chain_args, - use_current_agent=True, + log_user_input=False, conversation_name=conversation_name, voice_response=tts, ) diff --git a/agixt/endpoints/Chain.py b/agixt/endpoints/Chain.py index 9f245a21489b..b5c70001d77b 100644 --- a/agixt/endpoints/Chain.py +++ b/agixt/endpoints/Chain.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, HTTPException, Depends, Header from ApiClient import Chain, verify_api_key, get_api_client, is_admin -from Chains import Chains +from XT import AGiXT from Models import ( RunChain, RunChainStep, @@ -32,21 +32,6 @@ async def get_chain(chain_name: str, user=Depends(verify_api_key)): return {"chain": chain_data} -@app.get( - "/api/chain/{chain_name}/responses", - tags=["Chain"], - dependencies=[Depends(verify_api_key)], -) -async def get_chain_responses(chain_name: str, user=Depends(verify_api_key)): - try: - chain_data = Chain(user=user).get_step_response( - chain_name=chain_name, step_number="all" - ) - return {"chain": chain_data} - except: - raise HTTPException(status_code=404, detail="Chain not found") - - @app.post( "/api/chain/{chain_name}/run", tags=["Chain", "Admin"], @@ -60,14 +45,18 @@ async def run_chain( ): if is_admin(email=user, api_key=authorization) != True: raise HTTPException(status_code=403, detail="Access Denied") - ApiClient = get_api_client(authorization=authorization) - chain_response = await Chains(user=user, ApiClient=ApiClient).run_chain( + agent_name = user_input.agent_override if user_input.agent_override else "gpt4free" + chain_response = await AGiXT( + user=user, + agent_name=agent_name, + api_key=authorization, + ).execute_chain( chain_name=chain_name, user_input=user_input.prompt, agent_override=user_input.agent_override, - all_responses=user_input.all_responses, from_step=user_input.from_step, chain_args=user_input.chain_args, + log_user_input=False, ) try: if "Chain failed to complete" in chain_response: @@ -99,8 +88,15 @@ async def run_chain_step( raise HTTPException( status_code=404, detail=f"Step {step_number} not found. {e}" ) - ApiClient = get_api_client(authorization=authorization) - chain_step_response = await Chains(user=user, ApiClient=ApiClient).run_chain_step( + agent_name = ( + user_input.agent_override if user_input.agent_override else step["agent"] + ) + chain_step_response = await AGiXT( + user=user, + agent_name=agent_name, + api_key=authorization, + ).run_chain_step( + chain_run_id=user_input.chain_run_id, step=step, chain_name=chain_name, user_input=user_input.prompt, @@ -129,7 +125,7 @@ async def get_chain_args( if is_admin(email=user, api_key=authorization) != True: raise HTTPException(status_code=403, detail="Access Denied") ApiClient = get_api_client(authorization=authorization) - chain_args = Chains(user=user, ApiClient=ApiClient).get_chain_args( + chain_args = Chain(user=user, ApiClient=ApiClient).get_chain_args( chain_name=chain_name ) return {"chain_args": chain_args} diff --git a/agixt/endpoints/Completions.py b/agixt/endpoints/Completions.py index 1362f282dd78..59f0b4d2bbc8 100644 --- a/agixt/endpoints/Completions.py +++ b/agixt/endpoints/Completions.py @@ -13,7 +13,7 @@ TextToSpeech, ImageCreation, ) -from AGiXT import AGiXT +from XT import AGiXT app = APIRouter() diff --git a/tests/tests.ipynb b/tests/tests.ipynb index 44a6057cb6ff..fff013d4d0c3 100644 --- a/tests/tests.ipynb +++ b/tests/tests.ipynb @@ -1534,37 +1534,6 @@ "print(\"Run chain response:\", run_chain_resp)" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Get the responses from the chain running\n" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Chain: {'1': {'agent_name': 'new_agent', 'prompt_type': 'Prompt', 'prompt': {'prompt_name': 'Write a Poem', 'subject': 'Quantum Computers'}, 'step': 1, 'response': \"In the depths of the quantum realm,\\nWhere possibilities unfurl,\\nLies a marvel of modern science,\\nA computer that defies our world.\\n\\nQuantum entanglement's mysterious dance,\\nHarnessing particles, in a cosmic trance,\\nBits of information, not just zero or one,\\nA quantum computer, where wonders are spun.\\n\\nThe qubits, like tiny dancers on a stage,\\nCan exist in multiple states, they engage,\\nA quantum superposition, a delicate balance,\\nComputing power that leaves us in a trance.\\n\\nThrough quantum gates, these qubits entwine,\\nCreating a web of possibilities, oh so fine,\\nParallel universes, in computation they roam,\\nQuantum computers, bringing the unknown home.\\n\\nComplex algorithms, they can quickly solve,\\nShattering encryption, with problems they evolve,\\nFrom cryptography to simulating the universe,\\nQuantum computers, a scientific traverse.\\n\\nYet, in this realm of infinite potential,\\nErrors and decoherence can be consequential,\\nNoise and disturbances, they threaten the state,\\nA challenge to overcome, for quantum's fate.\\n\\nBut fear not, for scientists persist,\\nAdvancing quantum technology, a fervent twist,\\nWith every breakthrough, a step closer we come,\\nTo a future where quantum computers will hum.\\n\\nIn this world of uncertainty and flux,\\nQuantum computers, the next paradigm, unbox,\\nUnveiling the secrets of our reality's core,\\nA technological marvel, forever to adore.\"}, '2': {'agent_name': 'new_agent', 'prompt_type': 'Command', 'prompt': {'command_name': 'Write to File', 'filename': '{user_input}.txt', 'text': 'Poem:\\n{STEP1}'}, 'step': 2, 'response': 'File written to successfully.'}}\n" - ] - } - ], - "source": [ - "from agixtsdk import AGiXTSDK\n", - "\n", - "base_uri = \"http://localhost:7437\"\n", - "ApiClient = AGiXTSDK(base_uri=base_uri)\n", - "chain_name = \"Poem Writing Chain\"\n", - "chain = ApiClient.get_chain_responses(chain_name=chain_name)\n", - "print(\"Chain:\", chain)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -1591,8 +1560,7 @@ "\n", "base_uri = \"http://localhost:7437\"\n", "ApiClient = AGiXTSDK(base_uri=base_uri)\n", - "chain_name = \"Poem Writing Chain\"\n", - "delete_chain_resp = ApiClient.delete_chain(chain_name=chain_name)\n", + "delete_chain_resp = ApiClient.delete_chain(chain_name=\"Poem Writing Chain\")\n", "print(\"Delete chain response:\", delete_chain_resp)" ] },