Skip to content

Commit

Permalink
Add type to message chunks (#11232)
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev authored Sep 29, 2023
1 parent fb66b39 commit 8b4cb4e
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 3 deletions.
30 changes: 28 additions & 2 deletions libs/langchain/langchain/schema/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class HumanMessage(BaseMessage):
"""

type: Literal["human"] = "human"
is_chunk: Literal[False] = False


HumanMessage.update_forward_refs()
Expand All @@ -157,7 +158,10 @@ class HumanMessage(BaseMessage):
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
"""A Human Message chunk."""

pass
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
is_chunk: Literal[True] = True # type: ignore[assignment]


class AIMessage(BaseMessage):
Expand All @@ -169,6 +173,7 @@ class AIMessage(BaseMessage):
"""

type: Literal["ai"] = "ai"
is_chunk: Literal[False] = False


AIMessage.update_forward_refs()
Expand All @@ -177,6 +182,11 @@ class AIMessage(BaseMessage):
class AIMessageChunk(AIMessage, BaseMessageChunk):
"""A Message chunk from an AI."""

# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
is_chunk: Literal[True] = True # type: ignore[assignment]

def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
Expand All @@ -201,6 +211,7 @@ class SystemMessage(BaseMessage):
"""

type: Literal["system"] = "system"
is_chunk: Literal[False] = False


SystemMessage.update_forward_refs()
Expand All @@ -209,7 +220,10 @@ class SystemMessage(BaseMessage):
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
"""A System Message chunk."""

pass
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
is_chunk: Literal[True] = True # type: ignore[assignment]


class FunctionMessage(BaseMessage):
Expand All @@ -219,6 +233,7 @@ class FunctionMessage(BaseMessage):
"""The name of the function that was executed."""

type: Literal["function"] = "function"
is_chunk: Literal[False] = False


FunctionMessage.update_forward_refs()
Expand All @@ -227,6 +242,11 @@ class FunctionMessage(BaseMessage):
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""A Function Message chunk."""

# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
is_chunk: Literal[True] = True # type: ignore[assignment]

def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, FunctionMessageChunk):
if self.name != other.name:
Expand All @@ -252,6 +272,7 @@ class ChatMessage(BaseMessage):
"""The speaker / role of the Message."""

type: Literal["chat"] = "chat"
is_chunk: Literal[False] = False


ChatMessage.update_forward_refs()
Expand All @@ -260,6 +281,11 @@ class ChatMessage(BaseMessage):
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""A Chat Message chunk."""

# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
is_chunk: Literal[True] = True # type: ignore[assignment]

def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ChatMessageChunk):
if self.role != other.role:
Expand Down
Loading

0 comments on commit 8b4cb4e

Please sign in to comment.