Skip to content

Commit

Permalink
hack(huggingface): hack to count tokens of prompt starting with assis…
Browse files Browse the repository at this point in the history
…tant
  • Loading branch information
zhudotexe committed Feb 3, 2024
1 parent 649e2f2 commit 1957501
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion kani/engines/huggingface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,16 @@ def build_prompt(
"""
_ensure_chat_template(self.tokenizer)
conversation = [{"role": msg.role.value, "content": msg.text} for msg in messages]
return self.tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
try:
return self.tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
except TemplateError:
# the template probably enforces user/assistant,
# HACK: let's try a dummy user message then the assistant one, and strip the len of the dummy off (pain)
dummy_conversation = [{"role": "user", "content": "a"}]
dummy_len = len(self.tokenizer.apply_chat_template(dummy_conversation, add_generation_prompt=False))
dummy_conversation.extend(conversation)
toks = self.tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
return toks[dummy_len:]

async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
Expand Down

0 comments on commit 1957501

Please sign in to comment.