From 3a1b9259a7e11ed2221b9a1e9a1b1e55d1d78752 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 25 Sep 2024 15:34:17 +0200 Subject: [PATCH] core: Add ruff rules for comprehensions (C4) (#26829) --- .../langchain_core/beta/runnables/context.py | 6 +- .../language_models/chat_models.py | 4 +- .../langchain_core/language_models/llms.py | 4 +- libs/core/langchain_core/load/load.py | 6 +- libs/core/langchain_core/load/serializable.py | 4 +- .../langchain_core/output_parsers/json.py | 2 +- .../langchain_core/output_parsers/pydantic.py | 2 +- .../core/langchain_core/outputs/llm_result.py | 2 +- libs/core/langchain_core/prompts/chat.py | 10 +- libs/core/langchain_core/prompts/image.py | 2 +- libs/core/langchain_core/prompts/prompt.py | 2 +- libs/core/langchain_core/runnables/base.py | 2 +- .../langchain_core/runnables/fallbacks.py | 4 +- libs/core/langchain_core/runnables/retry.py | 2 +- libs/core/langchain_core/sys_info.py | 4 +- libs/core/langchain_core/tools/base.py | 10 +- .../langchain_core/utils/function_calling.py | 4 +- libs/core/pyproject.toml | 2 +- .../document_loaders/test_langsmith.py | 2 +- .../unit_tests/fake/test_fake_chat_model.py | 2 +- .../unit_tests/indexing/test_indexing.py | 32 +-- .../language_models/chat_models/test_base.py | 12 +- .../language_models/chat_models/test_cache.py | 2 +- .../language_models/llms/test_base.py | 12 +- .../output_parsers/test_base_parsers.py | 2 +- .../unit_tests/output_parsers/test_json.py | 4 +- .../unit_tests/runnables/test_context.py | 2 +- .../unit_tests/runnables/test_fallbacks.py | 2 +- .../unit_tests/runnables/test_runnable.py | 204 +++++++++--------- .../runnables/test_runnable_events_v1.py | 4 +- .../runnables/test_runnable_events_v2.py | 4 +- libs/core/tests/unit_tests/test_messages.py | 116 +++++----- .../tracers/test_async_base_tracer.py | 26 +-- .../unit_tests/tracers/test_base_tracer.py | 26 +-- 34 files changed, 259 insertions(+), 265 deletions(-) diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index 70cc1fae324d4..36222249d16ab 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -86,9 +86,9 @@ def _config_with_context( ) } deps_by_key = { - key: set( + key: { _key_from_id(dep) for spec in group for dep in (spec[0].dependencies or []) - ) + } for key, group in grouped_by_key.items() } @@ -198,7 +198,7 @@ async def ainvoke( configurable = config.get("configurable", {}) if isinstance(self.key, list): values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids)) - return {key: value for key, value in zip(self.key, values)} + return dict(zip(self.key, values)) else: return await configurable[self.ids[0]]() diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index f322383065f02..6336dfbab8316 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -551,7 +551,7 @@ def _get_ls_params( def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str: if self.is_lc_serializable(): params = {**kwargs, **{"stop": stop}} - param_string = str(sorted([(k, v) for k, v in params.items()])) + param_string = str(sorted(params.items())) # This code is not super efficient as it goes back and forth between # json and dict. serialized_repr = self._serialized @@ -561,7 +561,7 @@ def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> st else: params = self._get_invocation_params(stop=stop, **kwargs) params = {**params, **kwargs} - return str(sorted([(k, v) for k, v in params.items()])) + return str(sorted(params.items())) def generate( self, diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 7eeb679b7a2aa..af2b5ef0d8b54 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -166,7 +166,7 @@ def get_prompts( Raises: ValueError: If the cache is not set and cache is True. """ - llm_string = str(sorted([(k, v) for k, v in params.items()])) + llm_string = str(sorted(params.items())) missing_prompts = [] missing_prompt_idxs = [] existing_prompts = {} @@ -202,7 +202,7 @@ async def aget_prompts( Raises: ValueError: If the cache is not set and cache is True. """ - llm_string = str(sorted([(k, v) for k, v in params.items()])) + llm_string = str(sorted(params.items())) missing_prompts = [] missing_prompt_idxs = [] existing_prompts = {} diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py index 0fe91975c469b..d050d4696c94c 100644 --- a/libs/core/langchain_core/load/load.py +++ b/libs/core/langchain_core/load/load.py @@ -67,14 +67,14 @@ def __init__( Defaults to None. """ self.secrets_from_env = secrets_from_env - self.secrets_map = secrets_map or dict() + self.secrets_map = secrets_map or {} # By default, only support langchain, but user can pass in additional namespaces self.valid_namespaces = ( [*DEFAULT_NAMESPACES, *valid_namespaces] if valid_namespaces else DEFAULT_NAMESPACES ) - self.additional_import_mappings = additional_import_mappings or dict() + self.additional_import_mappings = additional_import_mappings or {} self.import_mappings = ( { **ALL_SERIALIZABLE_MAPPINGS, @@ -146,7 +146,7 @@ def __call__(self, value: dict[str, Any]) -> Any: # We don't need to recurse on kwargs # as json.loads will do that for us. - kwargs = value.get("kwargs", dict()) + kwargs = value.get("kwargs", {}) return cls(**kwargs) return value diff --git a/libs/core/langchain_core/load/serializable.py b/libs/core/langchain_core/load/serializable.py index f42ca89211644..9158c1e5b8be5 100644 --- a/libs/core/langchain_core/load/serializable.py +++ b/libs/core/langchain_core/load/serializable.py @@ -138,7 +138,7 @@ def lc_secrets(self) -> dict[str, str]: For example, {"openai_api_key": "OPENAI_API_KEY"} """ - return dict() + return {} @property def lc_attributes(self) -> dict: @@ -188,7 +188,7 @@ def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]: if not self.is_lc_serializable(): return self.to_json_not_implemented() - secrets = dict() + secrets = {} # Get latest values for kwargs if there is an attribute with same name lc_kwargs = {} for k, v in self: diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index fdb7c38fc2cd2..ff174767f1cd3 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -108,7 +108,7 @@ def get_format_instructions(self) -> str: return "Return a JSON object." else: # Copy schema to avoid altering original Pydantic schema. - schema = {k: v for k, v in self._get_schema(self.pydantic_object).items()} + schema = dict(self._get_schema(self.pydantic_object).items()) # Remove extraneous fields. reduced_schema = schema diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index 23c0dcf90a8b6..fb6e3dcd71786 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -90,7 +90,7 @@ def get_format_instructions(self) -> str: The format instructions for the JSON output. """ # Copy schema to avoid altering original Pydantic schema. - schema = {k: v for k, v in self.pydantic_object.model_json_schema().items()} + schema = dict(self.pydantic_object.model_json_schema().items()) # Remove extraneous fields. reduced_schema = schema diff --git a/libs/core/langchain_core/outputs/llm_result.py b/libs/core/langchain_core/outputs/llm_result.py index 8a429616dcf72..4430fae133f12 100644 --- a/libs/core/langchain_core/outputs/llm_result.py +++ b/libs/core/langchain_core/outputs/llm_result.py @@ -76,7 +76,7 @@ def flatten(self) -> list[LLMResult]: else: if self.llm_output is not None: llm_output = deepcopy(self.llm_output) - llm_output["token_usage"] = dict() + llm_output["token_usage"] = {} else: llm_output = None llm_results.append( diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 4c5dc492af07d..ab4a401057297 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -1007,11 +1007,11 @@ def __init__( input_vars.update(_message.input_variables) kwargs = { - **dict( - input_variables=sorted(input_vars), - optional_variables=sorted(optional_variables), - partial_variables=partial_vars, - ), + **{ + "input_variables": sorted(input_vars), + "optional_variables": sorted(optional_variables), + "partial_variables": partial_vars, + }, **kwargs, } cast(type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs) diff --git a/libs/core/langchain_core/prompts/image.py b/libs/core/langchain_core/prompts/image.py index f28b2ecb1a50f..c898dac31cf3f 100644 --- a/libs/core/langchain_core/prompts/image.py +++ b/libs/core/langchain_core/prompts/image.py @@ -18,7 +18,7 @@ def __init__(self, **kwargs: Any) -> None: if "input_variables" not in kwargs: kwargs["input_variables"] = [] - overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail")) + overlap = set(kwargs["input_variables"]) & {"url", "path", "detail"} if overlap: raise ValueError( "input_variables for the image template cannot contain" diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index 659bf1de258ea..7008be5883938 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -144,7 +144,7 @@ def __add__(self, other: Any) -> PromptTemplate: template = self.template + other.template # If any do not want to validate, then don't validate_template = self.validate_template and other.validate_template - partial_variables = {k: v for k, v in self.partial_variables.items()} + partial_variables = dict(self.partial_variables.items()) for k, v in other.partial_variables.items(): if k in partial_variables: raise ValueError("Cannot have same variable partialed twice.") diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 3c7262fb949c1..4aa8681a3913b 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -3778,7 +3778,7 @@ async def _ainvoke_step( for key, step in steps.items() ) ) - output = {key: value for key, value in zip(steps, results)} + output = dict(zip(steps, results)) # finish the root run except BaseException as e: await run_manager.on_chain_error(e) diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index 1d8e19bb034fa..d331e3567e820 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -294,7 +294,7 @@ def batch( ] to_return: dict[int, Any] = {} - run_again = {i: input for i, input in enumerate(inputs)} + run_again = dict(enumerate(inputs)) handled_exceptions: dict[int, BaseException] = {} first_to_raise = None for runnable in self.runnables: @@ -388,7 +388,7 @@ async def abatch( ) to_return = {} - run_again = {i: input for i, input in enumerate(inputs)} + run_again = dict(enumerate(inputs)) handled_exceptions: dict[int, BaseException] = {} first_to_raise = None for runnable in self.runnables: diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index af93fac8753d6..11269417a4011 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -117,7 +117,7 @@ def get_lc_namespace(cls) -> list[str]: @property def _kwargs_retrying(self) -> dict[str, Any]: - kwargs: dict[str, Any] = dict() + kwargs: dict[str, Any] = {} if self.max_attempt_number: kwargs["stop"] = stop_after_attempt(self.max_attempt_number) diff --git a/libs/core/langchain_core/sys_info.py b/libs/core/langchain_core/sys_info.py index eacf6730fde17..f70df1f631864 100644 --- a/libs/core/langchain_core/sys_info.py +++ b/libs/core/langchain_core/sys_info.py @@ -10,7 +10,7 @@ def _get_sub_deps(packages: Sequence[str]) -> list[str]: from importlib import metadata sub_deps = set() - _underscored_packages = set(pkg.replace("-", "_") for pkg in packages) + _underscored_packages = {pkg.replace("-", "_") for pkg in packages} for pkg in packages: try: @@ -33,7 +33,7 @@ def _get_sub_deps(packages: Sequence[str]) -> list[str]: return sorted(sub_deps, key=lambda x: x.lower()) -def print_sys_info(*, additional_pkgs: Sequence[str] = tuple()) -> None: +def print_sys_info(*, additional_pkgs: Sequence[str] = ()) -> None: """Print information about the environment for debugging purposes. Args: diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index b540767ab6579..1fc670cc27784 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -975,7 +975,7 @@ def _get_all_basemodel_annotations( ) and name not in fields: continue annotations[name] = param.annotation - orig_bases: tuple = getattr(cls, "__orig_bases__", tuple()) + orig_bases: tuple = getattr(cls, "__orig_bases__", ()) # cls has subscript: cls = FooBar[int] else: annotations = _get_all_basemodel_annotations( @@ -1007,11 +1007,9 @@ def _get_all_basemodel_annotations( # parent_origin = Baz, # generic_type_vars = (type vars in Baz) # generic_map = {type var in Baz: str} - generic_type_vars: tuple = getattr(parent_origin, "__parameters__", tuple()) - generic_map = { - type_var: t for type_var, t in zip(generic_type_vars, get_args(parent)) - } - for field in getattr(parent_origin, "__annotations__", dict()): + generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ()) + generic_map = dict(zip(generic_type_vars, get_args(parent))) + for field in getattr(parent_origin, "__annotations__", {}): annotations[field] = _replace_type_vars( annotations[field], generic_map, default_to_bound ) diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 6694b1f5bd11e..f88eb40b382a2 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -233,9 +233,7 @@ def _convert_any_typed_dicts_to_pydantic( new_arg_type = _convert_any_typed_dicts_to_pydantic( annotated_args[0], depth=depth + 1, visited=visited ) - field_kwargs = { - k: v for k, v in zip(("default", "description"), annotated_args[1:]) - } + field_kwargs = dict(zip(("default", "description"), annotated_args[1:])) if (field_desc := field_kwargs.get("description")) and not isinstance( field_desc, str ): diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index dab9b90b49180..5b98114ad6a7c 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -44,7 +44,7 @@ python = ">=3.12.4" [tool.poetry.extras] [tool.ruff.lint] -select = [ "B", "E", "F", "I", "N", "T201", "UP",] +select = [ "B", "C4", "E", "F", "I", "N", "T201", "UP",] ignore = [ "UP007",] [tool.coverage.run] diff --git a/libs/core/tests/unit_tests/document_loaders/test_langsmith.py b/libs/core/tests/unit_tests/document_loaders/test_langsmith.py index e754ab2d37220..6c4c7a54170b5 100644 --- a/libs/core/tests/unit_tests/document_loaders/test_langsmith.py +++ b/libs/core/tests/unit_tests/document_loaders/test_langsmith.py @@ -54,5 +54,5 @@ def test_lazy_load() -> None: expected.append( Document(example.inputs["first"]["second"].upper(), metadata=metadata) ) - actual = [doc for doc in loader.lazy_load()] + actual = list(loader.lazy_load()) assert expected == actual diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 2e92cf2f18ffb..1829bf773ce3c 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -55,7 +55,7 @@ async def test_generic_fake_chat_model_stream() -> None: ] assert len({chunk.id for chunk in chunks}) == 1 - chunks = [chunk for chunk in model.stream("meow")] + chunks = list(model.stream("meow")) assert chunks == [ _any_id_ai_message_chunk(content="hello"), _any_id_ai_message_chunk(content=" "), diff --git a/libs/core/tests/unit_tests/indexing/test_indexing.py b/libs/core/tests/unit_tests/indexing/test_indexing.py index 5bbbf49e392e1..0909176140602 100644 --- a/libs/core/tests/unit_tests/indexing/test_indexing.py +++ b/libs/core/tests/unit_tests/indexing/test_indexing.py @@ -185,11 +185,11 @@ def test_index_simple_delete_full( ): indexing_result = index(loader, record_manager, vector_store, cleanup="full") - doc_texts = set( + doc_texts = { # Ignoring type since doc should be in the store and not a None vector_store.get_by_ids([uid])[0].page_content # type: ignore for uid in vector_store.store - ) + } assert doc_texts == {"mutated document 1", "This is another document."} assert indexing_result == { @@ -267,11 +267,11 @@ async def test_aindex_simple_delete_full( "num_updated": 0, } - doc_texts = set( + doc_texts = { # Ignoring type since doc should be in the store and not a None vector_store.get_by_ids([uid])[0].page_content # type: ignore for uid in vector_store.store - ) + } assert doc_texts == {"mutated document 1", "This is another document."} # Attempt to index again verify that nothing changes @@ -558,11 +558,11 @@ def test_incremental_delete( "num_updated": 0, } - doc_texts = set( + doc_texts = { # Ignoring type since doc should be in the store and not a None vector_store.get_by_ids([uid])[0].page_content # type: ignore for uid in vector_store.store - ) + } assert doc_texts == {"This is another document.", "This is a test document."} # Attempt to index again verify that nothing changes @@ -617,11 +617,11 @@ def test_incremental_delete( "num_updated": 0, } - doc_texts = set( + doc_texts = { # Ignoring type since doc should be in the store and not a None vector_store.get_by_ids([uid])[0].page_content # type: ignore for uid in vector_store.store - ) + } assert doc_texts == { "mutated document 1", "mutated document 2", @@ -685,11 +685,11 @@ def test_incremental_indexing_with_batch_size( "num_updated": 0, } - doc_texts = set( + doc_texts = { # Ignoring type since doc should be in the store and not a None vector_store.get_by_ids([uid])[0].page_content # type: ignore for uid in vector_store.store - ) + } assert doc_texts == {"1", "2", "3", "4"} @@ -735,11 +735,11 @@ def test_incremental_delete_with_batch_size( "num_updated": 0, } - doc_texts = set( + doc_texts = { # Ignoring type since doc should be in the store and not a None vector_store.get_by_ids([uid])[0].page_content # type: ignore for uid in vector_store.store - ) + } assert doc_texts == {"1", "2", "3", "4"} # Attempt to index again verify that nothing changes @@ -880,11 +880,11 @@ async def test_aincremental_delete( "num_updated": 0, } - doc_texts = set( + doc_texts = { # Ignoring type since doc should be in the store and not a None vector_store.get_by_ids([uid])[0].page_content # type: ignore for uid in vector_store.store - ) + } assert doc_texts == {"This is another document.", "This is a test document."} # Attempt to index again verify that nothing changes @@ -939,11 +939,11 @@ async def test_aincremental_delete( "num_updated": 0, } - doc_texts = set( + doc_texts = { # Ignoring type since doc should be in the store and not a None vector_store.get_by_ids([uid])[0].page_content # type: ignore for uid in vector_store.store - ) + } assert doc_texts == { "mutated document 1", "mutated document 2", diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index f8aab17f930eb..99d74e47cc91c 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -53,10 +53,10 @@ def test_batch_size(messages: list, messages_2: list) -> None: with collect_runs() as cb: llm.batch([messages, messages_2], {"callbacks": [cb]}) assert len(cb.traced_runs) == 2 - assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) + assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs) with collect_runs() as cb: llm.batch([messages], {"callbacks": [cb]}) - assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) + assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs) assert len(cb.traced_runs) == 1 with collect_runs() as cb: @@ -76,11 +76,11 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None: # so we expect batch_size to always be 1 with collect_runs() as cb: await llm.abatch([messages, messages_2], {"callbacks": [cb]}) - assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) + assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs) assert len(cb.traced_runs) == 2 with collect_runs() as cb: await llm.abatch([messages], {"callbacks": [cb]}) - assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) + assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs) assert len(cb.traced_runs) == 1 with collect_runs() as cb: @@ -146,7 +146,7 @@ def _llm_type(self) -> str: return "fake-chat-model" model = ModelWithGenerate() - chunks = [chunk for chunk in model.stream("anything")] + chunks = list(model.stream("anything")) assert chunks == [_any_id_ai_message(content="hello")] chunks = [chunk async for chunk in model.astream("anything")] @@ -183,7 +183,7 @@ def _llm_type(self) -> str: return "fake-chat-model" model = ModelWithSyncStream() - chunks = [chunk for chunk in model.stream("anything")] + chunks = list(model.stream("anything")) assert chunks == [ _any_id_ai_message_chunk(content="a"), _any_id_ai_message_chunk(content="b"), diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py index 14c50a854ba47..3992420621ab4 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -262,7 +262,7 @@ def test_global_cache_stream() -> None: AIMessage(content="goodbye world"), ] model = GenericFakeChatModel(messages=iter(messages), cache=True) - chunks = [chunk for chunk in model.stream("some input")] + chunks = list(model.stream("some input")) assert len(chunks) == 3 # Assert that streaming information gets cached assert global_cache._cache != {} diff --git a/libs/core/tests/unit_tests/language_models/llms/test_base.py b/libs/core/tests/unit_tests/language_models/llms/test_base.py index 273bf9e0b8b28..e610838d69509 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_base.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_base.py @@ -40,12 +40,12 @@ def test_batch_size() -> None: llm = FakeListLLM(responses=["foo"] * 3) with collect_runs() as cb: llm.batch(["foo", "bar", "foo"], {"callbacks": [cb]}) - assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs]) + assert all((r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs) assert len(cb.traced_runs) == 3 llm = FakeListLLM(responses=["foo"]) with collect_runs() as cb: llm.batch(["foo"], {"callbacks": [cb]}) - assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) + assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs) assert len(cb.traced_runs) == 1 llm = FakeListLLM(responses=["foo"]) @@ -71,12 +71,12 @@ async def test_async_batch_size() -> None: llm = FakeListLLM(responses=["foo"] * 3) with collect_runs() as cb: await llm.abatch(["foo", "bar", "foo"], {"callbacks": [cb]}) - assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs]) + assert all((r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs) assert len(cb.traced_runs) == 3 llm = FakeListLLM(responses=["foo"]) with collect_runs() as cb: await llm.abatch(["foo"], {"callbacks": [cb]}) - assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs]) + assert all((r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs) assert len(cb.traced_runs) == 1 llm = FakeListLLM(responses=["foo"]) @@ -142,7 +142,7 @@ def _llm_type(self) -> str: return "fake-chat-model" model = ModelWithGenerate() - chunks = [chunk for chunk in model.stream("anything")] + chunks = list(model.stream("anything")) assert chunks == ["hello"] chunks = [chunk async for chunk in model.astream("anything")] @@ -179,7 +179,7 @@ def _llm_type(self) -> str: return "fake-chat-model" model = ModelWithSyncStream() - chunks = [chunk for chunk in model.stream("anything")] + chunks = list(model.stream("anything")) assert chunks == ["a", "b"] assert type(model)._astream == BaseLLM._astream astream_chunks = [chunk async for chunk in model.astream("anything")] diff --git a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py index 65e4f580862f7..a883eabc35a99 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py +++ b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py @@ -93,5 +93,5 @@ def parse_result( model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")])) chain = model | StrInvertCase() # inputs to models are ignored, response is hard-coded in model definition - chunks = [chunk for chunk in chain.stream("")] + chunks = list(chain.stream("")) assert chunks == ["HELLO", " ", "WORLD"] diff --git a/libs/core/tests/unit_tests/output_parsers/test_json.py b/libs/core/tests/unit_tests/output_parsers/test_json.py index 9753ff98ca385..f96f2b4d4821d 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_json.py +++ b/libs/core/tests/unit_tests/output_parsers/test_json.py @@ -596,10 +596,10 @@ class Joke(BaseModel): setup: str punchline: str - initial_joke_schema = {k: v for k, v in _schema(Joke).items()} + initial_joke_schema = dict(_schema(Joke).items()) SimpleJsonOutputParser(pydantic_object=Joke) openai_func = convert_to_openai_function(Joke) - retrieved_joke_schema = {k: v for k, v in _schema(Joke).items()} + retrieved_joke_schema = dict(_schema(Joke).items()) assert initial_joke_schema == retrieved_joke_schema assert openai_func.get("name", None) is not None diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py index 92a53b2c11a61..c00eb999424cb 100644 --- a/libs/core/tests/unit_tests/runnables/test_context.py +++ b/libs/core/tests/unit_tests/runnables/test_context.py @@ -391,7 +391,7 @@ async def test_runnable_seq_streaming_chunks() -> None: } ) - chunks = [c for c in chain.stream({"foo": "foo", "bar": "bar"})] + chunks = list(chain.stream({"foo": "foo", "bar": "bar"})) achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})] for c in chunks: assert c in achunks diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 424b61025ab55..69ea68ba9af16 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -264,7 +264,7 @@ def test_fallbacks_stream() -> None: runnable = RunnableGenerator(_generate_immediate_error).with_fallbacks( [RunnableGenerator(_generate)] ) - assert list(runnable.stream({})) == [c for c in "foo bar"] + assert list(runnable.stream({})) == list("foo bar") with pytest.raises(ValueError): runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks( diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 1823c6efa7603..39bafbd6e510b 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -1065,7 +1065,7 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: assert [ part async for part in seq.astream( - "hello", dict(metadata={"key": "value"}), my_kwarg="value" + "hello", {"metadata": {"key": "value"}}, my_kwarg="value" ) ] == [5] assert mock.call_args_list == [ @@ -1125,12 +1125,9 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: assert call in mock.call_args_list mock.reset_mock() - assert [ - part - for part in seq.stream( - "hello", dict(metadata={"key": "value"}), my_kwarg="value" - ) - ] == [5] + assert list( + seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value") + ) == [5] assert mock.call_args_list == [ mocker.call("hello", my_kwarg="value"), mocker.call(5), @@ -1155,13 +1152,13 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: ) assert spy.call_args_list[0].args[1:] == ( "hello", - dict( - tags=["a-tag"], - callbacks=None, - recursion_limit=25, - configurable={"hello": "there", "__secret_key": "nahnah"}, - metadata={"hello": "there", "bye": "now"}, - ), + { + "tags": ["a-tag"], + "callbacks": None, + "recursion_limit": 25, + "configurable": {"hello": "there", "__secret_key": "nahnah"}, + "metadata": {"hello": "there", "bye": "now"}, + }, ) spy.reset_mock() @@ -1174,7 +1171,7 @@ async def test_with_config(mocker: MockerFixture) -> None: assert spy.call_args_list == [ mocker.call( "hello", - dict(tags=["a-tag"], metadata={}, configurable={}), + {"tags": ["a-tag"], "metadata": {}, "configurable": {}}, ), ] spy.reset_mock() @@ -1200,19 +1197,19 @@ async def test_with_config(mocker: MockerFixture) -> None: assert [ *fake.with_config(tags=["a-tag"]).stream( - "hello", dict(metadata={"key": "value"}) + "hello", {"metadata": {"key": "value"}} ) ] == [5] assert spy.call_args_list == [ mocker.call( "hello", - dict(tags=["a-tag"], metadata={"key": "value"}, configurable={}), + {"tags": ["a-tag"], "metadata": {"key": "value"}, "configurable": {}}, ), ] spy.reset_mock() assert fake.with_config(recursion_limit=5).batch( - ["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] + ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}] ) == [5, 7] assert len(spy.call_args_list) == 2 @@ -1235,7 +1232,7 @@ async def test_with_config(mocker: MockerFixture) -> None: c for c in fake.with_config(recursion_limit=5).batch_as_completed( ["hello", "wooorld"], - [dict(tags=["a-tag"]), dict(metadata={"key": "value"})], + [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}], ) ) == [(0, 5), (1, 7)] @@ -1256,7 +1253,7 @@ async def test_with_config(mocker: MockerFixture) -> None: spy.reset_mock() assert fake.with_config(metadata={"a": "b"}).batch( - ["hello", "wooorld"], dict(tags=["a-tag"]) + ["hello", "wooorld"], {"tags": ["a-tag"]} ) == [5, 7] assert len(spy.call_args_list) == 2 for i, call in enumerate(spy.call_args_list): @@ -1266,7 +1263,7 @@ async def test_with_config(mocker: MockerFixture) -> None: spy.reset_mock() assert sorted( - c for c in fake.batch_as_completed(["hello", "wooorld"], dict(tags=["a-tag"])) + c for c in fake.batch_as_completed(["hello", "wooorld"], {"tags": ["a-tag"]}) ) == [(0, 5), (1, 7)] assert len(spy.call_args_list) == 2 for i, call in enumerate(spy.call_args_list): @@ -1284,7 +1281,12 @@ async def test_with_config(mocker: MockerFixture) -> None: assert spy.call_args_list == [ mocker.call( "hello", - dict(callbacks=[handler], metadata={"a": "b"}, configurable={}, tags=[]), + { + "callbacks": [handler], + "metadata": {"a": "b"}, + "configurable": {}, + "tags": [], + }, ), ] spy.reset_mock() @@ -1293,12 +1295,12 @@ async def test_with_config(mocker: MockerFixture) -> None: part async for part in fake.with_config(metadata={"a": "b"}).astream("hello") ] == [5] assert spy.call_args_list == [ - mocker.call("hello", dict(metadata={"a": "b"}, tags=[], configurable={})), + mocker.call("hello", {"metadata": {"a": "b"}, "tags": [], "configurable": {}}), ] spy.reset_mock() assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch( - ["hello", "wooorld"], dict(metadata={"key": "value"}) + ["hello", "wooorld"], {"metadata": {"key": "value"}} ) == [ 5, 7, @@ -1306,23 +1308,23 @@ async def test_with_config(mocker: MockerFixture) -> None: assert spy.call_args_list == [ mocker.call( "hello", - dict( - metadata={"key": "value"}, - tags=["c"], - callbacks=None, - recursion_limit=5, - configurable={}, - ), + { + "metadata": {"key": "value"}, + "tags": ["c"], + "callbacks": None, + "recursion_limit": 5, + "configurable": {}, + }, ), mocker.call( "wooorld", - dict( - metadata={"key": "value"}, - tags=["c"], - callbacks=None, - recursion_limit=5, - configurable={}, - ), + { + "metadata": {"key": "value"}, + "tags": ["c"], + "callbacks": None, + "recursion_limit": 5, + "configurable": {}, + }, ), ] spy.reset_mock() @@ -1332,7 +1334,7 @@ async def test_with_config(mocker: MockerFixture) -> None: c async for c in fake.with_config( recursion_limit=5, tags=["c"] - ).abatch_as_completed(["hello", "wooorld"], dict(metadata={"key": "value"})) + ).abatch_as_completed(["hello", "wooorld"], {"metadata": {"key": "value"}}) ] ) == [ (0, 5), @@ -1342,24 +1344,24 @@ async def test_with_config(mocker: MockerFixture) -> None: first_call = next(call for call in spy.call_args_list if call.args[0] == "hello") assert first_call == mocker.call( "hello", - dict( - metadata={"key": "value"}, - tags=["c"], - callbacks=None, - recursion_limit=5, - configurable={}, - ), + { + "metadata": {"key": "value"}, + "tags": ["c"], + "callbacks": None, + "recursion_limit": 5, + "configurable": {}, + }, ) second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld") assert second_call == mocker.call( "wooorld", - dict( - metadata={"key": "value"}, - tags=["c"], - callbacks=None, - recursion_limit=5, - configurable={}, - ), + { + "metadata": {"key": "value"}, + "tags": ["c"], + "callbacks": None, + "recursion_limit": 5, + "configurable": {}, + }, ) @@ -1367,20 +1369,20 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke") - assert fake.invoke("hello", dict(tags=["a-tag"])) == 5 + assert fake.invoke("hello", {"tags": ["a-tag"]}) == 5 assert spy.call_args_list == [ - mocker.call("hello", dict(tags=["a-tag"])), + mocker.call("hello", {"tags": ["a-tag"]}), ] spy.reset_mock() - assert [*fake.stream("hello", dict(metadata={"key": "value"}))] == [5] + assert [*fake.stream("hello", {"metadata": {"key": "value"}})] == [5] assert spy.call_args_list == [ - mocker.call("hello", dict(metadata={"key": "value"})), + mocker.call("hello", {"metadata": {"key": "value"}}), ] spy.reset_mock() assert fake.batch( - ["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] + ["hello", "wooorld"], [{"tags": ["a-tag"]}, {"metadata": {"key": "value"}}] ) == [5, 7] assert len(spy.call_args_list) == 2 @@ -1398,9 +1400,9 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: spy.reset_mock() - assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] + assert fake.batch(["hello", "wooorld"], {"tags": ["a-tag"]}) == [5, 7] assert len(spy.call_args_list) == 2 - assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"} + assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"} for call in spy.call_args_list: assert call.args[1].get("tags") == ["a-tag"] assert call.args[1].get("metadata") == {} @@ -1408,7 +1410,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: assert await fake.ainvoke("hello", config={"callbacks": []}) == 5 assert spy.call_args_list == [ - mocker.call("hello", dict(callbacks=[])), + mocker.call("hello", {"callbacks": []}), ] spy.reset_mock() @@ -1418,19 +1420,19 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: ] spy.reset_mock() - assert await fake.abatch(["hello", "wooorld"], dict(metadata={"key": "value"})) == [ + assert await fake.abatch(["hello", "wooorld"], {"metadata": {"key": "value"}}) == [ 5, 7, ] - assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"} + assert {call.args[0] for call in spy.call_args_list} == {"hello", "wooorld"} for call in spy.call_args_list: - assert call.args[1] == dict( - metadata={"key": "value"}, - tags=[], - callbacks=None, - recursion_limit=25, - configurable={}, - ) + assert call.args[1] == { + "metadata": {"key": "value"}, + "tags": [], + "callbacks": None, + "recursion_limit": 25, + "configurable": {}, + } async def test_prompt() -> None: @@ -1698,7 +1700,7 @@ def test_prompt_with_chat_model( chat_spy = mocker.spy(chat.__class__, "invoke") tracer = FakeTracer() assert chain.invoke( - {"question": "What is your name?"}, dict(callbacks=[tracer]) + {"question": "What is your name?"}, {"callbacks": [tracer]} ) == _any_id_ai_message(content="foo") assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( @@ -1722,7 +1724,7 @@ def test_prompt_with_chat_model( {"question": "What is your name?"}, {"question": "What is your favorite color?"}, ], - dict(callbacks=[tracer]), + {"callbacks": [tracer]}, ) == [ _any_id_ai_message(content="foo"), _any_id_ai_message(content="foo"), @@ -1763,7 +1765,7 @@ def test_prompt_with_chat_model( chat_spy = mocker.spy(chat.__class__, "stream") tracer = FakeTracer() assert [ - *chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer])) + *chain.stream({"question": "What is your name?"}, {"callbacks": [tracer]}) ] == [ _any_id_ai_message_chunk(content="f"), _any_id_ai_message_chunk(content="o"), @@ -1804,7 +1806,7 @@ async def test_prompt_with_chat_model_async( chat_spy = mocker.spy(chat.__class__, "ainvoke") tracer = FakeTracer() assert await chain.ainvoke( - {"question": "What is your name?"}, dict(callbacks=[tracer]) + {"question": "What is your name?"}, {"callbacks": [tracer]} ) == _any_id_ai_message(content="foo") assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( @@ -1828,7 +1830,7 @@ async def test_prompt_with_chat_model_async( {"question": "What is your name?"}, {"question": "What is your favorite color?"}, ], - dict(callbacks=[tracer]), + {"callbacks": [tracer]}, ) == [ _any_id_ai_message(content="foo"), _any_id_ai_message(content="foo"), @@ -1871,7 +1873,7 @@ async def test_prompt_with_chat_model_async( assert [ a async for a in chain.astream( - {"question": "What is your name?"}, dict(callbacks=[tracer]) + {"question": "What is your name?"}, {"callbacks": [tracer]} ) ] == [ _any_id_ai_message_chunk(content="f"), @@ -1910,9 +1912,7 @@ async def test_prompt_with_llm( llm_spy = mocker.spy(llm.__class__, "ainvoke") tracer = FakeTracer() assert ( - await chain.ainvoke( - {"question": "What is your name?"}, dict(callbacks=[tracer]) - ) + await chain.ainvoke({"question": "What is your name?"}, {"callbacks": [tracer]}) == "foo" ) assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} @@ -1935,7 +1935,7 @@ async def test_prompt_with_llm( {"question": "What is your name?"}, {"question": "What is your favorite color?"}, ], - dict(callbacks=[tracer]), + {"callbacks": [tracer]}, ) == ["bar", "foo"] assert prompt_spy.call_args.args[1] == [ {"question": "What is your name?"}, @@ -1966,7 +1966,7 @@ async def test_prompt_with_llm( assert [ token async for token in chain.astream( - {"question": "What is your name?"}, dict(callbacks=[tracer]) + {"question": "What is your name?"}, {"callbacks": [tracer]} ) ] == ["bar"] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} @@ -2110,7 +2110,7 @@ async def test_prompt_with_llm_parser( parser_spy = mocker.spy(parser.__class__, "ainvoke") tracer = FakeTracer() assert await chain.ainvoke( - {"question": "What is your name?"}, dict(callbacks=[tracer]) + {"question": "What is your name?"}, {"callbacks": [tracer]} ) == ["bear", "dog", "cat"] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert llm_spy.call_args.args[1] == ChatPromptValue( @@ -2135,7 +2135,7 @@ async def test_prompt_with_llm_parser( {"question": "What is your name?"}, {"question": "What is your favorite color?"}, ], - dict(callbacks=[tracer]), + {"callbacks": [tracer]}, ) == [["tomato", "lettuce", "onion"], ["bear", "dog", "cat"]] assert prompt_spy.call_args.args[1] == [ {"question": "What is your name?"}, @@ -2171,7 +2171,7 @@ async def test_prompt_with_llm_parser( assert [ token async for token in chain.astream( - {"question": "What is your name?"}, dict(callbacks=[tracer]) + {"question": "What is your name?"}, {"callbacks": [tracer]} ) ] == [["tomato"], ["lettuce"], ["onion"]] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} @@ -2495,9 +2495,7 @@ async def passthrough(input: Any) -> Any: llm_spy = mocker.spy(llm.__class__, "ainvoke") tracer = FakeTracer() assert ( - await chain.ainvoke( - {"question": "What is your name?"}, dict(callbacks=[tracer]) - ) + await chain.ainvoke({"question": "What is your name?"}, {"callbacks": [tracer]}) == "foo" ) assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} @@ -2539,7 +2537,7 @@ def test_prompt_with_chat_model_and_parser( parser_spy = mocker.spy(parser.__class__, "invoke") tracer = FakeTracer() assert chain.invoke( - {"question": "What is your name?"}, dict(callbacks=[tracer]) + {"question": "What is your name?"}, {"callbacks": [tracer]} ) == ["foo", "bar"] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( @@ -2608,7 +2606,7 @@ def test_combining_sequences( # Test invoke tracer = FakeTracer() assert combined_chain.invoke( - {"question": "What is your name?"}, dict(callbacks=[tracer]) + {"question": "What is your name?"}, {"callbacks": [tracer]} ) == ["baz", "qux"] assert tracer.runs == snapshot @@ -2658,7 +2656,7 @@ def test_seq_dict_prompt_llm( chat_spy = mocker.spy(chat.__class__, "invoke") parser_spy = mocker.spy(parser.__class__, "invoke") tracer = FakeTracer() - assert chain.invoke("What is your name?", dict(callbacks=[tracer])) == [ + assert chain.invoke("What is your name?", {"callbacks": [tracer]}) == [ "foo", "bar", ] @@ -2725,7 +2723,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> llm_spy = mocker.spy(llm.__class__, "invoke") tracer = FakeTracer() assert chain.invoke( - {"question": "What is your name?"}, dict(callbacks=[tracer]) + {"question": "What is your name?"}, {"callbacks": [tracer]} ) == { "chat": _any_id_ai_message(content="i'm a chatbot"), "llm": "i'm a textbot", @@ -2788,7 +2786,7 @@ async def test_router_runnable( router_spy = mocker.spy(router.__class__, "invoke") tracer = FakeTracer() assert ( - chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])) + chain.invoke({"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]}) == "4" ) assert router_spy.call_args.args[1] == { @@ -2849,7 +2847,7 @@ def router(input: dict[str, Any]) -> Runnable: math_spy = mocker.spy(math_chain.__class__, "invoke") tracer = FakeTracer() assert ( - chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])) + chain.invoke({"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]}) == "4" ) assert math_spy.call_args.args[1] == { @@ -2880,7 +2878,7 @@ async def arouter(input: dict[str, Any]) -> Runnable: tracer = FakeTracer() assert ( await achain.ainvoke( - {"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]) + {"key": "math", "question": "2 + 2"}, {"callbacks": [tracer]} ) == "4" ) @@ -2934,7 +2932,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N llm_spy = mocker.spy(llm.__class__, "invoke") tracer = FakeTracer() assert chain.invoke( - {"question": "What is your name?"}, dict(callbacks=[tracer]) + {"question": "What is your name?"}, {"callbacks": [tracer]} ) == { "chat": _any_id_ai_message(content="i'm a chatbot"), "llm": "i'm a textbot", @@ -3841,7 +3839,7 @@ def _lambda(x: int) -> Union[int, Runnable]: def test_runnable_lambda_stream() -> None: """Test that stream works for both normal functions & those returning Runnable.""" # Normal output should work - output: list[Any] = [chunk for chunk in RunnableLambda(range).stream(5)] + output: list[Any] = list(RunnableLambda(range).stream(5)) assert output == [range(5)] # Runnable output should also work @@ -4015,7 +4013,7 @@ def batch( spy = mocker.spy(ControlledExceptionRunnable, "batch") tracer = FakeTracer() inputs = ["foo", "bar", "baz", "qux"] - outputs = chain.batch(inputs, dict(callbacks=[tracer]), return_exceptions=True) + outputs = chain.batch(inputs, {"callbacks": [tracer]}, return_exceptions=True) assert len(outputs) == 4 assert isinstance(outputs[0], ValueError) assert isinstance(outputs[1], ValueError) @@ -4135,7 +4133,7 @@ async def abatch( tracer = FakeTracer() inputs = ["foo", "bar", "baz", "qux"] outputs = await chain.abatch( - inputs, dict(callbacks=[tracer]), return_exceptions=True + inputs, {"callbacks": [tracer]}, return_exceptions=True ) assert len(outputs) == 4 assert isinstance(outputs[0], ValueError) @@ -5080,13 +5078,13 @@ def idchain_sync(__input: dict) -> bool: chain = RunnablePassthrough.assign(urls=idchain_sync) tracer = FakeTracer() - chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer])) + chain.invoke({"example": [1, 2, 3]}, {"callbacks": [tracer]}) assert tracer.runs[0].name == "RunnableAssign" assert tracer.runs[0].child_runs[0].name == "RunnableParallel" tracer = FakeTracer() - for _ in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): + for _ in chain.stream({"example": [1, 2, 3]}, {"callbacks": [tracer]}): pass assert tracer.runs[0].name == "RunnableAssign" @@ -5100,13 +5098,13 @@ def idchain_sync(__input: dict) -> bool: chain = RunnablePassthrough.assign(urls=idchain_sync) tracer = FakeTracer() - await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer])) + await chain.ainvoke({"example": [1, 2, 3]}, {"callbacks": [tracer]}) assert tracer.runs[0].name == "RunnableAssign" assert tracer.runs[0].child_runs[0].name == "RunnableParallel" tracer = FakeTracer() - async for _ in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])): + async for _ in chain.astream({"example": [1, 2, 3]}, {"callbacks": [tracer]}): pass assert tracer.runs[0].name == "RunnableAssign" @@ -5260,7 +5258,7 @@ async def chunk_iterator_with_addable() -> AsyncIterator[dict[str, str]]: def test_passthrough_transform_with_dicts() -> None: """Test that default transform works with dicts.""" runnable = RunnablePassthrough(lambda x: x) - chunks = [chunk for chunk in runnable.transform(iter([{"foo": "a"}, {"foo": "n"}]))] + chunks = list(runnable.transform(iter([{"foo": "a"}, {"foo": "n"}]))) assert chunks == [{"foo": "a"}, {"foo": "n"}] diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 378d947e6bbba..67b303b3518e4 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -2033,7 +2033,7 @@ def add_one_(x: int) -> int: async def add_one_proxy_(x: int, config: RunnableConfig) -> int: streaming = add_one.stream(x, config) - results = [result for result in streaming] + results = list(streaming) return results[0] add_one_proxy = RunnableLambda(add_one_proxy_) # type: ignore @@ -2078,7 +2078,7 @@ def add_one(x: int) -> int: def add_one_proxy(x: int, config: RunnableConfig) -> int: # Use sync streaming streaming = add_one_.stream(x, config) - results = [result for result in streaming] + results = list(streaming) return results[0] add_one_proxy_ = RunnableLambda(add_one_proxy) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 4b4815217f2e4..ac355dddf2eba 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -1995,7 +1995,7 @@ def add_one(x: int) -> int: async def add_one_proxy(x: int, config: RunnableConfig) -> int: streaming = add_one_.stream(x, config) - results = [result for result in streaming] + results = list(streaming) return results[0] add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore @@ -2035,7 +2035,7 @@ def add_one(x: int) -> int: def add_one_proxy(x: int, config: RunnableConfig) -> int: # Use sync streaming streaming = add_one_.stream(x, config) - results = [result for result in streaming] + results = list(streaming) return results[0] add_one_proxy_ = RunnableLambda(add_one_proxy) diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 90868cc18a6c1..2743783ca85ad 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -448,10 +448,10 @@ def test_message_chunk_to_message() -> None: def test_tool_calls_merge() -> None: chunks: list[dict] = [ - dict(content=""), - dict( - content="", - additional_kwargs={ + {"content": ""}, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 0, @@ -461,10 +461,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 0, @@ -474,10 +474,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 0, @@ -487,10 +487,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 0, @@ -500,10 +500,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 0, @@ -513,10 +513,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 0, @@ -526,10 +526,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 0, @@ -539,10 +539,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 1, @@ -552,10 +552,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 1, @@ -565,10 +565,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 1, @@ -578,10 +578,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 1, @@ -591,10 +591,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 1, @@ -604,10 +604,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 1, @@ -617,10 +617,10 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict( - content="", - additional_kwargs={ + }, + { + "content": "", + "additional_kwargs": { "tool_calls": [ { "index": 1, @@ -630,8 +630,8 @@ def test_tool_calls_merge() -> None: } ] }, - ), - dict(content=""), + }, + {"content": ""}, ] final = None diff --git a/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py b/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py index a99ad529eda8f..1b243c0381696 100644 --- a/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py +++ b/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py @@ -98,7 +98,7 @@ async def test_tracer_chat_model_run() -> None: ], extra={}, serialized=SERIALIZED_CHAT, - inputs=dict(prompts=["Human: "]), + inputs={"prompts": ["Human: "]}, outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] error=None, run_type="llm", @@ -134,7 +134,7 @@ async def test_tracer_multiple_llm_runs() -> None: ], extra={}, serialized=SERIALIZED, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] error=None, run_type="llm", @@ -272,8 +272,8 @@ async def test_tracer_nested_run() -> None: ], extra={}, serialized={"name": "tool"}, - inputs=dict(input="test"), - outputs=dict(output="test"), + inputs={"input": "test"}, + outputs={"output": "test"}, error=None, run_type="tool", trace_id=chain_uuid, @@ -291,7 +291,7 @@ async def test_tracer_nested_run() -> None: ], extra={}, serialized=SERIALIZED, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] run_type="llm", trace_id=chain_uuid, @@ -311,7 +311,7 @@ async def test_tracer_nested_run() -> None: ], extra={}, serialized=SERIALIZED, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] run_type="llm", trace_id=chain_uuid, @@ -339,7 +339,7 @@ async def test_tracer_llm_run_on_error() -> None: ], extra={}, serialized=SERIALIZED, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=None, error=repr(exception), run_type="llm", @@ -370,7 +370,7 @@ async def test_tracer_llm_run_on_error_callback() -> None: ], extra={}, serialized=SERIALIZED, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=None, error=repr(exception), run_type="llm", @@ -436,7 +436,7 @@ async def test_tracer_tool_run_on_error() -> None: ], extra={}, serialized={"name": "tool"}, - inputs=dict(input="test"), + inputs={"input": "test"}, outputs=None, error=repr(exception), run_type="tool", @@ -527,7 +527,7 @@ async def test_tracer_nested_runs_on_error() -> None: extra={}, serialized=SERIALIZED, error=None, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type] run_type="llm", trace_id=chain_uuid, @@ -545,7 +545,7 @@ async def test_tracer_nested_runs_on_error() -> None: extra={}, serialized=SERIALIZED, error=None, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type] run_type="llm", trace_id=chain_uuid, @@ -563,7 +563,7 @@ async def test_tracer_nested_runs_on_error() -> None: extra={}, serialized={"name": "tool"}, error=repr(exception), - inputs=dict(input="test"), + inputs={"input": "test"}, outputs=None, trace_id=chain_uuid, dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}", @@ -580,7 +580,7 @@ async def test_tracer_nested_runs_on_error() -> None: extra={}, serialized=SERIALIZED, error=repr(exception), - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=None, run_type="llm", trace_id=chain_uuid, diff --git a/libs/core/tests/unit_tests/tracers/test_base_tracer.py b/libs/core/tests/unit_tests/tracers/test_base_tracer.py index 9e444c252f704..80d7a9297492c 100644 --- a/libs/core/tests/unit_tests/tracers/test_base_tracer.py +++ b/libs/core/tests/unit_tests/tracers/test_base_tracer.py @@ -103,7 +103,7 @@ def test_tracer_chat_model_run() -> None: ], extra={}, serialized=SERIALIZED_CHAT, - inputs=dict(prompts=["Human: "]), + inputs={"prompts": ["Human: "]}, outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] error=None, run_type="llm", @@ -139,7 +139,7 @@ def test_tracer_multiple_llm_runs() -> None: ], extra={}, serialized=SERIALIZED, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] error=None, run_type="llm", @@ -275,8 +275,8 @@ def test_tracer_nested_run() -> None: ], extra={}, serialized={"name": "tool"}, - inputs=dict(input="test"), - outputs=dict(output="test"), + inputs={"input": "test"}, + outputs={"output": "test"}, error=None, run_type="tool", trace_id=chain_uuid, @@ -294,7 +294,7 @@ def test_tracer_nested_run() -> None: ], extra={}, serialized=SERIALIZED, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] run_type="llm", trace_id=chain_uuid, @@ -314,7 +314,7 @@ def test_tracer_nested_run() -> None: ], extra={}, serialized=SERIALIZED, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=LLMResult(generations=[[]]), # type: ignore[arg-type] run_type="llm", trace_id=chain_uuid, @@ -342,7 +342,7 @@ def test_tracer_llm_run_on_error() -> None: ], extra={}, serialized=SERIALIZED, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=None, error=repr(exception), run_type="llm", @@ -373,7 +373,7 @@ def test_tracer_llm_run_on_error_callback() -> None: ], extra={}, serialized=SERIALIZED, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=None, error=repr(exception), run_type="llm", @@ -439,7 +439,7 @@ def test_tracer_tool_run_on_error() -> None: ], extra={}, serialized={"name": "tool"}, - inputs=dict(input="test"), + inputs={"input": "test"}, outputs=None, error=repr(exception), run_type="tool", @@ -528,7 +528,7 @@ def test_tracer_nested_runs_on_error() -> None: extra={}, serialized=SERIALIZED, error=None, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type] run_type="llm", trace_id=chain_uuid, @@ -546,7 +546,7 @@ def test_tracer_nested_runs_on_error() -> None: extra={}, serialized=SERIALIZED, error=None, - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=LLMResult(generations=[[]], llm_output=None), # type: ignore[arg-type] run_type="llm", trace_id=chain_uuid, @@ -564,7 +564,7 @@ def test_tracer_nested_runs_on_error() -> None: extra={}, serialized={"name": "tool"}, error=repr(exception), - inputs=dict(input="test"), + inputs={"input": "test"}, outputs=None, trace_id=chain_uuid, dotted_order=f"20230101T000000000000Z{chain_uuid}.20230101T000000000000Z{tool_uuid}", @@ -581,7 +581,7 @@ def test_tracer_nested_runs_on_error() -> None: extra={}, serialized=SERIALIZED, error=repr(exception), - inputs=dict(prompts=[]), + inputs={"prompts": []}, outputs=None, run_type="llm", trace_id=chain_uuid,