Skip to content

Commit

Permalink
chore: upload sandbox
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Oct 1, 2024
1 parent 416bd9d commit 17ff9d8
Show file tree
Hide file tree
Showing 8 changed files with 685 additions and 2 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,3 @@ cython_debug/


**.DS_Store
sandbox/
3 changes: 2 additions & 1 deletion sandbox/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# sandbox

This directory contains experimental little scripts that are not part of the main library.
This directory contains experimental little scripts that are not part of the main library. Don't rely on these
for anything - they're just for testing.
48 changes: 48 additions & 0 deletions sandbox/llama31-fc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json
from typing import Annotated

import pydantic

from kani import AIParam, FunctionCall, Kani, ToolCall, ai_function, chat_in_terminal
from kani.engines import WrapperEngine
from kani.engines.huggingface import HuggingEngine


class NaiveLlamaJSONFunctionCallingEngine(WrapperEngine):
async def predict(self, *args, **kwargs):
completion = await self.engine.predict(*args, **kwargs)
# if the completion is only JSON, try parsing it as a function call
try:
data = json.loads(completion.message.text)
function_call = LlamaFunctionCall.model_validate(data)
except (json.JSONDecodeError, pydantic.ValidationError):
return completion
else:
tc = ToolCall.from_function_call(FunctionCall.with_args(function_call.name, **function_call.parameters))
completion.message.content = None
completion.message.tool_calls = [tc]
return completion


class LlamaFunctionCall(pydantic.BaseModel):
name: str
parameters: dict


class MyKani(Kani):
@ai_function()
def get_weather(
self,
location: Annotated[str, AIParam(desc="The city and state, e.g. San Francisco, CA")],
unit: Annotated[str, AIParam(desc="'f' or 'c'")],
):
"""Get the current weather in a given location."""
# call some weather API, or just mock it for this example
degrees = 72 if unit == "f" else 22
return f"Weather in {location}: Sunny, {degrees} degrees {unit}."


if __name__ == "__main__":
engine = NaiveLlamaJSONFunctionCallingEngine(HuggingEngine(model_id="meta-llama/Llama-3.1-70B-Instruct"))
ai = MyKani(engine)
chat_in_terminal(ai, verbose=True, stream=False)
120 changes: 120 additions & 0 deletions sandbox/mistral-tokenizer.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
{
"cells": [
{
"metadata": {},
"cell_type": "code",
"source": [
"from mistral_common.protocol.instruct.messages import AssistantMessage, ToolMessage, UserMessage\n",
"from mistral_common.protocol.instruct.tool_calls import Function, FunctionCall, Tool, ToolCall\n",
"from mistral_common.tokens.instruct.normalize import ChatCompletionRequest\n",
"from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",
"\n",
"mistral_tokenizer = MistralTokenizer.v3()\n",
"\n",
"completion_request = ChatCompletionRequest(\n",
" tools=[\n",
" Tool(\n",
" function=Function(\n",
" name=\"get_current_weather\",\n",
" description=\"Get the current weather\",\n",
" parameters={\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
" },\n",
" \"format\": {\n",
" \"type\": \"string\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" \"description\": \"The temperature unit to use. Infer this from the users location.\",\n",
" },\n",
" },\n",
" \"required\": [\"location\", \"format\"],\n",
" },\n",
" )\n",
" )\n",
" ],\n",
" messages=[\n",
" UserMessage(content=\"What's the weather like today in Paris?\"),\n",
" AssistantMessage(\n",
" tool_calls=[\n",
" ToolCall(\n",
" id=\"1bdc45f90\",\n",
" function=FunctionCall(name=\"get_weather\", arguments='{\"location\": \"Tokyo, JP\", \"unit\": \"celsius\"}'),\n",
" ),\n",
" ]\n",
" ),\n",
" ToolMessage(\n",
" content=\"Weather in Tokyo, JP: Partly cloudy, 21 degrees celsius.\",\n",
" tool_call_id=\"1bdc45f90\",\n",
" ),\n",
" AssistantMessage(content=\"It's partly cloudy and 21 degrees in Tokyo.\"),\n",
" UserMessage(content=\"What's the weather like today in Paris?\"),\n",
" AssistantMessage(content=\"It's partly cloudy and 21 degrees in Tokyo.\"),\n",
" UserMessage(content=\"What's the weather like today in Paris?\"),\n",
" ],\n",
")\n",
"\n",
"tokenized = mistral_tokenizer.encode_chat_completion(completion_request)\n",
"\n",
"print(tokenized.text)\n"
],
"id": "ed0cf3012c98f3bf",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"hf_tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-Instruct-v0.3\")"
],
"id": "3f3fcf1d6cd725ca",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"test = \"\"\"<s>[INST] What's the weather in Tokyo?[/INST] [TOOL_CALLS] [{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Tokyo, JP\", \"unit\": \"celsius\"}, \"id\": \"dccf5329c\"}]</s>[TOOL_RESULTS] {\"call_id\": \"dccf5329c\", \"content\": \"Weather in Tokyo, JP: Partly cloudy, 21 degrees celsius.\"}[/TOOL_RESULTS] It's partly cloudy and 21 degrees in Tokyo.</s>\"\"\"\n",
"hf_tokenizer(test)"
],
"id": "188f275f9efe9f90",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"id": "695c5eafeff5fce9",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
32 changes: 32 additions & 0 deletions sandbox/mistral03-fc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import enum
from typing import Annotated

from kani import AIParam, Kani, ai_function, chat_in_terminal
from kani.engines.huggingface import HuggingEngine
from kani.prompts.impl.mistral import MistralFunctionCallingAdapter

model = HuggingEngine(model_id="mistralai/Mistral-7B-Instruct-v0.3")
engine = MistralFunctionCallingAdapter(model)


class Unit(enum.Enum):
FAHRENHEIT = "fahrenheit"
CELSIUS = "celsius"


class MyKani(Kani):
@ai_function()
def get_weather(
self,
location: Annotated[str, AIParam(desc="The city and state, e.g. San Francisco, CA")],
unit: Unit,
):
"""Get the current weather in a given location."""
# call some weather API, or just mock it for this example
degrees = 72 if unit == Unit.FAHRENHEIT else 22
return f"Weather in {location}: Sunny, {degrees} degrees {unit.value}."


ai = MyKani(engine)
if __name__ == "__main__":
chat_in_terminal(ai, verbose=True, stream=False)
18 changes: 18 additions & 0 deletions sandbox/openelm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from kani import Kani, chat_in_terminal
from kani.engines.huggingface import HuggingEngine
from kani.prompts.impl import LLAMA2_PIPELINE

model_id = "apple/OpenELM-270M-Instruct"


engine = HuggingEngine(
model_id=model_id,
prompt_pipeline=LLAMA2_PIPELINE,
tokenizer_kwargs=dict(trust_remote_code=True),
model_load_kwargs=dict(trust_remote_code=True),
)

ai = Kani(engine)

if __name__ == "__main__":
chat_in_terminal(ai)
Loading

0 comments on commit 17ff9d8

Please sign in to comment.