Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add input/output schemas to runnables #11063

Merged
merged 17 commits into from
Sep 28, 2023
24 changes: 22 additions & 2 deletions libs/langchain/langchain/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Type, Union

import yaml

Expand All @@ -22,7 +22,13 @@
)
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field, root_validator, validator
from langchain.pydantic_v1 import (
BaseModel,
Field,
create_model,
root_validator,
validator,
)
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.schema.runnable import Runnable, RunnableConfig

Expand Down Expand Up @@ -56,6 +62,20 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
chains and cannot return as rich of an output as `__call__`.
"""

@property
def input_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys}
)

@property
def output_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
)

def invoke(
self,
input: Dict[str, Any],
Expand Down
29 changes: 27 additions & 2 deletions libs/langchain/langchain/chains/combine_documents/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Base interface for chains combining documents."""

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type

from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.pydantic_v1 import Field
from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter


Expand All @@ -28,6 +28,20 @@ class BaseCombineDocumentsChain(Chain, ABC):
input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private:

@property
def input_schema(self) -> Type[BaseModel]:
return create_model(
"CombineDocumentsInput",
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
)

@property
def output_schema(self) -> Type[BaseModel]:
return create_model(
"CombineDocumentsOutput",
**{self.output_key: (str, None)}, # type: ignore[call-overload]
)

@property
def input_keys(self) -> List[str]:
"""Expect input key.
Expand Down Expand Up @@ -153,6 +167,17 @@ def output_keys(self) -> List[str]:
"""
return self.combine_docs_chain.output_keys

@property
def input_schema(self) -> Type[BaseModel]:
return create_model(
"AnalyzeDocumentChain",
**{self.input_key: (str, None)}, # type: ignore[call-overload]
)

@property
def output_schema(self) -> Type[BaseModel]:
return self.combine_docs_chain.output_schema

def _call(
self,
inputs: Dict[str, str],
Expand Down
15 changes: 14 additions & 1 deletion libs/langchain/langchain/chains/combine_documents/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.pydantic_v1 import Extra, root_validator
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator


class MapReduceDocumentsChain(BaseCombineDocumentsChain):
Expand Down Expand Up @@ -98,6 +98,19 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return_intermediate_steps: bool = False
"""Return the results of the map steps in the output."""

@property
def output_schema(self) -> type[BaseModel]:
if self.return_intermediate_steps:
return create_model(
"MapReduceDocumentsOutput",
**{
self.output_key: (str, None),
"intermediate_steps": (List[str], None),
}, # type: ignore[call-overload]
)

return super().output_schema

@property
def output_keys(self) -> List[str]:
"""Expect input key.
Expand Down
14 changes: 13 additions & 1 deletion libs/langchain/langchain/chains/combine_documents/map_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser
from langchain.pydantic_v1 import Extra, root_validator
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator


class MapRerankDocumentsChain(BaseCombineDocumentsChain):
Expand Down Expand Up @@ -77,6 +77,18 @@ class Config:
extra = Extra.forbid
arbitrary_types_allowed = True

@property
def output_schema(self) -> type[BaseModel]:
schema: Dict[str, Any] = {
self.output_key: (str, None),
}
if self.return_intermediate_steps:
schema["intermediate_steps"] = (List[str], None)
if self.metadata_keys:
schema.update({key: (Any, None) for key in self.metadata_keys})

return create_model("MapRerankOutput", **schema)

@property
def output_keys(self) -> List[str]:
"""Expect input key.
Expand Down
17 changes: 17 additions & 0 deletions libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
List,
Optional,
Sequence,
Union,
cast,
)

Expand All @@ -37,9 +38,14 @@
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
Expand Down Expand Up @@ -107,6 +113,17 @@ class Config:

# --- Runnable methods ---

@property
def OutputType(self) -> Any:
"""Get the input type for this runnable."""
return Union[
HumanMessageChunk,
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
SystemMessageChunk,
]

def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
Expand Down
35 changes: 2 additions & 33 deletions libs/langchain/langchain/chat_models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
Expand All @@ -52,39 +54,6 @@ class ChatLiteLLMException(Exception):
"""Error with the `LiteLLM I/O` library"""


def _truncate_at_stop_tokens(
text: str,
stop: Optional[List[str]],
) -> str:
"""Truncates text at the earliest stop token found."""
if stop is None:
return text

for stop_token in stop:
stop_token_idx = text.find(stop_token)
if stop_token_idx != -1:
text = text[:stop_token_idx]
return text


class FunctionMessage(BaseMessage):
"""Message for passing the result of executing a function back to a model."""

name: str
"""The name of the function that was executed."""

@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "function"


class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""Message Chunk for passing the result of executing a function back to a model."""

pass


def _create_retry_decorator(
llm: ChatLiteLLM,
run_manager: Optional[
Expand Down
5 changes: 5 additions & 0 deletions libs/langchain/langchain/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def set_verbose(cls, verbose: Optional[bool]) -> bool:

# --- Runnable methods ---

@property
def OutputType(self) -> Type[str]:
"""Get the input type for this runnable."""
return str

def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
Expand Down
12 changes: 10 additions & 2 deletions libs/langchain/langchain/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from langchain.schema.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ChatMessage,
HumanMessage,
Expand Down Expand Up @@ -280,7 +281,7 @@ class ChatPromptValue(PromptValue):
A type of a prompt value that is built from messages.
"""

messages: List[BaseMessage]
messages: Sequence[BaseMessage]
"""List of messages."""

def to_string(self) -> str:
Expand All @@ -289,7 +290,14 @@ def to_string(self) -> str:

def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of messages."""
return self.messages
return list(self.messages)


class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""

messages: Sequence[AnyMessage]


class BaseChatPromptTemplate(BasePromptTemplate, ABC):
Expand Down
19 changes: 18 additions & 1 deletion libs/langchain/langchain/schema/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
Union,
)

from typing_extensions import TypeAlias

from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable
Expand Down Expand Up @@ -70,6 +72,21 @@ class BaseLanguageModel(
Each of these has an equivalent asynchronous method.
"""

@property
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete

# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage],
]

@abstractmethod
def generate_prompt(
self,
Expand Down
Loading
Loading