diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index eca7364d7b..706dd74f1b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -66,6 +66,7 @@ from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing + from .persistence import AgentPersistence from .safety import SafetyException, ShieldRunnerMixin @@ -476,9 +477,12 @@ async def _run( ) span.set_attribute("output", retrieved_context) span.set_attribute("tool_name", MEMORY_QUERY_TOOL) - if retrieved_context: - last_message = input_messages[-1] - last_message.context = retrieved_context + + # append retrieved_context to the last user message + for message in input_messages[::-1]: + if isinstance(message, UserMessage): + message.context = retrieved_context + break output_attachments = [] diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 4a8fdd36a6..e0f86e3d7f 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -211,7 +211,7 @@ def test_code_interpreter_for_attachments(llama_stack_client, agent_config): } codex_agent = Agent(llama_stack_client, agent_config) - session_id = codex_agent.create_session("test-session") + session_id = codex_agent.create_session(f"test-session-{uuid4()}") inflation_doc = AgentDocument( content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", mime_type="text/csv", @@ -285,7 +285,8 @@ def test_rag_agent(llama_stack_client, agent_config): llama_stack_client.tool_runtime.rag_tool.insert( documents=documents, vector_db_id=vector_db_id, - chunk_size_in_tokens=512, + # small chunks help to get specific info out of the docs + chunk_size_in_tokens=128, ) agent_config = { **agent_config, @@ -299,11 +300,15 @@ def test_rag_agent(llama_stack_client, agent_config): ], } rag_agent = Agent(llama_stack_client, agent_config) - session_id = rag_agent.create_session("test-session") + session_id = rag_agent.create_session(f"test-session-{uuid4()}") user_prompts = [ - "What are the top 5 topics that were explained? Only list succinct bullet points.", + ( + "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", + "grouped-query", + ), + ("What command to use to get access to Llama3-8B-Instruct ?", "tune download"), ] - for prompt in user_prompts: + for prompt, expected_kw in user_prompts: print(f"User> {prompt}") response = rag_agent.create_turn( messages=[{"role": "user", "content": prompt}], @@ -312,3 +317,69 @@ def test_rag_agent(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "Tool:query_from_memory" in logs_str + assert expected_kw in logs_str.lower() + + +def test_rag_and_code_agent(llama_stack_client, agent_config): + urls = ["chat.rst"] + documents = [ + Document( + document_id=f"num-{i}", + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + metadata={}, + ) + for i, url in enumerate(urls) + ] + vector_db_id = "test-vector-db" + llama_stack_client.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model="all-MiniLM-L6-v2", + embedding_dimension=384, + ) + llama_stack_client.tool_runtime.rag_tool.insert( + documents=documents, + vector_db_id=vector_db_id, + chunk_size_in_tokens=128, + ) + agent_config = { + **agent_config, + "toolgroups": [ + dict( + name="builtin::rag", + args={"vector_db_ids": [vector_db_id]}, + ), + "builtin::code_interpreter", + ], + } + agent = Agent(llama_stack_client, agent_config) + inflation_doc = Document( + document_id="test_csv", + content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", + mime_type="text/csv", + metadata={}, + ) + user_prompts = [ + ( + "Here is a csv file, can you describe it?", + [inflation_doc], + "code_interpreter", + ), + ( + "What are the top 5 topics that were explained? Only list succinct bullet points.", + [], + "query_from_memory", + ), + ] + + for prompt, docs, tool_name in user_prompts: + print(f"User> {prompt}") + session_id = agent.create_session(f"test-session-{uuid4()}") + response = agent.create_turn( + messages=[{"role": "user", "content": prompt}], + session_id=session_id, + documents=docs, + ) + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + assert f"Tool:{tool_name}" in logs_str