Skip to content

Commit

Permalink
add examples
Browse files Browse the repository at this point in the history
  • Loading branch information
biobootloader committed Mar 24, 2024
1 parent 1bb7b38 commit 3028279
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 40 deletions.
81 changes: 55 additions & 26 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,82 @@
import os
import sys

import fire

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from spice import Spice
from spice.utils import fuzzy_model_lookup


def display_stats(response):
input_tokens = response.input_tokens
output_tokens = response.output_tokens
total_time = response.total_time
async def basic_example():
client = Spice()

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]

response = await client.call_llm(messages=messages, model="gpt-4-0125-preview")

print(response.text)

print(f"\n\nlogged: {input_tokens} input tokens, {output_tokens} output tokens, {total_time:.2f}s\n\n")

async def streaming_example():
# you can set a default model for the client instead of passing it with each call
client = Spice(model="claude-3-opus-20240229")

async def run(model="", stream=False):
model = fuzzy_model_lookup(model)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]

client = Spice(model)
response = await client.call_llm(messages=messages, stream=True)

response = await client.call_llm(messages=messages, stream=stream, logging_callback=display_stats)
async for text in response.stream():
print(text, end="", flush=True)

print(">>>>>>>>>>>>>")
if stream:
async for t in response.stream():
print(t, end="", flush=True)
else:
print(response.text)
print("\n<<<<<<<<<<<<<")
# response always includes the final text, no need build it from the stream yourself
print(response.text)

# response also includes helpful stats
print(f"Took {response.total_time:.2f}s")
print(f"Time to first token: {response.time_to_first_token:.2f}s")
print(f"Input/Output tokens: {response.input_tokens}/{response.output_tokens}")

if stream:
print(f"Time to first token: {response.time_to_first_token:.2f}s")

print(f"Input/Output tokens: {response.input_tokens}/{response.output_tokens}")
print(f"Characters per second: {response.characters_per_second:.2f}")
async def multiple_providers_example():
# alias models for easy configuration, even mixing providers
model_aliases = {
"task1_model": {"model": "gpt-4-0125-preview"},
"task2_model": {"model": "claude-3-opus-20240229"},
"task3_model": {"model": "claude-3-haiku-20240307"},
}

client = Spice(model_aliases=model_aliases)

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "list 5 random words"},
]

responses = await asyncio.gather(
client.call_llm(messages=messages, model="task1_model"),
client.call_llm(messages=messages, model="task2_model"),
client.call_llm(messages=messages, model="task3_model"),
)

for i, response in enumerate(responses, 1):
print(f"\nModel {i} response:")
print(response.text)
print(f"Characters per second: {response.characters_per_second:.2f}")


def run_async(model="", stream=False):
asyncio.run(run(model, stream))
async def run_all_examples():
print("Running basic example:")
await basic_example()
print("\n\nRunning streaming example:")
await streaming_example()
print("\n\nRunning multiple providers example:")
await multiple_providers_example()


if __name__ == "__main__":
fire.Fire(run_async)
asyncio.run(run_all_examples())
25 changes: 11 additions & 14 deletions spice/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,19 @@ def characters_per_second(self):


class Spice:
def __init__(self, default_model=None, default_provider=None, model_aliases=None):
self._default_model = default_model
def __init__(self, model=None, provider=None, model_aliases=None):
self._default_model = model

if default_model is not None:
if model is not None:
if model_aliases is not None:
raise SpiceError("model_aliases not supported when default_model is set")
raise SpiceError("model_aliases not supported when model is set")
self._model_aliases = None
if default_provider is None:
default_provider = _get_provider_from_model_name(default_model)
self._default_client = _get_client(default_provider)
if provider is None:
provider = _get_provider_from_model_name(model)
self._default_client = _get_client(provider)
else:
if default_provider is not None:
self._default_client = _get_client(default_provider)
if provider is not None:
self._default_client = _get_client(provider)
else:
self._default_client = None
self._clients = _get_clients_from_env()
Expand All @@ -97,7 +97,7 @@ def __init__(self, default_model=None, default_provider=None, model_aliases=None
if model_aliases is not None:
_validate_model_aliases(
self._model_aliases,
self._clients if self._default_client is None else {default_provider: self._default_client},
self._clients if self._default_client is None else {provider: self._default_client},
)

async def call_llm(
Expand All @@ -112,11 +112,8 @@ async def call_llm(
):
if model is None:
if self._default_model is None:
raise SpiceError("model argument is required when default_model is not set")
raise SpiceError("model argument is required when default model is not set at initialization")
model = self._default_model
else:
if self._default_model is not None:
raise SpiceError("model argument cannot be used when default_model is set")

if self._model_aliases is not None:
if model in self._model_aliases:
Expand Down

0 comments on commit 3028279

Please sign in to comment.