Skip to content

Commit

Permalink
chore: format
Browse files Browse the repository at this point in the history
  • Loading branch information
mathislucka committed Jan 9, 2025
1 parent 26394f1 commit ab03473
Showing 1 changed file with 101 additions and 53 deletions.
154 changes: 101 additions & 53 deletions test/core/pipeline/features/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def pipeline_that_has_a_component_with_only_default_inputs():
"answers": [
GeneratedAnswer(
data="Paris",
query="What " "is " "the " "capital " "of " "France?",
query="What is the capital of France?",
documents=[
Document(
id="413dccdf51a54cca75b7ed2eddac04e6e58560bd2f0caf4106a3efc023fe3651",
Expand Down Expand Up @@ -917,7 +917,7 @@ def fake_generator_run(self, generation_kwargs: Optional[Dict[str, Any]] = None,
pipe,
[
PipelineRunData(
inputs={"prompt_builder": {"query": "What is the capital of " "Italy?"}},
inputs={"prompt_builder": {"query": "What is the capital of Italy?"}},
expected_outputs={"router": {"correct_replies": ["Rome"]}},
expected_run_order=["prompt_builder", "generator", "router", "prompt_builder", "generator", "router"],
)
Expand Down Expand Up @@ -1809,7 +1809,11 @@ def run(self, create_document: bool = False):
],
)

@given("a pipeline that has a variadic component that receives partial inputs in a different order", target_fixture="pipeline_data")

@given(
"a pipeline that has a variadic component that receives partial inputs in a different order",
target_fixture="pipeline_data",
)
def that_has_a_variadic_component_that_receives_partial_inputs_different_order():
@component
class ConditionalDocumentCreator:
Expand Down Expand Up @@ -2281,13 +2285,15 @@ def that_has_a_string_variadic_component():
],
)


@given("a pipeline that is an agent that can use RAG", target_fixture="pipeline_data")
def an_agent_that_can_use_RAG():
@component
class FixedGenerator:
def __init__(self, replies):
self.replies = replies
self.idx = 0

@component.output_types(replies=List[str])
def run(self, prompt: str):
if self.idx < len(self.replies):
Expand All @@ -2304,7 +2310,11 @@ def run(self, prompt: str):
class FakeRetriever:
@component.output_types(documents=List[Document])
def run(self, query: str):
return {"documents": [Document(content="This is a document potentially answering the question.", meta={"access_group": 1})]}
return {
"documents": [
Document(content="This is a document potentially answering the question.", meta={"access_group": 1})
]
}

agent_prompt_template = """
Your task is to answer the user's question.
Expand Down Expand Up @@ -2379,16 +2389,17 @@ def run(self, query: str):
pp.connect("joiner.value", "concatenator.current_prompt")
pp.connect("concatenator.output", "joiner.value")




query = "Does this run reliably?"

return (
pp,
[
PipelineRunData(
inputs={"agent_prompt": {"query": query}, "rag_prompt": {"query": query}, "answer_builder": {"query": query}},
inputs={
"agent_prompt": {"query": query},
"rag_prompt": {"query": query},
"answer_builder": {"query": query},
},
expected_outputs={
"answer_builder": {
"answers": [GeneratedAnswer(data="answer: here is my answer", query=query, documents=[])]
Expand All @@ -2406,19 +2417,21 @@ def run(self, query: str):
"joiner",
"agent_llm",
"router",
"answer_builder"
"answer_builder",
],
)
],
)


@given("a pipeline that has a feedback loop", target_fixture="pipeline_data")
def has_feedback_loop():
@component
class FixedGenerator:
def __init__(self, replies):
self.replies = replies
self.idx = 0

@component.output_types(replies=List[str])
def run(self, prompt: str):
if self.idx < len(self.replies):
Expand All @@ -2431,7 +2444,6 @@ def run(self, prompt: str):

return {"replies": replies}


code_prompt_template = """
Generate code to solve the task: {{ task }}
Expand Down Expand Up @@ -2494,9 +2506,6 @@ def run(self, prompt: str):
pp.connect("code_llm.replies", "concatenator.current_prompt")
pp.connect("concatenator.output", "code_prompt.feedback")




task = "Generate code to generate christmas ascii-art"

return (
Expand All @@ -2505,24 +2514,35 @@ def run(self, prompt: str):
PipelineRunData(
inputs={"code_prompt": {"task": task}, "answer_builder": {"query": task}},
expected_outputs={
"answer_builder": {
"answers": [GeneratedAnswer(data="valid code", query=task, documents=[])]
}
"answer_builder": {"answers": [GeneratedAnswer(data="valid code", query=task, documents=[])]}
},
expected_run_order=[
'code_prompt', 'code_llm', 'feedback_prompt', 'feedback_llm', 'router', 'concatenator',
'code_prompt', 'code_llm', 'feedback_prompt', 'feedback_llm', 'router', 'answer_builder'],
"code_prompt",
"code_llm",
"feedback_prompt",
"feedback_llm",
"router",
"concatenator",
"code_prompt",
"code_llm",
"feedback_prompt",
"feedback_llm",
"router",
"answer_builder",
],
)
],
)


@given("a pipeline created in a non-standard order that has a loop", target_fixture="pipeline_data")
def has_non_standard_order_loop():
@component
class FixedGenerator:
def __init__(self, replies):
self.replies = replies
self.idx = 0

@component.output_types(replies=List[str])
def run(self, prompt: str):
if self.idx < len(self.replies):
Expand All @@ -2535,7 +2555,6 @@ def run(self, prompt: str):

return {"replies": replies}


code_prompt_template = """
Generate code to solve the task: {{ task }}
Expand All @@ -2551,9 +2570,6 @@ def run(self, prompt: str):
Provide additional feedback on why it fails.
"""




code_llm = FixedGenerator(replies=["invalid code", "valid code"])
code_prompt = PromptBuilder(template=code_prompt_template)

Expand Down Expand Up @@ -2592,7 +2608,6 @@ def run(self, prompt: str):

pp.add_component("answer_builder", answer_builder)


pp.connect("concatenator.output", "code_prompt.feedback")
pp.connect("code_prompt.prompt", "code_llm.prompt")
pp.connect("code_llm.replies", "feedback_prompt.code")
Expand All @@ -2603,7 +2618,6 @@ def run(self, prompt: str):
pp.connect("code_llm.replies", "router.code")
pp.connect("code_llm.replies", "concatenator.current_prompt")


task = "Generate code to generate christmas ascii-art"

return (
Expand All @@ -2612,32 +2626,44 @@ def run(self, prompt: str):
PipelineRunData(
inputs={"code_prompt": {"task": task}, "answer_builder": {"query": task}},
expected_outputs={
"answer_builder": {
"answers": [GeneratedAnswer(data="valid code", query=task, documents=[])]
}
"answer_builder": {"answers": [GeneratedAnswer(data="valid code", query=task, documents=[])]}
},
expected_run_order=[
'code_prompt', 'code_llm', 'feedback_prompt', 'feedback_llm', 'router', 'concatenator',
'code_prompt', 'code_llm', 'feedback_prompt', 'feedback_llm', 'router', 'answer_builder'],
"code_prompt",
"code_llm",
"feedback_prompt",
"feedback_llm",
"router",
"concatenator",
"code_prompt",
"code_llm",
"feedback_prompt",
"feedback_llm",
"router",
"answer_builder",
],
)
],
)


@given("a pipeline that has an agent with a feedback cycle", target_fixture="pipeline_data")
def agent_with_feedback_cycle():
@component
class FixedGenerator:
def __init__(self, replies):
self.replies = replies
self.idx = 0

@component.output_types(replies=List[str])
def run(self, prompt: str):
if self.idx < len(self.replies):
replies = [self.replies[self.idx]]
self.idx += 1
else:
self.idx = 1
self.idx = 0
replies = [self.replies[self.idx]]
self.idx += 1

return {"replies": replies}

Expand All @@ -2647,7 +2673,6 @@ class FakeFileEditor:
def run(self, replies: List[str]):
return {"files": "This is the edited file content."}


code_prompt_template = """
Generate code to solve the task: {{ task }}
Expand Down Expand Up @@ -2694,7 +2719,6 @@ def run(self, replies: List[str]):
]
feedback_router = ConditionalRouter(routes=routes)


tool_use_routes = [
{
"condition": "{{ 'Edit:' in replies[0] }}",
Expand All @@ -2711,11 +2735,9 @@ def run(self, replies: List[str]):
]
tool_use_router = ConditionalRouter(routes=tool_use_routes)


joiner = BranchJoiner(type_=str)
agent_concatenator = OutputAdapter(template="{{current_prompt + '\n' + files}}", output_type=str)


pp = Pipeline(max_runs_per_component=100)

pp.add_component("code_prompt", code_prompt)
Expand All @@ -2728,7 +2750,6 @@ def run(self, replies: List[str]):
pp.add_component("feedback_llm", feedback_llm)
pp.add_component("feedback_router", feedback_router)


# Main Agent
pp.connect("code_prompt.prompt", "joiner.value")
pp.connect("joiner.value", "code_llm.prompt")
Expand All @@ -2746,32 +2767,59 @@ def run(self, replies: List[str]):
pp.connect("agent_concatenator.output", "feedback_router.current_prompt")
pp.connect("feedback_router.fail", "joiner.value")


task = "Generate code to generate christmas ascii-art"

return (
pp,
[
PipelineRunData(
inputs={"code_prompt": {"task": task}},
expected_outputs={
"feedback_router": {
"pass": ["PASS"]
}
},
expected_outputs={"feedback_router": {"pass": ["PASS"]}},
expected_run_order=[
'code_prompt',

"joiner", "code_llm", "tool_use_router", "file_editor", "agent_concatenator",
"joiner", "code_llm", "tool_use_router", "file_editor", "agent_concatenator",
"joiner", "code_llm", "tool_use_router", "file_editor", "agent_concatenator",
"joiner", "code_llm", "tool_use_router", "feedback_prompt", "feedback_llm", "feedback_router",

"joiner", "code_llm", "tool_use_router", "file_editor", "agent_concatenator",
"joiner", "code_llm", "tool_use_router", "file_editor", "agent_concatenator",
"joiner", "code_llm", "tool_use_router", "file_editor", "agent_concatenator",
"joiner", "code_llm", "tool_use_router", "feedback_prompt", "feedback_llm", "feedback_router"
"code_prompt",
"joiner",
"code_llm",
"tool_use_router",
"file_editor",
"agent_concatenator",
"joiner",
"code_llm",
"tool_use_router",
"file_editor",
"agent_concatenator",
"joiner",
"code_llm",
"tool_use_router",
"file_editor",
"agent_concatenator",
"joiner",
"code_llm",
"tool_use_router",
"feedback_prompt",
"feedback_llm",
"feedback_router",
"joiner",
"code_llm",
"tool_use_router",
"file_editor",
"agent_concatenator",
"joiner",
"code_llm",
"tool_use_router",
"file_editor",
"agent_concatenator",
"joiner",
"code_llm",
"tool_use_router",
"file_editor",
"agent_concatenator",
"joiner",
"code_llm",
"tool_use_router",
"feedback_prompt",
"feedback_llm",
"feedback_router",
],
)
],
)
)

0 comments on commit ab03473

Please sign in to comment.