Skip to content

Commit

Permalink
Merge pull request #114 from AbanteAI/add-name
Browse files Browse the repository at this point in the history
Add name
  • Loading branch information
PCSwingle authored Sep 27, 2024
2 parents 385e967 + 12f8526 commit 92bd428
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 23 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.env
.venv
__pycache__
.idea
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ dependencies = [
"pillow",
"tiktoken",
"httpx",
"jinja2"
"jinja2",
"termcolor",
]

[project.optional-dependencies]
Expand All @@ -32,5 +33,4 @@ dev = [
"pyright",
"pytest",
"pytest-asyncio",
"termcolor",
]
53 changes: 33 additions & 20 deletions spice/spice_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class PromptMetadata(BaseModel):
class SpiceMessage(BaseModel):
role: Role
content: Union[TextContent, ImageContent]
name: Optional[str]
cache: bool = Field(default=False)
prompt_metadata: Optional[PromptMetadata]

Expand All @@ -45,6 +46,7 @@ def __init__(
role: Role,
text: Optional[str] = None,
image_url: Optional[str] = None,
name: Optional[str] = None,
cache: bool = False,
prompt_metadata: Optional[PromptMetadata] = None,
):
Expand All @@ -56,7 +58,7 @@ def __init__(
content = ImageContent(type="image_url", image_url=image_url)
else:
raise ValueError("Either text or image_url must be provided.")
super().__init__(role=role, content=content, cache=cache, prompt_metadata=prompt_metadata)
super().__init__(role=role, content=content, name=name, cache=cache, prompt_metadata=prompt_metadata)


class SpiceMessages(List[SpiceMessage]):
Expand All @@ -66,26 +68,28 @@ def __init__(self, client: Optional[Spice] = None, messages: Collection[SpiceMes
self._client = client
super().__init__(messages)

def add_text(self, role: Role, text: str, cache: bool = False) -> SpiceMessages:
self.append(SpiceMessage(role=role, text=text, cache=cache))
def add_text(self, role: Role, text: str, name: Optional[str] = None, cache: bool = False) -> SpiceMessages:
self.append(SpiceMessage(role=role, text=text, name=name, cache=cache))
return self

def add_user_text(self, text: str, cache: bool = False) -> SpiceMessages:
return self.add_text("user", text, cache)
def add_user_text(self, text: str, name: Optional[str] = None, cache: bool = False) -> SpiceMessages:
return self.add_text("user", text, name, cache)

def add_system_text(self, text: str, cache: bool = False) -> SpiceMessages:
return self.add_text("system", text, cache)
def add_system_text(self, text: str, name: Optional[str] = None, cache: bool = False) -> SpiceMessages:
return self.add_text("system", text, name, cache)

def add_assistant_text(self, text: str, cache: bool = False) -> SpiceMessages:
return self.add_text("assistant", text, cache)
def add_assistant_text(self, text: str, name: Optional[str] = None, cache: bool = False) -> SpiceMessages:
return self.add_text("assistant", text, name, cache)

def add_user_image_from_url(self, url: str, cache: bool = False) -> SpiceMessages:
def add_user_image_from_url(self, url: str, name: Optional[str] = None, cache: bool = False) -> SpiceMessages:
if not (url.startswith("http://") or url.startswith("https://")):
raise ImageError(f"Invalid image URL {url}: Must be http or https protocol.")
self.append(SpiceMessage(role="user", image_url=url, cache=cache))
self.append(SpiceMessage(role="user", image_url=url, name=name, cache=cache))
return self

def add_user_image_from_file(self, file_path: Path | str, cache: bool = False) -> SpiceMessages:
def add_user_image_from_file(
self, file_path: Path | str, name: Optional[str] = None, cache: bool = False
) -> SpiceMessages:
file_path = Path(file_path).expanduser().resolve()
if not file_path.exists():
raise ImageError(f"Invalid image at {file_path}: file does not exist.")
Expand All @@ -95,10 +99,12 @@ def add_user_image_from_file(self, file_path: Path | str, cache: bool = False) -
with file_path.open("rb") as file:
image_bytes = file.read()
image = base64.b64encode(image_bytes).decode("utf-8")
self.append(SpiceMessage(role="user", image_url=f"data:{media_type};base64,{image}", cache=cache))
self.append(SpiceMessage(role="user", image_url=f"data:{media_type};base64,{image}", name=name, cache=cache))
return self

def add_prompt(self, role: Role, name: str, cache: bool = False, **context: Any) -> SpiceMessages:
def add_prompt(
self, role: Role, name: str, message_name: Optional[str] = None, cache: bool = False, **context: Any
) -> SpiceMessages:
"""Appends a message with the given role and pre-loaded prompt using jinja to render the context."""
if self._client is None:
raise ValueError("Cannot add prompt without a Spice client.")
Expand All @@ -107,20 +113,27 @@ def add_prompt(self, role: Role, name: str, cache: bool = False, **context: Any)
SpiceMessage(
role=role,
text=self._client.get_rendered_prompt(name, **context),
name=message_name,
cache=cache,
prompt_metadata=PromptMetadata(name=name, content=self._client.get_prompt(name), context=context),
)
)
return self

def add_user_prompt(self, name: str, cache: bool = False, **context: Any) -> SpiceMessages:
return self.add_prompt("user", name, cache, **context)
def add_user_prompt(
self, name: str, message_name: Optional[str] = None, cache: bool = False, **context: Any
) -> SpiceMessages:
return self.add_prompt("user", name, message_name, cache, **context)

def add_system_prompt(self, name: str, cache: bool = False, **context: Any) -> SpiceMessages:
return self.add_prompt("system", name, cache, **context)
def add_system_prompt(
self, name: str, message_name: Optional[str] = None, cache: bool = False, **context: Any
) -> SpiceMessages:
return self.add_prompt("system", name, message_name, cache, **context)

def add_assistant_prompt(self, name: str, cache: bool = False, **context: Any) -> SpiceMessages:
return self.add_prompt("assistant", name, cache, **context)
def add_assistant_prompt(
self, name: str, message_name: Optional[str] = None, cache: bool = False, **context: Any
) -> SpiceMessages:
return self.add_prompt("assistant", name, message_name, cache, **context)

def copy(self):
new_copy = SpiceMessages(self._client, self)
Expand Down
11 changes: 10 additions & 1 deletion spice/wrapped_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,19 @@ def _convert_messages(self, messages: Collection[SpiceMessage]) -> List[ChatComp
converted_messages = []
for message in messages:
content_part = _spice_message_to_openai_content_part(message)
if converted_messages and converted_messages[-1]["role"] == message.role:
if (
converted_messages
and converted_messages[-1]["role"] == message.role
and (
("name" in converted_messages[-1]) == (message.name is not None)
and ("name" not in converted_messages[-1] or message.name == converted_messages[-1]["name"])
)
):
converted_messages[-1]["content"].append(content_part)
else:
converted_messages.append({"role": message.role, "content": [content_part]})
if message.name is not None:
converted_messages[-1]["name"] = message.name
return converted_messages

@override
Expand Down

0 comments on commit 92bd428

Please sign in to comment.