-
Notifications
You must be signed in to change notification settings - Fork 16.3k
/
Copy pathrwkv.py
235 lines (184 loc) Β· 7.19 KB
/
rwkv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""RWKV models.
Based on https://github.com/saharNooby/rwkv.cpp/blob/master/rwkv/chat_with_bot.py
https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py
"""
from typing import Any, Dict, List, Mapping, Optional, Set
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.utils import pre_init
from pydantic import BaseModel, ConfigDict
from langchain_community.llms.utils import enforce_stop_tokens
class RWKV(LLM, BaseModel):
"""RWKV language models.
To use, you should have the ``rwkv`` python package installed, the
pre-trained model file, and the model's config information.
Example:
.. code-block:: python
from langchain_community.llms import RWKV
model = RWKV(model="./models/rwkv-3b-fp16.bin", strategy="cpu fp32")
# Simplest invocation
response = model.invoke("Once upon a time, ")
"""
model: str
"""Path to the pre-trained RWKV model file."""
tokens_path: str
"""Path to the RWKV tokens file."""
strategy: str = "cpu fp32"
"""Token context window."""
rwkv_verbose: bool = True
"""Print debug information."""
temperature: float = 1.0
"""The temperature to use for sampling."""
top_p: float = 0.5
"""The top-p value to use for sampling."""
penalty_alpha_frequency: float = 0.4
"""Positive values penalize new tokens based on their existing frequency
in the text so far, decreasing the model's likelihood to repeat the same
line verbatim.."""
penalty_alpha_presence: float = 0.4
"""Positive values penalize new tokens based on whether they appear
in the text so far, increasing the model's likelihood to talk about
new topics.."""
CHUNK_LEN: int = 256
"""Batch size for prompt processing."""
max_tokens_per_generation: int = 256
"""Maximum number of tokens to generate."""
client: Any = None #: :meta private:
tokenizer: Any = None #: :meta private:
pipeline: Any = None #: :meta private:
model_tokens: Any = None #: :meta private:
model_state: Any = None #: :meta private:
model_config = ConfigDict(
extra="forbid",
)
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"verbose": self.verbose,
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_alpha_frequency": self.penalty_alpha_frequency,
"penalty_alpha_presence": self.penalty_alpha_presence,
"CHUNK_LEN": self.CHUNK_LEN,
"max_tokens_per_generation": self.max_tokens_per_generation,
}
@staticmethod
def _rwkv_param_names() -> Set[str]:
"""Get the identifying parameters."""
return {
"verbose",
}
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in the environment."""
try:
import tokenizers
except ImportError:
raise ImportError(
"Could not import tokenizers python package. "
"Please install it with `pip install tokenizers`."
)
try:
from rwkv.model import RWKV as RWKVMODEL
from rwkv.utils import PIPELINE
values["tokenizer"] = tokenizers.Tokenizer.from_file(values["tokens_path"])
rwkv_keys = cls._rwkv_param_names()
model_kwargs = {k: v for k, v in values.items() if k in rwkv_keys}
model_kwargs["verbose"] = values["rwkv_verbose"]
values["client"] = RWKVMODEL(
values["model"], strategy=values["strategy"], **model_kwargs
)
values["pipeline"] = PIPELINE(values["client"], values["tokens_path"])
except ImportError:
raise ImportError(
"Could not import rwkv python package. "
"Please install it with `pip install rwkv`."
)
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
**self._default_params,
**{k: v for k, v in self.__dict__.items() if k in RWKV._rwkv_param_names()},
}
@property
def _llm_type(self) -> str:
"""Return the type of llm."""
return "rwkv"
def run_rnn(self, _tokens: List[str], newline_adj: int = 0) -> Any:
AVOID_REPEAT_TOKENS = []
AVOID_REPEAT = "οΌοΌοΌοΌ"
for i in AVOID_REPEAT:
dd = self.pipeline.encode(i)
assert len(dd) == 1
AVOID_REPEAT_TOKENS += dd
tokens = [int(x) for x in _tokens]
self.model_tokens += tokens
out: Any = None
while len(tokens) > 0:
out, self.model_state = self.client.forward(
tokens[: self.CHUNK_LEN], self.model_state
)
tokens = tokens[self.CHUNK_LEN :]
END_OF_LINE = 187
out[END_OF_LINE] += newline_adj # adjust \n probability
if self.model_tokens[-1] in AVOID_REPEAT_TOKENS:
out[self.model_tokens[-1]] = -999999999
return out
def rwkv_generate(self, prompt: str) -> str:
self.model_state = None
self.model_tokens = []
logits = self.run_rnn(self.tokenizer.encode(prompt).ids)
begin = len(self.model_tokens)
out_last = begin
occurrence: Dict = {}
decoded = ""
for i in range(self.max_tokens_per_generation):
for n in occurrence:
logits[n] -= (
self.penalty_alpha_presence
+ occurrence[n] * self.penalty_alpha_frequency
)
token = self.pipeline.sample_logits(
logits, temperature=self.temperature, top_p=self.top_p
)
END_OF_TEXT = 0
if token == END_OF_TEXT:
break
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
logits = self.run_rnn([token])
xxx = self.tokenizer.decode(self.model_tokens[out_last:])
if "\ufffd" not in xxx: # avoid utf-8 display issues
decoded += xxx
out_last = begin + i + 1
if i >= self.max_tokens_per_generation - 100:
break
return decoded
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""RWKV generation
Args:
prompt: The prompt to pass into the model.
stop: A list of strings to stop generation when encountered.
Returns:
The string generated by the model.
Example:
.. code-block:: python
prompt = "Once upon a time, "
response = model.invoke(prompt, n_predict=55)
"""
text = self.rwkv_generate(prompt)
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text