diff --git a/.gitignore b/.gitignore index 3b3cff889..43817e659 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,7 @@ blob*/ blob/**/* blob/* -*.db +*.db # C extensions @@ -306,8 +306,9 @@ dist .yarn/install-state.gz .pnp.* - # Ignore static files static/ *.db-shm *.db-wal + +uv.lock diff --git a/examples/providers/local_ex.py b/examples/providers/local_ex.py new file mode 100644 index 000000000..a24d33b56 --- /dev/null +++ b/examples/providers/local_ex.py @@ -0,0 +1,27 @@ +""" +GGUF type example serving model from local storage. +""" + +from os.path import expanduser + +import ell + +ell.init(verbose=True, store="./logdir") +# Use models automatically registered by asking ollama + +# these are just examples for small models that i know I have. +# should ideally be able to point to a folder of models/blobs +model_name = "Llama-3.2-1B-Instruct-Q8_0.gguf" +model_path = expanduser( + "~/.cache/lm-studio/models/lmstudio-community/Llama-3.2-1B-Instruct-GGUF/" +) +ell.models.local.register(model_name=model_name, model_path=model_path) + + +@ell.simple(model=model_name, temperature=0.7) +def hello(world: str): + """You are a helpful assistant that writes in lower case.""" # System Message + return f"Say hello to {world[::-1]} with a poem." # User Message + + +print(hello("sama")) diff --git a/src/ell/configurator.py b/src/ell/configurator.py index 57e281fe1..263b2ece1 100644 --- a/src/ell/configurator.py +++ b/src/ell/configurator.py @@ -16,9 +16,9 @@ class _Model: name: str default_client: Optional[Union[openai.Client, Any]] = None #XXX: Deprecation in 0.1.0 - #XXX: We will depreciate this when streaming is implemented. + #XXX: We will depreciate this when streaming is implemented. # Currently we stream by default for the verbose renderer, - # but in the future we will not support streaming by default + # but in the future we will not support streaming by default # and stream=True must be passed which will then make API providers the # single source of truth for whether or not a model supports an api parameter. # This makes our implementation extremely light, only requiring us to provide @@ -44,9 +44,9 @@ def __init__(self, **data): self._lock = threading.Lock() self._local = threading.local() - + def register_model( - self, + self, name: str, default_client: Optional[Union[openai.Client, Any]] = None, supports_streaming: Optional[bool] = None @@ -74,12 +74,12 @@ def model_registry_override(self, overrides: Dict[str, _Model]): """ if not hasattr(self._local, 'stack'): self._local.stack = [] - + with self._lock: current_registry = self._local.stack[-1] if self._local.stack else self.registry new_registry = current_registry.copy() new_registry.update(overrides) - + self._local.stack.append(new_registry) try: yield @@ -187,7 +187,7 @@ def init( def get_store() -> Union[Store, None]: return config.store -# Will be deprecated at 0.1.0 +# Will be deprecated at 0.1.0 # You can add more helper functions here if needed def register_provider(provider: Provider, client_type: Type[Any]) -> None: diff --git a/src/ell/models/__init__.py b/src/ell/models/__init__.py index f57479deb..0cdec6e70 100644 --- a/src/ell/models/__init__.py +++ b/src/ell/models/__init__.py @@ -6,8 +6,9 @@ """ -import ell.models.openai import ell.models.anthropic -import ell.models.ollama +import ell.models.bedrock import ell.models.groq -import ell.models.bedrock \ No newline at end of file +import ell.models.local +import ell.models.ollama +import ell.models.openai diff --git a/src/ell/models/local.py b/src/ell/models/local.py new file mode 100644 index 000000000..8c2575d36 --- /dev/null +++ b/src/ell/models/local.py @@ -0,0 +1,29 @@ +import logging + +from ell.configurator import config +from ell.providers.local import LocalModelClient + +logger = logging.getLogger(__name__) +client = None + + +def register(model_name: str, model_path: str): + """ + Registers model from local disk + + This function sets up the Ollama client with the given base URL and + fetches available models from the Ollama API. It then registers these + models with the global configuration, allowing them to be used within + the ell framework. + + Args: + model_name (str): The name of the model to register. + model_path (str): The path to the model on disk. + + Note: + This function updates the global client and configuration. + It logs any errors encountered during the process. + """ + + client = LocalModelClient(model_name=model_name, model_path=model_path) + config.register_model(model_name, client) diff --git a/src/ell/providers/__init__.py b/src/ell/providers/__init__.py index 5d520960f..9b500612e 100644 --- a/src/ell/providers/__init__.py +++ b/src/ell/providers/__init__.py @@ -1,7 +1,8 @@ -import ell.providers.openai -import ell.providers.groq import ell.providers.anthropic import ell.providers.bedrock +import ell.providers.local +import ell.providers.groq +import ell.providers.openai # import ell.providers.mistral # import ell.providers.cohere # import ell.providers.gemini diff --git a/src/ell/providers/local.py b/src/ell/providers/local.py new file mode 100644 index 000000000..48cbdaf2d --- /dev/null +++ b/src/ell/providers/local.py @@ -0,0 +1,141 @@ +from typing import ( + Any, + Callable, + Iterable, + Optional, +) + +from ell.configurator import register_provider +from ell.provider import EllCallParams, Metadata, Provider +from ell.types import ContentBlock, Message +from ell.types._lstr import _lstr + +try: + import gpt4all + + class LocalModelClient(gpt4all.GPT4All): + # should probably fix the way this is done in _warnings via: `not client_to_use.api_key` + api_key = "okay" + +except ImportError: + raise ImportError("Please install the gpt4all package to use the LocalProvider.") + + +class LocalProvider(Provider): + """ + Custom Provider for LocalProvider models. + """ + + dangerous_disable_validation = True # Set to True to bypass validation if necessary + + def _construct_prompt(self, messages: list[Message]) -> str: + """ + Constructs a single prompt string from the list of ell Messages. + Adjust this method based on how LocalProvider expects the prompt to be formatted. + + Might need this as part of client.chat_session() in provider_call_function. + """ + prompt = "" + for message in messages: + if message.role == "system": + prompt += f"System: {message.text_only}\n" + elif message.role == "user": + prompt += f"User: {message.text_only}\n" + elif message.role == "assistant": + prompt += f"Assistant: {message.text_only}\n" + # Handle other roles if necessary + return prompt.strip() + + def provider_call_function( + self, + client: Any, + api_call_params: dict[str, Any] = {}, + ) -> Callable[..., Any]: + """ + Returns the function to call on the client with the given API call parameters + """ + if api_call_params.get("streaming", False): + raise NotImplementedError("Streaming responses not yet supported.") + + with client.chat_session(): + # not clear to me if you need to put the system prompt and prompt template in chat_session + return client.generate + + def translate_to_provider(self, ell_call: EllCallParams) -> dict[str, Any]: + """ + Translates EllCallParams to LocalProvider's generate method parameters. + """ + final_call_params = { + "prompt": self._construct_prompt(ell_call.messages), + "max_tokens": ell_call.api_params.get("max_tokens", 200), + "temp": ell_call.api_params.get("temperature", 0.7), + "top_k": ell_call.api_params.get("top_k", 40), + "top_p": ell_call.api_params.get("top_p", 0.4), + "min_p": ell_call.api_params.get("min_p", 0.0), + "repeat_penalty": ell_call.api_params.get("repeat_penalty", 1.18), + "repeat_last_n": ell_call.api_params.get("repeat_last_n", 64), + "n_batch": ell_call.api_params.get("n_batch", 8), + "n_predict": ell_call.api_params.get("n_predict", None), + "streaming": ell_call.api_params.get("stream", False), + # callback of `None` type on gpt4all will cause errors + # "callback": ell_call.api_params.get("callback", None), + } + + # Handle tools if any + if ell_call.tools: + # LocalProvider might not support tools directly; handle accordingly + # This is a placeholder for tool integration + final_call_params["tools"] = [ + { + "name": tool.__name__, + "description": tool.__doc__, + "parameters": tool.__ell_params_model__.model_json_schema(), + } + for tool in ell_call.tools + ] + + return final_call_params + + def translate_from_provider( + self, + provider_response: Iterable[str] | str, + ell_call: EllCallParams, + provider_call_params: dict[str, Any], + origin_id: Optional[str] = None, + logger: Optional[Callable[..., None]] = None, + ) -> tuple[list[Message], Metadata]: + """ + Translates LocalProvider's response back into ell's Message and Metadata formats. + Handles both streaming and non-streaming responses. + """ + metadata: Metadata = {} + messages: list[Message] = [] + streaming = provider_call_params.get("streaming", False) + + if streaming and isinstance(provider_response, Iterable): + # Handle streaming responses + raise NotImplementedError("Streaming responses not yet supported.") + else: + # Handle non-streaming responses + if isinstance(provider_response, str): + messages.append( + Message( + role="assistant", + content=[ + ContentBlock( + text=_lstr( + content=provider_response, origin_trace=origin_id + ) + ) + ], + ) + ) + else: + raise ValueError( + "Unexpected provider_response type for non-streaming response." + ) + + return messages, metadata + + +register_provider(LocalProvider(), LocalModelClient)