Skip to content

Commit

Permalink
Using langchain input types (#11204)
Browse files Browse the repository at this point in the history
Using langchain input type
  • Loading branch information
eyurtsev authored Sep 29, 2023
1 parent 77c7c9a commit 572968f
Show file tree
Hide file tree
Showing 9 changed files with 602 additions and 518 deletions.
38 changes: 20 additions & 18 deletions libs/langserve/langserve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import httpx
from langchain.callbacks.tracers.log_stream import RunLogPatch
from langchain.load.dump import dumpd
from langchain.load.load import load, loads
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import (
RunnableConfig,
Expand All @@ -19,6 +18,8 @@
)
from langchain.schema.runnable.utils import Input, Output

from langserve.serialization import simple_dumpd, simple_loads


def _without_callbacks(config: Optional[RunnableConfig]) -> RunnableConfig:
"""Evict callbacks from the config since those are definitely not supported."""
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
self.url = url
self.sync_client = httpx.Client(base_url=url, timeout=timeout)
self.async_client = httpx.AsyncClient(base_url=url, timeout=timeout)

# Register cleanup handler once RemoteRunnable is garbage collected
weakref.finalize(self, _close_clients, self.sync_client, self.async_client)

Expand All @@ -121,13 +123,13 @@ def _invoke(
response = self.sync_client.post(
"/invoke",
json={
"input": dumpd(input),
"input": simple_dumpd(input),
"config": _without_callbacks(config),
"kwargs": kwargs,
},
)
_raise_for_status(response)
return load(response.json())["output"]
return simple_loads(response.text)["output"]

def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand All @@ -142,13 +144,13 @@ async def _ainvoke(
response = await self.async_client.post(
"/invoke",
json={
"input": dumpd(input),
"input": simple_dumpd(input),
"config": _without_callbacks(config),
"kwargs": kwargs,
},
)
_raise_for_status(response)
return load(response.json())["output"]
return simple_loads(response.text)["output"]

async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand Down Expand Up @@ -180,13 +182,13 @@ def _batch(
response = self.sync_client.post(
"/batch",
json={
"inputs": dumpd(inputs),
"inputs": simple_dumpd(inputs),
"config": _config,
"kwargs": kwargs,
},
)
_raise_for_status(response)
return load(response.json())["output"]
return simple_loads(response.text)["output"]

def batch(
self,
Expand Down Expand Up @@ -222,13 +224,13 @@ async def _abatch(
response = await self.async_client.post(
"/batch",
json={
"inputs": dumpd(inputs),
"inputs": simple_dumpd(inputs),
"config": _config,
"kwargs": kwargs,
},
)
_raise_for_status(response)
return load(response.json())["output"]
return simple_loads(response.text)["output"]

async def abatch(
self,
Expand Down Expand Up @@ -259,11 +261,11 @@ def stream(

run_manager = callback_manager.on_chain_start(
dumpd(self),
dumpd(input),
simple_dumpd(input),
name=config.get("run_name"),
)
data = {
"input": dumpd(input),
"input": simple_dumpd(input),
"config": _without_callbacks(config),
"kwargs": kwargs,
}
Expand All @@ -283,7 +285,7 @@ def stream(
) as event_source:
for sse in event_source.iter_sse():
if sse.event == "data":
chunk = loads(sse.data)
chunk = simple_loads(sse.data)
yield chunk

if final_output:
Expand Down Expand Up @@ -313,11 +315,11 @@ async def astream(

run_manager = await callback_manager.on_chain_start(
dumpd(self),
dumpd(input),
simple_dumpd(input),
name=config.get("run_name"),
)
data = {
"input": dumpd(input),
"input": simple_dumpd(input),
"config": _without_callbacks(config),
"kwargs": kwargs,
}
Expand All @@ -334,7 +336,7 @@ async def astream(
) as event_source:
async for sse in event_source.aiter_sse():
if sse.event == "data":
chunk = loads(sse.data)
chunk = simple_loads(sse.data)
yield chunk

if final_output:
Expand Down Expand Up @@ -383,11 +385,11 @@ async def astream_log(

run_manager = await callback_manager.on_chain_start(
dumpd(self),
dumpd(input),
simple_dumpd(input),
name=config.get("run_name"),
)
data = {
"input": dumpd(input),
"input": simple_dumpd(input),
"config": _without_callbacks(config),
"kwargs": kwargs,
"include_names": include_names,
Expand All @@ -410,7 +412,7 @@ async def astream_log(
) as event_source:
async for sse in event_source.aiter_sse():
if sse.event == "data":
data = loads(sse.data)
data = simple_loads(sse.data)
chunk = RunLogPatch(*data["ops"])
yield chunk

Expand Down
94 changes: 94 additions & 0 deletions libs/langserve/langserve/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Serialization module for Well Known LangChain objects.
Specialized JSON serialization for well known LangChain objects that
can be expected to be frequently transmitted between chains.
"""
import json
from typing import Any, Union

from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete
from langchain.schema.document import Document
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from pydantic import BaseModel, ValidationError


class WellKnownLCObject(BaseModel):
"""A well known LangChain object."""

__root__: Union[
Document,
HumanMessage,
SystemMessage,
ChatMessage,
FunctionMessage,
AIMessage,
HumanMessageChunk,
SystemMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
AIMessageChunk,
StringPromptValue,
ChatPromptValueConcrete,
]


# Custom JSON Encoder
class _LangChainEncoder(json.JSONEncoder):
"""Custom JSON Encoder that can encode pydantic objects as well."""

def default(self, obj) -> Any:
if isinstance(obj, BaseModel):
return obj.dict()
return super().default(obj)


# Custom JSON Decoder
class _LangChainDecoder(json.JSONDecoder):
"""Custom JSON Decoder that handles well known LangChain objects."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the LangChainDecoder."""
super().__init__(object_hook=self.decoder, *args, **kwargs)

def decoder(self, value) -> Any:
"""Decode the value."""
if isinstance(value, dict):
try:
obj = WellKnownLCObject.parse_obj(value)
return obj.__root__
except ValidationError:
return {key: self.decoder(v) for key, v in value.items()}
elif isinstance(value, list):
return [self.decoder(item) for item in value]
else:
return value


# PUBLIC API


def simple_dumpd(obj: Any) -> Any:
"""Convert the given object to a JSON serializable object."""
return json.loads(json.dumps(obj, cls=_LangChainEncoder))


def simple_dumps(obj: Any) -> str:
"""Dump the given object as a JSON string."""
return json.dumps(obj, cls=_LangChainEncoder)


def simple_loads(s: str) -> Any:
"""Load the given JSON string."""
return json.loads(s, cls=_LangChainDecoder)
Loading

0 comments on commit 572968f

Please sign in to comment.