Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using langchain input types #11204

Merged
merged 10 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions libs/langserve/langserve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import httpx
from langchain.callbacks.tracers.log_stream import RunLogPatch
from langchain.load.dump import dumpd
from langchain.load.load import load, loads
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import (
RunnableConfig,
Expand All @@ -19,6 +18,8 @@
)
from langchain.schema.runnable.utils import Input, Output

from langserve.serialization import simple_dumpd, simple_loads


def _without_callbacks(config: Optional[RunnableConfig]) -> RunnableConfig:
"""Evict callbacks from the config since those are definitely not supported."""
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
self.url = url
self.sync_client = httpx.Client(base_url=url, timeout=timeout)
self.async_client = httpx.AsyncClient(base_url=url, timeout=timeout)

# Register cleanup handler once RemoteRunnable is garbage collected
weakref.finalize(self, _close_clients, self.sync_client, self.async_client)

Expand All @@ -121,13 +123,13 @@ def _invoke(
response = self.sync_client.post(
"/invoke",
json={
"input": dumpd(input),
"input": simple_dumpd(input),
"config": _without_callbacks(config),
"kwargs": kwargs,
},
)
_raise_for_status(response)
return load(response.json())["output"]
return simple_loads(response.text)["output"]

def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand All @@ -142,13 +144,13 @@ async def _ainvoke(
response = await self.async_client.post(
"/invoke",
json={
"input": dumpd(input),
"input": simple_dumpd(input),
"config": _without_callbacks(config),
"kwargs": kwargs,
},
)
_raise_for_status(response)
return load(response.json())["output"]
return simple_loads(response.text)["output"]

async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand Down Expand Up @@ -180,13 +182,13 @@ def _batch(
response = self.sync_client.post(
"/batch",
json={
"inputs": dumpd(inputs),
"inputs": simple_dumpd(inputs),
"config": _config,
"kwargs": kwargs,
},
)
_raise_for_status(response)
return load(response.json())["output"]
return simple_loads(response.text)["output"]

def batch(
self,
Expand Down Expand Up @@ -222,13 +224,13 @@ async def _abatch(
response = await self.async_client.post(
"/batch",
json={
"inputs": dumpd(inputs),
"inputs": simple_dumpd(inputs),
"config": _config,
"kwargs": kwargs,
},
)
_raise_for_status(response)
return load(response.json())["output"]
return simple_loads(response.text)["output"]

async def abatch(
self,
Expand Down Expand Up @@ -259,11 +261,11 @@ def stream(

run_manager = callback_manager.on_chain_start(
dumpd(self),
dumpd(input),
simple_dumpd(input),
name=config.get("run_name"),
)
data = {
"input": dumpd(input),
"input": simple_dumpd(input),
"config": _without_callbacks(config),
"kwargs": kwargs,
}
Expand All @@ -283,7 +285,7 @@ def stream(
) as event_source:
for sse in event_source.iter_sse():
if sse.event == "data":
chunk = loads(sse.data)
chunk = simple_loads(sse.data)
yield chunk

if final_output:
Expand Down Expand Up @@ -313,11 +315,11 @@ async def astream(

run_manager = await callback_manager.on_chain_start(
dumpd(self),
dumpd(input),
simple_dumpd(input),
name=config.get("run_name"),
)
data = {
"input": dumpd(input),
"input": simple_dumpd(input),
"config": _without_callbacks(config),
"kwargs": kwargs,
}
Expand All @@ -334,7 +336,7 @@ async def astream(
) as event_source:
async for sse in event_source.aiter_sse():
if sse.event == "data":
chunk = loads(sse.data)
chunk = simple_loads(sse.data)
yield chunk

if final_output:
Expand Down Expand Up @@ -383,11 +385,11 @@ async def astream_log(

run_manager = await callback_manager.on_chain_start(
dumpd(self),
dumpd(input),
simple_dumpd(input),
name=config.get("run_name"),
)
data = {
"input": dumpd(input),
"input": simple_dumpd(input),
"config": _without_callbacks(config),
"kwargs": kwargs,
"include_names": include_names,
Expand All @@ -410,7 +412,7 @@ async def astream_log(
) as event_source:
async for sse in event_source.aiter_sse():
if sse.event == "data":
data = loads(sse.data)
data = simple_loads(sse.data)
chunk = RunLogPatch(*data["ops"])
yield chunk

Expand Down
94 changes: 94 additions & 0 deletions libs/langserve/langserve/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Serialization module for Well Known LangChain objects.

Specialized JSON serialization for well known LangChain objects that
can be expected to be frequently transmitted between chains.
"""
import json
from typing import Any, Union

from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete
from langchain.schema.document import Document
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from pydantic import BaseModel, ValidationError


class WellKnownLCObject(BaseModel):
"""A well known LangChain object."""

__root__: Union[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Document,
HumanMessage,
SystemMessage,
ChatMessage,
FunctionMessage,
AIMessage,
HumanMessageChunk,
SystemMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
AIMessageChunk,
StringPromptValue,
ChatPromptValueConcrete,
]


# Custom JSON Encoder
class _LangChainEncoder(json.JSONEncoder):
"""Custom JSON Encoder that can encode pydantic objects as well."""

def default(self, obj) -> Any:
if isinstance(obj, BaseModel):
return obj.dict()
return super().default(obj)


# Custom JSON Decoder
class _LangChainDecoder(json.JSONDecoder):
"""Custom JSON Decoder that handles well known LangChain objects."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the LangChainDecoder."""
super().__init__(object_hook=self.decoder, *args, **kwargs)

def decoder(self, value) -> Any:
"""Decode the value."""
if isinstance(value, dict):
try:
obj = WellKnownLCObject.parse_obj(value)
return obj.__root__
except ValidationError:
return {key: self.decoder(v) for key, v in value.items()}
elif isinstance(value, list):
return [self.decoder(item) for item in value]
else:
return value


# PUBLIC API


def simple_dumpd(obj: Any) -> Any:
"""Convert the given object to a JSON serializable object."""
return json.loads(json.dumps(obj, cls=_LangChainEncoder))


def simple_dumps(obj: Any) -> str:
"""Dump the given object as a JSON string."""
return json.dumps(obj, cls=_LangChainEncoder)


def simple_loads(s: str) -> Any:
"""Load the given JSON string."""
return json.loads(s, cls=_LangChainDecoder)
Loading
Loading