forked from zilliztech/akcio
-
Notifications
You must be signed in to change notification settings - Fork 2
/
dolly_chat.py
69 lines (52 loc) · 2.41 KB
/
dolly_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import sys
import os
from typing import Mapping, Any, List, Optional
import torch
from transformers import pipeline
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatResult, HumanMessage, AIMessage, ChatGeneration
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from config import CHAT_CONFIG # pylint: disable=C0413
CHAT_CONFIG = CHAT_CONFIG['dolly']
llm_kwargs = CHAT_CONFIG.get('llm_kwargs', {})
class ChatLLM(BaseChatModel):
'''Chat with LLM given context. Must be a LangChain BaseLanguageModel to adapt agent.'''
model_name: str = CHAT_CONFIG['dolly_model']
device: str = llm_kwargs.get('device', 'auto')
generate_text = pipeline(
model=model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map=device)
def _generate(self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None, # pylint: disable=W0613
run_manager: Optional[Any] = None, # pylint: disable=W0613
) -> ChatResult:
prompt = self._create_prompt(messages)
resp = self.generate_text(prompt)
return self._create_chat_result(resp)
async def _agenerate(self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None, # pylint: disable=W0613
run_manager: Optional[Any] = None, # pylint: disable=W0613
) -> ChatResult:
prompt = self._create_prompt(messages)
resp = self.generate_text(prompt)
return self._create_chat_result(resp)
def _create_prompt(self, messages: List[BaseMessage]):
if isinstance(messages[-1], HumanMessage):
prompt = messages[-1].content
else:
raise AttributeError('Unsupported message for Dolly.')
return prompt
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
message = AIMessage(content=response[0]['generated_text'])
gen = ChatGeneration(message=message)
generations = [gen]
return ChatResult(generations=generations)
@property
def _llm_type(self) -> str:
return 'dolly'
# if __name__ == '__main__':
# chat = ChatLLM()
# messages = [HumanMessage(content='Translate this sentence from English to French. I love programming.')]
# ans = chat(messages)
# print(type(ans), ans)