Skip to content

Commit

Permalink
feat: make llamacppengine use HF chat templates
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Feb 14, 2025
1 parent eae979f commit 464bfe2
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 9 deletions.
11 changes: 10 additions & 1 deletion examples/4_engines_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,18 @@
engine = MistralToolCallParser(model)

# ========== llama.cpp ==========
# ---- Any Model (Chat Templates) ----
from kani.engines.huggingface import ChatTemplatePromptPipeline
from kani.engines.llamacpp import LlamaCppEngine
pipeline = ChatTemplatePromptPipeline.from_pretrained("org-id/base-model-id")
engine = LlamaCppEngine(repo_id="org-id/quant-model-id", filename="*.your-quant-type.gguf", prompt_pipeline=pipeline)

# ---- LLaMA v2 (llama.cpp) ----
from kani.engines.llamacpp import LlamaCppEngine
engine = LlamaCppEngine(repo_id="TheBloke/Llama-2-7B-Chat-GGUF", filename="*.Q4_K_M.gguf")
from kani.prompts.impl import LLAMA2_PIPELINE
engine = LlamaCppEngine(
repo_id="TheBloke/Llama-2-7B-Chat-GGUF", filename="*.Q4_K_M.gguf", prompt_pipeline=LLAMA2_PIPELINE
)

# ---- Mistral-7B (llama.cpp) ----
from kani.engines.llamacpp import LlamaCppEngine
Expand Down
1 change: 1 addition & 0 deletions kani/engines/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .base import HuggingEngine
from .chat_template_pipeline import ChatTemplatePromptPipeline
9 changes: 9 additions & 0 deletions kani/engines/huggingface/chat_template_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def __init__(self, tokenizer, steps=None):
self._padding_len_by_role: dict[ChatRole, int] = defaultdict(lambda: 0)
self._has_inferred_role_paddings = False

@classmethod
def from_pretrained(cls, model_id: str):
"""
Create a ChatTemplatePromptPipeline from a model ID.
Useful for applying a HF model's chat template to another engine.
"""
tok = transformers.AutoTokenizer.from_pretrained(model_id)
return cls(tok)

# ===== auto reserve inference =====
_chat_template_dummy_msg = {"role": "user", "content": "dummy"}

Expand Down
65 changes: 57 additions & 8 deletions kani/engines/llamacpp/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import warnings
from typing import AsyncIterable

from kani.ai_function import AIFunction
from kani.exceptions import MissingModelDependencies
from kani.models import ChatMessage
from kani.prompts.impl import LLAMA2_PIPELINE
from kani.prompts.pipeline import PromptPipeline
from ..base import BaseCompletion, BaseEngine, Completion

Expand All @@ -15,13 +16,13 @@
'The LlamaCppEngine requires extra dependencies. Please install kani with "pip install kani[cpp]". '
) from None

log = logging.getLogger(__name__)


class LlamaCppEngine(BaseEngine):
"""
This class implements the main decoding logic for any GGUF model (not just LLaMA as the name might suggest).
This engine defaults to LLaMA 2 Chat 7B with 4-bit quantization.
**GPU Support**
llama.cpp supports multiple acceleration backends, which may require different flags to be set during installation.
Expand All @@ -36,7 +37,7 @@ def __init__(
repo_id: str,
filename: str = None,
max_context_size: int = 0,
prompt_pipeline: PromptPipeline[str | list[int]] = LLAMA2_PIPELINE,
prompt_pipeline: PromptPipeline[str | list[int]] = None,
*,
model_load_kwargs: dict = None,
**hyperparams,
Expand All @@ -52,7 +53,6 @@ def __init__(
for more info.
:param hyperparams: Additional arguments to supply the model during generation.
"""

if model_load_kwargs is None:
model_load_kwargs = {}

Expand All @@ -64,6 +64,8 @@ def __init__(
self.model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **model_load_kwargs)
self.hyperparams = hyperparams

self.model.chat_format

self.max_context_size = max_context_size or self.model.n_ctx()

if self.token_reserve == 0 and self.pipeline:
Expand All @@ -81,14 +83,54 @@ def message_len(self, message: ChatMessage) -> int:
# default concrete base behaviour:
if self.pipeline is None:
raise NotImplementedError(
"You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class."
"You must pass a prompt_pipeline to the LlamaCppEngine to use it as a non-abstract class. If your model"
" uses a chat template (or is a quantization of a model with a chat template), you can use the"
" following:\n"
"from kani.engines.huggingface import ChatTemplatePromptPipeline\n"
"pipeline = ChatTemplatePromptPipeline.from_pretrained(base_model_id)\n"
"engine = LlamaCppEngine(..., prompt_pipeline=pipeline)"
)
prompt = self.pipeline.execute([message], for_measurement=True)
if isinstance(prompt, list):
return len(prompt)
elif isinstance(prompt, torch.Tensor):
return len(prompt[0])
tokenized = self.model.tokenize(prompt.encode(), add_bos=False, special=True)
return len(tokenized)

def function_token_reserve(self, functions: list[AIFunction]) -> int:
if not functions:
return 0
# default concrete base behaviour:
if self.pipeline is None:
raise NotImplementedError(
"You must pass a prompt_pipeline to the LlamaCppEngine to use it as a non-abstract class. If your model"
" uses a chat template (or is a quantization of a model with a chat template), you can use the"
" following:\n"
"from kani.engines.huggingface import ChatTemplatePromptPipeline\n"
"pipeline = ChatTemplatePromptPipeline.from_pretrained(base_model_id)\n"
"engine = LlamaCppEngine(..., prompt_pipeline=pipeline)"
)
prompt = self.pipeline.execute([], functions, for_measurement=True)
if isinstance(prompt, list):
return len(prompt)
elif isinstance(prompt, torch.Tensor):
toklen = len(prompt[0])
else:
# prompt str to tokens
tokenized = self.model.tokenize(prompt.encode(), add_bos=False, special=False)
toklen = len(tokenized)

# warn if there are functions but no tokens
if toklen == 0:
warnings.warn(
"Functions were given to the model, but the function prompt returned 0 tokens! This model may not"
" support function calling, or you may need to implement"
f" `{type(self).__name__}.function_token_reserve()`."
)

return toklen

def build_prompt(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None) -> str | list[int]:
"""
Given the list of messages from kani, build either a single string representing the prompt for the model,
Expand All @@ -98,9 +140,16 @@ def build_prompt(self, messages: list[ChatMessage], functions: list[AIFunction]
"""
if self.pipeline is None:
raise NotImplementedError(
"You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class."
"You must pass a prompt_pipeline to the LlamaCppEngine to use it as a non-abstract class. If your model"
" uses a chat template (or is a quantization of a model with a chat template), you can use the"
" following:\n"
"from kani.engines.huggingface import ChatTemplatePromptPipeline\n"
"pipeline = ChatTemplatePromptPipeline.from_pretrained(base_model_id)\n"
"engine = LlamaCppEngine(..., prompt_pipeline=pipeline)"
)
return self.pipeline(messages)
prompt = self.pipeline(messages, functions)
log.debug(f"BUILT PROMPT: {prompt}")
return prompt

def _get_generate_args(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams):
"""Internal method to build common params for the generate call"""
Expand Down
115 changes: 115 additions & 0 deletions sandbox/r1-quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Usage: python model_test_trains.py hf/model-id [tool_call_parser_name [prompt_pipeline_name]]
(This file isn't about training models - I just like Japanese trains.)
"""

import asyncio
import json
from typing import Annotated

import httpx

from kani import AIParam, ChatRole, Kani, ai_function, print_stream, print_width
from kani.engines.huggingface import ChatTemplatePromptPipeline
from kani.engines.llamacpp import LlamaCppEngine
from kani.utils.message_formatters import assistant_message_contents_thinking, assistant_message_thinking

pipeline = ChatTemplatePromptPipeline.from_pretrained("deepseek-ai/DeepSeek-R1")
engine = LlamaCppEngine(
repo_id="unsloth/DeepSeek-R1-GGUF", filename="DeepSeek-R1-GGUF/*UD-Q2_K_XL*.gguf", prompt_pipeline=pipeline
)


class WikipediaRetrievalKani(Kani):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wikipedia_client = httpx.AsyncClient(base_url="https://en.wikipedia.org/w/api.php", follow_redirects=True)

@ai_function()
async def wikipedia(
self,
title: Annotated[str, AIParam(desc='The article title on Wikipedia, e.g. "Train_station".')],
):
"""Get additional information about a topic from Wikipedia."""
# https://en.wikipedia.org/w/api.php?action=query&format=json&prop=extracts&titles=Train&explaintext=1&formatversion=2
resp = await self.wikipedia_client.get(
"/",
params={
"action": "query",
"format": "json",
"prop": "extracts",
"titles": title,
"explaintext": 1,
"formatversion": 2,
},
)
data = resp.json()
page = data["query"]["pages"][0]
if extract := page.get("extract"):
return extract
return f"The page {title!r} does not exist on Wikipedia."

@ai_function()
async def search(self, query: str):
"""Find titles of Wikipedia articles similar to the given query."""
# https://en.wikipedia.org/w/api.php?action=opensearch&format=json&search=Train
resp = await self.wikipedia_client.get("/", params={"action": "opensearch", "format": "json", "search": query})
return json.dumps(resp.json()[1])


async def stream_query(query: str):
async for stream in ai.full_round_stream(query):
# assistant
if stream.role == ChatRole.ASSISTANT:
await print_stream(stream, prefix="AI: ")
msg = await stream.message()
text = assistant_message_thinking(msg, show_args=True)
if text:
print_width(text, prefix="AI: ")
# function
elif stream.role == ChatRole.FUNCTION:
msg = await stream.message()
print_width(msg.text, prefix="FUNC: ")


async def print_query(query: str):
async for msg in ai.full_round(query):
# assistant
if msg.role == ChatRole.ASSISTANT:
text = assistant_message_contents_thinking(msg, show_args=True)
print_width(text, prefix="AI: ")
# function
elif msg.role == ChatRole.FUNCTION:
print_width(msg.text, prefix="FUNC: ")


async def main():
print(engine)
print("======== testing query simple ========")
await print_query("Tell me about the Yamanote line.")

print("======== testing query complex ========")
await print_query("How many subway lines does each station on the Yamanote line connect to?")

print("======== testing stream simple ========")
await stream_query("What are some of the weirdest trains in Japan?")

print("======== testing stream complex ========")
await stream_query(
"What is the fastest way from Oku-Tama to Noboribetsu? What is the cheapest way? Use JR lines only."
)


# basic system prompt since many models don't include their FC prompt in the chat template...
system_prompt = """\
You can use the following functions:
search(query: str) -- Searches for titles of Wikipedia articles.
wikipedia(title: Annotated[str, AIParam(desc='The article title on Wikipedia, e.g. "Train_station".')]) -- Gets the \
article text of a Wikipedia article given its title.
"""

ai = WikipediaRetrievalKani(engine, system_prompt=system_prompt)
if __name__ == "__main__":
asyncio.run(main())

0 comments on commit 464bfe2

Please sign in to comment.