From b376ddd83e7a5e35ab4dcde1b134ff2cdd764477 Mon Sep 17 00:00:00 2001 From: Adam Azzam <33043305+AAAZZAM@users.noreply.github.com> Date: Wed, 15 Nov 2023 18:14:30 -0500 Subject: [PATCH] ai_classifier hookups --- src/marvin/components/ai_classifier.py | 60 ++++++++++++------------ src/marvin/components/ai_function.py | 6 +-- src/marvin/components/prompt_function.py | 5 +- src/marvin/requests.py | 2 +- 4 files changed, 36 insertions(+), 37 deletions(-) diff --git a/src/marvin/components/ai_classifier.py b/src/marvin/components/ai_classifier.py index c62dd2fa3..7de44c2dc 100644 --- a/src/marvin/components/ai_classifier.py +++ b/src/marvin/components/ai_classifier.py @@ -39,15 +39,17 @@ class AIClassifier(BaseModel, Generic[P, T]): fn: Optional[Callable[P, T]] = None environment: Optional[BaseEnvironment] = None - prompt: Optional[str] = Field(default=inspect.cleandoc(""" - You are an expert classifier that always choose correctly. - - {{_doc}} - - You must classify `{{text}}` into one of the following classes: - {% for option in _options %} - Class {{ loop.index - 1}} (value: {{ option }}) - {% endfor %} - ASSISTANT: The correct class label is Class - """)) + prompt: Optional[str] = Field( + default=inspect.cleandoc( + "You are an expert classifier that always choose correctly." + " \n- {{_doc}}" + " \n- You must classify `{{text}}` into one of the following classes:" + "{% for option in _options %}" + " Class {{ loop.index - 1}} (value: {{ option }})" + "{% endfor %}" + "ASSISTANT: The correct class label is Class" + ) + ) enumerate: bool = True encoder: Callable[[str], list[int]] = Field(default=None) max_tokens: Optional[int] = 1 @@ -180,12 +182,11 @@ def ai_classifier( *, environment: Optional[BaseEnvironment] = None, prompt: Optional[str] = None, - model_name: str = "FormatResponse", - model_description: str = "Formats the response.", - field_name: str = "data", - field_description: str = "The data to format.", + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, **render_kwargs: Any, -) -> Callable[[Callable[P, T]], Callable[P, list[T]]]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: pass @@ -195,12 +196,11 @@ def ai_classifier( *, environment: Optional[BaseEnvironment] = None, prompt: Optional[str] = None, - model_name: str = "FormatResponse", - model_description: str = "Formats the response.", - field_name: str = "data", - field_description: str = "The data to format.", + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, **render_kwargs: Any, -) -> Callable[P, list[T]]: +) -> Callable[P, T]: pass @@ -209,28 +209,26 @@ def ai_classifier( *, environment: Optional[BaseEnvironment] = None, prompt: Optional[str] = None, - model_name: str = "FormatResponse", - model_description: str = "Formats the response.", - field_name: str = "data", - field_description: str = "The data to format.", + enumerate: bool = True, + encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder, + max_tokens: Optional[int] = 1, **render_kwargs: Any, -) -> Union[Callable[[Callable[P, T]], Callable[P, list[T]]], Callable[P, list[T]]]: - def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> list[T]: +) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]: + def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: return AIClassifier[P, T].as_decorator( func, environment=environment, prompt=prompt, - model_name=model_name, - model_description=model_description, - field_name=field_name, - field_description=field_description, + enumerate=enumerate, + encoder=encoder, + max_tokens=max_tokens, **render_kwargs, - )(*args, **kwargs) + )(*args, **kwargs)[0] if fn is not None: return wraps(fn)(partial(wrapper, fn)) - def decorator(fn: Callable[P, T]) -> Callable[P, list[T]]: + def decorator(fn: Callable[P, T]) -> Callable[P, T]: return wraps(fn)(partial(wrapper, fn)) return decorator diff --git a/src/marvin/components/ai_function.py b/src/marvin/components/ai_function.py index e26c513f5..f76d10b43 100644 --- a/src/marvin/components/ai_function.py +++ b/src/marvin/components/ai_function.py @@ -80,11 +80,11 @@ def parse(self, response: "ChatCompletion") -> T: model_description=self.description, field_name=self.field_name, field_description=self.field_description, - ).function.model - if not tool: + ).function + if not tool or not tool.model: raise NotImplementedError - return getattr(tool.model_validate_json(arguments), self.field_name) + return getattr(tool.model.model_validate_json(arguments), self.field_name) def as_prompt( self, diff --git a/src/marvin/components/prompt_function.py b/src/marvin/components/prompt_function.py index 89ece7c90..9a5468042 100644 --- a/src/marvin/components/prompt_function.py +++ b/src/marvin/components/prompt_function.py @@ -15,7 +15,6 @@ from pydantic import BaseModel from typing_extensions import Self -from marvin import settings from marvin.requests import BaseMessage as Message from marvin.requests import Prompt from marvin.serializers import ( @@ -23,6 +22,7 @@ create_tool_from_type, create_vocabulary_from_type, ) +from marvin.settings import settings from marvin.utilities.jinja import ( BaseEnvironment, Transcript, @@ -176,6 +176,7 @@ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self: field_name=field_name, field_description=field_description, ) + messages = Transcript( content=prompt or func.__doc__ or "" ).render_to_messages( @@ -193,7 +194,7 @@ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self: messages=messages, tool_choice={ "type": "function", - "function": {"name": tool.function.name}, + "function": {"name": getattr(tool.function, "name", model_name)}, }, tools=[tool], ) diff --git a/src/marvin/requests.py b/src/marvin/requests.py index 00918e273..3e9f45aa8 100644 --- a/src/marvin/requests.py +++ b/src/marvin/requests.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field from typing_extensions import Annotated, Self -from marvin import settings +from marvin.settings import settings T = TypeVar("T", bound=BaseModel)