Skip to content

Commit

Permalink
Fix llama_cpp and Llama type signatures. Closes ggml-org#221
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed May 19, 2023
1 parent fb57b94 commit 01a010b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 64 deletions.
76 changes: 35 additions & 41 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ class LlamaCache:
"""Cache for a llama.cpp model."""

def __init__(self, capacity_bytes: int = (2 << 30)):
self.cache_state: OrderedDict[
Tuple[llama_cpp.llama_token, ...], "LlamaState"
] = OrderedDict()
self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict()
self.capacity_bytes = capacity_bytes

@property
Expand All @@ -26,8 +24,8 @@ def cache_size(self):

def _find_longest_prefix_key(
self,
key: Tuple[llama_cpp.llama_token, ...],
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
min_len = 0
min_key = None
keys = (
Expand All @@ -39,7 +37,7 @@ def _find_longest_prefix_key(
min_key = k
return min_key

def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
Expand All @@ -48,10 +46,10 @@ def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
self.cache_state.move_to_end(_key)
return value

def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
def __contains__(self, key: Sequence[int]) -> bool:
return self._find_longest_prefix_key(tuple(key)) is not None

def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
def __setitem__(self, key: Sequence[int], value: "LlamaState"):
key = tuple(key)
if key in self.cache_state:
del self.cache_state[key]
Expand All @@ -63,7 +61,7 @@ def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState")
class LlamaState:
def __init__(
self,
eval_tokens: Deque[llama_cpp.llama_token],
eval_tokens: Deque[int],
eval_logits: Deque[List[float]],
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
llama_state_size: int,
Expand Down Expand Up @@ -141,7 +139,7 @@ def __init__(

self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch)
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.eval_tokens: Deque[int] = deque(maxlen=n_ctx)
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)

self.cache: Optional[LlamaCache] = None
Expand Down Expand Up @@ -176,9 +174,7 @@ def __init__(
if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)

def tokenize(
self, text: bytes, add_bos: bool = True
) -> List[llama_cpp.llama_token]:
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
"""Tokenize a string.
Args:
Expand All @@ -197,7 +193,7 @@ def tokenize(
self.ctx,
text,
tokens,
n_ctx,
llama_cpp.c_int(n_ctx),
llama_cpp.c_bool(add_bos),
)
if int(n_tokens) < 0:
Expand All @@ -216,7 +212,7 @@ def tokenize(
)
return list(tokens[:n_tokens])

def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
def detokenize(self, tokens: List[int]) -> bytes:
"""Detokenize a list of tokens.
Args:
Expand All @@ -228,7 +224,9 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
assert self.ctx is not None
output = b""
for token in tokens:
output += llama_cpp.llama_token_to_str(self.ctx, token)
output += llama_cpp.llama_token_to_str(
self.ctx, llama_cpp.llama_token(token)
)
return output

def set_cache(self, cache: Optional[LlamaCache]):
Expand All @@ -244,7 +242,7 @@ def reset(self):
self.eval_tokens.clear()
self.eval_logits.clear()

def eval(self, tokens: Sequence[llama_cpp.llama_token]):
def eval(self, tokens: Sequence[int]):
"""Evaluate a list of tokens.
Args:
Expand Down Expand Up @@ -458,7 +456,7 @@ def sample(

def generate(
self,
tokens: Sequence[llama_cpp.llama_token],
tokens: Sequence[int],
top_k: int = 40,
top_p: float = 0.95,
temp: float = 0.80,
Expand All @@ -470,9 +468,7 @@ def generate(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Generator[
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
]:
) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt.
Examples:
Expand Down Expand Up @@ -617,14 +613,14 @@ def _create_completion(
assert self.ctx is not None
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
completion_tokens: List[llama_cpp.llama_token] = []
completion_tokens: List[int] = []
# Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
b" " + prompt.encode("utf-8")
)
prompt_tokens: List[int] = self.tokenize(b" " + prompt.encode("utf-8"))
text: bytes = b""
returned_tokens: int = 0
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
stop = (
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
)
model_name: str = model if model is not None else self.model_path

if self.verbose:
Expand Down Expand Up @@ -724,7 +720,9 @@ def _create_completion(
for token in remaining_tokens:
token_end_position += len(self.detokenize([token]))
# Check if stop sequence is in the token
if token_end_position >= (remaining_length - first_stop_position - 1):
if token_end_position >= (
remaining_length - first_stop_position - 1
):
break
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
Expand All @@ -744,7 +742,7 @@ def _create_completion(
)
)
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode(
self.detokenize([i]).decode(
"utf-8", errors="ignore"
): logprob
for logprob, i in sorted_logprobs[:logprobs]
Expand Down Expand Up @@ -822,9 +820,7 @@ def _create_completion(
)
)
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode(
"utf-8", errors="ignore"
): logprob
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: current_logprobs[int(token)]})
Expand Down Expand Up @@ -924,9 +920,7 @@ def _create_completion(
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob: Optional[Dict[str, float]] = {
self.detokenize([llama_cpp.llama_token(i)]).decode(
"utf-8", errors="ignore"
): logprob
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: logprobs_token[int(token)]})
Expand Down Expand Up @@ -1188,7 +1182,9 @@ def create_chat_completion(
Returns:
Generated chat completion or a stream of chat completion chunks.
"""
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
stop = (
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
)
chat_history = "".join(
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
for message in messages
Expand Down Expand Up @@ -1296,17 +1292,17 @@ def load_state(self, state: LlamaState) -> None:
raise RuntimeError("Failed to set llama state data")

@staticmethod
def token_eos() -> llama_cpp.llama_token:
def token_eos() -> int:
"""Return the end-of-sequence token."""
return llama_cpp.llama_token_eos()

@staticmethod
def token_bos() -> llama_cpp.llama_token:
def token_bos() -> int:
"""Return the beginning-of-sequence token."""
return llama_cpp.llama_token_bos()

@staticmethod
def token_nl() -> llama_cpp.llama_token:
def token_nl() -> int:
"""Return the newline token."""
return llama_cpp.llama_token_nl()

Expand All @@ -1317,9 +1313,7 @@ def logits_to_logprobs(logits: List[float]) -> List[float]:
return [math.log(x / sum_exps) for x in exps]

@staticmethod
def longest_token_prefix(
a: Sequence[llama_cpp.llama_token], b: Sequence[llama_cpp.llama_token]
):
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
longest_prefix = 0
for _a, _b in zip(a, b):
if _a == _b:
Expand Down
Loading

0 comments on commit 01a010b

Please sign in to comment.