Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: Fix llm string representation for serializable models #23416

Merged
merged 5 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import inspect
import json
import uuid
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -449,7 +450,11 @@ def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> st
if self.is_lc_serializable():
params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = dumps(self)
# This code is not super efficient as it goes back and forth between
# json and dict.
serialized_repr = dumpd(self)
_cleanup_llm_representation(serialized_repr, 1)
llm_string = json.dumps(serialized_repr, sort_keys=True)
return llm_string + "---" + param_string
else:
params = self._get_invocation_params(stop=stop, **kwargs)
Expand Down Expand Up @@ -1216,3 +1221,20 @@ def _gen_info_and_msg_metadata(
**(generation.generation_info or {}),
**generation.message.response_metadata,
}


def _cleanup_llm_representation(serialized: Any, depth: int) -> None:
"""Remove non-serializable objects from a serialized object."""
if depth > 100: # Don't cooperate for pathological cases
return
if serialized["type"] == "not_implemented" and "repr" in serialized:
del serialized["repr"]

if "graph" in serialized:
del serialized["graph"]

if "kwargs" in serialized:
kwargs = serialized["kwargs"]

for value in kwargs.values():
_cleanup_llm_representation(value, depth + 1)
135 changes: 135 additions & 0 deletions libs/core/tests/unit_tests/language_models/chat_models/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.globals import set_llm_cache
from langchain_core.language_models.chat_models import _cleanup_llm_representation
from langchain_core.language_models.fake_chat_models import (
FakeListChatModel,
GenericFakeChatModel,
Expand Down Expand Up @@ -266,3 +267,137 @@ def test_global_cache_stream() -> None:
assert global_cache._cache != {}
finally:
set_llm_cache(None)


class CustomChat(GenericFakeChatModel):
@classmethod
def is_lc_serializable(cls) -> bool:
return True


async def test_can_swap_caches() -> None:
"""Test that we can use a different cache object.

This test verifies that when we fetch teh llm_string representation
of the chat model, we can swap the cache object and still get the same
result.
"""
cache = InMemoryCache()
chat_model = CustomChat(cache=cache, messages=iter(["hello"]))
result = await chat_model.ainvoke("foo")
assert result.content == "hello"

new_cache = InMemoryCache()
new_cache._cache = cache._cache.copy()

# Confirm that we get a cache hit!
chat_model = CustomChat(cache=new_cache, messages=iter(["goodbye"]))
result = await chat_model.ainvoke("foo")
assert result.content == "hello"


def test_llm_representation_for_serializable() -> None:
"""Test that the llm representation of a serializable chat model is correct."""
cache = InMemoryCache()
chat = CustomChat(cache=cache, messages=iter([]))
assert chat._get_llm_string() == (
'{"id": ["tests", "unit_tests", "language_models", "chat_models", '
'"test_cache", "CustomChat"], "kwargs": {"cache": {"id": ["tests", '
'"unit_tests", "language_models", "chat_models", "test_cache", '
'"InMemoryCache"], "lc": 1, "type": "not_implemented"}, "messages": {"id": '
'["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}}, "lc": '
'1, "name": "CustomChat", "type": "constructor"}---[(\'stop\', None)]'
)


def test_cleanup_serialized() -> None:
cleanup_serialized = {
"lc": 1,
"type": "constructor",
"id": [
"tests",
"unit_tests",
"language_models",
"chat_models",
"test_cache",
"CustomChat",
],
"kwargs": {
"cache": {
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"language_models",
"chat_models",
"test_cache",
"InMemoryCache",
],
"repr": "<tests.unit_tests.language_models.chat_models."
"test_cache.InMemoryCache object at 0x79ff437fe7d0>",
},
"messages": {
"lc": 1,
"type": "not_implemented",
"id": ["builtins", "list_iterator"],
"repr": "<list_iterator object at 0x79ff437f8d30>",
},
},
"name": "CustomChat",
"graph": {
"nodes": [
{"id": 0, "type": "schema", "data": "CustomChatInput"},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"tests",
"unit_tests",
"language_models",
"chat_models",
"test_cache",
"CustomChat",
],
"name": "CustomChat",
},
},
{"id": 2, "type": "schema", "data": "CustomChatOutput"},
],
"edges": [{"source": 0, "target": 1}, {"source": 1, "target": 2}],
},
}
_cleanup_llm_representation(cleanup_serialized, 1)
assert cleanup_serialized == {
"id": [
"tests",
"unit_tests",
"language_models",
"chat_models",
"test_cache",
"CustomChat",
],
"kwargs": {
"cache": {
"id": [
"tests",
"unit_tests",
"language_models",
"chat_models",
"test_cache",
"InMemoryCache",
],
"lc": 1,
"type": "not_implemented",
},
"messages": {
"id": ["builtins", "list_iterator"],
"lc": 1,
"type": "not_implemented",
},
},
"lc": 1,
"name": "CustomChat",
"type": "constructor",
}
Loading