Skip to content

Commit

Permalink
Add Fork conversation function (#1241)
Browse files Browse the repository at this point in the history
* Fork conversations

* Add test for forking conversation

* Fix conversation used

* Improve fork test
  • Loading branch information
Josh-XT authored Sep 7, 2024
1 parent 6dfb0e6 commit fc8c32a
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 1 deletion.
66 changes: 66 additions & 0 deletions agixt/Conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,72 @@ def get_conversation(self, limit=100, page=1):
session.close()
return {"interactions": return_messages}

def fork_conversation(self, message_id):
session = get_session()
user_data = session.query(User).filter(User.email == self.user).first()
user_id = user_data.id

# Get the original conversation
original_conversation = (
session.query(Conversation)
.filter(
Conversation.name == self.conversation_name,
Conversation.user_id == user_id,
)
.first()
)

if not original_conversation:
logging.info(f"No conversation found to fork.")
session.close()
return None

# Get all messages up to and including the specified message_id
messages = (
session.query(Message)
.filter(
Message.conversation_id == original_conversation.id,
Message.id <= message_id,
)
.order_by(Message.timestamp.asc())
.all()
)

if not messages:
logging.info(f"No messages found in the conversation to fork.")
session.close()
return None

# Create a new conversation
new_conversation_name = (
f"{self.conversation_name}_fork_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
)
new_conversation = Conversation(name=new_conversation_name, user_id=user_id)
session.add(new_conversation)
session.flush() # This will assign an id to new_conversation

# Copy messages to the new conversation
for message in messages:
new_message = Message(
role=message.role,
content=message.content,
conversation_id=new_conversation.id,
timestamp=message.timestamp,
updated_at=message.updated_at,
updated_by=message.updated_by,
feedback_received=message.feedback_received,
)
session.add(new_message)

session.commit()
forked_conversation_id = str(new_conversation.id)
session.close()

logging.info(
f"Conversation forked successfully. New conversation ID: {forked_conversation_id}"
)
return new_conversation_name

def get_activities(self, limit=100, page=1):
session = get_session()
user_data = session.query(User).filter(User.email == self.user).first()
Expand Down
5 changes: 5 additions & 0 deletions agixt/Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ class RenameConversationModel(BaseModel):
new_conversation_name: Optional[str] = "-"


class ConversationFork(BaseModel):
conversation_name: str
message_id: str


class TTSInput(BaseModel):
text: str

Expand Down
15 changes: 15 additions & 0 deletions agixt/endpoints/Conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
RenameConversationModel,
UpdateMessageModel,
DeleteMessageModel,
ConversationFork,
)
import json
from datetime import datetime
Expand Down Expand Up @@ -262,3 +263,17 @@ async def rename_conversation(
role=rename.agent_name,
)
return {"conversation_name": rename.new_conversation_name}


@app.post(
"/api/conversation/fork",
tags=["Conversation"],
dependencies=[Depends(verify_api_key)],
)
async def fork_conversation(
fork: ConversationFork, user=Depends(verify_api_key)
) -> ResponseMessage:
new_conversation_name = Conversations(
conversation_name=fork.conversation_name, user=user
).fork_conversation(message_id=fork.message_id)
return ResponseMessage(message=f"Forked conversation to {new_conversation_name}")
35 changes: 34 additions & 1 deletion tests/endpoint-tests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,31 @@
"conversations = agixt.get_conversations(agent_name=\"new_agent\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Manual Conversation Message"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agixt.new_conversation_message(\n",
" role=\"USER\",\n",
" conversation_name=\"AGiXT Conversation\",\n",
" message=\"This is a test message from the user!\",\n",
")\n",
"agixt.new_conversation_message(\n",
" role=\"new_agent\",\n",
" conversation_name=\"AGiXT Conversation\",\n",
" message=\"This is a test message from the agent!\",\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -742,7 +767,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Manual Conversation Message"
"## Fork a Conversation"
]
},
{
Expand All @@ -751,6 +776,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Add an extra interaction to the conversation so that there is more than just one\n",
"agixt.new_conversation_message(\n",
" role=\"USER\",\n",
" conversation_name=\"AGiXT Conversation\",\n",
Expand All @@ -760,6 +786,13 @@
" role=\"new_agent\",\n",
" conversation_name=\"AGiXT Conversation\",\n",
" message=\"This is a test message from the agent!\",\n",
")\n",
"forked_conversation = agixt.fork_conversation(\n",
" conversation_name=conversation_name, message_id=conversation[1][\"id\"]\n",
")\n",
"fork = agixt.get_conversation(\n",
" agent_name=agent_name,\n",
" conversation_name=forked_conversation,\n",
")"
]
},
Expand Down

0 comments on commit fc8c32a

Please sign in to comment.