Skip to content

Commit

Permalink
update ernie llm with eb_bot_sdk (#84)
Browse files Browse the repository at this point in the history
Signed-off-by: ChengZi <[email protected]>
  • Loading branch information
zc277584121 authored Oct 23, 2023
1 parent a7a4a21 commit 0f4893c
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 101 deletions.
1 change: 1 addition & 0 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
pip install coverage
pip install pytest
pip install -r requirements.txt
pip install -r test_requirements.txt
- name: Install test dependency
shell: bash
working-directory: tests
Expand Down
5 changes: 3 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
}
},
'ernie': {
'ernie_api_key': None, # If None, use environment value 'ERNIE_API_KEY'
'ernie_secret_key': None, # If None, use environment value 'ERNIE_SECRET_KEY'
'ernie_model': 'ernie-bot-turbo', # 'ernie-bot' or 'ernie-bot-turbo'
'eb_api_type': None, # If None, use environment value 'EB_API_TYPE'
'eb_access_token': None, # If None, use environment value 'EB_ACCESS_TOKEN'
'llm_kwargs': {}
},
'minimax': {
Expand Down
116 changes: 30 additions & 86 deletions src_langchain/llm/ernie.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,52 @@
from config import CHAT_CONFIG # pylint: disable=C0413
from typing import Mapping, Any, List, Optional, Tuple, Dict
import requests
import json
from typing import Any, List, Dict
import os
import sys

from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatResult, HumanMessage, AIMessage, SystemMessage, ChatMessage, ChatGeneration
from langchain.schema import BaseMessage, ChatResult, HumanMessage, AIMessage, SystemMessage, ChatMessage, \
ChatGeneration

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))


CHAT_CONFIG = CHAT_CONFIG['ernie']
llm_kwargs = CHAT_CONFIG.get('llm_kwargs', {})


class ChatLLM(BaseChatModel):
'''Chat with LLM given context. Must be a LangChain BaseLanguageModel to adapt agent.'''
api_key: str = CHAT_CONFIG['ernie_api_key']
secret_key: str = CHAT_CONFIG['ernie_secret_key']
temperature: float = llm_kwargs.get('temperature', 0)
max_tokens: Optional[int] = llm_kwargs.get('max_tokens', None)
n: int = llm_kwargs.get('n', 1)

def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params['messages'] = message_dicts
payload = json.dumps(params)
headers = {
'Content-Type': 'application/json'
}

url = self._create_url()
response = requests.request(
'POST', url, headers=headers, data=payload)
model_name: str = CHAT_CONFIG['ernie_model']
eb_api_type: str = CHAT_CONFIG['eb_api_type'] or os.getenv('EB_API_TYPE')
eb_access_token: str = CHAT_CONFIG['eb_access_token'] or os.getenv('EB_ACCESS_TOKEN')
llm_kwargs: dict = llm_kwargs

def _generate(self, messages: List[BaseMessage]) -> ChatResult:
import erniebot # pylint: disable=C0415
erniebot.api_type = self.eb_api_type
erniebot.access_token = self.eb_access_token

message_dicts = self._create_message_dicts(messages)
response = erniebot.ChatCompletion.create(
model=self.model_name,
messages=message_dicts,
**self.llm_kwargs
)
return self._create_chat_result(response)

async def _agenerate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params['messages'] = message_dicts
payload = json.dumps(params)
headers = {
'Content-Type': 'application/json'
}
url = self._create_url()
response = requests.request(
'POST', url, headers=headers, data=payload)
return self._create_chat_result(response)

def _create_url(self):
access_token = self._get_access_token(
api_key=self.api_key, secret_key=self.secret_key)
url = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=' \
+ access_token
return url

def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params: Dict[str, Any] = {**self._default_params}
if stop is not None:
if 'stop' in params:
raise ValueError(
'`stop` found in both the input and default params.')
params['stop'] = stop
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:

message_dicts = []
for m in messages:
message_dicts.append(self._convert_message_to_dict(m))
if isinstance(m, SystemMessage):
message_dicts.append(
{'role': 'assistant', 'content': 'OK.'}
)
return message_dicts, params
m_dict = self._convert_message_to_dict(m)
if m_dict:
message_dicts.append(self._convert_message_to_dict(m))
return message_dicts

def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
def _create_chat_result(self, response: 'EBResponse') -> ChatResult:
generations = []
response = response.json()
response = response.to_dict()
if 'result' not in response:
raise RuntimeError(response)
message = self._convert_dict_to_message(
Expand All @@ -88,14 +59,13 @@ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:

@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {}
if isinstance(message, ChatMessage):
message_dict = {'role': message.role, 'content': message.content}
elif isinstance(message, (HumanMessage, SystemMessage)):
elif isinstance(message, HumanMessage):
message_dict = {'role': 'user', 'content': message.content}
elif isinstance(message, AIMessage):
message_dict = {'role': 'assistant', 'content': message.content}
else:
raise ValueError(f'Got unknown type {message}')
if 'name' in message.additional_kwargs:
message_dict['name'] = message.additional_kwargs['name']
return message_dict
Expand All @@ -112,32 +82,6 @@ def _convert_dict_to_message(_dict: dict) -> BaseMessage: # pylint: disable=C01
else:
return ChatMessage(content=_dict['content'], role=role)

@staticmethod
def _get_access_token(api_key, secret_key):
url = 'https://aip.baidubce.com/oauth/2.0/token'
params = {
'grant_type': 'client_credentials',
'client_id': api_key,
'client_secret': secret_key
}
return str(requests.post(url, params=params).json().get('access_token'))

@property
def _default_params(self) -> Dict[str, Any]:
'''Get the default parameters for calling OpenAI API.'''
return {
'max_tokens': self.max_tokens,
'n': self.n,
'temperature': self.temperature,
}

@property
def _llm_type(self) -> str:
return 'ernie'


# if __name__ == '__main__':
# chat = ChatLLM()
# messages = [HumanMessage(content='Translate this sentence from English to French. I love programming.')]
# ans = chat(messages)
# print(type(ans), ans)
15 changes: 15 additions & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
langchain==0.0.230
unstructured
pexpect
pdf2image
SQLAlchemy>=2.0.15
psycopg2-binary
openai
gradio>=3.30.0
fastapi
uvicorn
towhee>=1.1.0
pymilvus
elasticsearch>=8.0.0
prometheus-client
erniebot
42 changes: 29 additions & 13 deletions tests/unit_tests/src_langchain/llm/test_ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,44 @@
import sys
import unittest
from unittest.mock import patch

from langchain.schema import HumanMessage
from requests import Response
from langchain.schema import HumanMessage, AIMessage

sys.path.append(os.path.join(os.path.dirname(__file__), '../../../../..'))


class TestERNIE(unittest.TestCase):
def test_generate(self):
with patch('requests.post') as mock_post, patch('requests.request') as mock_request:
mock_res1 = Response()
mock_res1._content = b'{ "access_token" : "mock_token" }'
mock_res2 = Response()
mock_res2._content = b'{ "result" : "mock answer", "usage" : 2 }'
mock_post.return_value = mock_res1
mock_request.return_value = mock_res2
from erniebot.response import EBResponse
with patch('erniebot.ChatCompletion.create') as mock_post:
mock_res = EBResponse(code=200,
body={'id': 'as-0000000000', 'object': 'chat.completion', 'created': 11111111,
'result': 'OK, this is a mock answer.',
'usage': {'prompt_tokens': 1, 'completion_tokens': 13, 'total_tokens': 14},
'need_clear_history': False, 'is_truncated': False},
headers={'Connection': 'keep-alive',
'Content-Security-Policy': 'frame-ancestors https://*.baidu.com/',
'Content-Type': 'application/json', 'Date': 'Mon, 23 Oct 2023 03:30:53 GMT',
'Server': 'nginx', 'Statement': 'AI-generated',
'Vary': 'Origin, Access-Control-Request-Method, Access-Control-Request-Headers',
'X-Frame-Options': 'allow-from https://*.baidu.com/',
'X-Request-Id': '0' * 32,
'Transfer-Encoding': 'chunked'}
)
mock_post.return_value = mock_res

from src_langchain.llm.ernie import ChatLLM

chat_llm = ChatLLM(api_key='mock-key', secret_key='mock-key')
messages = [HumanMessage(content='hello')]
EB_API_TYPE = 'mock_type'
EB_ACCESS_TOKEN = 'mock_token'

chat_llm = ChatLLM(eb_api_type=EB_API_TYPE, eb_access_token=EB_ACCESS_TOKEN)
messages = [
HumanMessage(content='hello'),
AIMessage(content='hello, can I help you?'),
HumanMessage(content='Please give me a mock answer.'),
]
res = chat_llm._generate(messages)
self.assertEqual(res.generations[0].text, 'mock answer')
self.assertEqual(res.generations[0].text, 'OK, this is a mock answer.')


if __name__ == '__main__':
Expand Down

0 comments on commit 0f4893c

Please sign in to comment.