Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Agents to support code and rag simultaneously #908

Merged
merged 4 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing

from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin

Expand Down Expand Up @@ -476,9 +477,12 @@ async def _run(
)
span.set_attribute("output", retrieved_context)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
if retrieved_context:
last_message = input_messages[-1]
last_message.context = retrieved_context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there another place where we set the context for RAG?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah how does RAG even work if you don't do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, all tests were passing for RAG but then i realized there was nothing getting added to the context.
I tested this change and now can see RAG work properly.


# append retrieved_context to the last user message
for message in input_messages[::-1]:
if isinstance(message, UserMessage):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other message type could this be if it's triggering RAG?

Separately curious why for RAG results we use .context but for other tool execs we do 708 input_messages = input_messages + [message, result_message]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ehhuang I think @hardikjshah did this so you could identify which message you added context to and then in the next turn get rid of it. we needed to keep only one context around as the turns proceeded.

I really think we should nuke that .context field completely and manage whatever state we need to manage completely within agent_instance.py (and ensure the invariant above)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other message type could this be if it's triggering RAG?

That was the original bug , where if both code and rag are enabled, we end up with a ToolResponseMessage coming from the code_interpreter side (check out def handle_documents)

And it was erroring since ToolResponseMessage does not have any attr context.

I agree that we should find a better solution to context. May be RAG responses are also passed in the ToolResponseMessage like all other tools. Although outside the scope of this PR.

message.context = retrieved_context
break

output_attachments = []

Expand Down
69 changes: 67 additions & 2 deletions tests/client-sdk/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_code_interpreter_for_attachments(llama_stack_client, agent_config):
}

codex_agent = Agent(llama_stack_client, agent_config)
session_id = codex_agent.create_session("test-session")
session_id = codex_agent.create_session(f"test-session-{uuid4()}")
inflation_doc = AgentDocument(
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
mime_type="text/csv",
Expand Down Expand Up @@ -299,7 +299,7 @@ def test_rag_agent(llama_stack_client, agent_config):
],
}
rag_agent = Agent(llama_stack_client, agent_config)
session_id = rag_agent.create_session("test-session")
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
"What are the top 5 topics that were explained? Only list succinct bullet points.",
]
Expand All @@ -312,3 +312,68 @@ def test_rag_agent(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:query_from_memory" in logs_str


def test_rag_and_code_agent(llama_stack_client, agent_config):
urls = ["chat.rst"]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
vector_db_id = "test-vector-db"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=128,
)
agent_config = {
**agent_config,
"toolgroups": [
dict(
name="builtin::rag",
args={"vector_db_ids": [vector_db_id]},
),
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client, agent_config)
inflation_doc = Document(
document_id="test_csv",
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
mime_type="text/csv",
metadata={},
)
user_prompts = [
(
"Here is a csv file, can you describe it?",
[inflation_doc],
"code_interpreter",
),
(
"What are the top 5 topics that were explained? Only list succinct bullet points.",
[],
"query_from_memory",
),
]

for prompt, docs, tool_name in user_prompts:
print(f"User> {prompt}")
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
documents=docs,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert f"Tool:{tool_name}" in logs_str