forked from zilliztech/akcio
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ernie.py
143 lines (124 loc) · 5.46 KB
/
ernie.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from config import CHAT_CONFIG # pylint: disable=C0413
from typing import Mapping, Any, List, Optional, Tuple, Dict
import requests
import json
import os
import sys
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatResult, HumanMessage, AIMessage, SystemMessage, ChatMessage, ChatGeneration
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
CHAT_CONFIG = CHAT_CONFIG['ernie']
llm_kwargs = CHAT_CONFIG.get('llm_kwargs', {})
class ChatLLM(BaseChatModel):
'''Chat with LLM given context. Must be a LangChain BaseLanguageModel to adapt agent.'''
api_key: str = CHAT_CONFIG['ernie_api_key']
secret_key: str = CHAT_CONFIG['ernie_secret_key']
temperature: float = llm_kwargs.get('temperature', 0)
max_tokens: Optional[int] = llm_kwargs.get('max_tokens', None)
n: int = llm_kwargs.get('n', 1)
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params['messages'] = message_dicts
payload = json.dumps(params)
headers = {
'Content-Type': 'application/json'
}
url = self._create_url()
response = requests.request(
'POST', url, headers=headers, data=payload)
return self._create_chat_result(response)
async def _agenerate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params['messages'] = message_dicts
payload = json.dumps(params)
headers = {
'Content-Type': 'application/json'
}
url = self._create_url()
response = requests.request(
'POST', url, headers=headers, data=payload)
return self._create_chat_result(response)
def _create_url(self):
access_token = self._get_access_token(
api_key=self.api_key, secret_key=self.secret_key)
url = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=' \
+ access_token
return url
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params: Dict[str, Any] = {**self._default_params}
if stop is not None:
if 'stop' in params:
raise ValueError(
'`stop` found in both the input and default params.')
params['stop'] = stop
message_dicts = []
for m in messages:
message_dicts.append(self._convert_message_to_dict(m))
if isinstance(m, SystemMessage):
message_dicts.append(
{'role': 'assistant', 'content': 'OK.'}
)
return message_dicts, params
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
response = response.json()
if 'result' not in response:
raise RuntimeError(response)
message = self._convert_dict_to_message(
{'role': 'assistant', 'content': response['result']})
gen = ChatGeneration(message=message)
generations.append(gen)
llm_output = {
'token_usage': response['usage'], 'model_name': 'ernie'}
return ChatResult(generations=generations, llm_output=llm_output)
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {'role': message.role, 'content': message.content}
elif isinstance(message, (HumanMessage, SystemMessage)):
message_dict = {'role': 'user', 'content': message.content}
elif isinstance(message, AIMessage):
message_dict = {'role': 'assistant', 'content': message.content}
else:
raise ValueError(f'Got unknown type {message}')
if 'name' in message.additional_kwargs:
message_dict['name'] = message.additional_kwargs['name']
return message_dict
@staticmethod
def _convert_dict_to_message(_dict: dict) -> BaseMessage: # pylint: disable=C0103
role = _dict['role']
if role == 'user':
return HumanMessage(content=_dict['content'])
elif role == 'assistant':
return AIMessage(content=_dict['content'])
elif role == 'system':
return SystemMessage(content=_dict['content'])
else:
return ChatMessage(content=_dict['content'], role=role)
@staticmethod
def _get_access_token(api_key, secret_key):
url = 'https://aip.baidubce.com/oauth/2.0/token'
params = {
'grant_type': 'client_credentials',
'client_id': api_key,
'client_secret': secret_key
}
return str(requests.post(url, params=params).json().get('access_token'))
@property
def _default_params(self) -> Dict[str, Any]:
'''Get the default parameters for calling OpenAI API.'''
return {
'max_tokens': self.max_tokens,
'n': self.n,
'temperature': self.temperature,
}
@property
def _llm_type(self) -> str:
return 'ernie'
# if __name__ == '__main__':
# chat = ChatLLM()
# messages = [HumanMessage(content='Translate this sentence from English to French. I love programming.')]
# ans = chat(messages)
# print(type(ans), ans)