Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use decorator for "register_auto_reply" #1176

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flaml/autogen/agentchat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from .agent import Agent
from .responsive_agent import ResponsiveAgent
from .responsive_agent import ResponsiveAgent, register_auto_reply
from .assistant_agent import AssistantAgent
from .user_proxy_agent import UserProxyAgent
from .groupchat import GroupChatManager

__all__ = [
"Agent",
"ResponsiveAgent",
"register_auto_reply",
"AssistantAgent",
"UserProxyAgent",
"GroupChatManager",
Expand Down
4 changes: 2 additions & 2 deletions flaml/autogen/agentchat/contrib/math_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Union
from time import sleep

from flaml.autogen.agentchat import Agent, UserProxyAgent
from flaml.autogen.agentchat import Agent, UserProxyAgent, register_auto_reply
from flaml.autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
from flaml.autogen.math_utils import get_answer

Expand Down Expand Up @@ -165,7 +165,6 @@ def __init__(
default_auto_reply=default_auto_reply,
**kwargs,
)
self.register_auto_reply(Agent, self._generate_math_reply, 1)
# fixed var
self._max_invalid_q_per_step = max_invalid_q_per_step

Expand Down Expand Up @@ -276,6 +275,7 @@ def execute_one_wolfram_query(self, query: str):
is_success = False
return output, is_success

@register_auto_reply(Agent, 1)
def _generate_math_reply(
self,
messages: Optional[List[Dict]] = None,
Expand Down
4 changes: 2 additions & 2 deletions flaml/autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from typing import Dict, List, Optional, Union
from .agent import Agent
from .responsive_agent import ResponsiveAgent
from .responsive_agent import ResponsiveAgent, register_auto_reply


class GroupChatManager(ResponsiveAgent):
Expand Down Expand Up @@ -35,12 +35,12 @@ def __init__(
human_input_mode=human_input_mode,
**kwargs,
)
self.register_auto_reply(Agent, self._generate_reply_for_participant)
self.max_round = max_round
self._agent_names = []
self._messages = []
# self._random = random.Random(seed)

@register_auto_reply(Agent)
def _generate_reply_for_participant(
self,
messages: Optional[List[Dict]] = None,
Expand Down
51 changes: 33 additions & 18 deletions flaml/autogen/agentchat/responsive_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,29 @@ def colored(x, *args, **kwargs):
return x


def register_auto_reply(class_type, position=0):
"""Register a class-specific reply function.
We achieve this by decorate the function with _regster_for and _insert_pos attributes.
BeibinLi marked this conversation as resolved.
Show resolved Hide resolved

The class-specific reply function will be called when the sender is an instance of the class_type.
The function registered later will be checked earlier by default.
To change the order, set the position to a positive integer.

The decorated reply_func (Callable): the reply function.

Args:
class_type (Class): the class type.
position (int): the position of the reply function in the reply function list.
"""
BeibinLi marked this conversation as resolved.
Show resolved Hide resolved

def decorator(reply_func):
reply_func._registered_for = class_type
reply_func._insert_pos = position
return reply_func

return decorator


class ResponsiveAgent(Agent):
"""(Experimental) A class for generic responsive agents which can be configured as assistant or user proxy.

Expand Down Expand Up @@ -110,24 +133,12 @@ def __init__(
self._default_auto_reply = default_auto_reply
self._class_specific_reply = []
self.reply_at_receive = defaultdict(bool)
self.register_auto_reply(Agent, self._generate_oai_reply)
self.register_auto_reply(Agent, self._generate_code_execution_reply)
self.register_auto_reply(Agent, self._generate_function_call_reply)
self.register_auto_reply(Agent, self._check_termination_and_human_reply)

def register_auto_reply(self, class_type, reply_func: Callable, position: int = 0):
"""Register a class-specific reply function.

The class-specific reply function will be called when the sender is an instance of the class_type.
The function registered later will be checked earlier by default.
To change the order, set the position to a positive integer.

Args:
class_type (Class): the class type.
reply_func (Callable): the reply function.
position (int): the position of the reply function in the reply function list.
"""
self._class_specific_reply.insert(position, (class_type, reply_func))
# Handle class-specific reply defined in the "register_auto_reply" decorator.
for cls in reversed(type(self).__mro__): # loop all supers
BeibinLi marked this conversation as resolved.
Show resolved Hide resolved
for name, method in vars(cls).items(): # loop all functions
if hasattr(method, "_registered_for") and hasattr(method, "_insert_pos"):
self._class_specific_reply.insert(method._insert_pos, (method._registered_for, method))

@property
def system_message(self):
Expand Down Expand Up @@ -388,6 +399,7 @@ def clear_history(self, agent: Optional[Agent] = None):
else:
self._oai_messages[agent].clear()

@register_auto_reply(Agent)
def _generate_oai_reply(
self,
messages: Optional[List[Dict]] = None,
Expand All @@ -404,6 +416,7 @@ def _generate_oai_reply(
)
return True, oai.ChatCompletion.extract_text_or_function_call(response)[0]

@register_auto_reply(Agent)
def _generate_code_execution_reply(
self,
messages: Optional[List[Dict]] = None,
Expand All @@ -426,6 +439,7 @@ def _generate_code_execution_reply(
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}"

@register_auto_reply(Agent)
def _generate_function_call_reply(
self,
messages: Optional[List[Dict]] = None,
Expand All @@ -439,6 +453,7 @@ def _generate_function_call_reply(
return True, func_return
return False, None

@register_auto_reply(Agent)
def _check_termination_and_human_reply(
self,
messages: Optional[List[Dict]] = None,
Expand Down Expand Up @@ -542,7 +557,7 @@ def generate_reply(
if isinstance(sender, class_specifc_reply[0]) and (
not exclude or class_specifc_reply[1] not in exclude
):
final, reply = class_specifc_reply[1](messages, sender)
final, reply = class_specifc_reply[1](self, messages, sender)
if final:
return reply
return self._default_auto_reply
Expand Down
Loading