From 3028279549d5ba2cbde3b3ade1444298bacb1a2d Mon Sep 17 00:00:00 2001 From: BioBootloader Date: Sat, 23 Mar 2024 17:50:21 -0700 Subject: [PATCH] add examples --- scripts/run.py | 81 ++++++++++++++++++++++++++++++++++---------------- spice/spice.py | 25 +++++++--------- 2 files changed, 66 insertions(+), 40 deletions(-) diff --git a/scripts/run.py b/scripts/run.py index b096652..a71b25f 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -2,53 +2,82 @@ import os import sys -import fire - sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from spice import Spice -from spice.utils import fuzzy_model_lookup -def display_stats(response): - input_tokens = response.input_tokens - output_tokens = response.output_tokens - total_time = response.total_time +async def basic_example(): + client = Spice() + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "list 5 random words"}, + ] + + response = await client.call_llm(messages=messages, model="gpt-4-0125-preview") + + print(response.text) - print(f"\n\nlogged: {input_tokens} input tokens, {output_tokens} output tokens, {total_time:.2f}s\n\n") +async def streaming_example(): + # you can set a default model for the client instead of passing it with each call + client = Spice(model="claude-3-opus-20240229") -async def run(model="", stream=False): - model = fuzzy_model_lookup(model) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "list 5 random words"}, ] - client = Spice(model) + response = await client.call_llm(messages=messages, stream=True) - response = await client.call_llm(messages=messages, stream=stream, logging_callback=display_stats) + async for text in response.stream(): + print(text, end="", flush=True) - print(">>>>>>>>>>>>>") - if stream: - async for t in response.stream(): - print(t, end="", flush=True) - else: - print(response.text) - print("\n<<<<<<<<<<<<<") + # response always includes the final text, no need build it from the stream yourself + print(response.text) + # response also includes helpful stats print(f"Took {response.total_time:.2f}s") + print(f"Time to first token: {response.time_to_first_token:.2f}s") + print(f"Input/Output tokens: {response.input_tokens}/{response.output_tokens}") - if stream: - print(f"Time to first token: {response.time_to_first_token:.2f}s") - print(f"Input/Output tokens: {response.input_tokens}/{response.output_tokens}") - print(f"Characters per second: {response.characters_per_second:.2f}") +async def multiple_providers_example(): + # alias models for easy configuration, even mixing providers + model_aliases = { + "task1_model": {"model": "gpt-4-0125-preview"}, + "task2_model": {"model": "claude-3-opus-20240229"}, + "task3_model": {"model": "claude-3-haiku-20240307"}, + } + + client = Spice(model_aliases=model_aliases) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "list 5 random words"}, + ] + + responses = await asyncio.gather( + client.call_llm(messages=messages, model="task1_model"), + client.call_llm(messages=messages, model="task2_model"), + client.call_llm(messages=messages, model="task3_model"), + ) + + for i, response in enumerate(responses, 1): + print(f"\nModel {i} response:") + print(response.text) + print(f"Characters per second: {response.characters_per_second:.2f}") -def run_async(model="", stream=False): - asyncio.run(run(model, stream)) +async def run_all_examples(): + print("Running basic example:") + await basic_example() + print("\n\nRunning streaming example:") + await streaming_example() + print("\n\nRunning multiple providers example:") + await multiple_providers_example() if __name__ == "__main__": - fire.Fire(run_async) + asyncio.run(run_all_examples()) diff --git a/spice/spice.py b/spice/spice.py index 089b383..97e1920 100644 --- a/spice/spice.py +++ b/spice/spice.py @@ -76,19 +76,19 @@ def characters_per_second(self): class Spice: - def __init__(self, default_model=None, default_provider=None, model_aliases=None): - self._default_model = default_model + def __init__(self, model=None, provider=None, model_aliases=None): + self._default_model = model - if default_model is not None: + if model is not None: if model_aliases is not None: - raise SpiceError("model_aliases not supported when default_model is set") + raise SpiceError("model_aliases not supported when model is set") self._model_aliases = None - if default_provider is None: - default_provider = _get_provider_from_model_name(default_model) - self._default_client = _get_client(default_provider) + if provider is None: + provider = _get_provider_from_model_name(model) + self._default_client = _get_client(provider) else: - if default_provider is not None: - self._default_client = _get_client(default_provider) + if provider is not None: + self._default_client = _get_client(provider) else: self._default_client = None self._clients = _get_clients_from_env() @@ -97,7 +97,7 @@ def __init__(self, default_model=None, default_provider=None, model_aliases=None if model_aliases is not None: _validate_model_aliases( self._model_aliases, - self._clients if self._default_client is None else {default_provider: self._default_client}, + self._clients if self._default_client is None else {provider: self._default_client}, ) async def call_llm( @@ -112,11 +112,8 @@ async def call_llm( ): if model is None: if self._default_model is None: - raise SpiceError("model argument is required when default_model is not set") + raise SpiceError("model argument is required when default model is not set at initialization") model = self._default_model - else: - if self._default_model is not None: - raise SpiceError("model argument cannot be used when default_model is set") if self._model_aliases is not None: if model in self._model_aliases: