From 5e93649ccab50fc291fb8cb32012774c63eda2a1 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Thu, 26 Sep 2024 17:27:47 -0700 Subject: [PATCH 1/4] x --- .../language_models/fake_chat_models.py | 14 ++++++++++++++ .../unit_tests/fake/test_fake_chat_model.py | 19 ++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index 3c41c1d462f50..2b54604c3bedc 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -13,6 +13,7 @@ from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import RunnableConfig class FakeMessagesListChatModel(BaseChatModel): @@ -128,6 +129,19 @@ async def _astream( def _identifying_params(self) -> dict[str, Any]: return {"responses": self.responses} + # manually override batch to preserve batch ordering with no concurrency + def batch( + self, + inputs: list[Any], + config: RunnableConfig | list[RunnableConfig] | None = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> list[BaseMessage]: + if isinstance(config, list): + return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)] + return [self.invoke(m, config, **kwargs) for m in inputs] + class FakeChatModel(SimpleChatModel): """Fake Chat Model wrapper for testing purposes.""" diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 1829bf773ce3c..7667c57a13090 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -5,7 +5,11 @@ from uuid import UUID from langchain_core.callbacks.base import AsyncCallbackHandler -from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatModel +from langchain_core.language_models import ( + FakeListChatModel, + GenericFakeChatModel, + ParrotFakeChatModel, +) from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from tests.unit_tests.stubs import ( @@ -205,3 +209,16 @@ def test_chat_model_inputs() -> None: assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message( content="blah" ) + + +def test_fake_list_chat_model_batch() -> None: + expected = [ + _any_id_ai_message(content="a"), + _any_id_ai_message(content="b"), + _any_id_ai_message(content="c"), + ] + for _ in range(20): + # run this 20 times to test race condition in batch + fake = FakeListChatModel(responses=["a", "b", "c"]) + resp = fake.batch(["1", "2", "3"]) + assert resp == expected, [resp.content for resp in resp] From 420430ed171ccb7688fe0298d314a76eddb1269f Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Thu, 26 Sep 2024 17:29:03 -0700 Subject: [PATCH 2/4] x --- libs/core/tests/unit_tests/fake/test_fake_chat_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 7667c57a13090..7502e17c50fde 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -221,4 +221,4 @@ def test_fake_list_chat_model_batch() -> None: # run this 20 times to test race condition in batch fake = FakeListChatModel(responses=["a", "b", "c"]) resp = fake.batch(["1", "2", "3"]) - assert resp == expected, [resp.content for resp in resp] + assert resp == expected From d83f7f3508637d27702df6b419cbd60a6f2d2351 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Thu, 26 Sep 2024 17:35:46 -0700 Subject: [PATCH 3/4] x --- .../language_models/fake_chat_models.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index 2b54604c3bedc..d8091c3e51c3a 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -133,7 +133,7 @@ def _identifying_params(self) -> dict[str, Any]: def batch( self, inputs: list[Any], - config: RunnableConfig | list[RunnableConfig] | None = None, + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, @@ -142,6 +142,20 @@ def batch( return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)] return [self.invoke(m, config, **kwargs) for m in inputs] + async def abatch( + self, + inputs: list[Any], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> list[BaseMessage]: + if isinstance(config, list): + # do Not use an async iterator here because need explicit ordering + return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)] + # do Not use an async iterator here because need explicit ordering + return [await self.ainvoke(m, config, **kwargs) for m in inputs] + class FakeChatModel(SimpleChatModel): """Fake Chat Model wrapper for testing purposes.""" From b0861303a32fd2516ec96ab9a624d60af6cd8232 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Thu, 3 Oct 2024 16:10:49 -0700 Subject: [PATCH 4/4] x --- .../language_models/chat_models/test_cache.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py index d2455f7a61977..0d5b89de7c354 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -199,19 +199,13 @@ async def test_global_cache_abatch() -> None: assert results[0].content == "hello" assert results[1].content == "hello" - ## RACE CONDITION -- note behavior is different from sync - # Now, reset cache and test the race condition - # For now we just hard-code the result, if this changes - # we can investigate further global_cache = InMemoryCache() set_llm_cache(global_cache) assert global_cache._cache == {} results = await chat_model.abatch(["prompt", "prompt"]) - # suspecting that tasks will be scheduled and executed in order - # if this ever fails, we can relax to a set comparison - # Cache misses likely guaranteed? + assert results[0].content == "meow" - assert results[1].content == "woof" + assert results[1].content == "meow" finally: set_llm_cache(None)