Skip to content

Commit

Permalink
Merge pull request #1036 from PrefectHQ/end-turn
Browse files Browse the repository at this point in the history
Refactor agents/teams/end-turn functions
  • Loading branch information
jlowin authored Jan 29, 2025
2 parents 7bb8c17 + dbf8f36 commit a2f63b9
Show file tree
Hide file tree
Showing 21 changed files with 569 additions and 793 deletions.
13 changes: 13 additions & 0 deletions .cursor/rules/python-style.mdc
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
description:
globs:
---

# Python Style

- Marvin supports Python 3.10+
- Use modern syntax and full type annotations
- Use | instead of Union
- Use builtins like list, dict instead of typing.List and typing.Dict
- use | None instead of Optional
- Always type optional keyword arguments correctly e.g. def f(x: int | None = None)
50 changes: 27 additions & 23 deletions src/marvin/agents/actor.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import uuid
from collections.abc import Callable
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Callable, Sequence

import pydantic_ai
from pydantic_ai.result import RunResult

import marvin
import marvin.utilities.asyncio
from marvin.engine.llm import Message
from marvin.memory.memory import Memory
from marvin.prompts import Template
from marvin.thread import Thread

if TYPE_CHECKING:
from marvin.agents.team import Team
from marvin.engine.end_turn import EndTurn
from marvin.engine.orchestrator import Orchestrator


@dataclass(kw_only=True)
class Actor:
class Actor(ABC):
id: str = field(
default_factory=lambda: uuid.uuid4().hex[:8],
metadata={"description": "Unique identifier for this actor"},
Expand Down Expand Up @@ -47,41 +49,43 @@ class Actor:
def __hash__(self) -> int:
return hash(self.id)

def get_delegates(self) -> list["Actor"] | None:
"""A list of actors that this actor can delegate to."""
return None

def get_agentlet(
@abstractmethod
async def _run(
self,
result_types: list[type],
tools: list[Callable[..., Any]] | None = None,
**kwargs: Any,
) -> pydantic_ai.Agent[Any, Any]:
raise NotImplementedError("Subclass must implement get_agentlet")
messages: list[Message],
tools: Sequence[Callable[..., Any]],
end_turn_tools: Sequence["EndTurn"],
) -> RunResult:
raise NotImplementedError("Actor subclasses must implement _run")

def start_turn(self):
async def start_turn(self, orchestrator: "Orchestrator"):
"""Called when the actor starts its turn."""
pass

def end_turn(self):
async def end_turn(self, orchestrator: "Orchestrator", result: RunResult):
"""Called when the actor ends its turn."""
pass

def get_tools(self) -> list[Callable[..., Any]]:
"""A list of tools that this actor can use during its turn."""
return []

def get_memories(self) -> list[Memory]:
"""A list of memories that this actor can use during its turn."""
def get_end_turn_tools(self) -> list["EndTurn"]:
"""A list of `EndTurn` tools that this actor can use to end its turn."""
return []

def get_end_turn_tools(self) -> list[type["EndTurn"]]:
"""A list of `EndTurn` tools that this actor can use to end its turn."""
def get_memories(self) -> list[Memory]:
"""A list of memories that this actor can use during its turn."""
return []

def get_prompt(self) -> str:
return Template(source=self.prompt).render()
return Template(source=self.prompt).render(actor=self)

def friendly_name(self) -> str:
return f'{self.__class__.__name__} "{self.name}" ({self.id})'
def friendly_name(self, verbose: bool = True) -> str:
if verbose:
return f'{self.__class__.__name__} "{self.name}" ({self.id})'
else:
return self.name

async def run_async(
self,
Expand Down
143 changes: 98 additions & 45 deletions src/marvin/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,30 @@
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, TypeVar, Union
from typing import TYPE_CHECKING, Any, TypeVar, Union

import pydantic_ai
from pydantic_ai.messages import ModelRequestPart, RetryPromptPart, ToolCallPart
from pydantic_ai.models import KnownModelName, Model, ModelSettings
from pydantic_ai.result import AgentDepsT, RunContext, RunResult

import marvin
import marvin.engine.llm
from marvin.agents.actor import Actor
from marvin.agents.names import AGENT_NAMES
from marvin.agents.team import SoloTeam, Swarm, Team
from marvin.memory.memory import Memory
from marvin.prompts import Template
from marvin.tools.thread import post_message_to_agents

from .actor import Actor
from marvin.utilities.logging import get_logger
from marvin.utilities.tools import wrap_tool_errors
from marvin.utilities.types import issubclass_safe

logger = get_logger(__name__)
T = TypeVar("T")

if TYPE_CHECKING:
from marvin.engine.end_turn import EndTurn
from marvin.engine.handlers import AsyncHandler, Handler
from marvin.engine.llm import Message


@dataclass(kw_only=True)
class Agent(Actor):
Expand Down Expand Up @@ -59,14 +66,6 @@ class Agent(Actor):
repr=False,
)

delegates: list[Actor] | None = field(
default=None,
repr=False,
metadata={
"description": "List of agents that this agent can delegate to. Provide an empty list if this agent can not delegate.",
},
)

prompt: str | Path = field(
default=Path("agent.jinja"),
metadata={"description": "Template for the agent's prompt"},
Expand All @@ -76,24 +75,14 @@ class Agent(Actor):
def __hash__(self) -> int:
return super().__hash__()

def friendly_name(self) -> str:
return f'Agent "{self.name}" ({self.id})'

def get_delegates(self) -> list[Actor] | None:
return self.delegates

def get_model(self) -> Model | KnownModelName:
return self.model or marvin.defaults.model

def get_memories(self) -> list[Memory]:
return self.memories

def get_tools(self) -> list[Callable[..., Any]]:
return (
self.tools
+ [t for m in self.memories for t in m.get_tools()]
+ [post_message_to_agents]
)
return self.tools + [t for m in self.memories for t in m.get_tools()]

def get_memories(self) -> list[Memory]:
return list(self.memories)

def get_model_settings(self) -> ModelSettings:
defaults: ModelSettings = {}
Expand All @@ -103,31 +92,95 @@ def get_model_settings(self) -> ModelSettings:

def get_agentlet(
self,
result_types: list[type],
result_type: type,
tools: list[Callable[..., Any]] | None = None,
**kwargs: Any,
handlers: list["Handler | AsyncHandler"] | None = None,
result_tool_name: str | None = None,
) -> pydantic_ai.Agent[Any, Any]:
if len(result_types) == 1:
result_type = result_types[0]
else:
result_type = Union[tuple(result_types)]
from marvin.engine.events import Event
from marvin.engine.handlers import AsyncHandler

async def handle_event(event: Event):
for handler in handlers or []:
if isinstance(handler, AsyncHandler):
await handler.handle(event)
else:
handler.handle(event)

tools = [wrap_tool_errors(tool) for tool in tools or []]

return pydantic_ai.Agent[Any, result_type]( # type: ignore
agentlet = pydantic_ai.Agent[Any, result_type]( # type: ignore
model=self.get_model(),
result_type=result_type,
tools=self.get_tools() + (tools or []),
tools=tools,
model_settings=self.get_model_settings(),
end_strategy="exhaustive",
**kwargs,
result_tool_name=result_tool_name or "EndTurn",
result_tool_description="This tool will end your turn. You may only use one EndTurn tool per turn.",
)

def get_prompt(self) -> str:
return Template(source=self.prompt).render(agent=self)
from marvin.engine.events import (
ToolCallEvent,
ToolRetryEvent,
ToolReturnEvent,
)

def as_team(self, team_class: Callable[[list[Actor]], Team] | None = None) -> Team:
all_agents = [self] + (self.delegates or [])
if len(all_agents) == 1:
team_class = team_class or SoloTeam
for tool in agentlet._function_tools.values(): # type: ignore[reportPrivateUsage]
# Wrap the tool run function to emit events for each call / result
# this can be removed when Pydantic AI supports streaming events
async def run(
message: ToolCallPart,
run_context: RunContext[AgentDepsT],
# pass as arg to avoid late binding issues
original_run: Callable[..., Any] = tool.run,
) -> ModelRequestPart:
await handle_event(ToolCallEvent(actor=self, message=message))
result = await original_run(message, run_context)
if isinstance(result, RetryPromptPart):
await handle_event(ToolRetryEvent(message=result))
else:
await handle_event(ToolReturnEvent(message=result))
return result

tool.run = run

return agentlet

async def _run(
self,
messages: list["Message"],
tools: list[Callable[..., Any]],
end_turn_tools: list["EndTurn"],
) -> RunResult:
from marvin.engine.end_turn import EndTurn

tools = tools + self.get_tools()
end_turn_tools = end_turn_tools + self.get_end_turn_tools()

# if any tools are actually EndTurn classes, remove them from tools and
# add them to end turn tools
for t in tools:
if issubclass_safe(t, EndTurn):
tools.remove(t)
end_turn_tools.append(t)

if not end_turn_tools:
result_type = [EndTurn]

if len(end_turn_tools) == 1:
result_type = end_turn_tools[0]
result_tool_name = result_type.__name__
else:
team_class = team_class or Swarm
return team_class(agents=all_agents)
result_type = Union[tuple(end_turn_tools)]
result_tool_name = "EndTurn"

agentlet = self.get_agentlet(
result_type=result_type,
tools=tools,
result_tool_name=result_tool_name,
)
result = await agentlet.run("", message_history=messages)
return result

def get_prompt(self) -> str:
return Template(source=self.prompt).render(agent=self)
Loading

0 comments on commit a2f63b9

Please sign in to comment.