Skip to content

Commit

Permalink
Chain improvements for tracking step responses (#1196)
Browse files Browse the repository at this point in the history
* Chain improvements for tracking step responses

* fix requirement on user

* fix import issue, move run chain functions

* fix user ref

* fix user_email ref

* add user input to inference call

* fix role

* fix endpoint

* remove old endpoint

* rerun tests

* cascade on delete

* cascade when deleting chain
  • Loading branch information
Josh-XT authored May 30, 2024
1 parent dc1a28d commit 1271f9c
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 363 deletions.
117 changes: 102 additions & 15 deletions agixt/Chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Chain as ChainDB,
ChainStep,
ChainStepResponse,
ChainRun,
Agent,
Argument,
ChainStepArgument,
Expand All @@ -11,6 +12,9 @@
User,
)
from Globals import getenv, DEFAULT_USER
from Prompts import Prompts
from Conversations import Conversations
from Extensions import Extensions
import logging

logging.basicConfig(
Expand All @@ -20,9 +24,10 @@


class Chain:
def __init__(self, user=DEFAULT_USER):
def __init__(self, user=DEFAULT_USER, ApiClient=None):
self.session = get_session()
self.user = user
self.ApiClient = ApiClient
try:
user_data = self.session.query(User).filter(User.email == self.user).first()
self.user_id = user_data.id
Expand Down Expand Up @@ -430,12 +435,14 @@ def move_step(self, chain_name, current_step_number, new_step_number):
)
self.session.commit()

def get_step_response(self, chain_name, step_number="all"):
chain = self.get_chain(chain_name=chain_name)
def get_step_response(self, chain_name, chain_run_id=None, step_number="all"):
if chain_run_id is None:
chain_run_id = self.get_last_chain_run_id(chain_name=chain_name)
chain_data = self.get_chain(chain_name=chain_name)
if step_number == "all":
chain_steps = (
self.session.query(ChainStep)
.filter(ChainStep.chain_id == chain["id"])
.filter(ChainStep.chain_id == chain_data["id"])
.order_by(ChainStep.step_number)
.all()
)
Expand All @@ -444,7 +451,10 @@ def get_step_response(self, chain_name, step_number="all"):
for step in chain_steps:
chain_step_responses = (
self.session.query(ChainStepResponse)
.filter(ChainStepResponse.chain_step_id == step.id)
.filter(
ChainStepResponse.chain_step_id == step.id,
ChainStepResponse.chain_run_id == chain_run_id,
)
.order_by(ChainStepResponse.timestamp)
.all()
)
Expand All @@ -456,7 +466,7 @@ def get_step_response(self, chain_name, step_number="all"):
chain_step = (
self.session.query(ChainStep)
.filter(
ChainStep.chain_id == chain["id"],
ChainStep.chain_id == chain_data["id"],
ChainStep.step_number == step_number,
)
.first()
Expand All @@ -465,7 +475,10 @@ def get_step_response(self, chain_name, step_number="all"):
if chain_step:
chain_step_responses = (
self.session.query(ChainStepResponse)
.filter(ChainStepResponse.chain_step_id == chain_step.id)
.filter(
ChainStepResponse.chain_step_id == chain_step.id,
ChainStepResponse.chain_run_id == chain_run_id,
)
.order_by(ChainStepResponse.timestamp)
.all()
)
Expand Down Expand Up @@ -585,7 +598,9 @@ def import_chain(self, chain_name: str, steps: dict):

return f"Imported chain: {chain_name}"

def get_step_content(self, chain_name, prompt_content, user_input, agent_name):
def get_step_content(
self, chain_run_id, chain_name, prompt_content, user_input, agent_name
):
if isinstance(prompt_content, dict):
new_prompt_content = {}
for arg, value in prompt_content.items():
Expand All @@ -599,7 +614,9 @@ def get_step_content(self, chain_name, prompt_content, user_input, agent_name):
for i in range(step_count):
new_step_number = int(value.split("{STEP")[1].split("}")[0])
step_response = self.get_step_response(
chain_name=chain_name, step_number=new_step_number
chain_run_id=chain_run_id,
chain_name=chain_name,
step_number=new_step_number,
)
if step_response:
resp = (
Expand Down Expand Up @@ -629,7 +646,9 @@ def get_step_content(self, chain_name, prompt_content, user_input, agent_name):
prompt_content.split("{STEP")[1].split("}")[0]
)
step_response = self.get_step_response(
chain_name=chain_name, step_number=new_step_number
chain_run_id=chain_run_id,
chain_name=chain_name,
step_number=new_step_number,
)
if step_response:
resp = (
Expand All @@ -644,13 +663,17 @@ def get_step_content(self, chain_name, prompt_content, user_input, agent_name):
else:
return prompt_content

async def update_step_response(self, chain_name, step_number, response):
chain = self.get_chain(chain_name=chain_name)
async def update_step_response(
self, chain_run_id, chain_name, step_number, response
):
chain_step = self.get_step(chain_name=chain_name, step_number=step_number)
if chain_step:
existing_response = (
self.session.query(ChainStepResponse)
.filter(ChainStepResponse.chain_step_id == chain_step.id)
.filter(
ChainStepResponse.chain_step_id == chain_step.id,
ChainStepResponse.chain_run_id == chain_run_id,
)
.order_by(ChainStepResponse.timestamp.desc())
.first()
)
Expand All @@ -667,13 +690,77 @@ async def update_step_response(self, chain_name, step_number, response):
self.session.commit()
else:
chain_step_response = ChainStepResponse(
chain_step_id=chain_step.id, content=response
chain_step_id=chain_step.id,
chain_run_id=chain_run_id,
content=response,
)
self.session.add(chain_step_response)
self.session.commit()
else:
chain_step_response = ChainStepResponse(
chain_step_id=chain_step.id, content=response
chain_step_id=chain_step.id,
chain_run_id=chain_run_id,
content=response,
)
self.session.add(chain_step_response)
self.session.commit()

async def get_chain_run_id(self, chain_name):
chain_run = ChainRun(
chain_id=self.get_chain(chain_name=chain_name)["id"],
user_id=self.user_id,
)
self.session.add(chain_run)
self.session.commit()
return chain_run.id

async def get_last_chain_run_id(self, chain_name):
chain_data = self.get_chain(chain_name=chain_name)
chain_run = (
self.session.query(ChainRun)
.filter(ChainRun.chain_id == chain_data["id"])
.order_by(ChainRun.timestamp.desc())
.first()
)
if chain_run:
return chain_run.id
else:
return await self.get_chain_run_id(chain_name=chain_name)

def get_chain_args(self, chain_name):
skip_args = [
"command_list",
"context",
"COMMANDS",
"date",
"conversation_history",
"agent_name",
"working_directory",
"helper_agent_name",
]
chain_data = self.get_chain(chain_name=chain_name)
steps = chain_data["steps"]
prompt_args = []
args = []
for step in steps:
try:
prompt = step["prompt"]
if "prompt_name" in prompt:
prompt_text = Prompts(user=self.user).get_prompt(
prompt_name=prompt["prompt_name"]
)
args = Prompts(user=self.user).get_prompt_args(
prompt_text=prompt_text
)
elif "command_name" in prompt:
args = Extensions().get_command_args(
command_name=prompt["command_name"]
)
elif "chain_name" in prompt:
args = self.get_chain_args(chain_name=prompt["chain_name"])
for arg in args:
if arg not in prompt_args and arg not in skip_args:
prompt_args.append(arg)
except Exception as e:
logging.error(f"Error getting chain args: {e}")
return prompt_args
Loading

0 comments on commit 1271f9c

Please sign in to comment.