Skip to content

Commit

Permalink
Create new RunnableSerializable base class in preparation for configu…
Browse files Browse the repository at this point in the history
…rable runnables (#11279)

- Also move RunnableBranch to its own file

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
  • Loading branch information
nfcampos authored Oct 2, 2023
2 parents 1cbe7f5 + c6a720f commit 0638f7b
Show file tree
Hide file tree
Showing 17 changed files with 607 additions and 509 deletions.
5 changes: 2 additions & 3 deletions libs/langchain/langchain/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import (
BaseModel,
Field,
Expand All @@ -30,7 +29,7 @@
validator,
)
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable

logger = logging.getLogger(__name__)

Expand All @@ -39,7 +38,7 @@ def _get_verbosity() -> bool:
return langchain.verbose


class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
"""Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/llms/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class FakeListLLM(LLM):
"""Fake LLM for testing purposes."""

responses: List
responses: List[str]
sleep: Optional[float] = None
i: int = 0

Expand Down
5 changes: 2 additions & 3 deletions libs/langchain/langchain/schema/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@

from typing_extensions import TypeAlias

from langchain.load.serializable import Serializable
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable
from langchain.schema.runnable import RunnableSerializable
from langchain.utils import get_pydantic_field_names

if TYPE_CHECKING:
Expand Down Expand Up @@ -54,7 +53,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:


class BaseLanguageModel(
Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC
RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC
):
"""Abstract base class for interfacing with language models.
Expand Down
11 changes: 6 additions & 5 deletions libs/langchain/langchain/schema/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from typing_extensions import get_args

from langchain.load.serializable import Serializable
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
from langchain.schema.output import (
ChatGeneration,
Expand All @@ -25,12 +24,12 @@
GenerationChunk,
)
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable

T = TypeVar("T")


class BaseLLMOutputParser(Serializable, Generic[T], ABC):
class BaseLLMOutputParser(Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model."""

@abstractmethod
Expand Down Expand Up @@ -63,7 +62,7 @@ async def aparse_result(


class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call."""

Expand Down Expand Up @@ -121,7 +120,9 @@ async def ainvoke(
)


class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
class BaseOutputParser(
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call.
Output parsers help structure language model responses.
Expand Down
5 changes: 2 additions & 3 deletions libs/langchain/langchain/schema/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@

import yaml

from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain.schema.document import Document
from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable


class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
"""Base class for all prompt templates, returning a prompt."""

input_variables: List[str]
Expand Down
5 changes: 2 additions & 3 deletions libs/langchain/langchain/schema/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema.document import Document
from langchain.schema.runnable import Runnable, RunnableConfig
from langchain.schema.runnable import RunnableConfig, RunnableSerializable

if TYPE_CHECKING:
from langchain.callbacks.manager import (
Expand All @@ -18,7 +17,7 @@
)


class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return
Expand Down
6 changes: 4 additions & 2 deletions libs/langchain/langchain/schema/runnable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from langchain.schema.runnable.base import (
Runnable,
RunnableBinding,
RunnableBranch,
RunnableLambda,
RunnableMap,
RunnableSequence,
RunnableWithFallbacks,
RunnableSerializable,
)
from langchain.schema.runnable.branch import RunnableBranch
from langchain.schema.runnable.config import RunnableConfig, patch_config
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable

Expand All @@ -19,6 +20,7 @@
"RouterInput",
"RouterRunnable",
"Runnable",
"RunnableSerializable",
"RunnableBinding",
"RunnableBranch",
"RunnableConfig",
Expand Down
5 changes: 2 additions & 3 deletions libs/langchain/langchain/schema/runnable/_locals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
Union,
)

from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.base import Input, Output, RunnableSerializable
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough

Expand Down Expand Up @@ -104,7 +103,7 @@ async def atransform(


class GetLocalVar(
Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
):
key: str
"""The key to extract from the local state."""
Expand Down
Loading

0 comments on commit 0638f7b

Please sign in to comment.