diff --git a/.gitignore b/.gitignore index ab5a061..220ee9c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .env .venv __pycache__ +.idea diff --git a/pyproject.toml b/pyproject.toml index 1372a75..dc1c0ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ dependencies = [ "pillow", "tiktoken", "httpx", - "jinja2" + "jinja2", + "termcolor", ] [project.optional-dependencies] @@ -32,5 +33,4 @@ dev = [ "pyright", "pytest", "pytest-asyncio", - "termcolor", ] diff --git a/spice/spice_message.py b/spice/spice_message.py index 1e2c140..5f4ca88 100644 --- a/spice/spice_message.py +++ b/spice/spice_message.py @@ -37,6 +37,7 @@ class PromptMetadata(BaseModel): class SpiceMessage(BaseModel): role: Role content: Union[TextContent, ImageContent] + name: Optional[str] cache: bool = Field(default=False) prompt_metadata: Optional[PromptMetadata] @@ -45,6 +46,7 @@ def __init__( role: Role, text: Optional[str] = None, image_url: Optional[str] = None, + name: Optional[str] = None, cache: bool = False, prompt_metadata: Optional[PromptMetadata] = None, ): @@ -56,7 +58,7 @@ def __init__( content = ImageContent(type="image_url", image_url=image_url) else: raise ValueError("Either text or image_url must be provided.") - super().__init__(role=role, content=content, cache=cache, prompt_metadata=prompt_metadata) + super().__init__(role=role, content=content, name=name, cache=cache, prompt_metadata=prompt_metadata) class SpiceMessages(List[SpiceMessage]): @@ -66,26 +68,28 @@ def __init__(self, client: Optional[Spice] = None, messages: Collection[SpiceMes self._client = client super().__init__(messages) - def add_text(self, role: Role, text: str, cache: bool = False) -> SpiceMessages: - self.append(SpiceMessage(role=role, text=text, cache=cache)) + def add_text(self, role: Role, text: str, name: Optional[str] = None, cache: bool = False) -> SpiceMessages: + self.append(SpiceMessage(role=role, text=text, name=name, cache=cache)) return self - def add_user_text(self, text: str, cache: bool = False) -> SpiceMessages: - return self.add_text("user", text, cache) + def add_user_text(self, text: str, name: Optional[str] = None, cache: bool = False) -> SpiceMessages: + return self.add_text("user", text, name, cache) - def add_system_text(self, text: str, cache: bool = False) -> SpiceMessages: - return self.add_text("system", text, cache) + def add_system_text(self, text: str, name: Optional[str] = None, cache: bool = False) -> SpiceMessages: + return self.add_text("system", text, name, cache) - def add_assistant_text(self, text: str, cache: bool = False) -> SpiceMessages: - return self.add_text("assistant", text, cache) + def add_assistant_text(self, text: str, name: Optional[str] = None, cache: bool = False) -> SpiceMessages: + return self.add_text("assistant", text, name, cache) - def add_user_image_from_url(self, url: str, cache: bool = False) -> SpiceMessages: + def add_user_image_from_url(self, url: str, name: Optional[str] = None, cache: bool = False) -> SpiceMessages: if not (url.startswith("http://") or url.startswith("https://")): raise ImageError(f"Invalid image URL {url}: Must be http or https protocol.") - self.append(SpiceMessage(role="user", image_url=url, cache=cache)) + self.append(SpiceMessage(role="user", image_url=url, name=name, cache=cache)) return self - def add_user_image_from_file(self, file_path: Path | str, cache: bool = False) -> SpiceMessages: + def add_user_image_from_file( + self, file_path: Path | str, name: Optional[str] = None, cache: bool = False + ) -> SpiceMessages: file_path = Path(file_path).expanduser().resolve() if not file_path.exists(): raise ImageError(f"Invalid image at {file_path}: file does not exist.") @@ -95,10 +99,12 @@ def add_user_image_from_file(self, file_path: Path | str, cache: bool = False) - with file_path.open("rb") as file: image_bytes = file.read() image = base64.b64encode(image_bytes).decode("utf-8") - self.append(SpiceMessage(role="user", image_url=f"data:{media_type};base64,{image}", cache=cache)) + self.append(SpiceMessage(role="user", image_url=f"data:{media_type};base64,{image}", name=name, cache=cache)) return self - def add_prompt(self, role: Role, name: str, cache: bool = False, **context: Any) -> SpiceMessages: + def add_prompt( + self, role: Role, name: str, message_name: Optional[str] = None, cache: bool = False, **context: Any + ) -> SpiceMessages: """Appends a message with the given role and pre-loaded prompt using jinja to render the context.""" if self._client is None: raise ValueError("Cannot add prompt without a Spice client.") @@ -107,20 +113,27 @@ def add_prompt(self, role: Role, name: str, cache: bool = False, **context: Any) SpiceMessage( role=role, text=self._client.get_rendered_prompt(name, **context), + name=message_name, cache=cache, prompt_metadata=PromptMetadata(name=name, content=self._client.get_prompt(name), context=context), ) ) return self - def add_user_prompt(self, name: str, cache: bool = False, **context: Any) -> SpiceMessages: - return self.add_prompt("user", name, cache, **context) + def add_user_prompt( + self, name: str, message_name: Optional[str] = None, cache: bool = False, **context: Any + ) -> SpiceMessages: + return self.add_prompt("user", name, message_name, cache, **context) - def add_system_prompt(self, name: str, cache: bool = False, **context: Any) -> SpiceMessages: - return self.add_prompt("system", name, cache, **context) + def add_system_prompt( + self, name: str, message_name: Optional[str] = None, cache: bool = False, **context: Any + ) -> SpiceMessages: + return self.add_prompt("system", name, message_name, cache, **context) - def add_assistant_prompt(self, name: str, cache: bool = False, **context: Any) -> SpiceMessages: - return self.add_prompt("assistant", name, cache, **context) + def add_assistant_prompt( + self, name: str, message_name: Optional[str] = None, cache: bool = False, **context: Any + ) -> SpiceMessages: + return self.add_prompt("assistant", name, message_name, cache, **context) def copy(self): new_copy = SpiceMessages(self._client, self) diff --git a/spice/wrapped_clients.py b/spice/wrapped_clients.py index e26277a..bcf3cfa 100644 --- a/spice/wrapped_clients.py +++ b/spice/wrapped_clients.py @@ -123,10 +123,19 @@ def _convert_messages(self, messages: Collection[SpiceMessage]) -> List[ChatComp converted_messages = [] for message in messages: content_part = _spice_message_to_openai_content_part(message) - if converted_messages and converted_messages[-1]["role"] == message.role: + if ( + converted_messages + and converted_messages[-1]["role"] == message.role + and ( + ("name" in converted_messages[-1]) == (message.name is not None) + and ("name" not in converted_messages[-1] or message.name == converted_messages[-1]["name"]) + ) + ): converted_messages[-1]["content"].append(content_part) else: converted_messages.append({"role": message.role, "content": [content_part]}) + if message.name is not None: + converted_messages[-1]["name"] = message.name return converted_messages @override