diff --git a/src/marvin/components/ai_image.py b/src/marvin/components/ai_image.py index 5818ac18b..d5c059f1d 100644 --- a/src/marvin/components/ai_image.py +++ b/src/marvin/components/ai_image.py @@ -1,5 +1,4 @@ import asyncio -import inspect from functools import partial, wraps from typing import ( Any, @@ -20,6 +19,7 @@ from marvin.client.openai import AsyncMarvinClient, MarvinClient from marvin.components.prompt.fn import PromptFunction +from marvin.prompts.images import IMAGE_PROMPT from marvin.utilities.jinja import ( BaseEnvironment, ) @@ -27,13 +27,6 @@ T = TypeVar("T") P = ParamSpec("P") -DEFAULT_PROMPT = inspect.cleandoc( - """ - {{_doc}} - {{_return_value}} - """ -) - class AIImageKwargs(TypedDict): environment: NotRequired[BaseEnvironment] @@ -45,7 +38,7 @@ class AIImageKwargs(TypedDict): class AIImageKwargsDefaults(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) environment: Optional[BaseEnvironment] = None - prompt: Optional[str] = DEFAULT_PROMPT + prompt: Optional[str] = IMAGE_PROMPT client: Optional[Client] = None aclient: Optional[AsyncClient] = None @@ -54,7 +47,7 @@ class AIImage(BaseModel, Generic[P]): model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) fn: Optional[Callable[P, Any]] = None environment: Optional[BaseEnvironment] = None - prompt: Optional[str] = Field(default=DEFAULT_PROMPT) + prompt: Optional[str] = Field(default=IMAGE_PROMPT) client: Client = Field(default_factory=lambda: MarvinClient().client) aclient: AsyncClient = Field(default_factory=lambda: AsyncMarvinClient().client) diff --git a/src/marvin/components/ai_speech.py b/src/marvin/components/ai_speech.py new file mode 100644 index 000000000..988ccde07 --- /dev/null +++ b/src/marvin/components/ai_speech.py @@ -0,0 +1,145 @@ +import asyncio +from functools import partial, wraps +from typing import ( + Any, + Callable, + Coroutine, + Generic, + Optional, + TypedDict, + TypeVar, + Union, + overload, +) + +from openai import AsyncClient, Client +from openai._base_client import HttpxBinaryResponseContent as AudioResponse +from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import NotRequired, ParamSpec, Self, Unpack + +from marvin.client.openai import AsyncMarvinClient, MarvinClient +from marvin.components.prompt.fn import PromptFunction +from marvin.prompts.speech import SPEECH_PROMPT +from marvin.utilities.jinja import ( + BaseEnvironment, +) + +T = TypeVar("T") +P = ParamSpec("P") + + +class AISpeechKwargs(TypedDict): + environment: NotRequired[BaseEnvironment] + prompt: NotRequired[str] + client: NotRequired[Client] + aclient: NotRequired[AsyncClient] + + +class AISpeechKwargsDefaults(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) + environment: Optional[BaseEnvironment] = None + prompt: Optional[str] = SPEECH_PROMPT + client: Optional[Client] = None + aclient: Optional[AsyncClient] = None + + +class AISpeech(BaseModel, Generic[P]): + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) + fn: Optional[Callable[P, Any]] = None + environment: Optional[BaseEnvironment] = None + prompt: Optional[str] = Field(default=SPEECH_PROMPT) + client: Client = Field(default_factory=lambda: MarvinClient().client) + aclient: AsyncClient = Field(default_factory=lambda: AsyncMarvinClient().client) + + def __call__( + self, *args: P.args, **kwargs: P.kwargs + ) -> Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]: + if asyncio.iscoroutinefunction(self.fn): + return self.acall(*args, **kwargs) + return self.call(*args, **kwargs) + + def call(self, *args: P.args, **kwargs: P.kwargs) -> AudioResponse: + prompt_str = self.as_prompt(*args, **kwargs) + return MarvinClient(client=self.client).speak(input=prompt_str) + + async def acall(self, *args: P.args, **kwargs: P.kwargs) -> AudioResponse: + prompt_str = self.as_prompt(*args, **kwargs) + return await AsyncMarvinClient(client=self.aclient).speak(input=prompt_str) + + def as_prompt( + self, + *args: P.args, + **kwargs: P.kwargs, + ) -> str: + tool_call = PromptFunction[BaseModel].as_tool_call( + fn=self.fn, + environment=self.environment, + prompt=self.prompt, + ) + return tool_call(*args, **kwargs).messages[0].content + + @overload + @classmethod + def as_decorator( + cls: type[Self], + **kwargs: Unpack[AISpeechKwargs], + ) -> Callable[P, Self]: + pass + + @overload + @classmethod + def as_decorator( + cls: type[Self], + fn: Callable[P, Any], + **kwargs: Unpack[AISpeechKwargs], + ) -> Self: + pass + + @classmethod + def as_decorator( + cls: type[Self], + fn: Optional[Callable[P, Any]] = None, + **kwargs: Unpack[AISpeechKwargs], + ) -> Union[Self, Callable[[Callable[P, Any]], Self]]: + passed_kwargs: dict[str, Any] = { + k: v for k, v in kwargs.items() if v is not None + } + if fn is None: + return partial( + cls, + **passed_kwargs, + ) + + return cls( + fn=fn, + **passed_kwargs, + ) + + +def ai_speech( + fn: Optional[Callable[P, Any]] = None, + **kwargs: Unpack[AISpeechKwargs], +) -> Union[ + Callable[ + [Callable[P, Any]], + Callable[P, Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]], + ], + Callable[P, Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]], +]: + def wrapper( + func: Callable[P, Any], *args_: P.args, **kwargs_: P.kwargs + ) -> Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]: + f = AISpeech[P].as_decorator( + func, **AISpeechKwargsDefaults(**kwargs).model_dump(exclude_none=True) + ) + return f(*args_, **kwargs_) + + if fn is not None: + return wraps(fn)(partial(wrapper, fn)) + + def decorator( + fn: Callable[P, Any], + ) -> Callable[P, Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]]: + return wraps(fn)(partial(wrapper, fn)) + + return decorator diff --git a/src/marvin/prompts/__init__.py b/src/marvin/prompts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/marvin/prompts/images.py b/src/marvin/prompts/images.py new file mode 100644 index 000000000..600a82c9d --- /dev/null +++ b/src/marvin/prompts/images.py @@ -0,0 +1,8 @@ +import inspect + +IMAGE_PROMPT = inspect.cleandoc( + """ + {{_doc | default("", true)}} + {{_return_value | default("", true)}} + """ +) diff --git a/src/marvin/prompts/speech.py b/src/marvin/prompts/speech.py new file mode 100644 index 000000000..7c626500b --- /dev/null +++ b/src/marvin/prompts/speech.py @@ -0,0 +1,8 @@ +import inspect + +SPEECH_PROMPT = inspect.cleandoc( + """ + {{_doc | default("", true)}} + {{_return_value | default("", true)}} + """ +) diff --git a/src/marvin/utilities/jinja.py b/src/marvin/utilities/jinja.py index 2e537c121..8fde971db 100644 --- a/src/marvin/utilities/jinja.py +++ b/src/marvin/utilities/jinja.py @@ -186,7 +186,7 @@ def render_to_messages( **kwargs: Any, ) -> list[Message]: pairs = split_text_by_tokens( - text=self.render(**kwargs), + text=self.render(**kwargs).strip(), split_tokens=[f"\n{role}" for role in self.roles.keys()], ) return [