Skip to content

Commit

Permalink
Use decorator for "register_auto_reply"
Browse files Browse the repository at this point in the history
  • Loading branch information
BeibinLi committed Aug 4, 2023
1 parent 2208dfb commit aec518c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 20 deletions.
3 changes: 2 additions & 1 deletion flaml/autogen/agentchat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from .agent import Agent
from .responsive_agent import ResponsiveAgent
from .responsive_agent import ResponsiveAgent, register_auto_reply
from .assistant_agent import AssistantAgent
from .user_proxy_agent import UserProxyAgent
from .groupchat import GroupChatManager

__all__ = [
"Agent",
"ResponsiveAgent",
"register_auto_reply",
"AssistantAgent",
"UserProxyAgent",
"GroupChatManager",
Expand Down
4 changes: 2 additions & 2 deletions flaml/autogen/agentchat/contrib/math_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Union
from time import sleep

from flaml.autogen.agentchat import Agent, UserProxyAgent
from flaml.autogen.agentchat import Agent, UserProxyAgent, register_auto_reply
from flaml.autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
from flaml.autogen.math_utils import get_answer

Expand Down Expand Up @@ -165,7 +165,6 @@ def __init__(
default_auto_reply=default_auto_reply,
**kwargs,
)
self.register_auto_reply(Agent, self._generate_math_reply, 1)
# fixed var
self._max_invalid_q_per_step = max_invalid_q_per_step

Expand Down Expand Up @@ -276,6 +275,7 @@ def execute_one_wolfram_query(self, query: str):
is_success = False
return output, is_success

@register_auto_reply(Agent, 1)
def _generate_math_reply(
self,
messages: Optional[List[Dict]] = None,
Expand Down
4 changes: 2 additions & 2 deletions flaml/autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from typing import Dict, List, Optional, Union
from .agent import Agent
from .responsive_agent import ResponsiveAgent
from .responsive_agent import ResponsiveAgent, register_auto_reply


class GroupChatManager(ResponsiveAgent):
Expand Down Expand Up @@ -35,12 +35,12 @@ def __init__(
human_input_mode=human_input_mode,
**kwargs,
)
self.register_auto_reply(Agent, self._generate_reply_for_participant)
self.max_round = max_round
self._agent_names = []
self._messages = []
# self._random = random.Random(seed)

@register_auto_reply(Agent)
def _generate_reply_for_participant(
self,
messages: Optional[List[Dict]] = None,
Expand Down
44 changes: 29 additions & 15 deletions flaml/autogen/agentchat/responsive_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@
def colored(x, *args, **kwargs):
return x

def register_auto_reply(class_type, position=0):
"""Register a class-specific reply function.
We achieve this by decorate the function with _regster_for and _insert_pos attributes.
The class-specific reply function will be called when the sender is an instance of the class_type.
The function registered later will be checked earlier by default.
To change the order, set the position to a positive integer.
The decorated reply_func (Callable): the reply function.
Args:
class_type (Class): the class type.
position (int): the position of the reply function in the reply function list.
"""
def decorator(reply_func):
reply_func._registered_for = class_type
reply_func._insert_pos = position
return reply_func
return decorator

class ResponsiveAgent(Agent):
"""(Experimental) A class for generic responsive agents which can be configured as assistant or user proxy.
Expand Down Expand Up @@ -110,24 +129,15 @@ def __init__(
self._default_auto_reply = default_auto_reply
self._class_specific_reply = []
self.reply_at_receive = defaultdict(bool)
self.register_auto_reply(Agent, self._generate_oai_reply)
self.register_auto_reply(Agent, self._generate_code_execution_reply)
self.register_auto_reply(Agent, self._generate_function_call_reply)
self.register_auto_reply(Agent, self._check_termination_and_human_reply)

def register_auto_reply(self, class_type, reply_func: Callable, position: int = 0):
"""Register a class-specific reply function.
# Handle class-specific reply defined in the "register_auto_reply" decorator.
for name, method in vars(type(self)).items():
if hasattr(method, '_registered_for') and hasattr(method, '_insert_pos'):
self._class_specific_reply.insert(method._insert_pos, (method._registered_for, method))



The class-specific reply function will be called when the sender is an instance of the class_type.
The function registered later will be checked earlier by default.
To change the order, set the position to a positive integer.

Args:
class_type (Class): the class type.
reply_func (Callable): the reply function.
position (int): the position of the reply function in the reply function list.
"""
self._class_specific_reply.insert(position, (class_type, reply_func))

@property
def system_message(self):
Expand Down Expand Up @@ -388,6 +398,7 @@ def clear_history(self, agent: Optional[Agent] = None):
else:
self._oai_messages[agent].clear()

@register_auto_reply(Agent)
def _generate_oai_reply(
self,
messages: Optional[List[Dict]] = None,
Expand All @@ -404,6 +415,7 @@ def _generate_oai_reply(
)
return True, oai.ChatCompletion.extract_text_or_function_call(response)[0]

@register_auto_reply(Agent)
def _generate_code_execution_reply(
self,
messages: Optional[List[Dict]] = None,
Expand All @@ -426,6 +438,7 @@ def _generate_code_execution_reply(
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed"
return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}"

@register_auto_reply(Agent)
def _generate_function_call_reply(
self,
messages: Optional[List[Dict]] = None,
Expand All @@ -439,6 +452,7 @@ def _generate_function_call_reply(
return True, func_return
return False, None

@register_auto_reply(Agent)
def _check_termination_and_human_reply(
self,
messages: Optional[List[Dict]] = None,
Expand Down

0 comments on commit aec518c

Please sign in to comment.