Skip to content

Commit

Permalink
Merge pull request #599 from better629/feat_other_basemodel
Browse files Browse the repository at this point in the history
support Message() without content param
  • Loading branch information
geekan authored Dec 21, 2023
2 parents 5612631 + 64c5673 commit eac2ba1
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 14 deletions.
4 changes: 2 additions & 2 deletions examples/search_kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ async def search():
role = Sales(profile="Sales", store=store)
role._watch({Action})
queries = [
Message("Which facial cleanser is good for oily skin?", cause_by=Action),
Message("Is L'Oreal good to use?", cause_by=Action),
Message(content="Which facial cleanser is good for oily skin?", cause_by=Action),
Message(content="Is L'Oreal good to use?", cause_by=Action),
]
for query in queries:
logger.info(f"User: {query}")
Expand Down
15 changes: 12 additions & 3 deletions metagpt/roles/researcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,27 @@ async def _act(self) -> Message:
research_system_text = self.research_system_text(topic, todo)
if isinstance(todo, CollectLinks):
links = await todo.run(topic, 4, 4)
ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=todo)
ret = Message(
content="", instruct_content=Report(topic=topic, links=links), role=self.profile, cause_by=todo
)
elif isinstance(todo, WebBrowseAndSummarize):
links = instruct_content.links
todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items())
summaries = await asyncio.gather(*todos)
summaries = list((url, summary) for i in summaries for (url, summary) in i.items() if summary)
ret = Message("", Report(topic=topic, summaries=summaries), role=self.profile, cause_by=todo)
ret = Message(
content="", instruct_content=Report(topic=topic, summaries=summaries), role=self.profile, cause_by=todo
)
else:
summaries = instruct_content.summaries
summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries)
content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text)
ret = Message("", Report(topic=topic, content=content), role=self.profile, cause_by=self._rc.todo)
ret = Message(
content="",
instruct_content=Report(topic=topic, content=content),
role=self.profile,
cause_by=self._rc.todo,
)
self._rc.memory.add(ret)
return ret

Expand Down
3 changes: 2 additions & 1 deletion metagpt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class Message(BaseModel):
sent_from: str = ""
send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL})

def __init__(self, **kwargs):
def __init__(self, content: str = "", **kwargs):
ic = kwargs.get("instruct_content", None)
if ic and not isinstance(ic, BaseModel) and "class" in ic:
# compatible with custom-defined ActionOutput
Expand All @@ -122,6 +122,7 @@ def __init__(self, **kwargs):
kwargs["instruct_content"] = ic_new

kwargs["id"] = kwargs.get("id", uuid.uuid4().hex)
kwargs["content"] = kwargs.get("content", content)
kwargs["cause_by"] = any_to_str(
kwargs.get("cause_by", import_class("UserRequirement", "metagpt.actions.add_requirement"))
)
Expand Down
2 changes: 1 addition & 1 deletion metagpt/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SubscriptionRunner(BaseModel):
>>> async def trigger():
... while True:
... yield Message("the latest news about OpenAI")
... yield Message(content="the latest news about OpenAI")
... await asyncio.sleep(3600 * 24)
>>> async def callback(msg: Message):
Expand Down
2 changes: 1 addition & 1 deletion tests/metagpt/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_all_messages():
UserMessage(test_content),
SystemMessage(test_content),
AIMessage(test_content),
Message(test_content, role="QA"),
Message(content=test_content, role="QA"),
]
for msg in msgs:
assert msg.content == test_content
Expand Down
10 changes: 5 additions & 5 deletions tests/metagpt/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ async def test_subscription_run():

async def trigger():
while True:
yield Message("the latest news about OpenAI")
yield Message(content="the latest news about OpenAI")
await asyncio.sleep(3600 * 24)

class MockRole(Role):
async def run(self, message=None):
return Message("")
return Message(content="")

async def callback(message):
nonlocal callback_done
Expand Down Expand Up @@ -61,19 +61,19 @@ async def callback(message):
async def test_subscription_run_error(loguru_caplog):
async def trigger1():
while True:
yield Message("the latest news about OpenAI")
yield Message(content="the latest news about OpenAI")
await asyncio.sleep(3600 * 24)

async def trigger2():
yield Message("the latest news about OpenAI")
yield Message(content="the latest news about OpenAI")

class MockRole1(Role):
async def run(self, message=None):
raise RuntimeError

class MockRole2(Role):
async def run(self, message=None):
return Message("")
return Message(content="")

async def callback(msg: Message):
print(msg)
Expand Down
2 changes: 1 addition & 1 deletion tests/metagpt/utils/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Input(BaseModel):
Input(x=RunCode, want="metagpt.actions.run_code.RunCode"),
Input(x=RunCode(), want="metagpt.actions.run_code.RunCode"),
Input(x=Message, want="metagpt.schema.Message"),
Input(x=Message(""), want="metagpt.schema.Message"),
Input(x=Message(content=""), want="metagpt.schema.Message"),
Input(x="A", want="A"),
]
for i in inputs:
Expand Down

0 comments on commit eac2ba1

Please sign in to comment.