Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ai speech #720

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions src/marvin/components/ai_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import inspect
from functools import partial, wraps
from typing import (
Any,
Expand All @@ -20,20 +19,14 @@

from marvin.client.openai import AsyncMarvinClient, MarvinClient
from marvin.components.prompt.fn import PromptFunction
from marvin.prompts.images import IMAGE_PROMPT
from marvin.utilities.jinja import (
BaseEnvironment,
)

T = TypeVar("T")
P = ParamSpec("P")

DEFAULT_PROMPT = inspect.cleandoc(
"""
{{_doc}}
{{_return_value}}
"""
)


class AIImageKwargs(TypedDict):
environment: NotRequired[BaseEnvironment]
Expand All @@ -45,7 +38,7 @@ class AIImageKwargs(TypedDict):
class AIImageKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = DEFAULT_PROMPT
prompt: Optional[str] = IMAGE_PROMPT
client: Optional[Client] = None
aclient: Optional[AsyncClient] = None

Expand All @@ -54,7 +47,7 @@ class AIImage(BaseModel, Generic[P]):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
fn: Optional[Callable[P, Any]] = None
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = Field(default=DEFAULT_PROMPT)
prompt: Optional[str] = Field(default=IMAGE_PROMPT)
client: Client = Field(default_factory=lambda: MarvinClient().client)
aclient: AsyncClient = Field(default_factory=lambda: AsyncMarvinClient().client)

Expand Down
145 changes: 145 additions & 0 deletions src/marvin/components/ai_speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import asyncio
from functools import partial, wraps
from typing import (
Any,
Callable,
Coroutine,
Generic,
Optional,
TypedDict,
TypeVar,
Union,
overload,
)

from openai import AsyncClient, Client
from openai._base_client import HttpxBinaryResponseContent as AudioResponse
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import NotRequired, ParamSpec, Self, Unpack

from marvin.client.openai import AsyncMarvinClient, MarvinClient
from marvin.components.prompt.fn import PromptFunction
from marvin.prompts.speech import SPEECH_PROMPT
from marvin.utilities.jinja import (
BaseEnvironment,
)

T = TypeVar("T")
P = ParamSpec("P")


class AISpeechKwargs(TypedDict):
environment: NotRequired[BaseEnvironment]
prompt: NotRequired[str]
client: NotRequired[Client]
aclient: NotRequired[AsyncClient]


class AISpeechKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = SPEECH_PROMPT
client: Optional[Client] = None
aclient: Optional[AsyncClient] = None


class AISpeech(BaseModel, Generic[P]):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
fn: Optional[Callable[P, Any]] = None
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = Field(default=SPEECH_PROMPT)
client: Client = Field(default_factory=lambda: MarvinClient().client)
aclient: AsyncClient = Field(default_factory=lambda: AsyncMarvinClient().client)

def __call__(
self, *args: P.args, **kwargs: P.kwargs
) -> Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]:
if asyncio.iscoroutinefunction(self.fn):
return self.acall(*args, **kwargs)
return self.call(*args, **kwargs)

def call(self, *args: P.args, **kwargs: P.kwargs) -> AudioResponse:
prompt_str = self.as_prompt(*args, **kwargs)
return MarvinClient(client=self.client).speak(input=prompt_str)

async def acall(self, *args: P.args, **kwargs: P.kwargs) -> AudioResponse:
prompt_str = self.as_prompt(*args, **kwargs)
return await AsyncMarvinClient(client=self.aclient).speak(input=prompt_str)

def as_prompt(
self,
*args: P.args,
**kwargs: P.kwargs,
) -> str:
tool_call = PromptFunction[BaseModel].as_tool_call(
fn=self.fn,
environment=self.environment,
prompt=self.prompt,
)
return tool_call(*args, **kwargs).messages[0].content

@overload
@classmethod
def as_decorator(
cls: type[Self],
**kwargs: Unpack[AISpeechKwargs],
) -> Callable[P, Self]:
pass

@overload
@classmethod
def as_decorator(
cls: type[Self],
fn: Callable[P, Any],
**kwargs: Unpack[AISpeechKwargs],
) -> Self:
pass

@classmethod
def as_decorator(
cls: type[Self],
fn: Optional[Callable[P, Any]] = None,
**kwargs: Unpack[AISpeechKwargs],
) -> Union[Self, Callable[[Callable[P, Any]], Self]]:
passed_kwargs: dict[str, Any] = {
k: v for k, v in kwargs.items() if v is not None
}
if fn is None:
return partial(
cls,
**passed_kwargs,
)

return cls(
fn=fn,
**passed_kwargs,
)


def ai_speech(
fn: Optional[Callable[P, Any]] = None,
**kwargs: Unpack[AISpeechKwargs],
) -> Union[
Callable[
[Callable[P, Any]],
Callable[P, Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]],
],
Callable[P, Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]],
]:
def wrapper(
func: Callable[P, Any], *args_: P.args, **kwargs_: P.kwargs
) -> Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]:
f = AISpeech[P].as_decorator(
func, **AISpeechKwargsDefaults(**kwargs).model_dump(exclude_none=True)
)
return f(*args_, **kwargs_)

if fn is not None:
return wraps(fn)(partial(wrapper, fn))

def decorator(
fn: Callable[P, Any],
) -> Callable[P, Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]]:
return wraps(fn)(partial(wrapper, fn))

return decorator
Empty file added src/marvin/prompts/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions src/marvin/prompts/images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import inspect

IMAGE_PROMPT = inspect.cleandoc(
"""
{{_doc | default("", true)}}
{{_return_value | default("", true)}}
"""
)
8 changes: 8 additions & 0 deletions src/marvin/prompts/speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import inspect

SPEECH_PROMPT = inspect.cleandoc(
"""
{{_doc | default("", true)}}
{{_return_value | default("", true)}}
"""
)
2 changes: 1 addition & 1 deletion src/marvin/utilities/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def render_to_messages(
**kwargs: Any,
) -> list[Message]:
pairs = split_text_by_tokens(
text=self.render(**kwargs),
text=self.render(**kwargs).strip(),
split_tokens=[f"\n{role}" for role in self.roles.keys()],
)
return [
Expand Down