Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: fix batch race condition in FakeListChatModel #26924

Merged
merged 5 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions libs/core/langchain_core/language_models/fake_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig


class FakeMessagesListChatModel(BaseChatModel):
Expand Down Expand Up @@ -128,6 +129,33 @@ async def _astream(
def _identifying_params(self) -> dict[str, Any]:
return {"responses": self.responses}

# manually override batch to preserve batch ordering with no concurrency
def batch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)]
return [self.invoke(m, config, **kwargs) for m in inputs]

async def abatch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)]
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, config, **kwargs) for m in inputs]


class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
Expand Down
19 changes: 18 additions & 1 deletion libs/core/tests/unit_tests/fake/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from uuid import UUID

from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel
from langchain_core.language_models import (
FakeListChatModel,
GenericFakeChatModel,
ParrotFakeChatModel,
)
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import (
Expand Down Expand Up @@ -205,3 +209,16 @@ def test_chat_model_inputs() -> None:
assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message(
content="blah"
)


def test_fake_list_chat_model_batch() -> None:
expected = [
_any_id_ai_message(content="a"),
_any_id_ai_message(content="b"),
_any_id_ai_message(content="c"),
]
for _ in range(20):
# run this 20 times to test race condition in batch
fake = FakeListChatModel(responses=["a", "b", "c"])
resp = fake.batch(["1", "2", "3"])
assert resp == expected
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,13 @@ async def test_global_cache_abatch() -> None:
assert results[0].content == "hello"
assert results[1].content == "hello"

## RACE CONDITION -- note behavior is different from sync
# Now, reset cache and test the race condition
# For now we just hard-code the result, if this changes
# we can investigate further
global_cache = InMemoryCache()
set_llm_cache(global_cache)
assert global_cache._cache == {}
results = await chat_model.abatch(["prompt", "prompt"])
# suspecting that tasks will be scheduled and executed in order
# if this ever fails, we can relax to a set comparison
# Cache misses likely guaranteed?

assert results[0].content == "meow"
assert results[1].content == "woof"
assert results[1].content == "meow"
finally:
set_llm_cache(None)

Expand Down
Loading