Skip to content

Commit

Permalink
支持异步chat
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyuxin committed Sep 3, 2023
1 parent 729239b commit 6d6a382
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
22 changes: 22 additions & 0 deletions lmclient/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def chat(self, user_input: str, **extra_parameters: Any) -> str:
else:
return self._recursive_function_call(reply, parameters)

async def async_chat(self, user_input: str, **extra_parameters: Any) -> str:
parameters = self._parameters.model_copy(update=extra_parameters)
self.history.append(Message(role='user', content=user_input))
model_response = await self._chat_model.async_chat_completion(self.history, parameters)
self.history.extend(model_response.messages)
if isinstance(reply := model_response.messages[-1].content, str):
return reply
else:
return await self._async_recursive_function_call(reply, parameters)

def run_function_call(self, function_call: FunctionCallDict):
function = None
for i in self.functions:
Expand Down Expand Up @@ -103,5 +113,17 @@ def _recursive_function_call(self, function_call: FunctionCallDict, parameters:
else:
return self._recursive_function_call(reply, parameters)

async def _async_recursive_function_call(self, function_call: FunctionCallDict, parameters: T_P) -> str:
function_output = self.run_function_call(function_call)
self.history.append(
Message(role='function', name=function_call['name'], content=json.dumps(function_output, ensure_ascii=False))
)
model_response = await self._chat_model.async_chat_completion(self.history, parameters)
self.history.extend(model_response.messages)
if isinstance(reply := model_response.messages[-1].content, str):
return reply
else:
return self._recursive_function_call(reply, parameters)

def reset(self) -> None:
self.history.clear()
23 changes: 23 additions & 0 deletions scripts/multimodel_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import asyncio

from lmclient import ChatEngine, MinimaxProChat, OpenAIChat, WenxinChat, ZhiPuChat

chat_models = {
'wenxin': WenxinChat(timeout=20),
'llama2-70b': WenxinChat(model='llama_2_70b', timeout=20),
'gpt4': OpenAIChat(model='gpt-4'),
'gpt3.5': OpenAIChat(model='gpt-3.5-turbo'),
'minimax': MinimaxProChat(),
'zhipu': ZhiPuChat(),
}
engines = {model_name: ChatEngine(chat_model=chat_model) for model_name, chat_model in chat_models.items()} # type: ignore


async def multimodel_chat(user_input: str):
reply_list = await asyncio.gather(*[engine.async_chat(user_input) for engine in engines.values()])
for model_name, reply in zip(engines.keys(), reply_list):
print(f'{model_name}: {reply}')


if __name__ == '__main__':
asyncio.run(multimodel_chat('你好'))

0 comments on commit 6d6a382

Please sign in to comment.