diff --git a/lmclient/chat_engine.py b/lmclient/chat_engine.py index cddd0ed..bd3d56a 100644 --- a/lmclient/chat_engine.py +++ b/lmclient/chat_engine.py @@ -71,6 +71,16 @@ def chat(self, user_input: str, **extra_parameters: Any) -> str: else: return self._recursive_function_call(reply, parameters) + async def async_chat(self, user_input: str, **extra_parameters: Any) -> str: + parameters = self._parameters.model_copy(update=extra_parameters) + self.history.append(Message(role='user', content=user_input)) + model_response = await self._chat_model.async_chat_completion(self.history, parameters) + self.history.extend(model_response.messages) + if isinstance(reply := model_response.messages[-1].content, str): + return reply + else: + return await self._async_recursive_function_call(reply, parameters) + def run_function_call(self, function_call: FunctionCallDict): function = None for i in self.functions: @@ -103,5 +113,17 @@ def _recursive_function_call(self, function_call: FunctionCallDict, parameters: else: return self._recursive_function_call(reply, parameters) + async def _async_recursive_function_call(self, function_call: FunctionCallDict, parameters: T_P) -> str: + function_output = self.run_function_call(function_call) + self.history.append( + Message(role='function', name=function_call['name'], content=json.dumps(function_output, ensure_ascii=False)) + ) + model_response = await self._chat_model.async_chat_completion(self.history, parameters) + self.history.extend(model_response.messages) + if isinstance(reply := model_response.messages[-1].content, str): + return reply + else: + return self._recursive_function_call(reply, parameters) + def reset(self) -> None: self.history.clear() diff --git a/scripts/multimodel_chat.py b/scripts/multimodel_chat.py new file mode 100644 index 0000000..55cd291 --- /dev/null +++ b/scripts/multimodel_chat.py @@ -0,0 +1,23 @@ +import asyncio + +from lmclient import ChatEngine, MinimaxProChat, OpenAIChat, WenxinChat, ZhiPuChat + +chat_models = { + 'wenxin': WenxinChat(timeout=20), + 'llama2-70b': WenxinChat(model='llama_2_70b', timeout=20), + 'gpt4': OpenAIChat(model='gpt-4'), + 'gpt3.5': OpenAIChat(model='gpt-3.5-turbo'), + 'minimax': MinimaxProChat(), + 'zhipu': ZhiPuChat(), +} +engines = {model_name: ChatEngine(chat_model=chat_model) for model_name, chat_model in chat_models.items()} # type: ignore + + +async def multimodel_chat(user_input: str): + reply_list = await asyncio.gather(*[engine.async_chat(user_input) for engine in engines.values()]) + for model_name, reply in zip(engines.keys(), reply_list): + print(f'{model_name}: {reply}') + + +if __name__ == '__main__': + asyncio.run(multimodel_chat('你好'))