Skip to content

Commit

Permalink
Merge pull request #720 from PrefectHQ/ai-speech
Browse files Browse the repository at this point in the history
Add ai speech
  • Loading branch information
jlowin authored Jan 4, 2024
2 parents 497f6ef + 0d01d93 commit 79934fe
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 11 deletions.
13 changes: 3 additions & 10 deletions src/marvin/components/ai_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import inspect
from functools import partial, wraps
from typing import (
Any,
Expand All @@ -20,20 +19,14 @@

from marvin.client.openai import AsyncMarvinClient, MarvinClient
from marvin.components.prompt.fn import PromptFunction
from marvin.prompts.images import IMAGE_PROMPT
from marvin.utilities.jinja import (
BaseEnvironment,
)

T = TypeVar("T")
P = ParamSpec("P")

DEFAULT_PROMPT = inspect.cleandoc(
"""
{{_doc}}
{{_return_value}}
"""
)


class AIImageKwargs(TypedDict):
environment: NotRequired[BaseEnvironment]
Expand All @@ -45,7 +38,7 @@ class AIImageKwargs(TypedDict):
class AIImageKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = DEFAULT_PROMPT
prompt: Optional[str] = IMAGE_PROMPT
client: Optional[Client] = None
aclient: Optional[AsyncClient] = None

Expand All @@ -54,7 +47,7 @@ class AIImage(BaseModel, Generic[P]):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
fn: Optional[Callable[P, Any]] = None
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = Field(default=DEFAULT_PROMPT)
prompt: Optional[str] = Field(default=IMAGE_PROMPT)
client: Client = Field(default_factory=lambda: MarvinClient().client)
aclient: AsyncClient = Field(default_factory=lambda: AsyncMarvinClient().client)

Expand Down
145 changes: 145 additions & 0 deletions src/marvin/components/ai_speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import asyncio
from functools import partial, wraps
from typing import (
Any,
Callable,
Coroutine,
Generic,
Optional,
TypedDict,
TypeVar,
Union,
overload,
)

from openai import AsyncClient, Client
from openai._base_client import HttpxBinaryResponseContent as AudioResponse
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import NotRequired, ParamSpec, Self, Unpack

from marvin.client.openai import AsyncMarvinClient, MarvinClient
from marvin.components.prompt.fn import PromptFunction
from marvin.prompts.speech import SPEECH_PROMPT
from marvin.utilities.jinja import (
BaseEnvironment,
)

T = TypeVar("T")
P = ParamSpec("P")


class AISpeechKwargs(TypedDict):
environment: NotRequired[BaseEnvironment]
prompt: NotRequired[str]
client: NotRequired[Client]
aclient: NotRequired[AsyncClient]


class AISpeechKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = SPEECH_PROMPT
client: Optional[Client] = None
aclient: Optional[AsyncClient] = None


class AISpeech(BaseModel, Generic[P]):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
fn: Optional[Callable[P, Any]] = None
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = Field(default=SPEECH_PROMPT)
client: Client = Field(default_factory=lambda: MarvinClient().client)
aclient: AsyncClient = Field(default_factory=lambda: AsyncMarvinClient().client)

def __call__(
self, *args: P.args, **kwargs: P.kwargs
) -> Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]:
if asyncio.iscoroutinefunction(self.fn):
return self.acall(*args, **kwargs)
return self.call(*args, **kwargs)

def call(self, *args: P.args, **kwargs: P.kwargs) -> AudioResponse:
prompt_str = self.as_prompt(*args, **kwargs)
return MarvinClient(client=self.client).speak(input=prompt_str)

async def acall(self, *args: P.args, **kwargs: P.kwargs) -> AudioResponse:
prompt_str = self.as_prompt(*args, **kwargs)
return await AsyncMarvinClient(client=self.aclient).speak(input=prompt_str)

def as_prompt(
self,
*args: P.args,
**kwargs: P.kwargs,
) -> str:
tool_call = PromptFunction[BaseModel].as_tool_call(
fn=self.fn,
environment=self.environment,
prompt=self.prompt,
)
return tool_call(*args, **kwargs).messages[0].content

@overload
@classmethod
def as_decorator(
cls: type[Self],
**kwargs: Unpack[AISpeechKwargs],
) -> Callable[P, Self]:
pass

@overload
@classmethod
def as_decorator(
cls: type[Self],
fn: Callable[P, Any],
**kwargs: Unpack[AISpeechKwargs],
) -> Self:
pass

@classmethod
def as_decorator(
cls: type[Self],
fn: Optional[Callable[P, Any]] = None,
**kwargs: Unpack[AISpeechKwargs],
) -> Union[Self, Callable[[Callable[P, Any]], Self]]:
passed_kwargs: dict[str, Any] = {
k: v for k, v in kwargs.items() if v is not None
}
if fn is None:
return partial(
cls,
**passed_kwargs,
)

return cls(
fn=fn,
**passed_kwargs,
)


def ai_speech(
fn: Optional[Callable[P, Any]] = None,
**kwargs: Unpack[AISpeechKwargs],
) -> Union[
Callable[
[Callable[P, Any]],
Callable[P, Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]],
],
Callable[P, Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]],
]:
def wrapper(
func: Callable[P, Any], *args_: P.args, **kwargs_: P.kwargs
) -> Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]:
f = AISpeech[P].as_decorator(
func, **AISpeechKwargsDefaults(**kwargs).model_dump(exclude_none=True)
)
return f(*args_, **kwargs_)

if fn is not None:
return wraps(fn)(partial(wrapper, fn))

def decorator(
fn: Callable[P, Any],
) -> Callable[P, Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]]:
return wraps(fn)(partial(wrapper, fn))

return decorator
Empty file added src/marvin/prompts/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions src/marvin/prompts/images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import inspect

IMAGE_PROMPT = inspect.cleandoc(
"""
{{_doc | default("", true)}}
{{_return_value | default("", true)}}
"""
)
8 changes: 8 additions & 0 deletions src/marvin/prompts/speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import inspect

SPEECH_PROMPT = inspect.cleandoc(
"""
{{_doc | default("", true)}}
{{_return_value | default("", true)}}
"""
)
2 changes: 1 addition & 1 deletion src/marvin/utilities/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def render_to_messages(
**kwargs: Any,
) -> list[Message]:
pairs = split_text_by_tokens(
text=self.render(**kwargs),
text=self.render(**kwargs).strip(),
split_tokens=[f"\n{role}" for role in self.roles.keys()],
)
return [
Expand Down

0 comments on commit 79934fe

Please sign in to comment.