From 3cc2ce6ac9755b4d9e5f39a69e56f0a61942d2f7 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Thu, 20 Apr 2023 16:58:04 -0700 Subject: [PATCH 01/36] callbacks changes --- langchain/callbacks/base.py | 104 +++++++++++++----------------------- 1 file changed, 37 insertions(+), 67 deletions(-) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 5b47b82a908c1..575dc0ddffd65 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -7,7 +7,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult -class BaseCallbackHandler(ABC): +class BaseCallbackHandler: """Base callback handler that can be used to handle callbacks from langchain.""" @property @@ -30,67 +30,54 @@ def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" return False - @abstractmethod def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> Any: """Run when LLM starts running.""" - @abstractmethod def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: """Run on new LLM token. Only available when streaming is enabled.""" - @abstractmethod def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: """Run when LLM ends running.""" - @abstractmethod def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> Any: """Run when LLM errors.""" - @abstractmethod def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> Any: """Run when chain starts running.""" - @abstractmethod def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: """Run when chain ends running.""" - @abstractmethod def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> Any: """Run when chain errors.""" - @abstractmethod def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> Any: """Run when tool starts running.""" - @abstractmethod def on_tool_end(self, output: str, **kwargs: Any) -> Any: """Run when tool ends running.""" - @abstractmethod def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> Any: """Run when tool errors.""" - @abstractmethod def on_text(self, text: str, **kwargs: Any) -> Any: """Run on arbitrary text.""" - @abstractmethod def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run on agent action.""" - @abstractmethod def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: """Run on agent end.""" @@ -127,6 +114,21 @@ def __init__(self, handlers: List[BaseCallbackHandler]) -> None: """Initialize callback manager.""" self.handlers: List[BaseCallbackHandler] = handlers + def _handle_event( + self, + event_name: str, + ignore_condition_name: Optional[str], + verbose: bool, + *args: Any, + **kwargs: Any + ) -> None: + for handler in self.handlers: + if ignore_condition_name is None or not getattr( + handler, ignore_condition_name + ): + if verbose or handler.always_verbose: + getattr(handler, event_name)(*args, **kwargs) + def on_llm_start( self, serialized: Dict[str, Any], @@ -135,28 +137,21 @@ def on_llm_start( **kwargs: Any ) -> None: """Run when LLM starts running.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - handler.on_llm_start(serialized, prompts, **kwargs) + self._handle_event( + "on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs + ) def on_llm_new_token( self, token: str, verbose: bool = False, **kwargs: Any ) -> None: """Run when LLM generates a new token.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - handler.on_llm_new_token(token, **kwargs) + self._handle_event("on_llm_new_token", "ignore_llm", verbose, token, **kwargs) def on_llm_end( self, response: LLMResult, verbose: bool = False, **kwargs: Any ) -> None: """Run when LLM ends running.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - handler.on_llm_end(response) + self._handle_event("on_llm_end", "ignore_llm", verbose, response, **kwargs) def on_llm_error( self, @@ -165,10 +160,7 @@ def on_llm_error( **kwargs: Any ) -> None: """Run when LLM errors.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - handler.on_llm_error(error) + self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs) def on_chain_start( self, @@ -178,19 +170,15 @@ def on_chain_start( **kwargs: Any ) -> None: """Run when chain starts running.""" - for handler in self.handlers: - if not handler.ignore_chain: - if verbose or handler.always_verbose: - handler.on_chain_start(serialized, inputs, **kwargs) + self._handle_event( + "on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs + ) def on_chain_end( self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any ) -> None: """Run when chain ends running.""" - for handler in self.handlers: - if not handler.ignore_chain: - if verbose or handler.always_verbose: - handler.on_chain_end(outputs) + self._handle_event("on_chain_end", "ignore_chain", verbose, outputs, **kwargs) def on_chain_error( self, @@ -199,10 +187,7 @@ def on_chain_error( **kwargs: Any ) -> None: """Run when chain errors.""" - for handler in self.handlers: - if not handler.ignore_chain: - if verbose or handler.always_verbose: - handler.on_chain_error(error) + self._handle_event("on_chain_error", "ignore_chain", verbose, error, **kwargs) def on_tool_start( self, @@ -212,26 +197,19 @@ def on_tool_start( **kwargs: Any ) -> None: """Run when tool starts running.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - handler.on_tool_start(serialized, input_str, **kwargs) + self._handle_event( + "on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs + ) def on_agent_action( self, action: AgentAction, verbose: bool = False, **kwargs: Any ) -> None: """Run when tool starts running.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - handler.on_agent_action(action, **kwargs) + self._handle_event("on_agent_action", "ignore_agent", verbose, action, **kwargs) def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None: """Run when tool ends running.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - handler.on_tool_end(output, **kwargs) + self._handle_event("on_tool_end", "ignore_agent", verbose, output, **kwargs) def on_tool_error( self, @@ -240,25 +218,17 @@ def on_tool_error( **kwargs: Any ) -> None: """Run when tool errors.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - handler.on_tool_error(error) + self._handle_event("on_tool_error", "ignore_agent", verbose, error, **kwargs) def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None: """Run on additional input from chains and agents.""" - for handler in self.handlers: - if verbose or handler.always_verbose: - handler.on_text(text, **kwargs) + self._handle_event("on_text", None, verbose, text, **kwargs) def on_agent_finish( self, finish: AgentFinish, verbose: bool = False, **kwargs: Any ) -> None: """Run on agent end.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - handler.on_agent_finish(finish, **kwargs) + self._handle_event("on_agent_finish", None, verbose, finish, **kwargs) def add_handler(self, handler: BaseCallbackHandler) -> None: """Add a handler to the callback manager.""" @@ -328,7 +298,7 @@ async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """Run on agent end.""" -async def _handle_event_for_handler( +async def _ahandle_event_for_handler( handler: BaseCallbackHandler, event_name: str, ignore_condition_name: Optional[str], @@ -370,7 +340,7 @@ async def _handle_event( """Generic event handler for AsyncCallbackManager.""" await asyncio.gather( *( - _handle_event_for_handler( + _ahandle_event_for_handler( handler, event_name, ignore_condition_name, verbose, *args, **kwargs ) for handler in self.handlers From fa4a4f29409ce356b8c78fd446e7614060d50e66 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Thu, 20 Apr 2023 17:06:00 -0700 Subject: [PATCH 02/36] cr --- langchain/callbacks/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 575dc0ddffd65..a3116dcbf2d93 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -228,7 +228,7 @@ def on_agent_finish( self, finish: AgentFinish, verbose: bool = False, **kwargs: Any ) -> None: """Run on agent end.""" - self._handle_event("on_agent_finish", None, verbose, finish, **kwargs) + self._handle_event("on_agent_finish", "ignore_agent", verbose, finish, **kwargs) def add_handler(self, handler: BaseCallbackHandler) -> None: """Add a handler to the callback manager.""" From 675e27c1360765e448160402733fe56b56174b70 Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Sat, 22 Apr 2023 21:40:59 -0700 Subject: [PATCH 03/36] Callbacks Refactor [2/n]: refactor `CallbackManager` code to own file (#3341) --- docs/ecosystem/gpt4all.md | 6 +- langchain/callbacks/__init__.py | 13 +- langchain/callbacks/base.py | 559 ++++++--------- langchain/callbacks/manager.py | 652 ++++++++++++++++++ langchain/callbacks/shared.py | 127 ---- langchain/callbacks/tracers/__init__.py | 11 +- langchain/callbacks/tracers/base.py | 51 -- langchain/callbacks/tracers/langchain.py | 6 +- .../chat_models/test_anthropic.py | 2 +- .../chat_models/test_openai.py | 2 +- .../chat_models/test_promptlayer_openai.py | 2 +- .../integration_tests/llms/test_anthropic.py | 2 +- tests/integration_tests/llms/test_openai.py | 2 +- tests/unit_tests/agents/test_agent.py | 2 +- .../callbacks/fake_callback_handler.py | 336 +++++---- .../callbacks/test_callback_manager.py | 252 ++++--- tests/unit_tests/chains/test_base.py | 2 +- tests/unit_tests/llms/test_callbacks.py | 2 +- 18 files changed, 1267 insertions(+), 762 deletions(-) create mode 100644 langchain/callbacks/manager.py delete mode 100644 langchain/callbacks/shared.py diff --git a/docs/ecosystem/gpt4all.md b/docs/ecosystem/gpt4all.md index 36422eb529c14..ea4704d8f76f1 100644 --- a/docs/ecosystem/gpt4all.md +++ b/docs/ecosystem/gpt4all.md @@ -28,13 +28,15 @@ To stream the model's predictions, add in a CallbackManager. ```python from langchain.llms import GPT4All -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + # There are many CallbackHandlers supported, such as # from langchain.callbacks.streamlit import StreamlitCallbackHandler callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) -model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8, callback_handler=callback_handler, verbose=True) +model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8, callback_handler=callback_handler, + verbose=True) # Generate text. Tokens are streamed through the callback manager. model("Once upon a time, ") diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py index c6137bf9bc7e1..232910ad32797 100644 --- a/langchain/callbacks/__init__.py +++ b/langchain/callbacks/__init__.py @@ -5,24 +5,22 @@ from langchain.callbacks.aim_callback import AimCallbackHandler from langchain.callbacks.base import ( - AsyncCallbackManager, BaseCallbackHandler, BaseCallbackManager, - CallbackManager, ) from langchain.callbacks.clearml_callback import ClearMLCallbackHandler from langchain.callbacks.comet_ml_callback import CometCallbackHandler +from langchain.callbacks.manager import CallbackManager from langchain.callbacks.openai_info import OpenAICallbackHandler -from langchain.callbacks.shared import SharedCallbackManager from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler -from langchain.callbacks.tracers import SharedLangChainTracer +from langchain.callbacks.tracers import LangChainTracer from langchain.callbacks.wandb_callback import WandbCallbackHandler def get_callback_manager() -> BaseCallbackManager: """Return the shared callback manager.""" - return SharedCallbackManager() + return CallbackManager([]) def set_handler(handler: BaseCallbackHandler) -> None: @@ -48,7 +46,7 @@ def set_default_callback_manager() -> None: def set_tracing_callback_manager(session_name: Optional[str] = None) -> None: """Set tracing callback manager.""" - handler = SharedLangChainTracer() + handler = LangChainTracer() callback = get_callback_manager() callback.set_handlers([handler, StdOutCallbackHandler()]) if session_name is None: @@ -71,10 +69,7 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: __all__ = [ - "CallbackManager", - "AsyncCallbackManager", "OpenAICallbackHandler", - "SharedCallbackManager", "StdOutCallbackHandler", "AimCallbackHandler", "WandbCallbackHandler", diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index a3116dcbf2d93..bd070104c7327 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -1,479 +1,352 @@ -"""Base callback handler that can be used to handle callbacks from langchain.""" -import asyncio -import functools -from abc import ABC, abstractmethod +"""Base callback handler that can be used to handle callbacks in langchain.""" +from __future__ import annotations + +import copy from typing import Any, Dict, List, Optional, Union from langchain.schema import AgentAction, AgentFinish, LLMResult -class BaseCallbackHandler: - """Base callback handler that can be used to handle callbacks from langchain.""" - - @property - def always_verbose(self) -> bool: - """Whether to call verbose callbacks even if verbose is False.""" - return False - - @property - def ignore_llm(self) -> bool: - """Whether to ignore LLM callbacks.""" - return False - - @property - def ignore_chain(self) -> bool: - """Whether to ignore chain callbacks.""" - return False - - @property - def ignore_agent(self) -> bool: - """Whether to ignore agent callbacks.""" - return False +class LLMManagerMixin: + """Mixin for LLM callbacks.""" - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + def on_llm_new_token( + self, + token: str, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> Any: - """Run when LLM starts running.""" - - def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: """Run on new LLM token. Only available when streaming is enabled.""" - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: + def on_llm_end( + self, + response: LLMResult, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: """Run when LLM ends running.""" def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + error: Union[Exception, KeyboardInterrupt], + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> Any: """Run when LLM errors.""" - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> Any: - """Run when chain starts running.""" - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: +class ChainManagerMixin: + """Mixin for chain callbacks.""" + + def on_chain_end( + self, + outputs: Dict[str, Any], + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: """Run when chain ends running.""" def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + error: Union[Exception, KeyboardInterrupt], + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> Any: """Run when chain errors.""" - def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> Any: - """Run when tool starts running.""" - - def on_tool_end(self, output: str, **kwargs: Any) -> Any: - """Run when tool ends running.""" - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + def on_agent_action( + self, + action: AgentAction, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> Any: - """Run when tool errors.""" - - def on_text(self, text: str, **kwargs: Any) -> Any: - """Run on arbitrary text.""" - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run on agent action.""" - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + def on_agent_finish( + self, + finish: AgentFinish, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: """Run on agent end.""" -class BaseCallbackManager(BaseCallbackHandler, ABC): - """Base callback manager that can be used to handle callbacks from LangChain.""" - - @property - def is_async(self) -> bool: - """Whether the callback manager is async.""" - return False - - @abstractmethod - def add_handler(self, callback: BaseCallbackHandler) -> None: - """Add a handler to the callback manager.""" - - @abstractmethod - def remove_handler(self, handler: BaseCallbackHandler) -> None: - """Remove a handler from the callback manager.""" - - def set_handler(self, handler: BaseCallbackHandler) -> None: - """Set handler as the only handler on the callback manager.""" - self.set_handlers([handler]) - - @abstractmethod - def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: - """Set handlers as the only handlers on the callback manager.""" +class ToolManagerMixin: + """Mixin for tool callbacks.""" + def on_tool_end( + self, + output_str: str, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Run when tool ends running.""" -class CallbackManager(BaseCallbackManager): - """Callback manager that can be used to handle callbacks from langchain.""" + def on_tool_error( + self, + error: Union[Exception, KeyboardInterrupt], + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Run when tool errors.""" - def __init__(self, handlers: List[BaseCallbackHandler]) -> None: - """Initialize callback manager.""" - self.handlers: List[BaseCallbackHandler] = handlers - def _handle_event( - self, - event_name: str, - ignore_condition_name: Optional[str], - verbose: bool, - *args: Any, - **kwargs: Any - ) -> None: - for handler in self.handlers: - if ignore_condition_name is None or not getattr( - handler, ignore_condition_name - ): - if verbose or handler.always_verbose: - getattr(handler, event_name)(*args, **kwargs) +class CallbackManagerMixin: + """Mixin for callback manager.""" def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], - verbose: bool = False, - **kwargs: Any - ) -> None: + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: """Run when LLM starts running.""" - self._handle_event( - "on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs - ) - - def on_llm_new_token( - self, token: str, verbose: bool = False, **kwargs: Any - ) -> None: - """Run when LLM generates a new token.""" - self._handle_event("on_llm_new_token", "ignore_llm", verbose, token, **kwargs) - - def on_llm_end( - self, response: LLMResult, verbose: bool = False, **kwargs: Any - ) -> None: - """Run when LLM ends running.""" - self._handle_event("on_llm_end", "ignore_llm", verbose, response, **kwargs) - - def on_llm_error( - self, - error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when LLM errors.""" - self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs) def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], - verbose: bool = False, - **kwargs: Any - ) -> None: + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: """Run when chain starts running.""" - self._handle_event( - "on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs - ) - - def on_chain_end( - self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any - ) -> None: - """Run when chain ends running.""" - self._handle_event("on_chain_end", "ignore_chain", verbose, outputs, **kwargs) - - def on_chain_error( - self, - error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when chain errors.""" - self._handle_event("on_chain_error", "ignore_chain", verbose, error, **kwargs) def on_tool_start( self, serialized: Dict[str, Any], input_str: str, - verbose: bool = False, - **kwargs: Any - ) -> None: + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: """Run when tool starts running.""" - self._handle_event( - "on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs - ) - def on_agent_action( - self, action: AgentAction, verbose: bool = False, **kwargs: Any - ) -> None: - """Run when tool starts running.""" - self._handle_event("on_agent_action", "ignore_agent", verbose, action, **kwargs) - def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None: - """Run when tool ends running.""" - self._handle_event("on_tool_end", "ignore_agent", verbose, output, **kwargs) +class RunManagerMixin: + """Mixin for run manager.""" - def on_tool_error( + def on_text( self, - error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when tool errors.""" - self._handle_event("on_tool_error", "ignore_agent", verbose, error, **kwargs) - - def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None: - """Run on additional input from chains and agents.""" - self._handle_event("on_text", None, verbose, text, **kwargs) - - def on_agent_finish( - self, finish: AgentFinish, verbose: bool = False, **kwargs: Any - ) -> None: - """Run on agent end.""" - self._handle_event("on_agent_finish", "ignore_agent", verbose, finish, **kwargs) - - def add_handler(self, handler: BaseCallbackHandler) -> None: - """Add a handler to the callback manager.""" - self.handlers.append(handler) - - def remove_handler(self, handler: BaseCallbackHandler) -> None: - """Remove a handler from the callback manager.""" - self.handlers.remove(handler) - - def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: - """Set handlers as the only handlers on the callback manager.""" - self.handlers = handlers - - -class AsyncCallbackHandler(BaseCallbackHandler): - """Async callback handler that can be used to handle callbacks from langchain.""" - - async def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Run when LLM starts running.""" - - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" - - async def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when LLM errors.""" - - async def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Run when chain starts running.""" - - async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Run when chain ends running.""" - - async def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when chain errors.""" - - async def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> None: - """Run when tool starts running.""" - - async def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running.""" - - async def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when tool errors.""" - - async def on_text(self, text: str, **kwargs: Any) -> None: + text: str, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: """Run on arbitrary text.""" - async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None: - """Run on agent action.""" - async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Run on agent end.""" +class BaseCallbackHandler( + LLMManagerMixin, + ChainManagerMixin, + ToolManagerMixin, + CallbackManagerMixin, + RunManagerMixin, +): + """Base callback handler that can be used to handle callbacks from langchain.""" + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return False -async def _ahandle_event_for_handler( - handler: BaseCallbackHandler, - event_name: str, - ignore_condition_name: Optional[str], - verbose: bool, - *args: Any, - **kwargs: Any -) -> None: - if ignore_condition_name is None or not getattr(handler, ignore_condition_name): - if verbose or handler.always_verbose: - event = getattr(handler, event_name) - if asyncio.iscoroutinefunction(event): - await event(*args, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, functools.partial(event, *args, **kwargs) - ) - - -class AsyncCallbackManager(BaseCallbackManager): - """Async callback manager that can be used to handle callbacks from LangChain.""" + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return False @property - def is_async(self) -> bool: - """Return whether the handler is async.""" - return True + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return False - def __init__(self, handlers: List[BaseCallbackHandler]) -> None: - """Initialize callback manager.""" - self.handlers: List[BaseCallbackHandler] = handlers - async def _handle_event( - self, - event_name: str, - ignore_condition_name: Optional[str], - verbose: bool, - *args: Any, - **kwargs: Any - ) -> None: - """Generic event handler for AsyncCallbackManager.""" - await asyncio.gather( - *( - _ahandle_event_for_handler( - handler, event_name, ignore_condition_name, verbose, *args, **kwargs - ) - for handler in self.handlers - ) - ) +class AsyncCallbackHandler(BaseCallbackHandler): + """Async callback handler that can be used to handle callbacks from langchain.""" async def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], - verbose: bool = False, - **kwargs: Any + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: """Run when LLM starts running.""" - await self._handle_event( - "on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs - ) async def on_llm_new_token( - self, token: str, verbose: bool = False, **kwargs: Any + self, + token: str, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: """Run on new LLM token. Only available when streaming is enabled.""" - await self._handle_event( - "on_llm_new_token", "ignore_llm", verbose, token, **kwargs - ) async def on_llm_end( - self, response: LLMResult, verbose: bool = False, **kwargs: Any + self, + response: LLMResult, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: """Run when LLM ends running.""" - await self._handle_event( - "on_llm_end", "ignore_llm", verbose, response, **kwargs - ) async def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: """Run when LLM errors.""" - await self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs) async def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], - verbose: bool = False, - **kwargs: Any + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: """Run when chain starts running.""" - await self._handle_event( - "on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs - ) async def on_chain_end( - self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any + self, + outputs: Dict[str, Any], + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: """Run when chain ends running.""" - await self._handle_event( - "on_chain_end", "ignore_chain", verbose, outputs, **kwargs - ) async def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: """Run when chain errors.""" - await self._handle_event( - "on_chain_error", "ignore_chain", verbose, error, **kwargs - ) async def on_tool_start( self, serialized: Dict[str, Any], input_str: str, - verbose: bool = False, - **kwargs: Any + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: """Run when tool starts running.""" - await self._handle_event( - "on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs - ) async def on_tool_end( - self, output: str, verbose: bool = False, **kwargs: Any + self, + output: str, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: """Run when tool ends running.""" - await self._handle_event( - "on_tool_end", "ignore_agent", verbose, output, **kwargs - ) async def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: """Run when tool errors.""" - await self._handle_event( - "on_tool_error", "ignore_agent", verbose, error, **kwargs - ) - async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None: - """Run when text is printed.""" - await self._handle_event("on_text", None, verbose, text, **kwargs) + async def on_text( + self, + text: str, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Run on arbitrary text.""" async def on_agent_action( - self, action: AgentAction, verbose: bool = False, **kwargs: Any + self, + action: AgentAction, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: """Run on agent action.""" - await self._handle_event( - "on_agent_action", "ignore_agent", verbose, action, **kwargs - ) async def on_agent_finish( - self, finish: AgentFinish, verbose: bool = False, **kwargs: Any + self, + finish: AgentFinish, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, ) -> None: - """Run when agent finishes.""" - await self._handle_event( - "on_agent_finish", "ignore_agent", verbose, finish, **kwargs + """Run on agent end.""" + + +class BaseCallbackManager(CallbackManagerMixin): + """Base callback manager that can be used to handle callbacks from LangChain.""" + + def __init__( + self, + handlers: List[BaseCallbackHandler], + inheritable_handlers: List[BaseCallbackHandler] = None, + parent_run_id: Optional[str] = None, + ) -> None: + """Initialize callback manager.""" + self.handlers: List[BaseCallbackHandler] = handlers + self.inheritable_handlers: List[BaseCallbackHandler] = ( + inheritable_handlers or [] ) + self.parent_run_id: Optional[str] = parent_run_id + + @property + def is_async(self) -> bool: + """Whether the callback manager is async.""" + return False - def add_handler(self, handler: BaseCallbackHandler) -> None: + def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: """Add a handler to the callback manager.""" self.handlers.append(handler) + if inherit: + self.inheritable_handlers.append(handler) def remove_handler(self, handler: BaseCallbackHandler) -> None: """Remove a handler from the callback manager.""" self.handlers.remove(handler) + self.inheritable_handlers.remove(handler) - def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: + def set_handlers(self, handlers: List[BaseCallbackHandler], inherit=True) -> None: """Set handlers as the only handlers on the callback manager.""" - self.handlers = handlers + self.handlers = [] + self.inheritable_handlers = [] + for handler in handlers: + self.add_handler(handler, inherit=inherit) + + def set_handler(self, handler: BaseCallbackHandler, inherit=True) -> None: + """Set handler as the only handler on the callback manager.""" + self.set_handlers([handler], inherit=inherit) + + def __copy__(self): + return self.__class__( + self.handlers.copy(), self.inheritable_handlers.copy(), self.parent_run_id + ) + + def __deepcopy__(self, memo): + return self.__class__( + [copy.deepcopy(handler, memo) for handler in self.handlers], + [copy.deepcopy(handler, memo) for handler in self.inheritable_handlers], + self.parent_run_id, + ) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py new file mode 100644 index 0000000000000..fdce94d3f5823 --- /dev/null +++ b/langchain/callbacks/manager.py @@ -0,0 +1,652 @@ +from __future__ import annotations + +import asyncio +import copy +import functools +import logging +import os +import uuid +from typing import Any, Dict, List, Optional, Type, TypeVar, Union + +from langchain.callbacks.base import ( + BaseCallbackHandler, + BaseCallbackManager, + ChainManagerMixin, + LLMManagerMixin, + RunManagerMixin, + ToolManagerMixin, +) +from langchain.callbacks.stdout import StdOutCallbackHandler +from langchain.callbacks.tracers.langchain import LangChainTracer +from langchain.schema import AgentAction, AgentFinish, LLMResult + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] - %(message)s", + handlers=[logging.StreamHandler()], +) + + +def _handle_event( + handlers: List[BaseCallbackHandler], + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + for handler in handlers: + try: + if ignore_condition_name is None or not getattr( + handler, ignore_condition_name + ): + getattr(handler, event_name)(*args, **kwargs) + except Exception as e: + logging.error(f"Error in {event_name} callback: {e}") + + +async def _ahandle_event_for_handler( + handler: BaseCallbackHandler, + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + try: + if ignore_condition_name is None or not getattr(handler, ignore_condition_name): + event = getattr(handler, event_name) + if asyncio.iscoroutinefunction(event): + await event(*args, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, functools.partial(event, *args, **kwargs) + ) + except Exception as e: + logging.error(f"Error in {event_name} callback: {e}") + + +async def _ahandle_event( + handlers: List[BaseCallbackHandler], + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + """Generic event handler for AsyncCallbackManager.""" + await asyncio.gather( + *( + _ahandle_event_for_handler( + handler, event_name, ignore_condition_name, *args, **kwargs + ) + for handler in handlers + ) + ) + + +class BaseRunManager(RunManagerMixin): + """Base class for run manager (a bound callback manager).""" + + def __init__( + self, + run_id: str, + handlers: List[BaseCallbackHandler], + inheritable_handlers: List[BaseCallbackHandler], + parent_run_id: str, + ) -> None: + """Initialize run manager.""" + self.run_id = run_id + self.handlers = handlers + self.inheritable_handlers = inheritable_handlers + self.parent_run_id = parent_run_id + + +class RunManager(BaseRunManager): + """Sync Run Manager.""" + + def on_text(self, text: str, **kwargs: Any) -> Any: + """Run when text is received.""" + _handle_event(self.handlers, "on_text", None, False, text, **kwargs) + + +class AsyncRunManager(BaseRunManager): + """Async Run Manager.""" + + async def on_text(self, text: str, **kwargs: Any) -> Any: + """Run when text is received.""" + await _ahandle_event(self.handlers, "on_text", None, False, text, **kwargs) + + +class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): + """Callback manager for LLM run.""" + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run when LLM generates a new token.""" + _handle_event( + self.handlers, + "on_llm_new_token", + "ignore_llm", + token, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + _handle_event( + self.handlers, + "on_llm_end", + "ignore_llm", + response, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_llm_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when LLM errors.""" + _handle_event( + self.handlers, + "on_llm_error", + "ignore_llm", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): + """Async callback manager for LLM run.""" + + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run when LLM generates a new token.""" + await _ahandle_event( + self.handlers, + "on_llm_new_token", + "ignore_llm", + token, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + await _ahandle_event( + self.handlers, + "on_llm_end", + "ignore_llm", + response, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_llm_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when LLM errors.""" + await _ahandle_event( + self.handlers, + "on_llm_error", + "ignore_llm", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class CallbackManagerForChainRun(RunManager, ChainManagerMixin): + """Callback manager for chain run.""" + + def get_child(self) -> CallbackManager: + """Get a child callback manager.""" + manager = CallbackManager([], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + return manager + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + _handle_event( + self.handlers, + "on_chain_end", + "ignore_chain", + outputs, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_chain_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when chain errors.""" + _handle_event( + self.handlers, + "on_chain_error", + "ignore_chain", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run when agent action is received.""" + _handle_event( + self.handlers, + "on_agent_action", + "ignore_agent", + action, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + """Run when agent finish is received.""" + _handle_event( + self.handlers, + "on_agent_finish", + "ignore_agent", + finish, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin): + """Async callback manager for chain run.""" + + def get_child(self) -> AsyncCallbackManager: + """Get a child callback manager.""" + manager = AsyncCallbackManager([], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + return manager + + async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + await _ahandle_event( + self.handlers, + "on_chain_end", + "ignore_chain", + outputs, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_chain_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when chain errors.""" + await _ahandle_event( + self.handlers, + "on_chain_error", + "ignore_chain", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run when agent action is received.""" + await _ahandle_event( + self.handlers, + "on_agent_action", + "ignore_agent", + action, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + """Run when agent finish is received.""" + await _ahandle_event( + self.handlers, + "on_agent_finish", + "ignore_agent", + finish, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class CallbackManagerForToolRun(RunManager, ToolManagerMixin): + """Callback manager for tool run.""" + + def get_child(self) -> CallbackManager: + """Get a child callback manager.""" + manager = CallbackManager([], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + return manager + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + _handle_event( + self.handlers, + "on_tool_end", + "ignore_agent", + output, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_tool_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when tool errors.""" + _handle_event( + self.handlers, + "on_tool_error", + "ignore_agent", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin): + """Async callback manager for tool run.""" + + def get_child(self) -> AsyncCallbackManager: + """Get a child callback manager.""" + manager = AsyncCallbackManager([], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + return manager + + async def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + await _ahandle_event( + self.handlers, + "on_tool_end", + "ignore_agent", + output, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_tool_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when tool errors.""" + await _ahandle_event( + self.handlers, + "on_tool_error", + "ignore_agent", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class CallbackManager(BaseCallbackManager): + """Callback manager that can be used to handle callbacks from langchain.""" + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + run_id: Optional[str] = None, + **kwargs: Any, + ) -> CallbackManagerForLLMRun: + """Run when LLM starts running.""" + if run_id is None: + run_id = uuid.uuid4() + + _handle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + prompts, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + return CallbackManagerForLLMRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + run_id: Optional[str] = None, + **kwargs: Any, + ) -> CallbackManagerForChainRun: + """Run when chain starts running.""" + if run_id is None: + run_id = uuid.uuid4() + + _handle_event( + self.handlers, + "on_chain_start", + "ignore_chain", + serialized, + inputs, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + return CallbackManagerForChainRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> CallbackManagerForToolRun: + """Run when tool starts running.""" + if run_id is None: + run_id = uuid.uuid4() + + _handle_event( + self.handlers, + "on_tool_start", + "ignore_agent", + serialized, + input_str, + run_id=run_id, + parent_run_id=parent_run_id, + **kwargs, + ) + + return CallbackManagerForToolRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + @classmethod + def configure( + cls, + inheritable_callbacks: Optional[ + Union[BaseCallbackManager, List[BaseCallbackHandler]] + ] = None, + local_callbacks: Optional[ + Union[BaseCallbackManager, List[BaseCallbackHandler]] + ] = None, + verbose: bool = False, + ) -> Optional[BaseCallbackManager]: + """Configure the callback manager.""" + return _configure(cls, inheritable_callbacks, local_callbacks, verbose) + + +class AsyncCallbackManager(BaseCallbackManager): + """Async callback manager that can be used to handle callbacks from LangChain.""" + + @property + def is_async(self) -> bool: + """Return whether the handler is async.""" + return True + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + run_id: Optional[str] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForLLMRun: + """Run when LLM starts running.""" + if run_id is None: + run_id = uuid.uuid4() + + await _ahandle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + prompts, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + return AsyncCallbackManagerForLLMRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + async def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + run_id: Optional[str] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForChainRun: + """Run when chain starts running.""" + if run_id is None: + run_id = uuid.uuid4() + + await _ahandle_event( + self.handlers, + "on_chain_start", + "ignore_chain", + serialized, + inputs, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + return AsyncCallbackManagerForChainRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForToolRun: + """Run when tool starts running.""" + if run_id is None: + run_id = uuid.uuid4() + + await _ahandle_event( + self.handlers, + "on_tool_start", + "ignore_agent", + serialized, + input_str, + run_id=run_id, + parent_run_id=parent_run_id, + **kwargs, + ) + + return AsyncCallbackManagerForToolRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + @classmethod + def configure( + cls, + inheritable_callbacks: Optional[ + Union[BaseCallbackManager, List[BaseCallbackHandler]] + ] = None, + local_callbacks: Optional[ + Union[BaseCallbackManager, List[BaseCallbackHandler]] + ] = None, + verbose: bool = False, + ) -> Optional[BaseCallbackManager]: + """Configure the callback manager.""" + return _configure(cls, inheritable_callbacks, local_callbacks, verbose) + + +T = TypeVar("T", CallbackManager, AsyncCallbackManager) + + +def _configure( + callback_manager_cls: Type[T], + inheritable_callbacks: Optional[Union[T, List[BaseCallbackHandler]]] = None, + local_callbacks: Optional[Union[T, List[BaseCallbackHandler]]] = None, + verbose: bool = False, +) -> Optional[T]: + """Configure the callback manager.""" + callback_manager: Optional[T] = None + if inheritable_callbacks or local_callbacks: + if isinstance(inheritable_callbacks, list) or not inheritable_callbacks: + callback_manager = callback_manager_cls( + handlers=inheritable_callbacks, + inheritable_handlers=inheritable_callbacks, + ) + else: + callback_manager = inheritable_callbacks + callback_manager = copy.deepcopy(callback_manager) + local_handlers_ = ( + local_callbacks + if isinstance(local_callbacks, list) + else (local_callbacks.handlers if local_callbacks else []) + ) + [callback_manager.add_handler(handler, False) for handler in local_handlers_] + + tracing_enabled = os.environ.get("LANGCHAIN_TRACING") is not None + if verbose or tracing_enabled: + if not callback_manager: + callback_manager = callback_manager_cls([]) + std_out_handler = StdOutCallbackHandler() + + if verbose and not any( + isinstance(handler, StdOutCallbackHandler) + for handler in callback_manager.handlers + ): + callback_manager.add_handler(std_out_handler, False) + + if tracing_enabled and not any( + isinstance(handler, LangChainTracer) + for handler in callback_manager.handlers + ): + handler = LangChainTracer() + handler.load_default_session() + callback_manager.add_handler(handler, True) + + return callback_manager diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py deleted file mode 100644 index 225b183e6b342..0000000000000 --- a/langchain/callbacks/shared.py +++ /dev/null @@ -1,127 +0,0 @@ -"""A shared CallbackManager.""" - -import threading -from typing import Any, Dict, List, Union - -from langchain.callbacks.base import ( - BaseCallbackHandler, - BaseCallbackManager, - CallbackManager, -) -from langchain.schema import AgentAction, AgentFinish, LLMResult - - -class Singleton: - """A thread-safe singleton class that can be inherited from.""" - - _instance = None - _lock = threading.Lock() - - def __new__(cls) -> Any: - """Create a new shared instance of the class.""" - if cls._instance is None: - with cls._lock: - # Another thread could have created the instance - # before we acquired the lock. So check that the - # instance is still nonexistent. - if not cls._instance: - cls._instance = super().__new__(cls) - return cls._instance - - -class SharedCallbackManager(Singleton, BaseCallbackManager): - """A thread-safe singleton CallbackManager.""" - - _callback_manager: CallbackManager = CallbackManager(handlers=[]) - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Run when LLM starts running.""" - with self._lock: - self._callback_manager.on_llm_start(serialized, prompts, **kwargs) - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" - with self._lock: - self._callback_manager.on_llm_end(response, **kwargs) - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run when LLM generates a new token.""" - with self._lock: - self._callback_manager.on_llm_new_token(token, **kwargs) - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when LLM errors.""" - with self._lock: - self._callback_manager.on_llm_error(error, **kwargs) - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Run when chain starts running.""" - with self._lock: - self._callback_manager.on_chain_start(serialized, inputs, **kwargs) - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Run when chain ends running.""" - with self._lock: - self._callback_manager.on_chain_end(outputs, **kwargs) - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when chain errors.""" - with self._lock: - self._callback_manager.on_chain_error(error, **kwargs) - - def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> None: - """Run when tool starts running.""" - with self._lock: - self._callback_manager.on_tool_start(serialized, input_str, **kwargs) - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run on agent action.""" - with self._lock: - self._callback_manager.on_agent_action(action, **kwargs) - - def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running.""" - with self._lock: - self._callback_manager.on_tool_end(output, **kwargs) - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when tool errors.""" - with self._lock: - self._callback_manager.on_tool_error(error, **kwargs) - - def on_text(self, text: str, **kwargs: Any) -> None: - """Run on arbitrary text.""" - with self._lock: - self._callback_manager.on_text(text, **kwargs) - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Run on agent end.""" - with self._lock: - self._callback_manager.on_agent_finish(finish, **kwargs) - - def add_handler(self, callback: BaseCallbackHandler) -> None: - """Add a callback to the callback manager.""" - with self._lock: - self._callback_manager.add_handler(callback) - - def remove_handler(self, callback: BaseCallbackHandler) -> None: - """Remove a callback from the callback manager.""" - with self._lock: - self._callback_manager.remove_handler(callback) - - def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: - """Set handlers as the only handlers on the callback manager.""" - with self._lock: - self._callback_manager.handlers = handlers diff --git a/langchain/callbacks/tracers/__init__.py b/langchain/callbacks/tracers/__init__.py index 8db5367fdf583..5dd69b4871317 100644 --- a/langchain/callbacks/tracers/__init__.py +++ b/langchain/callbacks/tracers/__init__.py @@ -1,12 +1,5 @@ """Tracers that record execution of LangChain runs.""" -from langchain.callbacks.tracers.base import SharedTracer, Tracer -from langchain.callbacks.tracers.langchain import BaseLangChainTracer +from langchain.callbacks.tracers.langchain import LangChainTracer - -class SharedLangChainTracer(SharedTracer, BaseLangChainTracer): - """Shared tracer that records LangChain execution to LangChain endpoint.""" - - -class LangChainTracer(Tracer, BaseLangChainTracer): - """Tracer that records LangChain execution to LangChain endpoint.""" +__all__ = ["LangChainTracer"] diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 2a99c1c824c46..0b2b99dd6483b 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -1,14 +1,11 @@ """Base interfaces for tracing runs.""" from __future__ import annotations -import threading from abc import ABC, abstractmethod -from dataclasses import dataclass, field from datetime import datetime from typing import Any, Dict, List, Optional, Union from langchain.callbacks.base import BaseCallbackHandler -from langchain.callbacks.shared import Singleton from langchain.callbacks.tracers.schemas import ( ChainRun, LLMRun, @@ -293,51 +290,3 @@ def _session(self, value: TracerSession) -> None: "Cannot set a session while a trace is being recorded" ) self._tracer_session = value - - -@dataclass -class TracerStack(threading.local): - """A stack of runs used for logging.""" - - stack: List[Union[LLMRun, ChainRun, ToolRun]] = field(default_factory=list) - execution_order: int = 1 - - -class SharedTracer(Singleton, BaseTracer, ABC): - """A thread-safe Singleton implementation of BaseTracer.""" - - _tracer_stack = TracerStack() - _tracer_session = None - - @property - def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]: - """Get the tracer stack.""" - return self._tracer_stack.stack - - @property - def _execution_order(self) -> int: - """Get the execution order for a run.""" - return self._tracer_stack.execution_order - - @_execution_order.setter - def _execution_order(self, value: int) -> None: - """Set the execution order for a run.""" - self._tracer_stack.execution_order = value - - @property - def _session(self) -> Optional[TracerSession]: - """Get the tracing session.""" - return self._tracer_session - - @_session.setter - def _session(self, value: TracerSession) -> None: - """Set the tracing session.""" - with self._lock: - # TODO: currently, we are only checking current thread's stack. - # Need to make sure that we are not in the middle of a trace - # in any thread. - if self._stack: - raise TracerException( - "Cannot set a session while a trace is being recorded" - ) - self._tracer_session = value diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index d25022041aab1..a45ccd3b42e90 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -8,7 +8,7 @@ import requests -from langchain.callbacks.tracers.base import BaseTracer +from langchain.callbacks.tracers.base import BaseTracer, Tracer from langchain.callbacks.tracers.schemas import ( ChainRun, LLMRun, @@ -110,3 +110,7 @@ def _add_child_run( def _generate_id(self) -> Optional[Union[int, str]]: """Generate an id for a run.""" return None + + +class LangChainTracer(Tracer, BaseLangChainTracer): + """Tracer that records LangChain execution to LangChain endpoint.""" diff --git a/tests/integration_tests/chat_models/test_anthropic.py b/tests/integration_tests/chat_models/test_anthropic.py index 60fe58f319f8d..a7186c8be6b86 100644 --- a/tests/integration_tests/chat_models/test_anthropic.py +++ b/tests/integration_tests/chat_models/test_anthropic.py @@ -3,7 +3,7 @@ import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.chat_models.anthropic import ChatAnthropic from langchain.schema import ( AIMessage, diff --git a/tests/integration_tests/chat_models/test_openai.py b/tests/integration_tests/chat_models/test_openai.py index 06394ebca076c..432c2c88d3328 100644 --- a/tests/integration_tests/chat_models/test_openai.py +++ b/tests/integration_tests/chat_models/test_openai.py @@ -3,7 +3,7 @@ import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.chat_models.openai import ChatOpenAI from langchain.schema import ( BaseMessage, diff --git a/tests/integration_tests/chat_models/test_promptlayer_openai.py b/tests/integration_tests/chat_models/test_promptlayer_openai.py index c9962f75a3890..ab68a0850b691 100644 --- a/tests/integration_tests/chat_models/test_promptlayer_openai.py +++ b/tests/integration_tests/chat_models/test_promptlayer_openai.py @@ -2,7 +2,7 @@ import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI from langchain.schema import ( BaseMessage, diff --git a/tests/integration_tests/llms/test_anthropic.py b/tests/integration_tests/llms/test_anthropic.py index 8c7717cfc7d57..851a94ed00881 100644 --- a/tests/integration_tests/llms/test_anthropic.py +++ b/tests/integration_tests/llms/test_anthropic.py @@ -3,7 +3,7 @@ import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.llms.anthropic import Anthropic from langchain.schema import LLMResult from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 9db120a598448..e10a9c0b33864 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -5,7 +5,7 @@ import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.llms.loading import load_llm from langchain.llms.openai import OpenAI, OpenAIChat from langchain.schema import LLMResult diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index f30b3f05ce784..4b0f736a4a43b 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -4,7 +4,7 @@ from langchain.agents import AgentExecutor, AgentType, initialize_agent from langchain.agents.tools import Tool -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.llms.base import LLM from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 921596e7563d4..6dd92a2621b19 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -1,10 +1,9 @@ """A fake callback handler for testing purposes.""" -from typing import Any, Dict, List, Union +from typing import Any from pydantic import BaseModel from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult class BaseFakeCallbackHandler(BaseModel): @@ -17,27 +16,6 @@ class BaseFakeCallbackHandler(BaseModel): ignore_llm_: bool = False ignore_chain_: bool = False ignore_agent_: bool = False - always_verbose_: bool = False - - @property - def always_verbose(self) -> bool: - """Whether to call verbose callbacks even if verbose is False.""" - return self.always_verbose_ - - @property - def ignore_llm(self) -> bool: - """Whether to ignore LLM callbacks.""" - return self.ignore_llm_ - - @property - def ignore_chain(self) -> bool: - """Whether to ignore chain callbacks.""" - return self.ignore_chain_ - - @property - def ignore_agent(self) -> bool: - """Whether to ignore agent callbacks.""" - return self.ignore_agent_ # add finer-grained counters for easier debugging of failing tests chain_starts: int = 0 @@ -47,156 +25,282 @@ def ignore_agent(self) -> bool: llm_streams: int = 0 tool_starts: int = 0 tool_ends: int = 0 + agent_actions: int = 0 agent_ends: int = 0 -class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler): - """Fake callback handler for testing.""" +class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): + """Base fake callback handler mixin for testing.""" - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Run when LLM starts running.""" + def on_llm_start_common(self) -> None: self.llm_starts += 1 self.starts += 1 - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run when LLM generates a new token.""" - self.llm_streams += 1 - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" + def on_llm_end_common(self) -> None: self.llm_ends += 1 self.ends += 1 - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when LLM errors.""" + def on_llm_error_common(self) -> None: self.errors += 1 - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Run when chain starts running.""" + def on_llm_new_token_common(self) -> None: + self.llm_streams += 1 + + def on_chain_start_common(self) -> None: self.chain_starts += 1 self.starts += 1 - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Run when chain ends running.""" + def on_chain_end_common(self) -> None: self.chain_ends += 1 self.ends += 1 - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when chain errors.""" + def on_chain_error_common(self) -> None: self.errors += 1 - def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> None: - """Run when tool starts running.""" + def on_tool_start_common(self) -> None: self.tool_starts += 1 self.starts += 1 - def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running.""" + def on_tool_end_common(self) -> None: self.tool_ends += 1 self.ends += 1 - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when tool errors.""" + def on_tool_error_common(self) -> None: self.errors += 1 - def on_text(self, text: str, **kwargs: Any) -> None: - """Run when agent is ending.""" - self.text += 1 + def on_agent_action_common(self) -> None: + self.agent_actions += 1 + self.starts += 1 - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Run when agent ends running.""" + def on_agent_finish_common(self) -> None: self.agent_ends += 1 self.ends += 1 - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run on agent action.""" - self.tool_starts += 1 - self.starts += 1 + def on_text_common(self) -> None: + self.text += 1 + + +class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): + """Fake callback handler for testing.""" + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return self.ignore_llm_ + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return self.ignore_chain_ + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return self.ignore_agent_ -class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler): + def on_llm_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_start_common() + + def on_llm_new_token( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_new_token_common() + + def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_end_common() + + def on_llm_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_error_common() + + def on_chain_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_start_common() + + def on_chain_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_end_common() + + def on_chain_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_error_common() + + def on_tool_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_start_common() + + def on_tool_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_end_common() + + def on_tool_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_error_common() + + def on_agent_action( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_agent_action_common() + + def on_agent_finish( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_agent_finish_common() + + def on_text( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_text_common() + + def __deepcopy__(self, memo): + return self + + +class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin): """Fake async callback handler for testing.""" + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return self.ignore_llm_ + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return self.ignore_chain_ + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return self.ignore_agent_ + async def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when LLM starts running.""" - self.llm_starts += 1 - self.starts += 1 + self.on_llm_start_common() - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run when LLM generates a new token.""" - self.llm_streams += 1 + async def on_llm_new_token( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_new_token_common() - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" - self.llm_ends += 1 - self.ends += 1 + async def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_end_common() async def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when LLM errors.""" - self.errors += 1 + self.on_llm_error_common() async def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when chain starts running.""" - self.chain_starts += 1 - self.starts += 1 + self.on_chain_start_common() - async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Run when chain ends running.""" - self.chain_ends += 1 - self.ends += 1 + async def on_chain_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_chain_end_common() async def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when chain errors.""" - self.errors += 1 + self.on_chain_error_common() async def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when tool starts running.""" - self.tool_starts += 1 - self.starts += 1 + self.on_tool_start_common() - async def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running.""" - self.tool_ends += 1 - self.ends += 1 + async def on_tool_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_tool_end_common() async def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when tool errors.""" - self.errors += 1 + self.on_tool_error_common() - async def on_text(self, text: str, **kwargs: Any) -> None: - """Run when agent is ending.""" - self.text += 1 + async def on_agent_action( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_agent_action_common() - async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Run when agent ends running.""" - self.agent_ends += 1 - self.ends += 1 + async def on_agent_finish( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_agent_finish_common() - async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None: - """Run on agent action.""" - self.tool_starts += 1 - self.starts += 1 + async def on_text( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_text_common() + + def __deepcopy__(self, memo): + return self diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 0f61fdd30f28f..df72c1ea8e6ec 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -3,13 +3,9 @@ import pytest -from langchain.callbacks.base import ( - AsyncCallbackManager, - BaseCallbackManager, - CallbackManager, -) -from langchain.callbacks.shared import SharedCallbackManager -from langchain.schema import AgentFinish, LLMResult +from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager +from langchain.callbacks.stdout import StdOutCallbackHandler +from langchain.schema import AgentAction, AgentFinish, LLMResult from tests.unit_tests.callbacks.fake_callback_handler import ( BaseFakeCallbackHandler, FakeAsyncCallbackHandler, @@ -18,19 +14,26 @@ def _test_callback_manager( - manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler + manager: CallbackManager, *handlers: BaseFakeCallbackHandler ) -> None: """Test the CallbackManager.""" - manager.on_llm_start({}, []) - manager.on_llm_end(LLMResult(generations=[])) - manager.on_llm_error(Exception()) - manager.on_chain_start({"name": "foo"}, {}) - manager.on_chain_end({}) - manager.on_chain_error(Exception()) - manager.on_tool_start({}, "") - manager.on_tool_end("") - manager.on_tool_error(Exception()) - manager.on_agent_finish(AgentFinish(log="", return_values={})) + run_manager = manager.on_llm_start({}, []) + run_manager.on_llm_end(LLMResult(generations=[])) + run_manager.on_llm_error(Exception()) + run_manager.on_llm_new_token("foo") + run_manager.on_text("foo") + + run_manager = manager.on_chain_start({"name": "foo"}, {}) + run_manager.on_chain_end({}) + run_manager.on_chain_error(Exception()) + run_manager.on_agent_action(AgentAction(tool_input="foo", log="", tool="")) + run_manager.on_agent_finish(AgentFinish(log="", return_values={})) + run_manager.on_text("foo") + + run_manager = manager.on_tool_start({}, "") + run_manager.on_tool_end("") + run_manager.on_tool_error(Exception()) + run_manager.on_text("foo") _check_num_calls(handlers) @@ -38,75 +41,60 @@ async def _test_callback_manager_async( manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler ) -> None: """Test the CallbackManager.""" - await manager.on_llm_start({}, []) - await manager.on_llm_end(LLMResult(generations=[])) - await manager.on_llm_error(Exception()) - await manager.on_chain_start({"name": "foo"}, {}) - await manager.on_chain_end({}) - await manager.on_chain_error(Exception()) - await manager.on_tool_start({}, "") - await manager.on_tool_end("") - await manager.on_tool_error(Exception()) - await manager.on_agent_finish(AgentFinish(log="", return_values={})) + run_manager = await manager.on_llm_start({}, []) + await run_manager.on_llm_end(LLMResult(generations=[])) + await run_manager.on_llm_error(Exception()) + await run_manager.on_llm_new_token("foo") + await run_manager.on_text("foo") + + run_manager = await manager.on_chain_start({"name": "foo"}, {}) + await run_manager.on_chain_end({}) + await run_manager.on_chain_error(Exception()) + await run_manager.on_agent_action(AgentAction(tool_input="foo", log="", tool="")) + await run_manager.on_agent_finish(AgentFinish(log="", return_values={})) + await run_manager.on_text("foo") + + run_manager = await manager.on_tool_start({}, "") + await run_manager.on_tool_end("") + await run_manager.on_tool_error(Exception()) + await run_manager.on_text("foo") _check_num_calls(handlers) def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None: for handler in handlers: - if handler.always_verbose: - assert handler.starts == 3 - assert handler.ends == 4 - assert handler.errors == 3 - else: - assert handler.starts == 0 - assert handler.ends == 0 - assert handler.errors == 0 - - -def _test_callback_manager_pass_in_verbose( - manager: BaseCallbackManager, *handlers: FakeCallbackHandler -) -> None: - """Test the CallbackManager.""" - manager.on_llm_start({}, [], verbose=True) - manager.on_llm_end(LLMResult(generations=[]), verbose=True) - manager.on_llm_error(Exception(), verbose=True) - manager.on_chain_start({"name": "foo"}, {}, verbose=True) - manager.on_chain_end({}, verbose=True) - manager.on_chain_error(Exception(), verbose=True) - manager.on_tool_start({}, "", verbose=True) - manager.on_tool_end("", verbose=True) - manager.on_tool_error(Exception(), verbose=True) - manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True) - for handler in handlers: - assert handler.starts == 3 + assert handler.starts == 4 assert handler.ends == 4 assert handler.errors == 3 + assert handler.text == 3 + assert handler.llm_starts == 1 + assert handler.llm_ends == 1 + assert handler.llm_streams == 1 -def test_callback_manager() -> None: - """Test the CallbackManager.""" - handler1 = FakeCallbackHandler(always_verbose_=True) - handler2 = FakeCallbackHandler(always_verbose_=False) - manager = CallbackManager([handler1, handler2]) - _test_callback_manager(manager, handler1, handler2) + assert handler.chain_starts == 1 + assert handler.chain_ends == 1 + + assert handler.tool_starts == 1 + assert handler.tool_ends == 1 -def test_callback_manager_pass_in_verbose() -> None: +def test_callback_manager() -> None: """Test the CallbackManager.""" handler1 = FakeCallbackHandler() handler2 = FakeCallbackHandler() manager = CallbackManager([handler1, handler2]) - _test_callback_manager_pass_in_verbose(manager, handler1, handler2) + _test_callback_manager(manager, handler1, handler2) def test_ignore_llm() -> None: """Test ignore llm param for callback handlers.""" - handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True) - handler2 = FakeCallbackHandler(always_verbose_=True) + handler1 = FakeCallbackHandler(ignore_llm_=True) + handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - manager.on_llm_start({}, [], verbose=True) - manager.on_llm_end(LLMResult(generations=[]), verbose=True) - manager.on_llm_error(Exception(), verbose=True) + run_manager = manager.on_llm_start({}, []) + run_manager.on_llm_end(LLMResult(generations=[])) + run_manager.on_llm_error(Exception()) assert handler1.starts == 0 assert handler1.ends == 0 assert handler1.errors == 0 @@ -117,12 +105,12 @@ def test_ignore_llm() -> None: def test_ignore_chain() -> None: """Test ignore chain param for callback handlers.""" - handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True) - handler2 = FakeCallbackHandler(always_verbose_=True) + handler1 = FakeCallbackHandler(ignore_chain_=True) + handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - manager.on_chain_start({"name": "foo"}, {}, verbose=True) - manager.on_chain_end({}, verbose=True) - manager.on_chain_error(Exception(), verbose=True) + run_manager = manager.on_chain_start({"name": "foo"}, {}) + run_manager.on_chain_end({}) + run_manager.on_chain_error(Exception()) assert handler1.starts == 0 assert handler1.ends == 0 assert handler1.errors == 0 @@ -133,39 +121,24 @@ def test_ignore_chain() -> None: def test_ignore_agent() -> None: """Test ignore agent param for callback handlers.""" - handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True) - handler2 = FakeCallbackHandler(always_verbose_=True) + handler1 = FakeCallbackHandler(ignore_agent_=True) + handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - manager.on_tool_start({}, "", verbose=True) - manager.on_tool_end("", verbose=True) - manager.on_tool_error(Exception(), verbose=True) - manager.on_agent_finish(AgentFinish({}, ""), verbose=True) + run_manager = manager.on_tool_start({}, "") + run_manager.on_tool_end("") + run_manager.on_tool_error(Exception()) assert handler1.starts == 0 assert handler1.ends == 0 assert handler1.errors == 0 assert handler2.starts == 1 - assert handler2.ends == 2 + assert handler2.ends == 1 assert handler2.errors == 1 -def test_shared_callback_manager() -> None: - """Test the SharedCallbackManager.""" - manager1 = SharedCallbackManager() - manager2 = SharedCallbackManager() - - assert manager1 is manager2 - - handler1 = FakeCallbackHandler(always_verbose_=True) - handler2 = FakeCallbackHandler() - manager1.add_handler(handler1) - manager2.add_handler(handler2) - _test_callback_manager(manager1, handler1, handler2) - - @pytest.mark.asyncio async def test_async_callback_manager() -> None: """Test the AsyncCallbackManager.""" - handler1 = FakeAsyncCallbackHandler(always_verbose_=True) + handler1 = FakeAsyncCallbackHandler() handler2 = FakeAsyncCallbackHandler() manager = AsyncCallbackManager([handler1, handler2]) await _test_callback_manager_async(manager, handler1, handler2) @@ -174,8 +147,95 @@ async def test_async_callback_manager() -> None: @pytest.mark.asyncio async def test_async_callback_manager_sync_handler() -> None: """Test the AsyncCallbackManager.""" - handler1 = FakeCallbackHandler(always_verbose_=True) + handler1 = FakeCallbackHandler() handler2 = FakeAsyncCallbackHandler() - handler3 = FakeAsyncCallbackHandler(always_verbose_=True) + handler3 = FakeAsyncCallbackHandler() manager = AsyncCallbackManager([handler1, handler2, handler3]) await _test_callback_manager_async(manager, handler1, handler2, handler3) + + +def test_callback_manager_inheritance() -> None: + handler1, handler2, handler3, handler4 = ( + FakeCallbackHandler(), + FakeCallbackHandler(), + FakeCallbackHandler(), + FakeCallbackHandler(), + ) + + callback_manager1 = CallbackManager([handler1, handler2]) + assert callback_manager1.handlers == [handler1, handler2] + assert callback_manager1.inheritable_handlers == [] + + callback_manager2 = CallbackManager([]) + assert callback_manager2.handlers == [] + assert callback_manager2.inheritable_handlers == [] + + callback_manager2.set_handlers([handler1, handler2]) + assert callback_manager2.handlers == [handler1, handler2] + assert callback_manager2.inheritable_handlers == [handler1, handler2] + + callback_manager2.set_handlers([handler3, handler4], inherit=False) + assert callback_manager2.handlers == [handler3, handler4] + assert callback_manager2.inheritable_handlers == [] + + callback_manager2.add_handler(handler1) + assert callback_manager2.handlers == [handler3, handler4, handler1] + assert callback_manager2.inheritable_handlers == [handler1] + + callback_manager2.add_handler(handler2, inherit=False) + assert callback_manager2.handlers == [handler3, handler4, handler1, handler2] + assert callback_manager2.inheritable_handlers == [handler1] + + run_manager = callback_manager2.on_chain_start({"name": "foo"}, {}) + child_manager = run_manager.get_child() + assert child_manager.handlers == [handler1] + assert child_manager.inheritable_handlers == [handler1] + + child_manager = child_manager.on_tool_start({}, "") + assert child_manager.handlers == [handler1] + assert child_manager.inheritable_handlers == [handler1] + + child_manager = child_manager.get_child() + assert child_manager.handlers == [handler1] + assert child_manager.inheritable_handlers == [handler1] + + +def test_callback_manager_configure() -> None: + """Test callback manager configuration.""" + handler1, handler2, handler3, handler4 = ( + FakeCallbackHandler(), + FakeCallbackHandler(), + FakeCallbackHandler(), + FakeCallbackHandler(), + ) + + inheritable_callbacks = [handler1, handler2] + local_callbacks = [handler3, handler4] + configured_manager = CallbackManager.configure( + inheritable_callbacks=inheritable_callbacks, + local_callbacks=local_callbacks, + verbose=True, + ) + + assert len(configured_manager.handlers) == 5 + assert len(configured_manager.inheritable_handlers) == 2 + assert configured_manager.inheritable_handlers == inheritable_callbacks + assert configured_manager.handlers[:4] == inheritable_callbacks + local_callbacks + assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler) + assert isinstance(configured_manager, CallbackManager) + + async_local_callbacks = AsyncCallbackManager([handler3, handler4]) + async_configured_manager = AsyncCallbackManager.configure( + inheritable_callbacks=inheritable_callbacks, + local_callbacks=async_local_callbacks, + verbose=False, + ) + + assert len(async_configured_manager.handlers) == 4 + assert len(async_configured_manager.inheritable_handlers) == 2 + assert async_configured_manager.inheritable_handlers == inheritable_callbacks + assert async_configured_manager.handlers == inheritable_callbacks + [ + handler3, + handler4, + ] + assert isinstance(async_configured_manager, AsyncCallbackManager) diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 0b0aebf760f04..8c3081c2f149c 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -3,7 +3,7 @@ import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.chains.base import Chain from langchain.schema import BaseMemory from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py index d9d52630b7fdc..2802e58e38420 100644 --- a/tests/unit_tests/llms/test_callbacks.py +++ b/tests/unit_tests/llms/test_callbacks.py @@ -1,5 +1,5 @@ """Test LLM callbacks.""" -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.llms.fake_llm import FakeLLM From 90cef7b53ac5e92467b87b06dbf12c8f8eeadabe Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Sat, 22 Apr 2023 21:49:34 -0700 Subject: [PATCH 04/36] cr --- langchain/callbacks/manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index fdce94d3f5823..7c96e4fe1af6b 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -633,13 +633,12 @@ def _configure( if verbose or tracing_enabled: if not callback_manager: callback_manager = callback_manager_cls([]) - std_out_handler = StdOutCallbackHandler() if verbose and not any( isinstance(handler, StdOutCallbackHandler) for handler in callback_manager.handlers ): - callback_manager.add_handler(std_out_handler, False) + callback_manager.add_handler(StdOutCallbackHandler(), False) if tracing_enabled and not any( isinstance(handler, LangChainTracer) From 4cdd19bd4e807a242db328ea1fe6dbf42b5771e6 Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Tue, 25 Apr 2023 18:20:16 -0700 Subject: [PATCH 05/36] Callbacks Refactor [2/n] update tracer to work with new callbacks mechanism (#3381) --- docs/modules/callbacks/getting_started.ipynb | 1045 ++++++++++++++--- langchain/__init__.py | 8 - langchain/agents/agent.py | 157 ++- .../agents/agent_toolkits/openapi/planner.py | 2 +- langchain/agents/chat/base.py | 3 +- langchain/agents/conversational/base.py | 2 +- langchain/agents/conversational_chat/base.py | 2 +- langchain/agents/initialize.py | 2 +- langchain/agents/load_tools.py | 4 +- langchain/agents/mrkl/base.py | 2 +- langchain/agents/tools.py | 41 +- langchain/base_language.py | 55 + langchain/callbacks/__init__.py | 65 +- langchain/callbacks/manager.py | 98 +- langchain/callbacks/tracers/base.py | 312 +++-- langchain/callbacks/tracers/langchain.py | 89 +- langchain/callbacks/tracers/schemas.py | 5 +- langchain/chains/api/base.py | 2 +- langchain/chains/base.py | 129 +- langchain/chains/constitutional_ai/base.py | 2 +- .../chains/conversational_retrieval/base.py | 3 +- langchain/chains/llm.py | 117 +- langchain/chains/llm_bash/base.py | 2 +- langchain/chains/llm_math/base.py | 94 +- langchain/chains/pal/base.py | 2 +- langchain/chains/prompt_selector.py | 2 +- langchain/chains/qa_generation/base.py | 2 +- langchain/chains/qa_with_sources/base.py | 2 +- langchain/chains/qa_with_sources/loading.py | 2 +- .../chains/question_answering/__init__.py | 2 +- langchain/chains/retrieval_qa/base.py | 3 +- langchain/chains/sql_database/base.py | 2 +- langchain/chains/summarize/__init__.py | 2 +- langchain/chat_models/anthropic.py | 32 +- langchain/chat_models/base.py | 187 ++- langchain/chat_models/openai.py | 32 +- langchain/chat_models/promptlayer_openai.py | 18 +- .../autonomous_agents/baby_agi/baby_agi.py | 2 +- .../baby_agi/task_creation.py | 2 +- .../baby_agi/task_execution.py | 2 +- .../baby_agi/task_prioritization.py | 2 +- langchain/llms/ai21.py | 8 +- langchain/llms/aleph_alpha.py | 8 +- langchain/llms/anthropic.py | 33 +- langchain/llms/bananadev.py | 8 +- langchain/llms/base.py | 222 ++-- langchain/llms/cerebriumai.py | 8 +- langchain/llms/cohere.py | 8 +- langchain/llms/deepinfra.py | 8 +- langchain/llms/fake.py | 8 +- langchain/llms/forefrontai.py | 8 +- langchain/llms/gooseai.py | 8 +- langchain/llms/gpt4all.py | 8 +- langchain/llms/huggingface_endpoint.py | 8 +- langchain/llms/huggingface_hub.py | 8 +- langchain/llms/huggingface_pipeline.py | 8 +- langchain/llms/llamacpp.py | 8 +- langchain/llms/manifest.py | 8 +- langchain/llms/modal.py | 8 +- langchain/llms/nlpcloud.py | 8 +- langchain/llms/openai.py | 63 +- langchain/llms/petals.py | 8 +- langchain/llms/promptlayer_openai.py | 32 +- langchain/llms/replicate.py | 8 +- langchain/llms/rwkv.py | 8 +- langchain/llms/sagemaker_endpoint.py | 8 +- langchain/llms/self_hosted.py | 8 +- langchain/llms/self_hosted_hugging_face.py | 8 +- langchain/llms/stochasticai.py | 8 +- langchain/llms/writer.py | 8 +- langchain/memory/entity.py | 3 +- langchain/memory/kg.py | 2 +- langchain/memory/summary.py | 2 +- langchain/memory/token_buffer.py | 3 +- langchain/output_parsers/fix.py | 3 +- langchain/output_parsers/retry.py | 2 +- langchain/schema.py | 39 - langchain/tools/base.py | 97 +- langchain/utilities/serpapi.py | 4 +- tests/integration_tests/callbacks/__init__.py | 0 .../callbacks/test_langchain_tracer.py | 85 ++ .../callbacks/test_openai_callback.py | 36 + .../callbacks/tracers/test_tracer.py | 276 ++--- 83 files changed, 2440 insertions(+), 1199 deletions(-) create mode 100644 langchain/base_language.py create mode 100644 tests/integration_tests/callbacks/__init__.py create mode 100644 tests/integration_tests/callbacks/test_langchain_tracer.py create mode 100644 tests/integration_tests/callbacks/test_openai_callback.py diff --git a/docs/modules/callbacks/getting_started.ipynb b/docs/modules/callbacks/getting_started.ipynb index cc74a3e8f332e..6d4a99837ab53 100644 --- a/docs/modules/callbacks/getting_started.ipynb +++ b/docs/modules/callbacks/getting_started.ipynb @@ -17,33 +17,7 @@ "source": [ "LangChain provides a callback system that allows you to hook into the various stages of your LLM application. This is useful for logging, [monitoring](https://python.langchain.com/en/latest/tracing.html), [streaming](https://python.langchain.com/en/latest/modules/models/llms/examples/streaming_llm.html), and other tasks.\n", "\n", - "You can subscribe to these events by using the `callback_manager` argument available throughout the API. A `CallbackManager` is an object that manages a list of `CallbackHandlers`. The `CallbackManager` will call the appropriate method on each handler when the event is triggered." - ] - }, - { - "cell_type": "markdown", - "id": "fdb72e8d-a02a-474d-96bf-f5759432afc8", - "metadata": { - "tags": [] - }, - "source": [ - "```python\n", - "class CallbackManager(BaseCallbackHandler):\n", - " \"\"\"Base callback manager that can be used to handle callbacks from LangChain.\"\"\"\n", - "\n", - " def add_handler(self, callback: BaseCallbackHandler) -> None:\n", - " \"\"\"Add a handler to the callback manager.\"\"\"\n", - "\n", - " def remove_handler(self, handler: BaseCallbackHandler) -> None:\n", - " \"\"\"Remove a handler from the callback manager.\"\"\"\n", - "\n", - " def set_handler(self, handler: BaseCallbackHandler) -> None:\n", - " \"\"\"Set handler as the only handler on the callback manager.\"\"\"\n", - " self.set_handlers([handler])\n", - "\n", - " def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:\n", - " \"\"\"Set handlers as the only handlers on the callback manager.\"\"\"\n", - "```" + "You can subscribe to these events by using the `callbacks` argument available throughout the API. This argument list of handler objects, which are expected to implement one or more of the methods described in the API docs." ] }, { @@ -62,70 +36,57 @@ }, "source": [ "```python\n", - "class BaseCallbackHandler(ABC):\n", + "class BaseCallbackHandler:\n", " \"\"\"Base callback handler that can be used to handle callbacks from langchain.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_llm_start(\n", " self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when LLM starts running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:\n", " \"\"\"Run on new LLM token. Only available when streaming is enabled.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:\n", " \"\"\"Run when LLM ends running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_llm_error(\n", " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when LLM errors.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_chain_start(\n", " self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when chain starts running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:\n", " \"\"\"Run when chain ends running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_chain_error(\n", " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when chain errors.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_tool_start(\n", " self, serialized: Dict[str, Any], input_str: str, **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when tool starts running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_tool_end(self, output: str, **kwargs: Any) -> Any:\n", " \"\"\"Run when tool ends running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_tool_error(\n", " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when tool errors.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_text(self, text: str, **kwargs: Any) -> Any:\n", " \"\"\"Run on arbitrary text.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:\n", " \"\"\"Run on agent action.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:\n", " \"\"\"Run on agent end.\"\"\"\n", "```" @@ -136,14 +97,16 @@ "id": "d3bf3304-43fb-47ad-ae50-0637a17018a2", "metadata": {}, "source": [ - "## Creating and Using a Custom `CallbackHandler`\n", + "## Using an existing handler\n", "\n", - "By default, a shared CallbackManager with the StdOutCallbackHandler will be used by models, chains, agents, and tools. However, you can pass in your own CallbackManager with a custom CallbackHandler:" + "LangChain provides a few built-in handlers that you can use to get started. These are available in the `langchain/callbacks` module. The most basic handler is the `StdOutCallbackHandler`, which simply logs all events to `stdout`. In the future we will add more default handlers to the library. \n", + "\n", + "**Note** when the `verbose` flag on the object is set to true, the `StdOutCallbackHandler` will be invoked even without being explicitly passed in." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "id": "80532dfc-d687-4147-a0c9-1f90cc3e868c", "metadata": { "tags": [] @@ -155,16 +118,16 @@ "text": [ "\n", "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "AgentAction(tool='Search', tool_input=\"US Open men's final 2019 winner\", log=' I need to find out who won the US Open men\\'s final in 2019 and then calculate his age raised to the 0.334 power.\\nAction: Search\\nAction Input: \"US Open men\\'s final 2019 winner\"')\n", - "Rafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\n", - "AgentAction(tool='Search', tool_input='Rafael Nadal age', log=' I need to find out the age of the winner\\nAction: Search\\nAction Input: \"Rafael Nadal age\"')\n", - "36 years\n", - "AgentAction(tool='Calculator', tool_input='36^0.334', log=' I now need to calculate his age raised to the 0.334 power\\nAction: Calculator\\nAction Input: 36^0.334')\n", - "Answer: 3.3098250249682484\n", + "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3m1 + 2 = \u001b[0m\n", "\n", - " I now know the final answer\n", - "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3m1 + 2 = \u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -172,193 +135,929 @@ { "data": { "text/plain": [ - "\"Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\"" + "'\\n\\n3'" ] }, - "execution_count": 1, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from typing import Any, Dict, List, Optional, Union\n", - "\n", - "from langchain.agents import initialize_agent, load_tools\n", - "from langchain.agents import AgentType\n", - "from langchain.callbacks.base import CallbackManager, BaseCallbackHandler\n", + "from langchain.callbacks import StdOutCallbackHandler\n", + "from langchain.chains import LLMChain\n", "from langchain.llms import OpenAI\n", - "from langchain.schema import AgentAction, AgentFinish, LLMResult\n", + "from langchain.prompts import PromptTemplate\n", "\n", - "class MyCustomCallbackHandler(BaseCallbackHandler):\n", - " \"\"\"Custom CallbackHandler.\"\"\"\n", + "handler = StdOutCallbackHandler()\n", + "llm = OpenAI()\n", + "prompt = PromptTemplate.from_template(\"1 + {number} = \")\n", "\n", - " def on_llm_start(\n", + "# First, let's explicitly set the StdOutCallbackHandler in `callbacks`\n", + "chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])\n", + "chain.run(number=2)\n", + "\n", + "# Then, let's use the `verbose` flag to achieve the same result\n", + "chain = LLMChain(llm=llm, prompt=prompt, verbose=True)\n", + "chain.run(number=2)" + ] + }, + { + "cell_type": "markdown", + "id": "389c8448-5283-49e3-8c04-dbe1522e202c", + "metadata": {}, + "source": [ + "## Creating a custom handler\n", + "\n", + "You can create a custom handler to set on the object as well. In the example below, we'll implement streaming with a custom handler." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "1b2e6588-0681-4cab-937a-7cc4790cea9a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "My custom handler, token: \n", + "My custom handler, token: Why\n", + "My custom handler, token: did\n", + "My custom handler, token: the\n", + "My custom handler, token: tomato\n", + "My custom handler, token: turn\n", + "My custom handler, token: red\n", + "My custom handler, token: ?\n", + "\n", + "\n", + "My custom handler, token: Because\n", + "My custom handler, token: it\n", + "My custom handler, token: saw\n", + "My custom handler, token: the\n", + "My custom handler, token: salad\n", + "My custom handler, token: dressing\n", + "My custom handler, token: !\n", + "My custom handler, token: \n" + ] + }, + { + "data": { + "text/plain": [ + "AIMessage(content='Why did the tomato turn red?\\n\\nBecause it saw the salad dressing!', additional_kwargs={})" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.callbacks.base import BaseCallbackHandler\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.schema import HumanMessage\n", + "\n", + "class MyCustomHandler(BaseCallbackHandler):\n", + " def on_llm_new_token(self, token: str, **kwargs) -> None:\n", + " print(f\"My custom handler, token: {token}\")\n", + "\n", + "# To enable streaming, we pass in `streaming=True` to the ChatModel constructor\n", + "# Additionally, we pass in a list with our custom handler\n", + "chat = ChatOpenAI(max_tokens=25, streaming=True, callbacks=[MyCustomHandler()])\n", + "\n", + "chat([HumanMessage(content=\"Tell me a joke\")])" + ] + }, + { + "cell_type": "markdown", + "id": "bc9785fa-4f71-4797-91a3-4fe7e57d0429", + "metadata": { + "tags": [] + }, + "source": [ + "## Async Callbacks\n", + "\n", + "If you are planning to use the async API, it is recommended to use `AsyncCallbackHandler` to avoid blocking the runloop. \n", + "\n", + "**Advanced** if you use a sync `CallbackHandler` while using an async method to run your llm/chain/tool/agent, it will still work. However, under the hood, it will be called with [`run_in_executor`](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor) which can cause issues if your `CallbackHandler` is not thread-safe." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "c702e0c9-a961-4897-90c1-cdd13b6f16b2", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "zzzz....\n", + "Hi! I just woke up. Your llm is starting\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 16:48:58,880 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/chat/completions processing_ms=210 request_id=0846181d992a4fbc954c80cf78e5bfb5 response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sync handler being called in a `thread_pool_executor`: token: \n", + "Sync handler being called in a `thread_pool_executor`: token: Why\n", + "Sync handler being called in a `thread_pool_executor`: token: don\n", + "Sync handler being called in a `thread_pool_executor`: token: 't\n", + "Sync handler being called in a `thread_pool_executor`: token: scientists\n", + "Sync handler being called in a `thread_pool_executor`: token: trust\n", + "Sync handler being called in a `thread_pool_executor`: token: atoms\n", + "Sync handler being called in a `thread_pool_executor`: token: ?\n", + "Sync handler being called in a `thread_pool_executor`: token: \n", + "\n", + "\n", + "Sync handler being called in a `thread_pool_executor`: token: Because\n", + "Sync handler being called in a `thread_pool_executor`: token: they\n", + "Sync handler being called in a `thread_pool_executor`: token: make\n", + "Sync handler being called in a `thread_pool_executor`: token: up\n", + "Sync handler being called in a `thread_pool_executor`: token: everything\n", + "Sync handler being called in a `thread_pool_executor`: token: !\n", + "Sync handler being called in a `thread_pool_executor`: token: \n", + "zzzz....\n", + "Hi! I just woke up. Your llm is ending\n" + ] + }, + { + "data": { + "text/plain": [ + "LLMResult(generations=[[ChatGeneration(text=\"Why don't scientists trust atoms? \\n\\nBecause they make up everything!\", generation_info=None, message=AIMessage(content=\"Why don't scientists trust atoms? \\n\\nBecause they make up everything!\", additional_kwargs={}))]], llm_output={'token_usage': {}, 'model_name': 'gpt-3.5-turbo'})" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import asyncio\n", + "from typing import Any, Dict, List\n", + "from langchain.schema import LLMResult\n", + "from langchain.callbacks.base import AsyncCallbackHandler\n", + "\n", + "class MyCustomSyncHandler(BaseCallbackHandler):\n", + " def on_llm_new_token(self, token: str, **kwargs) -> None:\n", + " print(f\"Sync handler being called in a `thread_pool_executor`: token: {token}\")\n", + "\n", + "class MyCustomAsyncHandler(AsyncCallbackHandler):\n", + " \"\"\"Async callback handler that can be used to handle callbacks from langchain.\"\"\"\n", + "\n", + " async def on_llm_start(\n", " self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n", " ) -> None:\n", - " \"\"\"Print out the prompts.\"\"\"\n", - " pass\n", + " \"\"\"Run when chain starts running.\"\"\"\n", + " print(\"zzzz....\")\n", + " await asyncio.sleep(0.3)\n", + " class_name = serialized[\"name\"]\n", + " print(\"Hi! I just woke up. Your llm is starting\")\n", "\n", - " def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", + " async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:\n", + " \"\"\"Run when chain ends running.\"\"\"\n", + " print(\"zzzz....\")\n", + " await asyncio.sleep(0.3)\n", + " print(\"Hi! I just woke up. Your llm is ending\")\n", "\n", - " def on_llm_new_token(self, token: str, **kwargs: Any) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", + "# To enable streaming, we pass in `streaming=True` to the ChatModel constructor\n", + "# Additionally, we pass in a list with our custom handler\n", + "chat = ChatOpenAI(max_tokens=25, streaming=True, callbacks=[MyCustomSyncHandler(), MyCustomAsyncHandler()])\n", + "\n", + "await chat.agenerate([[HumanMessage(content=\"Tell me a joke\")]])" + ] + }, + { + "cell_type": "markdown", + "id": "d26dbb34-fcc3-401c-a115-39c7620d2d65", + "metadata": {}, + "source": [ + "## Using multiple handlers, passing in handlers\n", + "\n", + "In the previous examples, we passed in callback handlers upon creation of an object by using `callbacks=`. In this case, the callbacks will be scoped to that particular object. \n", + "\n", + "However, in many cases, it is advantageous to pass in handlers instead when running the object. When we pass through `CallbackHandlers` using the `callbacks` keyword arg when executing an run, those callbacks will be issued by all nested objects involved in the execution. For example, when a handler is passed through to an `Agent`, it will be used for all callbacks related to the agent and all the objects involved in the agent's execution, in this case, the `Tools`, `LLMChain`, and `LLM`.\n", + "\n", + "This prevents us from having to manually attach the handlers to each individual nested object." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "8eec8756-1828-45cb-9699-38ac8543a150", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "on_chain_start AgentExecutor\n", + "on_chain_start LLMChain\n", + "on_llm_start OpenAI\n", + "on_llm_start (I'm the second handler!!) OpenAI\n", + "on_new_token I\n", + "on_new_token need\n", + "on_new_token to\n", + "on_new_token use\n", + "on_new_token a\n", + "on_new_token calculator\n", + "on_new_token to\n", + "on_new_token solve\n", + "on_new_token this\n", + "on_new_token .\n", + "on_new_token \n", + "Action\n", + "on_new_token :\n", + "on_new_token Calculator\n", + "on_new_token \n", + "Action\n", + "on_new_token Input\n", + "on_new_token :\n", + "on_new_token 2\n", + "on_new_token ^\n", + "on_new_token 0\n", + "on_new_token .\n", + "on_new_token 235\n", + "on_new_token \n", + "on_agent_action AgentAction(tool='Calculator', tool_input='2^0.235', log=' I need to use a calculator to solve this.\\nAction: Calculator\\nAction Input: 2^0.235')\n", + "on_tool_start Calculator\n", + "on_chain_start LLMMathChain\n", + "on_chain_start LLMChain\n", + "on_llm_start OpenAI\n", + "on_llm_start (I'm the second handler!!) OpenAI\n", + "on_new_token \n", + "\n", + "on_new_token ```text\n", + "on_new_token \n", + "\n", + "on_new_token 2\n", + "on_new_token **\n", + "on_new_token 0\n", + "on_new_token .\n", + "on_new_token 235\n", + "on_new_token \n", + "\n", + "on_new_token ```\n", + "\n", + "on_new_token ...\n", + "on_new_token num\n", + "on_new_token expr\n", + "on_new_token .\n", + "on_new_token evaluate\n", + "on_new_token (\"\n", + "on_new_token 2\n", + "on_new_token **\n", + "on_new_token 0\n", + "on_new_token .\n", + "on_new_token 235\n", + "on_new_token \")\n", + "on_new_token ...\n", + "on_new_token \n", + "\n", + "on_new_token \n", + "on_chain_start LLMChain\n", + "on_llm_start OpenAI\n", + "on_llm_start (I'm the second handler!!) OpenAI\n", + "on_new_token I\n", + "on_new_token now\n", + "on_new_token know\n", + "on_new_token the\n", + "on_new_token final\n", + "on_new_token answer\n", + "on_new_token .\n", + "on_new_token \n", + "Final\n", + "on_new_token Answer\n", + "on_new_token :\n", + "on_new_token 1\n", + "on_new_token .\n", + "on_new_token 17\n", + "on_new_token 690\n", + "on_new_token 67\n", + "on_new_token 372\n", + "on_new_token 187\n", + "on_new_token 674\n", + "on_new_token \n" + ] + }, + { + "data": { + "text/plain": [ + "'1.1769067372187674'" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from typing import Dict, Union, Any, List\n", + "\n", + "from langchain.callbacks.base import BaseCallbackHandler\n", + "from langchain.schema import AgentAction\n", + "from langchain.agents import AgentType, initialize_agent, load_tools\n", + "from langchain.callbacks import tracing_enabled\n", + "from langchain.llms import OpenAI\n", + "\n", + "# First, define custom callback handler implementations\n", + "class MyCustomHandlerOne(BaseCallbackHandler):\n", + " def on_llm_start(\n", + " self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n", + " ) -> Any:\n", + " print(f\"on_llm_start {serialized['name']}\")\n", + "\n", + " def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:\n", + " print(f\"on_new_token {token}\")\n", "\n", " def on_llm_error(\n", " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", + " ) -> Any:\n", + " \"\"\"Run when LLM errors.\"\"\"\n", "\n", " def on_chain_start(\n", " self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Print out that we are entering a chain.\"\"\"\n", - " class_name = serialized[\"name\"]\n", - " print(f\"\\n\\n\\033[1m> Entering new {class_name} chain...\\033[0m\")\n", - "\n", - " def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:\n", - " \"\"\"Print out that we finished a chain.\"\"\"\n", - " print(\"\\n\\033[1m> Finished chain.\\033[0m\")\n", - "\n", - " def on_chain_error(\n", - " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", + " ) -> Any:\n", + " print(f\"on_chain_start {serialized['name']}\")\n", "\n", " def on_tool_start(\n", - " self,\n", - " serialized: Dict[str, Any],\n", - " input_str: str,\n", - " **kwargs: Any,\n", - " ) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", + " self, serialized: Dict[str, Any], input_str: str, **kwargs: Any\n", + " ) -> Any:\n", + " print(f\"on_tool_start {serialized['name']}\")\n", "\n", - " def on_agent_action(\n", - " self, action: AgentAction, color: Optional[str] = None, **kwargs: Any\n", + " def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:\n", + " print(f\"on_agent_action {action}\")\n", + "\n", + "class MyCustomHandlerTwo(BaseCallbackHandler):\n", + " def on_llm_start(\n", + " self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n", " ) -> Any:\n", - " \"\"\"Run on agent action.\"\"\"\n", - " print(action)\n", - "\n", - " def on_tool_end(\n", - " self,\n", - " output: str,\n", - " color: Optional[str] = None,\n", - " observation_prefix: Optional[str] = None,\n", - " llm_prefix: Optional[str] = None,\n", - " **kwargs: Any,\n", - " ) -> None:\n", - " \"\"\"If not the final action, print out observation.\"\"\"\n", - " print(output)\n", + " print(f\"on_llm_start (I'm the second handler!!) {serialized['name']}\")\n", "\n", - " def on_tool_error(\n", - " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", - "\n", - " def on_text(\n", - " self,\n", - " text: str,\n", - " color: Optional[str] = None,\n", - " end: str = \"\",\n", - " **kwargs: Optional[str],\n", - " ) -> None:\n", - " \"\"\"Run when agent ends.\"\"\"\n", - " print(text)\n", + "# Instantiate the handlers\n", + "handler1 = MyCustomHandlerOne()\n", + "handler2 = MyCustomHandlerTwo()\n", "\n", - " def on_agent_finish(\n", - " self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Run on agent end.\"\"\"\n", - " print(finish.log)\n", - "manager = CallbackManager([MyCustomCallbackHandler()])\n", - "llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)\n", - "tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, callback_manager=manager)\n", + "# Setup the agent. Only the `llm` will issue callbacks for handler2\n", + "llm = OpenAI(temperature=0, streaming=True, callbacks=[handler2])\n", + "tools = load_tools([\"llm-math\"], llm=llm)\n", "agent = initialize_agent(\n", - " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager\n", + " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n", ")\n", - "agent.run(\"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\")" + "\n", + "# Callbacks for handler1 will be issued by every object involved in the \n", + "# Agent execution (llm, llmchain, tool, agent executor)\n", + "agent.run(\"What is 2 raised to the 0.235 power?\", callbacks=[handler1])" ] }, { "cell_type": "markdown", - "id": "bc9785fa-4f71-4797-91a3-4fe7e57d0429", + "id": "32b29135-f852-4492-88ed-547275c72c53", + "metadata": {}, + "source": [ + "# Tracing and Token Counting" + ] + }, + { + "cell_type": "markdown", + "id": "fbb606d6-2863-46c5-8347-9f0bdb3805bb", + "metadata": {}, + "source": [ + "Tracing and token counting are two capabilities we provide which are built on our callbacks mechanism." + ] + }, + { + "cell_type": "markdown", + "id": "f62cd10c-494c-47d6-aa98-6e926cb9c456", + "metadata": {}, + "source": [ + "## Tracing" + ] + }, + { + "cell_type": "markdown", + "id": "d5a74b3f-3769-4a4f-99c7-b6a3b20a94e2", + "metadata": {}, + "source": [ + "There are two recommended ways to trace your LangChains. One is by setting the `LANGCHAIN_TRACING` environment variable to `\"true\"`. The other is to use a context manager `with tracing_enabled()` to trace a particular block of code.\n", + "\n", + "**Note** if the environment variable is set, all code will be traced, regardless of whether or not it's within the context manager." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "f164dfd5-d987-4b6a-a7c8-019c651ce47f", "metadata": { "tags": [] }, + "outputs": [], "source": [ - "## Async Support\n", + "import os\n", + "\n", + "from langchain.agents import AgentType, initialize_agent, load_tools\n", + "from langchain.callbacks import tracing_enabled\n", + "from langchain.llms import OpenAI\n", "\n", - "If you are planning to use the async API, it is recommended to use `AsyncCallbackHandler` and `AsyncCallbackManager` to avoid blocking the runloop." + "# To run the code, make sure to set OPENAI_API_KEY and SERPAPI_API_KEY\n", + "llm = OpenAI(temperature=0)\n", + "tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n", + "agent = initialize_agent(\n", + " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n", + ")\n", + "\n", + "questions = [\n", + " \"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\",\n", + " \"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\",\n", + " \"Who won the most recent formula 1 grand prix? What is their age raised to the 0.23 power?\",\n", + " \"Who won the US Open women's final in 2019? What is her age raised to the 0.34 power?\",\n", + " \"Who is Beyonce's husband? What is his age raised to the 0.19 power?\",\n", + "]" ] }, { "cell_type": "code", - "execution_count": 3, - "id": "c702e0c9-a961-4897-90c1-cdd13b6f16b2", + "execution_count": 33, + "id": "6be7777e-ec1d-438f-ae33-3a93c45f808e", "metadata": { "tags": [] }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:33:40,925 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "zzzz....\n", "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "zzzz....\n", + "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", + "Action: Search\n", + "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Action: Search\n", + "Action Input: \"Rafael Nadal age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n", + "Action: Calculator\n", + "Action Input: 36^0.334\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:34:05,653 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n", + "2023-04-25 17:34:05,673 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", + "Action: Search\n", + "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", + "Action: Search\n", + "Action Input: \"Harry Styles age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", + "Action: Calculator\n", + "Action Input: 29^0.23\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", + "Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] } ], "source": [ - "import asyncio\n", - "from aiohttp import ClientSession\n", + "os.environ[\"LANGCHAIN_TRACING\"] = \"true\"\n", "\n", - "from langchain.callbacks.base import AsyncCallbackHandler, AsyncCallbackManager\n", + "# Both of the agent runs will be traced because the environment variable is set\n", + "agent.run(questions[0])\n", + "with tracing_enabled() as session:\n", + " assert session\n", + " agent.run(questions[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "a6fd6026-dc1e-4d48-893d-3592539c7828", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:37:45,895 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n", + "2023-04-25 17:37:45,982 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", + "Action: Search\n", + "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Action: Search\n", + "Action Input: \"Rafael Nadal age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate his age raised to the 0.334 power\n", + "Action: Calculator\n", + "Action Input: 36^0.334\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", + "Action: Search\n", + "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", + "Action: Search\n", + "Action Input: \"Harry Styles age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", + "Action: Calculator\n", + "Action Input: 29^0.23\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", + "Final Answer: Harry Styles, Olivia Wilde's boyfriend, is 29 years old and his age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\"Harry Styles, Olivia Wilde's boyfriend, is 29 years old and his age raised to the 0.23 power is 2.169459462491557.\"" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "if \"LANGCHAIN_TRACING\" in os.environ:\n", + " del os.environ[\"LANGCHAIN_TRACING\"]\n", + "with tracing_enabled() as session:\n", + " assert session\n", + " agent.run(questions[0]) # this should be traced\n", "\n", - "class MyCustomAsyncCallbackHandler(AsyncCallbackHandler):\n", - " \"\"\"Async callback handler that can be used to handle callbacks from langchain.\"\"\"\n", + "agent.run(questions[1]) # this should not be traced" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "9383a351-4983-44e9-abd7-ef942e1c65c4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:07,944 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n", + "2023-04-25 17:39:08,038 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n", + "2023-04-25 17:39:08,039 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:10,123 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1795 request_id=9139c5a1b136a84603a4adc584bbdd9b response_code=200\n", + "2023-04-25 17:39:10,127 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1881 request_id=11729ae35c511f56238ab69a5856efcc response_code=200\n", + "2023-04-25 17:39:10,238 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1995 request_id=5c319aa337991381b80b4c4b858b7f75 response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", + "Action: Search\n", + "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n", + "Action: Search\n", + "Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", + "Action: Search\n", + "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:11,863 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1070 request_id=82b718523868a00c0d3f047ac8a9ecea response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", + "Action: Search\n", + "Action Input: \"Harry Styles age\"\u001b[0m\u001b[33;1m\u001b[1;3m29 years\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:12,611 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1829 request_id=fe0a82fe729ebc37b7983474d9418a84 response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Action: Search\n", + "Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[33;1m\u001b[1;3m36 years\u001b[0m\u001b[33;1m\u001b[1;3mLewis Hamilton has won 103 Grands Prix during his career. He won 21 races with McLaren and has won 82 with Mercedes. Lewis Hamilton holds the record for the ...\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:15,366 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2813 request_id=396f7e8180605345ad13693d91ebfdda response_code=200\n", + "2023-04-25 17:39:15,558 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2049 request_id=fa2004f0d6934e94f09632caacda8ca4 response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", + "Action: Calculator\n", + "Action Input: 29^0.23\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n", + "Action: Calculator\n", + "Action Input: 36^0.334\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:17,295 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1604 request_id=f9c40e7fb3d94d936b285c3a5a0eb55f response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:18,181 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2791 request_id=9f0c27a6e995895a518d5614d5e54c61 response_code=200\n", + "2023-04-25 17:39:18,185 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2207 request_id=b5281a29649bfbeaad532391eacf954d response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m I need to find out Lewis Hamilton's age\n", + "Action: Search\n", + "Action Input: \"Lewis Hamilton Age\"\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\u001b[33;1m\u001b[1;3m38 years\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:20,282 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2702 request_id=761627e38a5b6e580262357668dd635b response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:20,605 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2182 request_id=4cafeb94298befebe0da1e3f4a38ab27 response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:22,431 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2555 request_id=299ae3539ed6fb681ac2f8d16e73a6bc response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32;1m\u001b[1;3m I now need to calculate 38 raised to the 0.23 power\n", + "Action: Calculator\n", + "Action Input: 38^0.23\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:24,802 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2194 request_id=e09fe0fed313ba77c9a6444c41c12f1f response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[36;1m\u001b[1;3mAnswer: 2.3086081644669734\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:39:26,912 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1963 request_id=8bedf74c7e4bfc5dfe014dcca47ce363 response_code=200\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\"Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\"" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The context manager is concurrency safe:\n", + "if \"LANGCHAIN_TRACING\" in os.environ:\n", + " del os.environ[\"LANGCHAIN_TRACING\"]\n", "\n", - " async def on_chain_start(\n", - " self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Run when chain starts running.\"\"\"\n", - " print(\"zzzz....\")\n", - " await asyncio.sleep(0.5)\n", - " class_name = serialized[\"name\"]\n", - " print(f\"\\n\\n\\033[1m> Entering new {class_name} chain...\\033[0m\")\n", + "# start a background task\n", + "task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced\n", + "with tracing_enabled() as session:\n", + " assert session\n", + " tasks = [agent.arun(q) for q in questions[1:3]] # these should be traced\n", + " await asyncio.gather(*tasks)\n", "\n", - " async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:\n", - " \"\"\"Run when chain ends running.\"\"\"\n", - " print(\"zzzz....\")\n", - " await asyncio.sleep(0.5)\n", - " print(\"\\n\\033[1m> Finished chain.\\033[0m\")\n", - "\n", - "manager = AsyncCallbackManager([MyCustomAsyncCallbackHandler()])\n", - "\n", - "# To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n", - "# but you must manually close the client session at the end of your program/event loop\n", - "aiosession = ClientSession()\n", - "llm = OpenAI(temperature=0, callback_manager=manager)\n", - "async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession, callback_manager=manager)\n", - "async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n", - "await async_agent.arun(\"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\")\n", - "await aiosession.close()" + "await task" + ] + }, + { + "cell_type": "markdown", + "id": "254fef1b-6b6e-4352-9cf4-363fba895ac7", + "metadata": {}, + "source": [ + "## Token Counting\n", + "LangChain offers a context manager that allows you to count tokens." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "5c3e0b89-2c5e-4036-bdf2-fb6b750e360c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-04-25 17:43:22,369 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=492 request_id=a13b1f276947e6e8a2179ebf7c092878 response_code=200\n", + "2023-04-25 17:43:22,376 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=491 request_id=5bd41d073db19e0002eb3d862b9fde22 response_code=200\n", + "2023-04-25 17:43:22,441 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=526 request_id=155a0aa6a078db963fda3fe3b68c463e response_code=200\n", + "2023-04-25 17:43:23,072 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=475 request_id=231e6e20ff294f2b0a46b4844d739c09 response_code=200\n", + "2023-04-25 17:43:23,681 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1088 request_id=90f4cf31a38f395d7ea98bd76d9bb36f response_code=200\n" + ] + } + ], + "source": [ + "from langchain.callbacks import get_openai_callback\n", + "\n", + "llm = OpenAI(temperature=0)\n", + "with get_openai_callback() as cb:\n", + " llm(\"What is the square root of 4?\")\n", + "\n", + "total_tokens = cb.total_tokens\n", + "assert total_tokens > 0\n", + "\n", + "with get_openai_callback() as cb:\n", + " llm(\"What is the square root of 4?\")\n", + " llm(\"What is the square root of 4?\")\n", + "\n", + "assert cb.total_tokens == total_tokens * 2\n", + "\n", + "# You can kick off concurrent runs from within the context manager\n", + "with get_openai_callback() as cb:\n", + " await asyncio.gather(\n", + " *[llm.agenerate([\"What is the square root of 4?\"]) for _ in range(3)]\n", + " )\n", + "\n", + "assert cb.total_tokens == total_tokens * 3\n", + "\n", + "# The context manager is concurrency safe\n", + "task = asyncio.create_task(llm.agenerate([\"What is the square root of 4?\"]))\n", + "with get_openai_callback() as cb:\n", + " await llm.agenerate([\"What is the square root of 4?\"])\n", + "\n", + "await task\n", + "assert cb.total_tokens == total_tokens" ] }, { "cell_type": "code", "execution_count": null, - "id": "86be6304-e433-4048-880c-a92a73244407", + "id": "f5d6521d-1901-473d-a2cd-a4e88db7f851", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/__init__.py b/langchain/__init__.py index 2eca9091177d4..2d44e935c3286 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -5,11 +5,6 @@ from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.cache import BaseCache -from langchain.callbacks import ( - set_default_callback_manager, - set_handler, - set_tracing_callback_manager, -) from langchain.chains import ( ConversationChain, LLMBashChain, @@ -65,7 +60,6 @@ verbose: bool = False llm_cache: Optional[BaseCache] = None -set_default_callback_manager() # For backwards compatibility SerpAPIChain = SerpAPIWrapper @@ -115,7 +109,5 @@ "VectorDBQAWithSourcesChain", "QAWithSourcesChain", "PALChain", - "set_handler", - "set_tracing_callback_manager", "LlamaCpp", ] diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index c675b5582833d..eb9b859fab341 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -13,7 +13,13 @@ from pydantic import BaseModel, root_validator from langchain.agents.tools import InvalidTool +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + Callbacks, +) from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping @@ -23,7 +29,6 @@ from langchain.schema import ( AgentAction, AgentFinish, - BaseLanguageModel, BaseMessage, BaseOutputParser, ) @@ -46,13 +51,17 @@ def get_allowed_tools(self) -> Optional[List[str]]: @abstractmethod def plan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: @@ -61,13 +70,17 @@ def plan( @abstractmethod async def aplan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: @@ -170,13 +183,17 @@ def get_allowed_tools(self) -> Optional[List[str]]: @abstractmethod def plan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[List[AgentAction], AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: @@ -185,13 +202,17 @@ def plan( @abstractmethod async def aplan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[List[AgentAction], AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: @@ -285,38 +306,52 @@ def input_keys(self) -> List[str]: return list(set(self.llm_chain.input_keys) - {"intermediate_steps"}) def plan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: Action specifying what tool to use. """ output = self.llm_chain.run( - intermediate_steps=intermediate_steps, stop=self.stop, **kwargs + intermediate_steps=intermediate_steps, + stop=self.stop, + callbacks=callbacks, + **kwargs, ) return self.output_parser.parse(output) async def aplan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: Action specifying what tool to use. """ output = await self.llm_chain.arun( - intermediate_steps=intermediate_steps, stop=self.stop, **kwargs + intermediate_steps=intermediate_steps, + stop=self.stop, + callbacks=callbacks, + **kwargs, ) return self.output_parser.parse(output) @@ -368,37 +403,45 @@ def _construct_scratchpad( return thoughts def plan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: Action specifying what tool to use. """ full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) - full_output = self.llm_chain.predict(**full_inputs) + full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) return self.output_parser.parse(full_output) async def aplan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: Action specifying what tool to use. """ full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) - full_output = await self.llm_chain.apredict(**full_inputs) + full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs) return self.output_parser.parse(full_output) def get_full_inputs( @@ -632,24 +675,27 @@ def _should_continue(self, iterations: int, time_elapsed: float) -> bool: return True - def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]: - self.callback_manager.on_agent_finish( - output, color="green", verbose=self.verbose - ) + def _return( + self, + output: AgentFinish, + intermediate_steps: list, + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + if run_manager: + run_manager.on_agent_finish(output, color="green", verbose=self.verbose) final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps return final_output async def _areturn( - self, output: AgentFinish, intermediate_steps: list + self, + output: AgentFinish, + intermediate_steps: list, + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: - if self.callback_manager.is_async: - await self.callback_manager.on_agent_finish( - output, color="green", verbose=self.verbose - ) - else: - self.callback_manager.on_agent_finish( + if run_manager: + await run_manager.on_agent_finish( output, color="green", verbose=self.verbose ) final_output = output.return_values @@ -663,13 +709,18 @@ def _take_next_step( color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]], + run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: """Take a single step in the thought-action-observation loop. Override this to take control of how the agent makes and acts on choices. """ # Call the LLM to see what to do. - output = self.agent.plan(intermediate_steps, **inputs) + output = self.agent.plan( + intermediate_steps, + callbacks=run_manager.get_child() if run_manager else None, + **inputs, + ) # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): return output @@ -680,9 +731,8 @@ def _take_next_step( actions = output result = [] for agent_action in actions: - self.callback_manager.on_agent_action( - agent_action, verbose=self.verbose, color="green" - ) + if run_manager: + run_manager.on_agent_action(agent_action, color="green") # Otherwise we lookup the tool if agent_action.tool in name_to_tool_map: tool = name_to_tool_map[agent_action.tool] @@ -696,6 +746,7 @@ def _take_next_step( agent_action.tool_input, verbose=self.verbose, color=color, + callbacks=run_manager.get_child() if run_manager else None, **tool_run_kwargs, ) else: @@ -704,6 +755,7 @@ def _take_next_step( agent_action.tool, verbose=self.verbose, color=None, + callbacks=run_manager.get_child() if run_manager else None, **tool_run_kwargs, ) result.append((agent_action, observation)) @@ -715,13 +767,18 @@ async def _atake_next_step( color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: """Take a single step in the thought-action-observation loop. Override this to take control of how the agent makes and acts on choices. """ # Call the LLM to see what to do. - output = await self.agent.aplan(intermediate_steps, **inputs) + output = await self.agent.aplan( + intermediate_steps, + callbacks=run_manager.get_child() if run_manager else None, + **inputs, + ) # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): return output @@ -734,12 +791,8 @@ async def _atake_next_step( async def _aperform_agent_action( agent_action: AgentAction, ) -> Tuple[AgentAction, str]: - if self.callback_manager.is_async: - await self.callback_manager.on_agent_action( - agent_action, verbose=self.verbose, color="green" - ) - else: - self.callback_manager.on_agent_action( + if run_manager: + await run_manager.on_agent_action( agent_action, verbose=self.verbose, color="green" ) # Otherwise we lookup the tool @@ -755,6 +808,7 @@ async def _aperform_agent_action( agent_action.tool_input, verbose=self.verbose, color=color, + callbacks=run_manager.get_child() if run_manager else None, **tool_run_kwargs, ) else: @@ -763,6 +817,7 @@ async def _aperform_agent_action( agent_action.tool, verbose=self.verbose, color=None, + callbacks=run_manager.get_child() if run_manager else None, **tool_run_kwargs, ) return agent_action, observation @@ -774,7 +829,11 @@ async def _aperform_agent_action( return list(result) - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """Run text through and get agent response.""" # Construct a mapping of tool name to tool for easy lookup name_to_tool_map = {tool.name: tool for tool in self.tools} @@ -790,10 +849,16 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: # We now enter the agent loop (until it returns something). while self._should_continue(iterations, time_elapsed): next_step_output = self._take_next_step( - name_to_tool_map, color_mapping, inputs, intermediate_steps + name_to_tool_map, + color_mapping, + inputs, + intermediate_steps, + run_manager=run_manager, ) if isinstance(next_step_output, AgentFinish): - return self._return(next_step_output, intermediate_steps) + return self._return( + next_step_output, intermediate_steps, run_manager=run_manager + ) intermediate_steps.extend(next_step_output) if len(next_step_output) == 1: @@ -809,7 +874,11 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: ) return self._return(output, intermediate_steps) - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, str], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: """Run text through and get agent response.""" # Construct a mapping of tool name to tool for easy lookup name_to_tool_map = {tool.name: tool for tool in self.tools} @@ -827,7 +896,11 @@ async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: try: while self._should_continue(iterations, time_elapsed): next_step_output = await self._atake_next_step( - name_to_tool_map, color_mapping, inputs, intermediate_steps + name_to_tool_map, + color_mapping, + inputs, + intermediate_steps, + run_manager=run_manager, ) if isinstance(next_step_output, AgentFinish): return await self._areturn(next_step_output, intermediate_steps) @@ -845,7 +918,9 @@ async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) - return await self._areturn(output, intermediate_steps) + return await self._areturn( + output, intermediate_steps, run_manager=run_manager + ) except TimeoutError: # stop early when interrupted by the async timeout output = self.agent.return_stopped_response( diff --git a/langchain/agents/agent_toolkits/openapi/planner.py b/langchain/agents/agent_toolkits/openapi/planner.py index 8865bc4222826..4545e73216e6d 100644 --- a/langchain/agents/agent_toolkits/openapi/planner.py +++ b/langchain/agents/agent_toolkits/openapi/planner.py @@ -26,12 +26,12 @@ from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.tools import Tool +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.llms.openai import OpenAI from langchain.memory import ReadOnlySharedMemory from langchain.prompts import PromptTemplate from langchain.requests import RequestsWrapper -from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool from langchain.tools.requests.tool import BaseRequestsTool diff --git a/langchain/agents/chat/base.py b/langchain/agents/chat/base.py index 7245c10d115c0..04ceca71d0559 100644 --- a/langchain/agents/chat/base.py +++ b/langchain/agents/chat/base.py @@ -5,6 +5,7 @@ from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.chat.output_parser import ChatOutputParser from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain from langchain.prompts.base import BasePromptTemplate @@ -13,7 +14,7 @@ HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain.schema import AgentAction, BaseLanguageModel +from langchain.schema import AgentAction from langchain.tools import BaseTool diff --git a/langchain/agents/conversational/base.py b/langchain/agents/conversational/base.py index 75018314bb667..16a43a90150a9 100644 --- a/langchain/agents/conversational/base.py +++ b/langchain/agents/conversational/base.py @@ -9,10 +9,10 @@ from langchain.agents.agent_types import AgentType from langchain.agents.conversational.output_parser import ConvoOutputParser from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain from langchain.prompts import PromptTemplate -from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/langchain/agents/conversational_chat/base.py b/langchain/agents/conversational_chat/base.py index a91915c0afc4f..d9b83ecc6c644 100644 --- a/langchain/agents/conversational_chat/base.py +++ b/langchain/agents/conversational_chat/base.py @@ -12,6 +12,7 @@ SUFFIX, TEMPLATE_TOOL_RESPONSE, ) +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain from langchain.prompts.base import BasePromptTemplate @@ -24,7 +25,6 @@ from langchain.schema import ( AgentAction, AIMessage, - BaseLanguageModel, BaseMessage, BaseOutputParser, HumanMessage, diff --git a/langchain/agents/initialize.py b/langchain/agents/initialize.py index 72784b8999756..9a52b15196508 100644 --- a/langchain/agents/initialize.py +++ b/langchain/agents/initialize.py @@ -4,8 +4,8 @@ from langchain.agents.agent import AgentExecutor from langchain.agents.agent_types import AgentType from langchain.agents.loading import AGENT_TO_CLASS, load_agent +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager -from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index 790127ecd6842..fe5ea878d7d9d 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -103,8 +103,8 @@ def _get_llm_math(llm: BaseLLM) -> BaseTool: return Tool( name="Calculator", description="Useful for when you need to answer questions about math.", - func=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).run, - coroutine=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).arun, + func=LLMMathChain(llm=llm).run, + coroutine=LLMMathChain(llm=llm).arun, ) diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 4bb1d519ed69e..4850a8fd71a4d 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -10,10 +10,10 @@ from langchain.agents.mrkl.output_parser import MRKLOutputParser from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.tools import Tool +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain from langchain.prompts import PromptTemplate -from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index c5aa094eb676f..86dcadb498597 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -4,6 +4,11 @@ from pydantic import BaseModel, validate_arguments +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, + Callbacks, +) from langchain.tools.base import BaseTool @@ -26,14 +31,42 @@ def args(self) -> dict: valid_keys = signature(self.func).parameters return {k: schema[k] for k in valid_keys} - def _run(self, *args: Any, **kwargs: Any) -> str: + def _run( + self, + *args: Any, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: """Use the tool.""" - return self.func(*args, **kwargs) + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) + ) - async def _arun(self, *args: Any, **kwargs: Any) -> str: + async def _arun( + self, + *args: Any, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: """Use the tool asynchronously.""" + new_argument_supported = signature(self.coroutine).parameters.get("callbacks") if self.coroutine: - return await self.coroutine(*args, **kwargs) + return ( + await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else await self.coroutine(*args, **kwargs) + ) raise NotImplementedError("Tool does not support async") # TODO: this is for backwards compatibility, remove in future diff --git a/langchain/base_language.py b/langchain/base_language.py new file mode 100644 index 0000000000000..3821f46424ad2 --- /dev/null +++ b/langchain/base_language.py @@ -0,0 +1,55 @@ +"""Base class for all language models.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List, Optional + +from pydantic import BaseModel + +from langchain.callbacks.manager import Callbacks +from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string + + +class BaseLanguageModel(BaseModel, ABC): + @abstractmethod + def generate_prompt( + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + ) -> LLMResult: + """Take in a list of prompt values and return an LLMResult.""" + + @abstractmethod + async def agenerate_prompt( + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + ) -> LLMResult: + """Take in a list of prompt values and return an LLMResult.""" + + def get_num_tokens(self, text: str) -> int: + """Get the number of tokens present in the text.""" + # TODO: this method may not be exact. + # TODO: this method may differ based on model (eg codex). + try: + from transformers import GPT2TokenizerFast + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "This is needed in order to calculate get_num_tokens. " + "Please install it with `pip install transformers`." + ) + # create a GPT-3 tokenizer instance + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + + # tokenize the text using the GPT-3 tokenizer + tokenized_text = tokenizer.tokenize(text) + + # calculate the number of tokens in the tokenized text + return len(tokenized_text) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Get the number of tokens in the message.""" + return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py index 232910ad32797..28a00e0b5c7f3 100644 --- a/langchain/callbacks/__init__.py +++ b/langchain/callbacks/__init__.py @@ -1,7 +1,6 @@ """Callback handlers that allow listening to events in LangChain.""" -import os from contextlib import contextmanager -from typing import Generator, Optional +from typing import Generator from langchain.callbacks.aim_callback import AimCallbackHandler from langchain.callbacks.base import ( @@ -10,64 +9,17 @@ ) from langchain.callbacks.clearml_callback import ClearMLCallbackHandler from langchain.callbacks.comet_ml_callback import CometCallbackHandler -from langchain.callbacks.manager import CallbackManager +from langchain.callbacks.manager import ( + CallbackManager, + get_openai_callback, + tracing_enabled, +) from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from langchain.callbacks.tracers import LangChainTracer from langchain.callbacks.wandb_callback import WandbCallbackHandler - -def get_callback_manager() -> BaseCallbackManager: - """Return the shared callback manager.""" - return CallbackManager([]) - - -def set_handler(handler: BaseCallbackHandler) -> None: - """Set handler.""" - callback = get_callback_manager() - callback.set_handler(handler) - - -def set_default_callback_manager() -> None: - """Set default callback manager.""" - default_handler = os.environ.get("LANGCHAIN_HANDLER", "stdout") - if default_handler == "stdout": - set_handler(StdOutCallbackHandler()) - elif default_handler == "langchain": - session = os.environ.get("LANGCHAIN_SESSION") - set_tracing_callback_manager(session) - else: - raise ValueError( - f"LANGCHAIN_HANDLER should be one of `stdout` " - f"or `langchain`, got {default_handler}" - ) - - -def set_tracing_callback_manager(session_name: Optional[str] = None) -> None: - """Set tracing callback manager.""" - handler = LangChainTracer() - callback = get_callback_manager() - callback.set_handlers([handler, StdOutCallbackHandler()]) - if session_name is None: - handler.load_default_session() - else: - try: - handler.load_session(session_name) - except Exception: - raise ValueError(f"session {session_name} not found") - - -@contextmanager -def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: - """Get OpenAI callback handler in a context manager.""" - handler = OpenAICallbackHandler() - manager = get_callback_manager() - manager.add_handler(handler) - yield handler - manager.remove_handler(handler) - - __all__ = [ "OpenAICallbackHandler", "StdOutCallbackHandler", @@ -77,8 +29,5 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: "CometCallbackHandler", "AsyncIteratorCallbackHandler", "get_openai_callback", - "set_tracing_callback_manager", - "set_default_callback_manager", - "set_handler", - "get_callback_manager", + "tracing_enabled", ] diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 7c96e4fe1af6b..69628e0ad4d46 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -6,7 +6,9 @@ import logging import os import uuid -from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union from langchain.callbacks.base import ( BaseCallbackHandler, @@ -16,7 +18,9 @@ RunManagerMixin, ToolManagerMixin, ) +from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler +from langchain.callbacks.tracers.base import TracerSession from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.schema import AgentAction, AgentFinish, LLMResult @@ -26,6 +30,36 @@ handlers=[logging.StreamHandler()], ) +Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] + +openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( + "openai_callback", default=None +) +tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( + "tracing_callback", default=None +) + + +@contextmanager +def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: + """Get OpenAI callback handler in a context manager.""" + cb = OpenAICallbackHandler() + openai_callback_var.set(cb) + yield cb + openai_callback_var.set(None) + + +@contextmanager +def tracing_enabled( + session_name: str = "default", +) -> Generator[TracerSession, None, None]: + """Get OpenAI callback handler in a context manager.""" + cb = LangChainTracer() + session = cb.load_session(session_name) + tracing_callback_var.set(cb) + yield session + tracing_callback_var.set(None) + def _handle_event( handlers: List[BaseCallbackHandler], @@ -104,7 +138,7 @@ class RunManager(BaseRunManager): def on_text(self, text: str, **kwargs: Any) -> Any: """Run when text is received.""" - _handle_event(self.handlers, "on_text", None, False, text, **kwargs) + _handle_event(self.handlers, "on_text", None, text, **kwargs) class AsyncRunManager(BaseRunManager): @@ -112,7 +146,7 @@ class AsyncRunManager(BaseRunManager): async def on_text(self, text: str, **kwargs: Any) -> Any: """Run when text is received.""" - await _ahandle_event(self.handlers, "on_text", None, False, text, **kwargs) + await _ahandle_event(self.handlers, "on_text", None, text, **kwargs) class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): @@ -415,7 +449,7 @@ def on_llm_start( ) -> CallbackManagerForLLMRun: """Run when LLM starts running.""" if run_id is None: - run_id = uuid.uuid4() + run_id = str(uuid.uuid4()) _handle_event( self.handlers, @@ -441,7 +475,7 @@ def on_chain_start( ) -> CallbackManagerForChainRun: """Run when chain starts running.""" if run_id is None: - run_id = uuid.uuid4() + run_id = str(uuid.uuid4()) _handle_event( self.handlers, @@ -468,7 +502,7 @@ def on_tool_start( ) -> CallbackManagerForToolRun: """Run when tool starts running.""" if run_id is None: - run_id = uuid.uuid4() + run_id = str(uuid.uuid4()) _handle_event( self.handlers, @@ -477,7 +511,7 @@ def on_tool_start( serialized, input_str, run_id=run_id, - parent_run_id=parent_run_id, + parent_run_id=self.parent_run_id, **kwargs, ) @@ -517,7 +551,7 @@ async def on_llm_start( ) -> AsyncCallbackManagerForLLMRun: """Run when LLM starts running.""" if run_id is None: - run_id = uuid.uuid4() + run_id = str(uuid.uuid4()) await _ahandle_event( self.handlers, @@ -543,7 +577,7 @@ async def on_chain_start( ) -> AsyncCallbackManagerForChainRun: """Run when chain starts running.""" if run_id is None: - run_id = uuid.uuid4() + run_id = str(uuid.uuid4()) await _ahandle_event( self.handlers, @@ -570,7 +604,7 @@ async def on_tool_start( ) -> AsyncCallbackManagerForToolRun: """Run when tool starts running.""" if run_id is None: - run_id = uuid.uuid4() + run_id = str(uuid.uuid4()) await _ahandle_event( self.handlers, @@ -579,7 +613,7 @@ async def on_tool_start( serialized, input_str, run_id=run_id, - parent_run_id=parent_run_id, + parent_run_id=self.parent_run_id, **kwargs, ) @@ -610,14 +644,16 @@ def _configure( inheritable_callbacks: Optional[Union[T, List[BaseCallbackHandler]]] = None, local_callbacks: Optional[Union[T, List[BaseCallbackHandler]]] = None, verbose: bool = False, -) -> Optional[T]: +) -> T: """Configure the callback manager.""" callback_manager: Optional[T] = None if inheritable_callbacks or local_callbacks: if isinstance(inheritable_callbacks, list) or not inheritable_callbacks: callback_manager = callback_manager_cls( - handlers=inheritable_callbacks, - inheritable_handlers=inheritable_callbacks, + handlers=inheritable_callbacks if inheritable_callbacks else [], + inheritable_handlers=inheritable_callbacks + if inheritable_callbacks + else [], ) else: callback_manager = inheritable_callbacks @@ -627,13 +663,19 @@ def _configure( if isinstance(local_callbacks, list) else (local_callbacks.handlers if local_callbacks else []) ) - [callback_manager.add_handler(handler, False) for handler in local_handlers_] - - tracing_enabled = os.environ.get("LANGCHAIN_TRACING") is not None - if verbose or tracing_enabled: - if not callback_manager: - callback_manager = callback_manager_cls([]) - + [ + callback_manager.add_handler(copy.deepcopy(handler), False) + for handler in local_handlers_ + ] + + if not callback_manager: + callback_manager = callback_manager_cls([]) + tracer = tracing_callback_var.get() + open_ai = openai_callback_var.get() + tracing_enabled = ( + os.environ.get("LANGCHAIN_TRACING") is not None or tracer is not None + ) + if verbose or tracing_enabled or open_ai is not None: if verbose and not any( isinstance(handler, StdOutCallbackHandler) for handler in callback_manager.handlers @@ -644,8 +686,16 @@ def _configure( isinstance(handler, LangChainTracer) for handler in callback_manager.handlers ): - handler = LangChainTracer() - handler.load_default_session() - callback_manager.add_handler(handler, True) + if tracer: + callback_manager.add_handler(copy.deepcopy(tracer), True) + else: + handler = LangChainTracer() + handler.load_default_session() + callback_manager.add_handler(handler, True) + if open_ai is not None and not any( + isinstance(handler, OpenAICallbackHandler) + for handler in callback_manager.handlers + ): + callback_manager.add_handler(open_ai, True) return callback_manager diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 0b2b99dd6483b..763f710d9e5ae 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, List, Optional, Union +from uuid import uuid4 from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.tracers.schemas import ( @@ -13,7 +14,7 @@ TracerSession, TracerSessionCreate, ) -from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.schema import LLMResult class TracerException(Exception): @@ -23,13 +24,26 @@ class TracerException(Exception): class BaseTracer(BaseCallbackHandler, ABC): """Base interface for tracers.""" - @abstractmethod + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {} + self.execution_order: int = 1 + self.session: Optional[TracerSession] = None + + @staticmethod def _add_child_run( - self, parent_run: Union[ChainRun, ToolRun], child_run: Union[LLMRun, ChainRun, ToolRun], ) -> None: """Add child run to a chain run or tool run.""" + if isinstance(child_run, LLMRun): + parent_run.child_llm_runs.append(child_run) + elif isinstance(child_run, ChainRun): + parent_run.child_chain_runs.append(child_run) + elif isinstance(child_run, ToolRun): + parent_run.child_tool_runs.append(child_run) + else: + raise TracerException(f"Invalid run type: {type(child_run)}") @abstractmethod def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: @@ -39,15 +53,11 @@ def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: def _persist_session(self, session: TracerSessionCreate) -> TracerSession: """Persist a tracing session.""" - @abstractmethod - def _generate_id(self) -> Optional[Union[int, str]]: - """Generate an id for a run.""" - def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession: """NOT thread safe, do not call this method from multiple threads.""" session_create = TracerSessionCreate(name=name, extra=kwargs) session = self._persist_session(session_create) - self._session = session + self.session = session return session @abstractmethod @@ -58,235 +68,211 @@ def load_session(self, session_name: str) -> TracerSession: def load_default_session(self) -> TracerSession: """Load the default tracing session and set it as the Tracer's session.""" - @property - @abstractmethod - def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]: - """Get the tracer stack.""" - - @property - @abstractmethod - def _execution_order(self) -> int: - """Get the execution order for a run.""" - - @_execution_order.setter - @abstractmethod - def _execution_order(self, value: int) -> None: - """Set the execution order for a run.""" - - @property - @abstractmethod - def _session(self) -> Optional[TracerSession]: - """Get the tracing session.""" - - @_session.setter - @abstractmethod - def _session(self, value: TracerSession) -> None: - """Set the tracing session.""" - def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """Start a trace for a run.""" - self._execution_order += 1 - - if self._stack: - if not ( - isinstance(self._stack[-1], ChainRun) - or isinstance(self._stack[-1], ToolRun) - ): + self.execution_order += 1 + + if run.parent_uuid: + parent_run = self.run_map[run.parent_uuid] + if parent_run: + if isinstance(parent_run, LLMRun): + raise TracerException( + "Cannot add child run to an LLM run. " + "LLM runs are not allowed to have children." + ) + self._add_child_run(parent_run, run) + else: raise TracerException( - f"Nested {run.__class__.__name__} can only be" - f" logged inside a ChainRun or ToolRun" + f"Parent run with UUID {run.parent_uuid} not found." ) - self._add_child_run(self._stack[-1], run) - self._stack.append(run) - def _end_trace(self) -> None: + self.run_map[run.uuid] = run + + def _end_trace(self, run) -> None: """End a trace for a run.""" - run = self._stack.pop() - if not self._stack: - self._execution_order = 1 + if not run.parent_uuid: self._persist_run(run) + self.execution_order = 1 + self.run_map.pop(run.uuid) def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, + serialized: Dict[str, Any], + prompts: List[str], + run_id: str = None, + parent_run_id: str = None, + **kwargs: Any, ) -> None: """Start a trace for an LLM run.""" - if self._session is None: - raise TracerException( - "Initialize a session with `new_session()` before starting a trace." - ) + if self.session is None: + self.session = self.load_default_session() + + if run_id is None: + run_id = str(uuid4()) llm_run = LLMRun( + uuid=run_id, + parent_uuid=parent_run_id, serialized=serialized, prompts=prompts, extra=kwargs, start_time=datetime.utcnow(), - execution_order=self._execution_order, - session_id=self._session.id, - id=self._generate_id(), + execution_order=self.execution_order, + session_id=self.session.id, ) self._start_trace(llm_run) - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Handle a new token for an LLM run.""" - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + def on_llm_end( + self, response: LLMResult, run_id: str = None, **kwargs: Any + ) -> None: """End a trace for an LLM run.""" - if not self._stack or not isinstance(self._stack[-1], LLMRun): - raise TracerException("No LLMRun found to be traced") + if not run_id: + raise TracerException("No run_id provided for on_llm_end callback.") - self._stack[-1].end_time = datetime.utcnow() - self._stack[-1].response = response + if not self.run_map or not isinstance(self.run_map[run_id], LLMRun): + raise TracerException("No LLMRun found to be traced") - self._end_trace() + llm_run = self.run_map[run_id] + llm_run.response = response + llm_run.end_time = datetime.utcnow() + self._end_trace(llm_run) def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + error: Union[Exception, KeyboardInterrupt], + run_id: str = None, + **kwargs: Any, ) -> None: """Handle an error for an LLM run.""" - if not self._stack or not isinstance(self._stack[-1], LLMRun): - raise TracerException("No LLMRun found to be traced") + if not run_id: + raise TracerException("No run_id provided for on_llm_error callback.") - self._stack[-1].error = repr(error) - self._stack[-1].end_time = datetime.utcnow() + if not self.run_map or not isinstance(self.run_map[run_id], LLMRun): + raise TracerException("No LLMRun found to be traced") - self._end_trace() + llm_run = self.run_map[run_id] + llm_run.error = repr(error) + llm_run.end_time = datetime.utcnow() + self._end_trace(llm_run) def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + run_id: str = None, + parent_run_id: str = None, + **kwargs: Any, ) -> None: """Start a trace for a chain run.""" - if self._session is None: - raise TracerException( - "Initialize a session with `new_session()` before starting a trace." - ) + if self.session is None: + self.session = self.load_default_session() + + if run_id is None: + run_id = str(uuid4()) chain_run = ChainRun( + uuid=run_id, + parent_uuid=parent_run_id, serialized=serialized, inputs=inputs, extra=kwargs, start_time=datetime.utcnow(), - execution_order=self._execution_order, + execution_order=self.execution_order, child_runs=[], - session_id=self._session.id, - id=self._generate_id(), + session_id=self.session.id, ) self._start_trace(chain_run) - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + def on_chain_end( + self, outputs: Dict[str, Any], run_id: str = None, **kwargs: Any + ) -> None: """End a trace for a chain run.""" - if not self._stack or not isinstance(self._stack[-1], ChainRun): - raise TracerException("No ChainRun found to be traced") + if not run_id: + raise TracerException("No run_id provided for on_chain_end callback.") - self._stack[-1].end_time = datetime.utcnow() - self._stack[-1].outputs = outputs + if not self.run_map or not isinstance(self.run_map[run_id], ChainRun): + raise TracerException("No ChainRun found to be traced") - self._end_trace() + chain_run = self.run_map[run_id] + chain_run.outputs = outputs + chain_run.end_time = datetime.utcnow() + self._end_trace(chain_run) def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + error: Union[Exception, KeyboardInterrupt], + run_id: str = None, + **kwargs: Any, ) -> None: """Handle an error for a chain run.""" - if not self._stack or not isinstance(self._stack[-1], ChainRun): - raise TracerException("No ChainRun found to be traced") + if not run_id: + raise TracerException("No run_id provided for on_chain_error callback.") - self._stack[-1].end_time = datetime.utcnow() - self._stack[-1].error = repr(error) + if not self.run_map or not isinstance(self.run_map[run_id], ChainRun): + raise TracerException("No ChainRun found to be traced") - self._end_trace() + chain_run = self.run_map[run_id] + chain_run.error = repr(error) + chain_run.end_time = datetime.utcnow() + self._end_trace(chain_run) def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + self, + serialized: Dict[str, Any], + input_str: str, + run_id: str = None, + parent_run_id: str = None, + **kwargs: Any, ) -> None: """Start a trace for a tool run.""" - if self._session is None: - raise TracerException( - "Initialize a session with `new_session()` before starting a trace." - ) + if self.session is None: + self.session = self.load_default_session() + + if run_id is None: + run_id = str(uuid4()) tool_run = ToolRun( + uuid=run_id, + parent_uuid=parent_run_id, serialized=serialized, # TODO: this is duplicate info as above, not needed. action=str(serialized), tool_input=input_str, extra=kwargs, start_time=datetime.utcnow(), - execution_order=self._execution_order, + execution_order=self.execution_order, child_runs=[], - session_id=self._session.id, - id=self._generate_id(), + session_id=self.session.id, ) self._start_trace(tool_run) - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end(self, output: str, run_id: str = None, **kwargs: Any) -> None: """End a trace for a tool run.""" - if not self._stack or not isinstance(self._stack[-1], ToolRun): - raise TracerException("No ToolRun found to be traced") + if not run_id: + raise TracerException("No run_id provided for on_tool_end callback.") - self._stack[-1].end_time = datetime.utcnow() - self._stack[-1].output = output + if not self.run_map or not isinstance(self.run_map[run_id], ToolRun): + raise TracerException("No ToolRun found to be traced") - self._end_trace() + tool_run = self.run_map[run_id] + tool_run.output = output + tool_run.end_time = datetime.utcnow() + self._end_trace(tool_run) def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + error: Union[Exception, KeyboardInterrupt], + run_id: str = None, + **kwargs: Any, ) -> None: """Handle an error for a tool run.""" - if not self._stack or not isinstance(self._stack[-1], ToolRun): + if not run_id: + raise TracerException("No run_id provided for on_tool_error callback.") + + if not self.run_map or not isinstance(self.run_map[run_id], ToolRun): raise TracerException("No ToolRun found to be traced") - self._stack[-1].end_time = datetime.utcnow() - self._stack[-1].error = repr(error) - - self._end_trace() - - def on_text(self, text: str, **kwargs: Any) -> None: - """Handle a text message.""" - pass - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Handle an agent finish message.""" - pass - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Do nothing.""" - pass - - -class Tracer(BaseTracer, ABC): - """A non-thread safe implementation of the BaseTracer interface.""" - - def __init__(self) -> None: - """Initialize a tracer.""" - self._tracer_stack: List[Union[LLMRun, ChainRun, ToolRun]] = [] - self._tracer_execution_order = 1 - self._tracer_session: Optional[TracerSession] = None - - @property - def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]: - """Get the tracer stack.""" - return self._tracer_stack - - @property - def _execution_order(self) -> int: - """Get the execution order for a run.""" - return self._tracer_execution_order - - @_execution_order.setter - def _execution_order(self, value: int) -> None: - """Set the execution order for a run.""" - self._tracer_execution_order = value - - @property - def _session(self) -> Optional[TracerSession]: - """Get the tracing session.""" - return self._tracer_session - - @_session.setter - def _session(self, value: TracerSession) -> None: - """Set the tracing session.""" - if self._stack: - raise TracerException( - "Cannot set a session while a trace is being recorded" - ) - self._tracer_session = value + tool_run = self.run_map[run_id] + tool_run.error = repr(error) + tool_run.end_time = datetime.utcnow() + self._end_trace(tool_run) diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index a45ccd3b42e90..1cbf65e87a5f5 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -3,12 +3,11 @@ import logging import os -from abc import ABC from typing import Any, Dict, Optional, Union import requests -from langchain.callbacks.tracers.base import BaseTracer, Tracer +from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import ( ChainRun, LLMRun, @@ -18,14 +17,17 @@ ) -class BaseLangChainTracer(BaseTracer, ABC): +class LangChainTracer(BaseTracer): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - always_verbose: bool = True - _endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") - _headers: Dict[str, Any] = {"Content-Type": "application/json"} - if os.getenv("LANGCHAIN_API_KEY"): - _headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") + def __init__(self, **kwargs: Any) -> None: + """Initialize the LangChain tracer.""" + super().__init__(**kwargs) + self.session = self.load_default_session() + self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") + self._headers: Dict[str, Any] = {"Content-Type": "application/json"} + if os.getenv("LANGCHAIN_API_KEY"): + self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """Persist a run.""" @@ -59,58 +61,43 @@ def _persist_session(self, session_create: TracerSessionCreate) -> TracerSession session = TracerSession(id=1, **session_create.dict()) return session - def load_session(self, session_name: str) -> TracerSession: + def _load_session(self, session_name: Optional[str] = None) -> TracerSession: """Load a session from the tracer.""" try: - r = requests.get( - f"{self._endpoint}/sessions?name={session_name}", - headers=self._headers, - ) + url = f"{self._endpoint}/sessions" + if session_name: + url += f"?name={session_name}" + r = requests.get(url, headers=self._headers) + tracer_session = TracerSession(**r.json()[0]) - self._session = tracer_session - return tracer_session except Exception as e: + session_type = "default" if not session_name else session_name logging.warning( - f"Failed to load session {session_name}, using empty session: {e}" + f"Failed to load {session_type} session, using empty session: {e}" ) tracer_session = TracerSession(id=1) - self._session = tracer_session - return tracer_session + + self.session = tracer_session + return tracer_session + + def load_session(self, session_name: str) -> TracerSession: + """Load a session with the given name from the tracer.""" + return self._load_session(session_name) def load_default_session(self) -> TracerSession: """Load the default tracing session and set it as the Tracer's session.""" - try: - r = requests.get( - f"{self._endpoint}/sessions", - headers=self._headers, - ) - # Use the first session result - tracer_session = TracerSession(**r.json()[0]) - self._session = tracer_session - return tracer_session - except Exception as e: - logging.warning(f"Failed to default session, using empty session: {e}") - tracer_session = TracerSession(id=1) - self._session = tracer_session - return tracer_session - - def _add_child_run( - self, - parent_run: Union[ChainRun, ToolRun], - child_run: Union[LLMRun, ChainRun, ToolRun], - ) -> None: - """Add child run to a chain run or tool run.""" - if isinstance(child_run, LLMRun): - parent_run.child_llm_runs.append(child_run) - elif isinstance(child_run, ChainRun): - parent_run.child_chain_runs.append(child_run) - else: - parent_run.child_tool_runs.append(child_run) + return self._load_session("default") - def _generate_id(self) -> Optional[Union[int, str]]: - """Generate an id for a run.""" - return None + def __deepcopy__(self, memo): + """Deepcopy the tracer.""" - -class LangChainTracer(Tracer, BaseLangChainTracer): - """Tracer that records LangChain execution to LangChain endpoint.""" + # TODO: this is a hack to get tracing to work with the current backend + # we need to not use execution order, then remove this check + if self.execution_order == 1: + copy = LangChainTracer() + copy.session = self.session + copy.run_map = dict(self.run_map) + copy.execution_order = self.execution_order + return copy + else: + return self diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index bb77d747e7c8b..d5987ec7b8892 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -32,7 +32,8 @@ class TracerSession(TracerSessionBase): class BaseRun(BaseModel): """Base class for Run.""" - id: Optional[Union[int, str]] = None + uuid: str + parent_uuid: Optional[str] = None start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) extra: Optional[Dict[str, Any]] = None @@ -57,7 +58,6 @@ class ChainRun(BaseRun): child_llm_runs: List[LLMRun] = Field(default_factory=list) child_chain_runs: List[ChainRun] = Field(default_factory=list) child_tool_runs: List[ToolRun] = Field(default_factory=list) - child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list) class ToolRun(BaseRun): @@ -69,7 +69,6 @@ class ToolRun(BaseRun): child_llm_runs: List[LLMRun] = Field(default_factory=list) child_chain_runs: List[ChainRun] = Field(default_factory=list) child_tool_runs: List[ToolRun] = Field(default_factory=list) - child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list) ChainRun.update_forward_refs() diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 47f37b736818a..da0bd362207c0 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -5,12 +5,12 @@ from pydantic import Field, root_validator +from langchain.base_language import BaseLanguageModel from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.prompts import BasePromptTemplate from langchain.requests import TextRequestsWrapper -from langchain.schema import BaseLanguageModel class APIChain(Chain): diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 1b1837a37034a..fc1cf9d69d8e0 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -1,15 +1,23 @@ """Base interface that all chains should implement.""" +import inspect import json +import warnings from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict, List, Optional, Union import yaml -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, root_validator, validator import langchain -from langchain.callbacks import get_callback_manager from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForChainRun, + CallbackManager, + CallbackManagerForChainRun, + Callbacks, +) from langchain.schema import BaseMemory @@ -21,9 +29,8 @@ class Chain(BaseModel, ABC): """Base interface that all chains should implement.""" memory: Optional[BaseMemory] = None - callback_manager: BaseCallbackManager = Field( - default_factory=get_callback_manager, exclude=True - ) + callbacks: Callbacks = None + callback_manager: Optional[BaseCallbackManager] = None verbose: bool = Field( default_factory=_get_verbosity ) # Whether to print the response text @@ -37,15 +44,16 @@ class Config: def _chain_type(self) -> str: raise NotImplementedError("Saving not supported for this chain type.") - @validator("callback_manager", pre=True, always=True) - def set_callback_manager( - cls, callback_manager: Optional[BaseCallbackManager] - ) -> BaseCallbackManager: - """If callback manager is None, set it. - - This allows users to pass in None as callback manager, which is a nice UX. - """ - return callback_manager or get_callback_manager() + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values @validator("verbose", pre=True, always=True) def set_verbose(cls, verbose: Optional[bool]) -> bool: @@ -82,15 +90,26 @@ def _validate_outputs(self, outputs: Dict[str, str]) -> None: ) @abstractmethod - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: """Run the logic of this chain and return the output.""" - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, str], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: """Run the logic of this chain and return the output.""" raise NotImplementedError("Async call not supported for this chain type.") def __call__( - self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False + self, + inputs: Union[Dict[str, Any], Any], + return_only_outputs: bool = False, + callbacks: Callbacks = None, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -104,21 +123,31 @@ def __call__( """ inputs = self.prep_inputs(inputs) - self.callback_manager.on_chain_start( + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") + run_manager = callback_manager.on_chain_start( {"name": self.__class__.__name__}, inputs, - verbose=self.verbose, ) try: - outputs = self._call(inputs) + outputs = ( + self._call(inputs, run_manager=run_manager) + if new_arg_supported + else self._call(inputs) + ) except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_chain_error(e, verbose=self.verbose) + run_manager.on_chain_error(e) raise e - self.callback_manager.on_chain_end(outputs, verbose=self.verbose) + run_manager.on_chain_end(outputs) return self.prep_outputs(inputs, outputs, return_only_outputs) async def acall( - self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False + self, + inputs: Union[Dict[str, Any], Any], + return_only_outputs: bool = False, + callbacks: Callbacks = None, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -132,30 +161,24 @@ async def acall( """ inputs = self.prep_inputs(inputs) - if self.callback_manager.is_async: - await self.callback_manager.on_chain_start( - {"name": self.__class__.__name__}, - inputs, - verbose=self.verbose, - ) - else: - self.callback_manager.on_chain_start( - {"name": self.__class__.__name__}, - inputs, - verbose=self.verbose, - ) + callback_manager = AsyncCallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") + run_manager = await callback_manager.on_chain_start( + {"name": self.__class__.__name__}, + inputs, + ) try: - outputs = await self._acall(inputs) + outputs = ( + await self._acall(inputs, run_manager=run_manager) + if new_arg_supported + else await self._acall(inputs) + ) except (KeyboardInterrupt, Exception) as e: - if self.callback_manager.is_async: - await self.callback_manager.on_chain_error(e, verbose=self.verbose) - else: - self.callback_manager.on_chain_error(e, verbose=self.verbose) + await run_manager.on_chain_error(e) raise e - if self.callback_manager.is_async: - await self.callback_manager.on_chain_end(outputs, verbose=self.verbose) - else: - self.callback_manager.on_chain_end(outputs, verbose=self.verbose) + await run_manager.on_chain_end(outputs) return self.prep_outputs(inputs, outputs, return_only_outputs) def prep_outputs( @@ -195,11 +218,13 @@ def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]: self._validate_inputs(inputs) return inputs - def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: + def apply( + self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None + ) -> List[Dict[str, str]]: """Call the chain on all inputs in the list.""" - return [self(inputs) for inputs in input_list] + return [self(inputs, callbacks=callbacks) for inputs in input_list] - def run(self, *args: Any, **kwargs: Any) -> str: + def run(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str: """Run the chain as text in, text out or multiple variables, text out.""" if len(self.output_keys) != 1: raise ValueError( @@ -210,17 +235,17 @@ def run(self, *args: Any, **kwargs: Any) -> str: if args and not kwargs: if len(args) != 1: raise ValueError("`run` supports only one positional argument.") - return self(args[0])[self.output_keys[0]] + return self(args[0], callbacks=callbacks)[self.output_keys[0]] if kwargs and not args: - return self(kwargs)[self.output_keys[0]] + return self(kwargs, callbacks=callbacks)[self.output_keys[0]] raise ValueError( f"`run` supported with either positional arguments or keyword arguments" f" but not both. Got args: {args} and kwargs: {kwargs}." ) - async def arun(self, *args: Any, **kwargs: Any) -> str: + async def arun(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str: """Run the chain as text in, text out or multiple variables, text out.""" if len(self.output_keys) != 1: raise ValueError( @@ -231,10 +256,10 @@ async def arun(self, *args: Any, **kwargs: Any) -> str: if args and not kwargs: if len(args) != 1: raise ValueError("`run` supports only one positional argument.") - return (await self.acall(args[0]))[self.output_keys[0]] + return (await self.acall(args[0], callbacks=callbacks))[self.output_keys[0]] if kwargs and not args: - return (await self.acall(kwargs))[self.output_keys[0]] + return (await self.acall(kwargs, callbacks=callbacks))[self.output_keys[0]] raise ValueError( f"`run` supported with either positional arguments or keyword arguments" diff --git a/langchain/chains/constitutional_ai/base.py b/langchain/chains/constitutional_ai/base.py index b3ff12f5ed321..e6de521306989 100644 --- a/langchain/chains/constitutional_ai/base.py +++ b/langchain/chains/constitutional_ai/base.py @@ -1,13 +1,13 @@ """Chain for applying constitutional principles to the outputs of another chain.""" from typing import Any, Dict, List, Optional +from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.constitutional_ai.principles import PRINCIPLES from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT from langchain.chains.llm import LLMChain from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class ConstitutionalChain(Chain): diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index b7fb299e869d9..5792e2fb236ec 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -8,6 +8,7 @@ from pydantic import Extra, Field, root_validator +from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain @@ -15,7 +16,7 @@ from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, BaseMessage, BaseRetriever, Document +from langchain.schema import BaseMessage, BaseRetriever, Document from langchain.vectorstores.base import VectorStore # Depending on the memory type and configuration, the chat history format may differ. diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index eb3963222c92c..f2b8652bfc2b9 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -5,11 +5,17 @@ from pydantic import Extra +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + Callbacks, +) from langchain.chains.base import Chain from langchain.input import get_colored_text from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BaseLanguageModel, LLMResult, PromptValue +from langchain.schema import LLMResult, PromptValue class LLMChain(Chain): @@ -53,21 +59,39 @@ def output_keys(self) -> List[str]: """ return [self.output_key] - def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: - return self.apply([inputs])[0] - - def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + return self.apply([inputs], run_manager=run_manager)[0] + + def generate( + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> LLMResult: """Generate LLM result from inputs.""" - prompts, stop = self.prep_prompts(input_list) - return self.llm.generate_prompt(prompts, stop) - - async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult: + prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) + return self.llm.generate_prompt( + prompts, stop, callbacks=run_manager.get_child() if run_manager else None + ) + + async def agenerate( + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> LLMResult: """Generate LLM result from inputs.""" prompts, stop = await self.aprep_prompts(input_list) - return await self.llm.agenerate_prompt(prompts, stop) + return await self.llm.agenerate_prompt( + prompts, stop, callbacks=run_manager.get_child() if run_manager else None + ) def prep_prompts( - self, input_list: List[Dict[str, Any]] + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Tuple[List[PromptValue], Optional[List[str]]]: """Prepare prompts from inputs.""" stop = None @@ -79,7 +103,8 @@ def prep_prompts( prompt = self.prompt.format_prompt(**selected_inputs) _colored_text = get_colored_text(prompt.to_string(), "green") _text = "Prompt after formatting:\n" + _colored_text - self.callback_manager.on_text(_text, end="\n", verbose=self.verbose) + if run_manager: + run_manager.on_text(_text, end="\n", verbose=self.verbose) if "stop" in inputs and inputs["stop"] != stop: raise ValueError( "If `stop` is present in any inputs, should be present in all." @@ -88,7 +113,9 @@ def prep_prompts( return prompts, stop async def aprep_prompts( - self, input_list: List[Dict[str, Any]] + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Tuple[List[PromptValue], Optional[List[str]]]: """Prepare prompts from inputs.""" stop = None @@ -100,12 +127,8 @@ async def aprep_prompts( prompt = self.prompt.format_prompt(**selected_inputs) _colored_text = get_colored_text(prompt.to_string(), "green") _text = "Prompt after formatting:\n" + _colored_text - if self.callback_manager.is_async: - await self.callback_manager.on_text( - _text, end="\n", verbose=self.verbose - ) - else: - self.callback_manager.on_text(_text, end="\n", verbose=self.verbose) + if run_manager: + await run_manager.on_text(_text, end="\n", verbose=self.verbose) if "stop" in inputs and inputs["stop"] != stop: raise ValueError( "If `stop` is present in any inputs, should be present in all." @@ -113,14 +136,22 @@ async def aprep_prompts( prompts.append(prompt) return prompts, stop - def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: + def apply( + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> List[Dict[str, str]]: """Utilize the LLM generate method for speed gains.""" - response = self.generate(input_list) + response = self.generate(input_list, run_manager=run_manager) return self.create_outputs(response) - async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: + async def aapply( + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> List[Dict[str, str]]: """Utilize the LLM generate method for speed gains.""" - response = await self.agenerate(input_list) + response = await self.agenerate(input_list, run_manager=run_manager) return self.create_outputs(response) def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]: @@ -131,13 +162,18 @@ def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]: for generation in response.generations ] - async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: - return (await self.aapply([inputs]))[0] + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + return (await self.aapply([inputs], run_manager=run_manager))[0] - def predict(self, **kwargs: Any) -> str: + def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str: """Format prompt with kwargs and pass to LLM. Args: + callbacks: Callbacks to pass to LLMChain **kwargs: Keys to pass to prompt template. Returns: @@ -148,12 +184,13 @@ def predict(self, **kwargs: Any) -> str: completion = llm.predict(adjective="funny") """ - return self(kwargs)[self.output_key] + return self(kwargs, callbacks=callbacks)[self.output_key] - async def apredict(self, **kwargs: Any) -> str: + async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str: """Format prompt with kwargs and pass to LLM. Args: + callbacks: Callbacks to pass to LLMChain **kwargs: Keys to pass to prompt template. Returns: @@ -164,31 +201,35 @@ async def apredict(self, **kwargs: Any) -> str: completion = llm.predict(adjective="funny") """ - return (await self.acall(kwargs))[self.output_key] + return (await self.acall(kwargs, callbacks=callbacks))[self.output_key] - def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]: + def predict_and_parse( + self, callbacks: Callbacks = None, **kwargs: Any + ) -> Union[str, List[str], Dict[str, str]]: """Call predict and then parse the results.""" - result = self.predict(**kwargs) + result = self.predict(callbacks=callbacks, **kwargs) if self.prompt.output_parser is not None: return self.prompt.output_parser.parse(result) else: return result async def apredict_and_parse( - self, **kwargs: Any + self, callbacks: Callbacks = None, **kwargs: Any ) -> Union[str, List[str], Dict[str, str]]: """Call apredict and then parse the results.""" - result = await self.apredict(**kwargs) + result = await self.apredict(callbacks=callbacks, **kwargs) if self.prompt.output_parser is not None: return self.prompt.output_parser.parse(result) else: return result def apply_and_parse( - self, input_list: List[Dict[str, Any]] + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Sequence[Union[str, List[str], Dict[str, str]]]: """Call apply and then parse the results.""" - result = self.apply(input_list) + result = self.apply(input_list, run_manager=run_manager) return self._parse_result(result) def _parse_result( @@ -202,10 +243,12 @@ def _parse_result( return result async def aapply_and_parse( - self, input_list: List[Dict[str, Any]] + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Sequence[Union[str, List[str], Dict[str, str]]]: """Call apply and then parse the results.""" - result = await self.aapply(input_list) + result = await self.aapply(input_list, run_manager=run_manager) return self._parse_result(result) @property diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 9a9f44b7587bb..8883255027489 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -3,11 +3,11 @@ from pydantic import Extra +from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.prompt import PROMPT from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel from langchain.utilities.bash import BashProcess diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index 6c5c905d52b67..ff7c000dda3fd 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -1,16 +1,20 @@ """Chain that interprets a prompt and executes python code to do math.""" import math import re -from typing import Dict, List +from typing import Dict, List, Optional import numexpr from pydantic import Extra +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.prompt import PROMPT from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class LLMMathChain(Chain): @@ -68,15 +72,19 @@ def _evaluate_expression(self, expression: str) -> str: # Remove any leading and trailing brackets from the output return re.sub(r"^\[|\]$", "", output) - def _process_llm_result(self, llm_output: str) -> Dict[str, str]: - self.callback_manager.on_text(llm_output, color="green", verbose=self.verbose) + def _process_llm_result( + self, llm_output: str, run_manager: Optional[CallbackManagerForChainRun] = None + ) -> Dict[str, str]: + if run_manager: + run_manager.on_text(llm_output, color="green", verbose=self.verbose) llm_output = llm_output.strip() text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) if text_match: expression = text_match.group(1) output = self._evaluate_expression(expression) - self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) - self.callback_manager.on_text(output, color="yellow", verbose=self.verbose) + if run_manager: + run_manager.on_text("\nAnswer: ", verbose=self.verbose) + run_manager.on_text(output, color="yellow", verbose=self.verbose) answer = "Answer: " + output elif llm_output.startswith("Answer:"): answer = llm_output @@ -86,30 +94,21 @@ def _process_llm_result(self, llm_output: str) -> Dict[str, str]: raise ValueError(f"unknown format from LLM: {llm_output}") return {self.output_key: answer} - async def _aprocess_llm_result(self, llm_output: str) -> Dict[str, str]: - if self.callback_manager.is_async: - await self.callback_manager.on_text( - llm_output, color="green", verbose=self.verbose - ) - else: - self.callback_manager.on_text( - llm_output, color="green", verbose=self.verbose - ) + async def _aprocess_llm_result( + self, + llm_output: str, + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + if run_manager: + await run_manager.on_text(llm_output, color="green", verbose=self.verbose) llm_output = llm_output.strip() text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) if text_match: expression = text_match.group(1) output = self._evaluate_expression(expression) - if self.callback_manager.is_async: - await self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) - await self.callback_manager.on_text( - output, color="yellow", verbose=self.verbose - ) - else: - await self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) - await self.callback_manager.on_text( - output, color="yellow", verbose=self.verbose - ) + if run_manager: + await run_manager.on_text("\nAnswer: ", verbose=self.verbose) + await run_manager.on_text(output, color="yellow", verbose=self.verbose) answer = "Answer: " + output elif llm_output.startswith("Answer:"): answer = llm_output @@ -119,30 +118,35 @@ async def _aprocess_llm_result(self, llm_output: str) -> Dict[str, str]: raise ValueError(f"unknown format from LLM: {llm_output}") return {self.output_key: answer} - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: - llm_executor = LLMChain( - prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager - ) - self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose) + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + llm_executor = LLMChain(prompt=self.prompt, llm=self.llm) + if run_manager: + run_manager.on_text(inputs[self.input_key]) llm_output = llm_executor.predict( - question=inputs[self.input_key], stop=["```output"] - ) - return self._process_llm_result(llm_output) - - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: - llm_executor = LLMChain( - prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager + question=inputs[self.input_key], + stop=["```output"], + callbacks=run_manager.get_child() if run_manager else None, ) - if self.callback_manager.is_async: - await self.callback_manager.on_text( - inputs[self.input_key], verbose=self.verbose - ) - else: - self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose) + return self._process_llm_result(llm_output, run_manager=run_manager) + + async def _acall( + self, + inputs: Dict[str, str], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + llm_executor = LLMChain(prompt=self.prompt, llm=self.llm) + if run_manager: + await run_manager.on_text(inputs[self.input_key]) llm_output = await llm_executor.apredict( - question=inputs[self.input_key], stop=["```output"] + question=inputs[self.input_key], + stop=["```output"], + callbacks=run_manager.get_child() if run_manager else None, ) - return await self._aprocess_llm_result(llm_output) + return await self._aprocess_llm_result(llm_output, run_manager=run_manager) @property def _chain_type(self) -> str: diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 0d15b90be7681..6c8cfa4706e4c 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -8,12 +8,12 @@ from pydantic import Extra +from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain.chains.pal.math_prompt import MATH_PROMPT from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel from langchain.utilities import PythonREPL diff --git a/langchain/chains/prompt_selector.py b/langchain/chains/prompt_selector.py index 190907cc1490f..e40e4f8a0b4a6 100644 --- a/langchain/chains/prompt_selector.py +++ b/langchain/chains/prompt_selector.py @@ -3,10 +3,10 @@ from pydantic import BaseModel, Field +from langchain.base_language import BaseLanguageModel from langchain.chat_models.base import BaseChatModel from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class BasePromptSelector(BaseModel, ABC): diff --git a/langchain/chains/qa_generation/base.py b/langchain/chains/qa_generation/base.py index 66907befaec23..2dfb4b5d29818 100644 --- a/langchain/chains/qa_generation/base.py +++ b/langchain/chains/qa_generation/base.py @@ -5,11 +5,11 @@ from pydantic import Field +from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index 5c6317edd0c31..10b5d96f9c858 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -8,6 +8,7 @@ from pydantic import Extra, root_validator +from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -21,7 +22,6 @@ ) from langchain.docstore.document import Document from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class BaseQAWithSourcesChain(Chain, ABC): diff --git a/langchain/chains/qa_with_sources/loading.py b/langchain/chains/qa_with_sources/loading.py index c1d923aefb4e3..2ce4c56ea25ca 100644 --- a/langchain/chains/qa_with_sources/loading.py +++ b/langchain/chains/qa_with_sources/loading.py @@ -1,6 +1,7 @@ """Load question answering with sources chains.""" from typing import Any, Mapping, Optional, Protocol +from langchain.base_language import BaseLanguageModel from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain @@ -14,7 +15,6 @@ ) from langchain.chains.question_answering import map_rerank_prompt from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class LoadingCallable(Protocol): diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 2ba684f07af96..95c24f0a1ffba 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -1,6 +1,7 @@ """Load question answering chains.""" from typing import Any, Mapping, Optional, Protocol +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -15,7 +16,6 @@ stuff_prompt, ) from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class LoadingCallable(Protocol): diff --git a/langchain/chains/retrieval_qa/base.py b/langchain/chains/retrieval_qa/base.py index dc1d68bfdf07d..7e30f4720467a 100644 --- a/langchain/chains/retrieval_qa/base.py +++ b/langchain/chains/retrieval_qa/base.py @@ -7,6 +7,7 @@ from pydantic import Extra, Field, root_validator +from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain @@ -14,7 +15,7 @@ from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR from langchain.prompts import PromptTemplate -from langchain.schema import BaseLanguageModel, BaseRetriever, Document +from langchain.schema import BaseRetriever, Document from langchain.vectorstores.base import VectorStore diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index aa1a2e6dc5608..3d8761b211dc6 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -5,11 +5,11 @@ from pydantic import Extra, Field +from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel from langchain.sql_database import SQLDatabase diff --git a/langchain/chains/summarize/__init__.py b/langchain/chains/summarize/__init__.py index c31fda479f9d1..6fc835dd0f955 100644 --- a/langchain/chains/summarize/__init__.py +++ b/langchain/chains/summarize/__init__.py @@ -1,6 +1,7 @@ """Load summarizing chains.""" from typing import Any, Mapping, Optional, Protocol +from langchain.base_language import BaseLanguageModel from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain @@ -8,7 +9,6 @@ from langchain.chains.llm import LLMChain from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class LoadingCallable(Protocol): diff --git a/langchain/chat_models/anthropic.py b/langchain/chat_models/anthropic.py index f56c606361eca..daed935bce7cc 100644 --- a/langchain/chat_models/anthropic.py +++ b/langchain/chat_models/anthropic.py @@ -2,6 +2,10 @@ from pydantic import Extra +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.chat_models.base import BaseChatModel from langchain.llms.anthropic import _AnthropicCommon from langchain.schema import ( @@ -85,7 +89,10 @@ def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str: ) # trim off the trailing ' ' that might come from the "Assistant: " def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> ChatResult: prompt = self._convert_messages_to_prompt(messages) params: Dict[str, Any] = {"prompt": prompt, **self._default_params} @@ -98,10 +105,10 @@ def _generate( for data in stream_resp: delta = data["completion"][len(completion) :] completion = data["completion"] - self.callback_manager.on_llm_new_token( - delta, - verbose=self.verbose, - ) + if run_manager: + run_manager.on_llm_new_token( + delta, + ) else: response = self.client.completion(**params) completion = response["completion"] @@ -109,7 +116,10 @@ def _generate( return ChatResult(generations=[ChatGeneration(message=message)]) async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> ChatResult: prompt = self._convert_messages_to_prompt(messages) params: Dict[str, Any] = {"prompt": prompt, **self._default_params} @@ -122,15 +132,9 @@ async def _agenerate( async for data in stream_resp: delta = data["completion"][len(completion) :] completion = data["completion"] - if self.callback_manager.is_async: - await self.callback_manager.on_llm_new_token( - delta, - verbose=self.verbose, - ) - else: - self.callback_manager.on_llm_new_token( + if run_manager: + await run_manager.on_llm_new_token( delta, - verbose=self.verbose, ) else: response = await self.client.acompletion(**params) diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 91816de1648f8..fbcc08b10ec73 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -1,21 +1,30 @@ import asyncio +import inspect +import warnings from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Dict, List, Optional -from pydantic import Extra, Field, validator +from pydantic import Extra, Field, root_validator import langchain -from langchain.callbacks import get_callback_manager +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForLLMRun, + CallbackManager, + CallbackManagerForLLMRun, + Callbacks, +) from langchain.schema import ( AIMessage, - BaseLanguageModel, BaseMessage, ChatGeneration, ChatResult, HumanMessage, LLMResult, PromptValue, + get_buffer_string, ) @@ -26,7 +35,19 @@ def _get_verbosity() -> bool: class BaseChatModel(BaseLanguageModel, ABC): verbose: bool = Field(default_factory=_get_verbosity) """Whether to print out response text.""" - callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) + callbacks: Callbacks = None + callback_manager: Optional[BaseCallbackManager] = None + + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values class Config: """Configuration for this pydantic object.""" @@ -34,98 +55,130 @@ class Config: extra = Extra.forbid arbitrary_types_allowed = True - @validator("callback_manager", pre=True, always=True) - def set_callback_manager( - cls, callback_manager: Optional[BaseCallbackManager] - ) -> BaseCallbackManager: - """If callback manager is None, set it. - - This allows users to pass in None as callback manager, which is a nice UX. - """ - return callback_manager or get_callback_manager() - def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: return {} def generate( - self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: """Top Level call""" - results = [self._generate(m, stop=stop) for m in messages] + + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + message_strings = [get_buffer_string(m) for m in messages] + run_manager = callback_manager.on_llm_start( + {"name": self.__class__.__name__}, message_strings + ) + + new_arg_supported = inspect.signature(self._generate).parameters.get( + "run_manager" + ) + try: + results = [ + self._generate(m, stop=stop, run_manager=run_manager) + if new_arg_supported + else self._generate(m, stop=stop) + for m in messages + ] + except (KeyboardInterrupt, Exception) as e: + run_manager.on_llm_error(e) + raise e llm_output = self._combine_llm_outputs([res.llm_output for res in results]) generations = [res.generations for res in results] - return LLMResult(generations=generations, llm_output=llm_output) + output = LLMResult(generations=generations, llm_output=llm_output) + run_manager.on_llm_end(output) + return output async def agenerate( - self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: """Top Level call""" - results = await asyncio.gather( - *[self._agenerate(m, stop=stop) for m in messages] + + callback_manager = AsyncCallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + message_strings = [get_buffer_string(m) for m in messages] + run_manager = await callback_manager.on_llm_start( + {"name": self.__class__.__name__}, message_strings ) - llm_output = self._combine_llm_outputs([res.llm_output for res in results]) - generations = [res.generations for res in results] - return LLMResult(generations=generations, llm_output=llm_output) - def generate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None - ) -> LLMResult: - prompt_messages = [p.to_messages() for p in prompts] - prompt_strings = [p.to_string() for p in prompts] - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose + new_arg_supported = inspect.signature(self._agenerate).parameters.get( + "run_manager" ) try: - output = self.generate(prompt_messages, stop=stop) + results = await asyncio.gather( + *[ + self._agenerate(m, stop=stop, run_manager=run_manager) + if new_arg_supported + else self._agenerate(m, stop=stop) + for m in messages + ] + ) except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_llm_error(e, verbose=self.verbose) + await run_manager.on_llm_error(e) raise e - self.callback_manager.on_llm_end(output, verbose=self.verbose) + llm_output = self._combine_llm_outputs([res.llm_output for res in results]) + generations = [res.generations for res in results] + output = LLMResult(generations=generations, llm_output=llm_output) + await run_manager.on_llm_end(output) return output + def generate_prompt( + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + ) -> LLMResult: + prompt_messages = [p.to_messages() for p in prompts] + return self.generate(prompt_messages, stop=stop, callbacks=callbacks) + async def agenerate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: prompt_messages = [p.to_messages() for p in prompts] - prompt_strings = [p.to_string() for p in prompts] - if self.callback_manager.is_async: - await self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose - ) - else: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose - ) - try: - output = await self.agenerate(prompt_messages, stop=stop) - except (KeyboardInterrupt, Exception) as e: - if self.callback_manager.is_async: - await self.callback_manager.on_llm_error(e, verbose=self.verbose) - else: - self.callback_manager.on_llm_error(e, verbose=self.verbose) - raise e - if self.callback_manager.is_async: - await self.callback_manager.on_llm_end(output, verbose=self.verbose) - else: - self.callback_manager.on_llm_end(output, verbose=self.verbose) - return output + return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks) @abstractmethod def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> ChatResult: """Top Level call""" @abstractmethod async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> ChatResult: """Top Level call""" def __call__( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> BaseMessage: - return self._generate(messages, stop=stop).generations[0].message + generation = self.generate( + [messages], stop=stop, callbacks=callbacks + ).generations[0][0] + if isinstance(generation, ChatGeneration): + return generation.message + else: + raise ValueError("Unexpected generation type") def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str: result = self([HumanMessage(content=message)], stop=stop) @@ -134,15 +187,21 @@ def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str: class SimpleChatModel(BaseChatModel): def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> ChatResult: - output_str = self._call(messages, stop=stop) + output_str = self._call(messages, stop=stop, run_manager=run_manager) message = AIMessage(content=output_str) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) @abstractmethod def _call( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: """Simpler interface.""" diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 4bbf1ee816c93..cd5efb1c3241a 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -14,6 +14,10 @@ wait_exponential, ) +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.chat_models.base import BaseChatModel from langchain.schema import ( AIMessage, @@ -242,7 +246,10 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: return {"token_usage": overall_token_usage, "model_name": self.model_name} def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> ChatResult: message_dicts, params = self._create_message_dicts(messages, stop) if self.streaming: @@ -255,10 +262,8 @@ def _generate( role = stream_resp["choices"][0]["delta"].get("role", role) token = stream_resp["choices"][0]["delta"].get("content", "") inner_completion += token - self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) + if run_manager: + run_manager.on_llm_new_token(token) message = _convert_dict_to_message( {"content": inner_completion, "role": role} ) @@ -287,7 +292,10 @@ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: return ChatResult(generations=generations, llm_output=llm_output) async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> ChatResult: message_dicts, params = self._create_message_dicts(messages, stop) if self.streaming: @@ -300,16 +308,8 @@ async def _agenerate( role = stream_resp["choices"][0]["delta"].get("role", role) token = stream_resp["choices"][0]["delta"].get("content", "") inner_completion += token - if self.callback_manager.is_async: - await self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) - else: - self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) + if run_manager: + await run_manager.on_llm_new_token(token) message = _convert_dict_to_message( {"content": inner_completion, "role": role} ) diff --git a/langchain/chat_models/promptlayer_openai.py b/langchain/chat_models/promptlayer_openai.py index 38b664162d3bf..6f9b9a08db9c6 100644 --- a/langchain/chat_models/promptlayer_openai.py +++ b/langchain/chat_models/promptlayer_openai.py @@ -2,6 +2,10 @@ import datetime from typing import List, Optional +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.chat_models import ChatOpenAI from langchain.schema import BaseMessage, ChatResult @@ -33,13 +37,16 @@ class PromptLayerChatOpenAI(ChatOpenAI): return_pl_id: Optional[bool] = False def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> ChatResult: """Call ChatOpenAI generate and then call PromptLayer API to log the request.""" from promptlayer.utils import get_api_key, promptlayer_api_request request_start_time = datetime.datetime.now().timestamp() - generated_responses = super()._generate(messages, stop) + generated_responses = super()._generate(messages, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() message_dicts, params = super()._create_message_dicts(messages, stop) for i, generation in enumerate(generated_responses.generations): @@ -67,13 +74,16 @@ def _generate( return generated_responses async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> ChatResult: """Call ChatOpenAI agenerate and then call PromptLayer to log.""" from promptlayer.utils import get_api_key, promptlayer_api_request_async request_start_time = datetime.datetime.now().timestamp() - generated_responses = await super()._agenerate(messages, stop) + generated_responses = await super()._agenerate(messages, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() message_dicts, params = super()._create_message_dicts(messages, stop) for i, generation in enumerate(generated_responses.generations): diff --git a/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py b/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py index fffb413ace5fa..3b7ce122f2831 100644 --- a/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py +++ b/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field +from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.experimental.autonomous_agents.baby_agi.task_creation import ( TaskCreationChain, @@ -13,7 +14,6 @@ from langchain.experimental.autonomous_agents.baby_agi.task_prioritization import ( TaskPrioritizationChain, ) -from langchain.schema import BaseLanguageModel from langchain.vectorstores.base import VectorStore diff --git a/langchain/experimental/autonomous_agents/baby_agi/task_creation.py b/langchain/experimental/autonomous_agents/baby_agi/task_creation.py index 122b0dbf9889d..d3a1dc81567bf 100644 --- a/langchain/experimental/autonomous_agents/baby_agi/task_creation.py +++ b/langchain/experimental/autonomous_agents/baby_agi/task_creation.py @@ -1,5 +1,5 @@ from langchain import LLMChain, PromptTemplate -from langchain.schema import BaseLanguageModel +from langchain.base_language import BaseLanguageModel class TaskCreationChain(LLMChain): diff --git a/langchain/experimental/autonomous_agents/baby_agi/task_execution.py b/langchain/experimental/autonomous_agents/baby_agi/task_execution.py index b85619f866cb2..aac943c03fe4c 100644 --- a/langchain/experimental/autonomous_agents/baby_agi/task_execution.py +++ b/langchain/experimental/autonomous_agents/baby_agi/task_execution.py @@ -1,5 +1,5 @@ from langchain import LLMChain, PromptTemplate -from langchain.schema import BaseLanguageModel +from langchain.base_language import BaseLanguageModel class TaskExecutionChain(LLMChain): diff --git a/langchain/experimental/autonomous_agents/baby_agi/task_prioritization.py b/langchain/experimental/autonomous_agents/baby_agi/task_prioritization.py index 19e9d79a0be9a..d8b44c585d49b 100644 --- a/langchain/experimental/autonomous_agents/baby_agi/task_prioritization.py +++ b/langchain/experimental/autonomous_agents/baby_agi/task_prioritization.py @@ -1,5 +1,5 @@ from langchain import LLMChain, PromptTemplate -from langchain.schema import BaseLanguageModel +from langchain.base_language import BaseLanguageModel class TaskPrioritizationChain(LLMChain): diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index 4ec0326a1bd50..181adb0bc0c1e 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -4,6 +4,7 @@ import requests from pydantic import BaseModel, Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -106,7 +107,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "ai21" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to AI21's complete endpoint. Args: diff --git a/langchain/llms/aleph_alpha.py b/langchain/llms/aleph_alpha.py index dd17bc44d6c9c..bcdbebf8ad754 100644 --- a/langchain/llms/aleph_alpha.py +++ b/langchain/llms/aleph_alpha.py @@ -3,6 +3,7 @@ from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -200,7 +201,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "alpeh_alpha" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to Aleph Alpha's completion endpoint. Args: diff --git a/langchain/llms/anthropic.py b/langchain/llms/anthropic.py index e609627967ead..27f83469b1520 100644 --- a/langchain/llms/anthropic.py +++ b/langchain/llms/anthropic.py @@ -4,6 +4,10 @@ from pydantic import BaseModel, Extra, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -142,7 +146,12 @@ def _wrap_prompt(self, prompt: str) -> str: # As a last resort, wrap the prompt ourselves to emulate instruct-style. return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: r"""Call out to Anthropic's completion endpoint. Args: @@ -171,9 +180,8 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: for data in stream_resp: delta = data["completion"][len(current_completion) :] current_completion = data["completion"] - self.callback_manager.on_llm_new_token( - delta, verbose=self.verbose, **data - ) + if run_manager: + run_manager.on_llm_new_token(delta, **data) return current_completion response = self.client.completion( prompt=self._wrap_prompt(prompt), @@ -182,7 +190,12 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: ) return response["completion"] - async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + ) -> str: """Call out to Anthropic's completion endpoint asynchronously.""" stop = self._get_anthropic_stop(stop) if self.streaming: @@ -195,14 +208,8 @@ async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: async for data in stream_resp: delta = data["completion"][len(current_completion) :] current_completion = data["completion"] - if self.callback_manager.is_async: - await self.callback_manager.on_llm_new_token( - delta, verbose=self.verbose, **data - ) - else: - self.callback_manager.on_llm_new_token( - delta, verbose=self.verbose, **data - ) + if run_manager: + await run_manager.on_llm_new_token(delta, **data) return current_completion response = await self.client.acompletion( prompt=self._wrap_prompt(prompt), diff --git a/langchain/llms/bananadev.py b/langchain/llms/bananadev.py index 697ebcc79af9b..8d95c1edd28a7 100644 --- a/langchain/llms/bananadev.py +++ b/langchain/llms/bananadev.py @@ -4,6 +4,7 @@ from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -80,7 +81,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "banana" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call to Banana endpoint.""" try: import banana_dev as banana diff --git a/langchain/llms/base.py b/langchain/llms/base.py index dd2397928f0ce..cdfe366e560b8 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -1,16 +1,25 @@ """Base interface for large language models to expose.""" +import inspect import json +import warnings from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import yaml -from pydantic import Extra, Field, validator +from pydantic import Extra, Field, root_validator, validator import langchain -from langchain.callbacks import get_callback_manager +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager -from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForLLMRun, + CallbackManager, + CallbackManagerForLLMRun, + Callbacks, +) +from langchain.schema import Generation, LLMResult, PromptValue def _get_verbosity() -> bool: @@ -59,7 +68,8 @@ class BaseLLM(BaseLanguageModel, ABC): cache: Optional[bool] = None verbose: bool = Field(default_factory=_get_verbosity) """Whether to print out response text.""" - callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) + callbacks: Callbacks = None + callback_manager: Optional[BaseCallbackManager] = None class Config: """Configuration for this pydantic object.""" @@ -67,15 +77,16 @@ class Config: extra = Extra.forbid arbitrary_types_allowed = True - @validator("callback_manager", pre=True, always=True) - def set_callback_manager( - cls, callback_manager: Optional[BaseCallbackManager] - ) -> BaseCallbackManager: - """If callback manager is None, set it. - - This allows users to pass in None as callback manager, which is a nice UX. - """ - return callback_manager or get_callback_manager() + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values @validator("verbose", pre=True, always=True) def set_verbose(cls, verbose: Optional[bool]) -> bool: @@ -90,30 +101,45 @@ def set_verbose(cls, verbose: Optional[bool]) -> bool: @abstractmethod def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: """Run the LLM on the given prompts.""" @abstractmethod async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: """Run the LLM on the given prompts.""" def generate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] - return self.generate(prompt_strings, stop=stop) + return self.generate(prompt_strings, stop=stop, callbacks=callbacks) async def agenerate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] - return await self.agenerate(prompt_strings, stop=stop) + return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks) def generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: """Run the LLM on the given prompt and input.""" # If string is passed in directly no errors will be raised but outputs will @@ -124,21 +150,31 @@ def generate( f" argument of type {type(prompts)}." ) disregard_cache = self.cache is not None and not self.cache + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + new_arg_supported = inspect.signature(self._generate).parameters.get( + "run_manager" + ) if langchain.llm_cache is None or disregard_cache: # This happens when langchain.cache is None, but self.cache is True if self.cache is not None and self.cache: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts, verbose=self.verbose + run_manager = callback_manager.on_llm_start( + {"name": self.__class__.__name__}, prompts ) try: - output = self._generate(prompts, stop=stop) + output = ( + self._generate(prompts, stop=stop, run_manager=run_manager) + if new_arg_supported + else self._generate(prompts, stop=stop) + ) except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_llm_error(e, verbose=self.verbose) + run_manager.on_llm_error(e) raise e - self.callback_manager.on_llm_end(output, verbose=self.verbose) + run_manager.on_llm_end(output) return output params = self.dict() params["stop"] = stop @@ -149,15 +185,19 @@ def generate( missing_prompts, ) = get_prompts(params, prompts) if len(missing_prompts) > 0: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose + run_manager = self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, missing_prompts ) try: - new_results = self._generate(missing_prompts, stop=stop) + new_results = ( + self._generate(missing_prompts, stop=stop, run_manager=run_manager) + if new_arg_supported + else self._generate(missing_prompts, stop=stop) + ) except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_llm_error(e, verbose=self.verbose) + run_manager.on_llm_error(e) raise e - self.callback_manager.on_llm_end(new_results, verbose=self.verbose) + run_manager.on_llm_end(new_results) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) @@ -167,36 +207,38 @@ def generate( return LLMResult(generations=generations, llm_output=llm_output) async def agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: """Run the LLM on the given prompt and input.""" disregard_cache = self.cache is not None and not self.cache + callback_manager = AsyncCallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + new_arg_supported = inspect.signature(self._agenerate).parameters.get( + "run_manager" + ) if langchain.llm_cache is None or disregard_cache: # This happens when langchain.cache is None, but self.cache is True if self.cache is not None and self.cache: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - if self.callback_manager.is_async: - await self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts, verbose=self.verbose - ) - else: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts, verbose=self.verbose - ) + run_manager = await callback_manager.on_llm_start( + {"name": self.__class__.__name__}, prompts + ) try: - output = await self._agenerate(prompts, stop=stop) + output = ( + await self._agenerate(prompts, stop=stop, run_manager=run_manager) + if new_arg_supported + else await self._agenerate(prompts, stop=stop) + ) except (KeyboardInterrupt, Exception) as e: - if self.callback_manager.is_async: - await self.callback_manager.on_llm_error(e, verbose=self.verbose) - else: - self.callback_manager.on_llm_error(e, verbose=self.verbose) + await run_manager.on_llm_error(e, verbose=self.verbose) raise e - if self.callback_manager.is_async: - await self.callback_manager.on_llm_end(output, verbose=self.verbose) - else: - self.callback_manager.on_llm_end(output, verbose=self.verbose) + await run_manager.on_llm_end(output, verbose=self.verbose) return output params = self.dict() params["stop"] = stop @@ -207,32 +249,22 @@ async def agenerate( missing_prompts, ) = get_prompts(params, prompts) if len(missing_prompts) > 0: - if self.callback_manager.is_async: - await self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, - missing_prompts, - verbose=self.verbose, - ) - else: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, - missing_prompts, - verbose=self.verbose, - ) + run_manager = await callback_manager.on_llm_start( + {"name": self.__class__.__name__}, + missing_prompts, + ) try: - new_results = await self._agenerate(missing_prompts, stop=stop) + new_results = ( + await self._agenerate( + missing_prompts, stop=stop, run_manager=run_manager + ) + if new_arg_supported + else await self._agenerate(missing_prompts, stop=stop) + ) except (KeyboardInterrupt, Exception) as e: - if self.callback_manager.is_async: - await self.callback_manager.on_llm_error(e, verbose=self.verbose) - else: - self.callback_manager.on_llm_error(e, verbose=self.verbose) + await run_manager.on_llm_error(e) raise e - if self.callback_manager.is_async: - await self.callback_manager.on_llm_end( - new_results, verbose=self.verbose - ) - else: - self.callback_manager.on_llm_end(new_results, verbose=self.verbose) + await run_manager.on_llm_end(new_results) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) @@ -241,9 +273,15 @@ async def agenerate( generations = [existing_prompts[i] for i in range(len(prompts))] return LLMResult(generations=generations, llm_output=llm_output) - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def __call__( + self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None + ) -> str: """Check Cache and run the LLM on the given prompt and input.""" - return self.generate([prompt], stop=stop).generations[0][0].text + return ( + self.generate([prompt], stop=stop, callbacks=callbacks) + .generations[0][0] + .text + ) @property def _identifying_params(self) -> Mapping[str, Any]: @@ -307,30 +345,56 @@ class LLM(BaseLLM): """ @abstractmethod - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Run the LLM on the given prompt and input.""" - async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + ) -> str: """Run the LLM on the given prompt and input.""" raise NotImplementedError("Async generation not implemented for this LLM.") def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: """Run the LLM on the given prompt and input.""" # TODO: add caching here. generations = [] + new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") for prompt in prompts: - text = self._call(prompt, stop=stop) + text = ( + self._call(prompt, stop=stop, run_manager=run_manager) + if new_arg_supported + else self._call(prompt, stop=stop) + ) generations.append([Generation(text=text)]) return LLMResult(generations=generations) async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: """Run the LLM on the given prompt and input.""" generations = [] + new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") for prompt in prompts: - text = await self._acall(prompt, stop=stop) + text = ( + await self._acall(prompt, stop=stop, run_manager=run_manager) + if new_arg_supported + else await self._acall(prompt, stop=stop) + ) generations.append([Generation(text=text)]) return LLMResult(generations=generations) diff --git a/langchain/llms/cerebriumai.py b/langchain/llms/cerebriumai.py index 2937d7ffc9329..3da3dfbc73a0c 100644 --- a/langchain/llms/cerebriumai.py +++ b/langchain/llms/cerebriumai.py @@ -4,6 +4,7 @@ from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -81,7 +82,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "cerebriumai" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call to CerebriumAI endpoint.""" try: from cerebrium import model_api_request diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index 91894d1b575b8..2eff193b78083 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -4,6 +4,7 @@ from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -100,7 +101,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "cohere" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to Cohere's generate endpoint. Args: diff --git a/langchain/llms/deepinfra.py b/langchain/llms/deepinfra.py index 55b4c98bb3021..f2c22823485f1 100644 --- a/langchain/llms/deepinfra.py +++ b/langchain/llms/deepinfra.py @@ -4,6 +4,7 @@ import requests from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -60,7 +61,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "deepinfra" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to DeepInfra's inference API endpoint. Args: diff --git a/langchain/llms/fake.py b/langchain/llms/fake.py index aec4abb9766ce..3df15b9c520c0 100644 --- a/langchain/llms/fake.py +++ b/langchain/llms/fake.py @@ -1,6 +1,7 @@ """Fake LLM wrapper for testing purposes.""" from typing import Any, List, Mapping, Optional +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -15,7 +16,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "fake-list" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """First try to lookup in queries, else return 'foo' or 'bar'.""" response = self.responses[self.i] self.i += 1 diff --git a/langchain/llms/forefrontai.py b/langchain/llms/forefrontai.py index 1e34377a5ef65..8c49918abd606 100644 --- a/langchain/llms/forefrontai.py +++ b/langchain/llms/forefrontai.py @@ -4,6 +4,7 @@ import requests from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -81,7 +82,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "forefrontai" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to ForefrontAI's complete endpoint. Args: diff --git a/langchain/llms/gooseai.py b/langchain/llms/gooseai.py index ec7ca28dc8052..571feb2b37c1a 100644 --- a/langchain/llms/gooseai.py +++ b/langchain/llms/gooseai.py @@ -4,6 +4,7 @@ from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -130,7 +131,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "gooseai" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call the GooseAI API.""" params = self._default_params if stop is not None: diff --git a/langchain/llms/gpt4all.py b/langchain/llms/gpt4all.py index bf0300bb31720..780bddb9f4342 100644 --- a/langchain/llms/gpt4all.py +++ b/langchain/llms/gpt4all.py @@ -4,6 +4,7 @@ from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -159,7 +160,12 @@ def _llm_type(self) -> str: """Return the type of llm.""" return "gpt4all" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: r"""Call out to GPT4All's generate method. Args: diff --git a/langchain/llms/huggingface_endpoint.py b/langchain/llms/huggingface_endpoint.py index 7f7561c866a25..66a073c172f19 100644 --- a/langchain/llms/huggingface_endpoint.py +++ b/langchain/llms/huggingface_endpoint.py @@ -4,6 +4,7 @@ import requests from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -88,7 +89,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "huggingface_endpoint" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to HuggingFace Hub's inference endpoint. Args: diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index e7d3af993cd0a..2838b858945c2 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -3,6 +3,7 @@ from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -84,7 +85,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "huggingface_hub" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to HuggingFace Hub's inference endpoint. Args: diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py index be8787dac9a9d..529cea28cc85f 100644 --- a/langchain/llms/huggingface_pipeline.py +++ b/langchain/llms/huggingface_pipeline.py @@ -5,6 +5,7 @@ from pydantic import Extra +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -146,7 +147,12 @@ def _identifying_params(self) -> Mapping[str, Any]: def _llm_type(self) -> str: return "huggingface_pipeline" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: response = self.pipeline(prompt) if self.pipeline.task == "text-generation": # Text generation return includes the starter text. diff --git a/langchain/llms/llamacpp.py b/langchain/llms/llamacpp.py index 0c83c7635da1f..a4203e0364df4 100644 --- a/langchain/llms/llamacpp.py +++ b/langchain/llms/llamacpp.py @@ -4,6 +4,7 @@ from pydantic import Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM logger = logging.getLogger(__name__) @@ -154,7 +155,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "llama.cpp" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call the Llama model and return the output. Args: diff --git a/langchain/llms/manifest.py b/langchain/llms/manifest.py index f6042144b06d4..0cef977e34806 100644 --- a/langchain/llms/manifest.py +++ b/langchain/llms/manifest.py @@ -3,6 +3,7 @@ from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -42,7 +43,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "manifest" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to LLM through Manifest.""" if stop is not None and len(stop) != 1: raise NotImplementedError( diff --git a/langchain/llms/modal.py b/langchain/llms/modal.py index 4c159a3953a20..53f112f70f31d 100644 --- a/langchain/llms/modal.py +++ b/langchain/llms/modal.py @@ -5,6 +5,7 @@ import requests from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -69,7 +70,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "modal" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call to Modal endpoint.""" params = self.model_kwargs or {} response = requests.post( diff --git a/langchain/llms/nlpcloud.py b/langchain/llms/nlpcloud.py index b3b25d0b5d0fa..72f6b38e4c4d1 100644 --- a/langchain/llms/nlpcloud.py +++ b/langchain/llms/nlpcloud.py @@ -3,6 +3,7 @@ from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -111,7 +112,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "nlpcloud" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to NLPCloud's create endpoint. Args: diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index b98a7f5d20311..c951dfc01c085 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -29,6 +29,10 @@ wait_exponential, ) +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import BaseLLM from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env @@ -253,7 +257,10 @@ def _default_params(self) -> Dict[str, Any]: return {**normal_params, **self.model_kwargs} def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: """Call out to OpenAI's endpoint with k unique prompts. @@ -286,11 +293,12 @@ def _generate( for stream_resp in completion_with_retry( self, prompt=_prompts, **params ): - self.callback_manager.on_llm_new_token( - stream_resp["choices"][0]["text"], - verbose=self.verbose, - logprobs=stream_resp["choices"][0]["logprobs"], - ) + if run_manager: + run_manager.on_llm_new_token( + stream_resp["choices"][0]["text"], + verbose=self.verbose, + logprobs=stream_resp["choices"][0]["logprobs"], + ) _update_response(response, stream_resp) choices.extend(response["choices"]) else: @@ -302,7 +310,10 @@ def _generate( return self.create_llm_result(choices, prompts, token_usage) async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: """Call out to OpenAI's endpoint async with k unique prompts.""" params = self._invocation_params @@ -321,14 +332,8 @@ async def _agenerate( async for stream_resp in await acompletion_with_retry( self, prompt=_prompts, **params ): - if self.callback_manager.is_async: - await self.callback_manager.on_llm_new_token( - stream_resp["choices"][0]["text"], - verbose=self.verbose, - logprobs=stream_resp["choices"][0]["logprobs"], - ) - else: - self.callback_manager.on_llm_new_token( + if run_manager: + await run_manager.on_llm_new_token( stream_resp["choices"][0]["text"], verbose=self.verbose, logprobs=stream_resp["choices"][0]["logprobs"], @@ -704,7 +709,10 @@ def _get_chat_params( return messages, params def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: messages, params = self._get_chat_params(prompts, stop) if self.streaming: @@ -713,10 +721,10 @@ def _generate( for stream_resp in completion_with_retry(self, messages=messages, **params): token = stream_resp["choices"][0]["delta"].get("content", "") response += token - self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) + if run_manager: + run_manager.on_llm_new_token( + token, + ) return LLMResult( generations=[[Generation(text=response)]], ) @@ -734,7 +742,10 @@ def _generate( ) async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: messages, params = self._get_chat_params(prompts, stop) if self.streaming: @@ -745,15 +756,9 @@ async def _agenerate( ): token = stream_resp["choices"][0]["delta"].get("content", "") response += token - if self.callback_manager.is_async: - await self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) - else: - self.callback_manager.on_llm_new_token( + if run_manager: + await run_manager.on_llm_new_token( token, - verbose=self.verbose, ) return LLMResult( generations=[[Generation(text=response)]], diff --git a/langchain/llms/petals.py b/langchain/llms/petals.py index ed30a28f59be1..293d240cef807 100644 --- a/langchain/llms/petals.py +++ b/langchain/llms/petals.py @@ -4,6 +4,7 @@ from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -130,7 +131,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "petals" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call the Petals API.""" params = self._default_params inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"] diff --git a/langchain/llms/promptlayer_openai.py b/langchain/llms/promptlayer_openai.py index c7dd9cf3e0188..77df805177640 100644 --- a/langchain/llms/promptlayer_openai.py +++ b/langchain/llms/promptlayer_openai.py @@ -2,6 +2,10 @@ import datetime from typing import List, Optional +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms import OpenAI, OpenAIChat from langchain.schema import LLMResult @@ -33,13 +37,16 @@ class PromptLayerOpenAI(OpenAI): return_pl_id: Optional[bool] = False def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: """Call OpenAI generate and then call PromptLayer API to log the request.""" from promptlayer.utils import get_api_key, promptlayer_api_request request_start_time = datetime.datetime.now().timestamp() - generated_responses = super()._generate(prompts, stop) + generated_responses = super()._generate(prompts, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() for i in range(len(prompts)): prompt = prompts[i] @@ -69,12 +76,15 @@ def _generate( return generated_responses async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: from promptlayer.utils import get_api_key, promptlayer_api_request_async request_start_time = datetime.datetime.now().timestamp() - generated_responses = await super()._agenerate(prompts, stop) + generated_responses = await super()._agenerate(prompts, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() for i in range(len(prompts)): prompt = prompts[i] @@ -131,13 +141,16 @@ class PromptLayerOpenAIChat(OpenAIChat): return_pl_id: Optional[bool] = False def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: """Call OpenAI generate and then call PromptLayer API to log the request.""" from promptlayer.utils import get_api_key, promptlayer_api_request request_start_time = datetime.datetime.now().timestamp() - generated_responses = super()._generate(prompts, stop) + generated_responses = super()._generate(prompts, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() for i in range(len(prompts)): prompt = prompts[i] @@ -167,12 +180,15 @@ def _generate( return generated_responses async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: from promptlayer.utils import get_api_key, promptlayer_api_request_async request_start_time = datetime.datetime.now().timestamp() - generated_responses = await super()._agenerate(prompts, stop) + generated_responses = await super()._agenerate(prompts, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() for i in range(len(prompts)): prompt = prompts[i] diff --git a/langchain/llms/replicate.py b/langchain/llms/replicate.py index 42213a49741ef..3c381d314fccd 100644 --- a/langchain/llms/replicate.py +++ b/langchain/llms/replicate.py @@ -4,6 +4,7 @@ from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -78,7 +79,12 @@ def _llm_type(self) -> str: """Return type of model.""" return "replicate" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call to replicate endpoint.""" try: import replicate as replicate_python diff --git a/langchain/llms/rwkv.py b/langchain/llms/rwkv.py index 5c27185ab66ea..0e873d48cbef9 100644 --- a/langchain/llms/rwkv.py +++ b/langchain/llms/rwkv.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -204,7 +205,12 @@ def rwkv_generate(self, prompt: str) -> str: return decoded - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: r"""RWKV generation Args: diff --git a/langchain/llms/sagemaker_endpoint.py b/langchain/llms/sagemaker_endpoint.py index d9efe51a3c6a1..4d17160932630 100644 --- a/langchain/llms/sagemaker_endpoint.py +++ b/langchain/llms/sagemaker_endpoint.py @@ -4,6 +4,7 @@ from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -194,7 +195,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "sagemaker_endpoint" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to Sagemaker inference endpoint. Args: diff --git a/langchain/llms/self_hosted.py b/langchain/llms/self_hosted.py index df529d80f4786..e7e51725f5ebc 100644 --- a/langchain/llms/self_hosted.py +++ b/langchain/llms/self_hosted.py @@ -6,6 +6,7 @@ from pydantic import Extra +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -208,5 +209,10 @@ def _identifying_params(self) -> Mapping[str, Any]: def _llm_type(self) -> str: return "self_hosted_llm" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop) diff --git a/langchain/llms/self_hosted_hugging_face.py b/langchain/llms/self_hosted_hugging_face.py index dd62348cca427..49bd8536eee66 100644 --- a/langchain/llms/self_hosted_hugging_face.py +++ b/langchain/llms/self_hosted_hugging_face.py @@ -5,6 +5,7 @@ from pydantic import Extra +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.self_hosted import SelfHostedPipeline from langchain.llms.utils import enforce_stop_tokens @@ -198,5 +199,10 @@ def _identifying_params(self) -> Mapping[str, Any]: def _llm_type(self) -> str: return "selfhosted_huggingface_pipeline" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop) diff --git a/langchain/llms/stochasticai.py b/langchain/llms/stochasticai.py index 052e6efc840c6..5d2fe7300ec4f 100644 --- a/langchain/llms/stochasticai.py +++ b/langchain/llms/stochasticai.py @@ -6,6 +6,7 @@ import requests from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -80,7 +81,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "stochasticai" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to StochasticAI's complete endpoint. Args: diff --git a/langchain/llms/writer.py b/langchain/llms/writer.py index a3a74f5905e8e..2cec183515e3d 100644 --- a/langchain/llms/writer.py +++ b/langchain/llms/writer.py @@ -4,6 +4,7 @@ import requests from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -117,7 +118,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "writer" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to Writer's complete endpoint. Args: diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index 8863c076fc712..ec02dd2c8ea33 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -5,6 +5,7 @@ from pydantic import Field +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import ( @@ -13,7 +14,7 @@ ) from langchain.memory.utils import get_prompt_input_key from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string +from langchain.schema import BaseMessage, get_buffer_string logger = logging.getLogger(__name__) diff --git a/langchain/memory/kg.py b/langchain/memory/kg.py index 8b2b5f6ba478f..2c71a33c44e3b 100644 --- a/langchain/memory/kg.py +++ b/langchain/memory/kg.py @@ -2,6 +2,7 @@ from pydantic import Field +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.graphs import NetworkxEntityGraph from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples @@ -13,7 +14,6 @@ from langchain.memory.utils import get_prompt_input_key from langchain.prompts.base import BasePromptTemplate from langchain.schema import ( - BaseLanguageModel, BaseMessage, SystemMessage, get_buffer_string, diff --git a/langchain/memory/summary.py b/langchain/memory/summary.py index 4873b824b4e82..7a2d04f47c312 100644 --- a/langchain/memory/summary.py +++ b/langchain/memory/summary.py @@ -2,12 +2,12 @@ from pydantic import BaseModel, root_validator +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import SUMMARY_PROMPT from langchain.prompts.base import BasePromptTemplate from langchain.schema import ( - BaseLanguageModel, BaseMessage, SystemMessage, get_buffer_string, diff --git a/langchain/memory/token_buffer.py b/langchain/memory/token_buffer.py index bb4da209d9294..c5e3c01b6378c 100644 --- a/langchain/memory/token_buffer.py +++ b/langchain/memory/token_buffer.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List +from langchain.base_language import BaseLanguageModel from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string +from langchain.schema import BaseMessage, get_buffer_string class ConversationTokenBufferMemory(BaseChatMemory): diff --git a/langchain/output_parsers/fix.py b/langchain/output_parsers/fix.py index dfa3d639e45f7..a46b2e4ec63a3 100644 --- a/langchain/output_parsers/fix.py +++ b/langchain/output_parsers/fix.py @@ -2,10 +2,11 @@ from typing import TypeVar +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException +from langchain.schema import BaseOutputParser, OutputParserException T = TypeVar("T") diff --git a/langchain/output_parsers/retry.py b/langchain/output_parsers/retry.py index b1982608842a6..080d1a4906501 100644 --- a/langchain/output_parsers/retry.py +++ b/langchain/output_parsers/retry.py @@ -2,11 +2,11 @@ from typing import TypeVar +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( - BaseLanguageModel, BaseOutputParser, OutputParserException, PromptValue, diff --git a/langchain/schema.py b/langchain/schema.py index 65f530948c356..af7568edd2d9d 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -171,45 +171,6 @@ def to_messages(self) -> List[BaseMessage]: """Return prompt as messages.""" -class BaseLanguageModel(BaseModel, ABC): - @abstractmethod - def generate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None - ) -> LLMResult: - """Take in a list of prompt values and return an LLMResult.""" - - @abstractmethod - async def agenerate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None - ) -> LLMResult: - """Take in a list of prompt values and return an LLMResult.""" - - def get_num_tokens(self, text: str) -> int: - """Get the number of tokens present in the text.""" - # TODO: this method may not be exact. - # TODO: this method may differ based on model (eg codex). - try: - from transformers import GPT2TokenizerFast - except ImportError: - raise ValueError( - "Could not import transformers python package. " - "This is needed in order to calculate get_num_tokens. " - "Please install it with `pip install transformers`." - ) - # create a GPT-3 tokenizer instance - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") - - # tokenize the text using the GPT-3 tokenizer - tokenized_text = tokenizer.tokenize(text) - - # calculate the number of tokens in the tokenized text - return len(tokenized_text) - - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - """Get the number of tokens in the message.""" - return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) - - class BaseMemory(BaseModel, ABC): """Base interface for memory in chains.""" diff --git a/langchain/tools/base.py b/langchain/tools/base.py index bc17386610c12..a370305f0b56e 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -1,13 +1,21 @@ """Base implementation for tools or skills.""" +import inspect +import warnings from abc import ABC, abstractmethod from inspect import signature from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union -from pydantic import BaseModel, Extra, Field, validate_arguments, validator +from pydantic import BaseModel, Extra, Field, root_validator, validate_arguments -from langchain.callbacks import get_callback_manager from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForToolRun, + CallbackManager, + CallbackManagerForToolRun, + Callbacks, +) def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]: @@ -28,7 +36,8 @@ class BaseTool(ABC, BaseModel): """Pydantic model class to validate and parse the tool's input arguments.""" return_direct: bool = False verbose: bool = False - callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) + callbacks: Callbacks = None + callback_manager: Optional[BaseCallbackManager] = None class Config: """Configuration for this pydantic object.""" @@ -60,15 +69,16 @@ def _parse_input( if input_args is not None: input_args.validate(tool_input) - @validator("callback_manager", pre=True, always=True) - def set_callback_manager( - cls, callback_manager: Optional[BaseCallbackManager] - ) -> BaseCallbackManager: - """If callback manager is None, set it. - - This allows users to pass in None as callback manager, which is a nice UX. - """ - return callback_manager or get_callback_manager() + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values @abstractmethod def _run(self, *args: Any, **kwargs: Any) -> str: @@ -84,6 +94,7 @@ def run( verbose: Optional[bool] = None, start_color: Optional[str] = "green", color: Optional[str] = "green", + callbacks: Callbacks = None, **kwargs: Any, ) -> str: """Run the tool.""" @@ -92,22 +103,22 @@ def run( verbose_ = verbose else: verbose_ = self.verbose - self.callback_manager.on_tool_start( + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, verbose=verbose_ + ) + run_manager = callback_manager.on_tool_start( {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), - verbose=verbose_, color=start_color, **kwargs, ) try: tool_args, tool_kwargs = _to_args_and_kwargs(tool_input) - observation = self._run(*tool_args, **tool_kwargs) + observation = self._run(*tool_args, run_manager=run_manager, **tool_kwargs) except (Exception, KeyboardInterrupt) as e: - self.callback_manager.on_tool_error(e, verbose=verbose_) + run_manager.on_tool_error(e) raise e - self.callback_manager.on_tool_end( - observation, verbose=verbose_, color=color, name=self.name, **kwargs - ) + run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs) return observation async def arun( @@ -116,6 +127,7 @@ async def arun( verbose: Optional[bool] = None, start_color: Optional[str] = "green", color: Optional[str] = "green", + callbacks: Callbacks = None, **kwargs: Any, ) -> str: """Run the tool asynchronously.""" @@ -124,42 +136,27 @@ async def arun( verbose_ = verbose else: verbose_ = self.verbose - if self.callback_manager.is_async: - await self.callback_manager.on_tool_start( - {"name": self.name, "description": self.description}, - tool_input if isinstance(tool_input, str) else str(tool_input), - verbose=verbose_, - color=start_color, - **kwargs, - ) - else: - self.callback_manager.on_tool_start( - {"name": self.name, "description": self.description}, - tool_input if isinstance(tool_input, str) else str(tool_input), - verbose=verbose_, - color=start_color, - **kwargs, - ) + callback_manager = AsyncCallbackManager.configure( + callbacks, self.callbacks, verbose=verbose_ + ) + run_manager = await callback_manager.on_tool_start( + {"name": self.name, "description": self.description}, + tool_input if isinstance(tool_input, str) else str(tool_input), + color=start_color, + **kwargs, + ) try: # We then call the tool on the tool input to get an observation args, kwargs = _to_args_and_kwargs(tool_input) - observation = await self._arun(*args, **kwargs) + observation = await self._arun(*args, run_manager=run_manager, **kwargs) except (Exception, KeyboardInterrupt) as e: - if self.callback_manager.is_async: - await self.callback_manager.on_tool_error(e, verbose=verbose_) - else: - self.callback_manager.on_tool_error(e, verbose=verbose_) + await run_manager.on_tool_error(e) raise e - if self.callback_manager.is_async: - await self.callback_manager.on_tool_end( - observation, verbose=verbose_, color=color, name=self.name, **kwargs - ) - else: - self.callback_manager.on_tool_end( - observation, verbose=verbose_, color=color, name=self.name, **kwargs - ) + await run_manager.on_tool_end( + observation, color=color, name=self.name, **kwargs + ) return observation - def __call__(self, tool_input: str) -> str: + def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: """Make tool callable.""" - return self.run(tool_input) + return self.run(tool_input, callbacks=callbacks) diff --git a/langchain/utilities/serpapi.py b/langchain/utilities/serpapi.py index 70a14b5de4169..7dd756e50174a 100644 --- a/langchain/utilities/serpapi.py +++ b/langchain/utilities/serpapi.py @@ -76,7 +76,7 @@ def validate_environment(cls, values: Dict) -> Dict: ) return values - async def arun(self, query: str) -> str: + async def arun(self, query: str, **kwargs: Any) -> str: """Use aiohttp to run query through SerpAPI and parse result.""" def construct_url_and_params() -> Tuple[str, Dict[str, str]]: @@ -99,7 +99,7 @@ def construct_url_and_params() -> Tuple[str, Dict[str, str]]: return self._process_response(res) - def run(self, query: str) -> str: + def run(self, query: str, **kwargs: Any) -> str: """Run query through SerpAPI and parse result.""" return self._process_response(self.results(query)) diff --git a/tests/integration_tests/callbacks/__init__.py b/tests/integration_tests/callbacks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration_tests/callbacks/test_langchain_tracer.py b/tests/integration_tests/callbacks/test_langchain_tracer.py new file mode 100644 index 0000000000000..2fccf6d0b21e5 --- /dev/null +++ b/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -0,0 +1,85 @@ +import asyncio +import os +import time + +import pytest +from aiohttp import ClientSession + +from langchain.agents import AgentType, initialize_agent, load_tools +from langchain.callbacks import tracing_enabled +from langchain.llms import OpenAI + +questions = [ + "Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?", + "Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?", + "Who won the most recent formula 1 grand prix? What is their age raised to the 0.23 power?", + "Who won the US Open women's final in 2019? What is her age raised to the 0.34 power?", + "Who is Beyonce's husband? What is his age raised to the 0.19 power?", +] + + +def test_tracing_sequential(): + os.environ["LANGCHAIN_TRACING"] = "true" + + def generate_serially(): + for q in questions[:3]: + llm = OpenAI(temperature=0) + tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + agent.run(q) + + s = time.perf_counter() + generate_serially() + elapsed = time.perf_counter() - s + print(f"Serial executed in {elapsed:0.2f} seconds.") + + +@pytest.mark.asyncio +async def test_tracing_concurrent(): + os.environ["LANGCHAIN_TRACING"] = "true" + aiosession = ClientSession() + llm = OpenAI(temperature=0) + async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession) + agent = initialize_agent( + async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + tasks = [agent.arun(q) for q in questions[:3]] + await asyncio.gather(*tasks) + await aiosession.close() + + +def test_tracing_context_manager(): + llm = OpenAI(temperature=0) + tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + if "LANGCHAIN_TRACING" in os.environ: + del os.environ["LANGCHAIN_TRACING"] + with tracing_enabled() as session: + assert session + agent.run(questions[0]) # this should be traced + + agent.run(questions[0]) # this should not be traced + + +@pytest.mark.asyncio +async def test_tracing_context_manager_async(): + llm = OpenAI(temperature=0) + async_tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + if "LANGCHAIN_TRACING" in os.environ: + del os.environ["LANGCHAIN_TRACING"] + + # start a background task + task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced + with tracing_enabled() as session: + assert session + tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced + await asyncio.gather(*tasks) + + await task diff --git a/tests/integration_tests/callbacks/test_openai_callback.py b/tests/integration_tests/callbacks/test_openai_callback.py new file mode 100644 index 0000000000000..74edc700e06cf --- /dev/null +++ b/tests/integration_tests/callbacks/test_openai_callback.py @@ -0,0 +1,36 @@ +import asyncio + +import pytest + +from langchain import OpenAI +from langchain.callbacks import get_openai_callback + + +@pytest.mark.asyncio +async def test_openai_callback(): + llm = OpenAI(temperature=0) + with get_openai_callback() as cb: + llm("What is the square root of 4?") + + total_tokens = cb.total_tokens + assert total_tokens > 0 + + with get_openai_callback() as cb: + llm("What is the square root of 4?") + llm("What is the square root of 4?") + + assert cb.total_tokens == total_tokens * 2 + + with get_openai_callback() as cb: + await asyncio.gather( + *[llm.agenerate(["What is the square root of 4?"]) for _ in range(3)] + ) + + assert cb.total_tokens == total_tokens * 3 + + task = asyncio.create_task(llm.agenerate(["What is the square root of 4?"])) + with get_openai_callback() as cb: + await llm.agenerate(["What is the square root of 4?"]) + + await task + assert cb.total_tokens == total_tokens diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py index ab18d53e7ea60..32a74fbe54028 100644 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -1,9 +1,9 @@ """Test Tracer classes.""" from __future__ import annotations -import threading from datetime import datetime from typing import List, Optional, Union +from uuid import uuid4 import pytest from freezegun import freeze_time @@ -12,9 +12,7 @@ BaseTracer, ChainRun, LLMRun, - SharedTracer, ToolRun, - Tracer, TracerException, TracerSession, ) @@ -27,7 +25,7 @@ @freeze_time("2023-01-01") def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: return ChainRun( - id=None, + uuid="chain_uuid", error=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -37,9 +35,11 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: inputs={}, outputs={}, session_id=TEST_SESSION_ID, - child_runs=[ + child_chain_runs=[], + child_tool_runs=[ ToolRun( - id=None, + uuid="tool_uuid", + parent_uuid="chain_uuid", start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -50,9 +50,12 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: action="{}", session_id=TEST_SESSION_ID, error=None, - child_runs=[ + child_chain_runs=[], + child_tool_runs=[], + child_llm_runs=[ LLMRun( - id=None, + uuid="llm_uuid1", + parent_uuid="tool_uuid", error=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -65,8 +68,11 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: ) ], ), + ], + child_llm_runs=[ LLMRun( - id=None, + uuid="llm_uuid2", + parent_uuid="chain_uuid", error=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -83,27 +89,25 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: def _perform_nested_run(tracer: BaseTracer) -> None: """Perform a nested run.""" - tracer.on_chain_start(serialized={}, inputs={}) - tracer.on_tool_start(serialized={}, input_str="test") - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) - tracer.on_tool_end("test") - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) - tracer.on_chain_end(outputs={}) - - -def _add_child_run( - parent_run: Union[ChainRun, ToolRun], - child_run: Union[LLMRun, ChainRun, ToolRun], -) -> None: - """Add child run to a chain run or tool run.""" - parent_run.child_runs.append(child_run) - - -def _generate_id() -> Optional[Union[int, str]]: - """Generate an id for a run.""" - return None + chain_uuid = "chain_uuid" + tool_uuid = "tool_uuid" + llm_uuid1 = "llm_uuid1" + llm_uuid2 = "llm_uuid2" + + tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid) + tracer.on_tool_start( + serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid + ) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) + tracer.on_tool_end("test", run_id=tool_uuid) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) + tracer.on_chain_end(outputs={}, run_id=chain_uuid) def load_session(session_name: str) -> TracerSession: @@ -121,7 +125,7 @@ def load_default_session() -> TracerSession: return TracerSession(id=1, name="default", start_time=datetime.utcnow()) -class FakeTracer(Tracer): +class FakeTracer(BaseTracer): """Fake tracer that records LangChain execution.""" def __init__(self) -> None: @@ -133,58 +137,6 @@ def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """Persist a run.""" self.runs.append(run) - def _add_child_run( - self, - parent_run: Union[ChainRun, ToolRun], - child_run: Union[LLMRun, ChainRun, ToolRun], - ) -> None: - """Add child run to a chain run or tool run.""" - _add_child_run(parent_run, child_run) - - def _generate_id(self) -> Optional[Union[int, str]]: - """Generate an id for a run.""" - return _generate_id() - - def _persist_session(self, session: TracerSessionCreate) -> TracerSession: - """Persist a tracing session.""" - return _persist_session(session) - - def load_session(self, session_name: str) -> TracerSession: - """Load a tracing session.""" - return load_session(session_name) - - def load_default_session(self) -> TracerSession: - """Load a tracing session.""" - return load_default_session() - - -class FakeSharedTracer(SharedTracer): - """Fake shared tracer that records LangChain execution.""" - - runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] - - def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: - """Persist a run.""" - with self._lock: - self.runs.append(run) - - def remove_runs(self) -> None: - """Remove all runs.""" - with self._lock: - self.runs = [] - - def _add_child_run( - self, - parent_run: Union[ChainRun, ToolRun], - child_run: Union[LLMRun, ChainRun, ToolRun], - ) -> None: - """Add child run to a chain run or tool run.""" - _add_child_run(parent_run, child_run) - - def _generate_id(self) -> Optional[Union[int, str]]: - """Generate an id for a run.""" - return _generate_id() - def _persist_session(self, session: TracerSessionCreate) -> TracerSession: """Persist a tracing session.""" return _persist_session(session) @@ -201,8 +153,10 @@ def load_default_session(self) -> TracerSession: @freeze_time("2023-01-01") def test_tracer_llm_run() -> None: """Test tracer on an LLM run.""" + uuid = str(uuid4()) compare_run = LLMRun( - id=None, + uuid=uuid, + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -216,20 +170,11 @@ def test_tracer_llm_run() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) + tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) assert tracer.runs == [compare_run] -@freeze_time("2023-01-01") -def test_tracer_llm_run_errors_no_session() -> None: - """Test tracer on an LLM run without a session.""" - tracer = FakeTracer() - - with pytest.raises(TracerException): - tracer.on_llm_start(serialized={}, prompts=[]) - - @freeze_time("2023-01-01") def test_tracer_llm_run_errors_no_start() -> None: """Test tracer on an LLM run without a start.""" @@ -237,14 +182,16 @@ def test_tracer_llm_run_errors_no_start() -> None: tracer.new_session() with pytest.raises(TracerException): - tracer.on_llm_end(response=LLMResult(generations=[[]])) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=str(uuid4())) @freeze_time("2023-01-01") def test_tracer_multiple_llm_runs() -> None: """Test the tracer with multiple runs.""" + uuid = str(uuid4()) compare_run = LLMRun( - id=None, + uuid=uuid, + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -260,8 +207,8 @@ def test_tracer_multiple_llm_runs() -> None: tracer.new_session() num_runs = 10 for _ in range(num_runs): - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) + tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) assert tracer.runs == [compare_run] * num_runs @@ -269,8 +216,10 @@ def test_tracer_multiple_llm_runs() -> None: @freeze_time("2023-01-01") def test_tracer_chain_run() -> None: """Test tracer on a Chain run.""" + uuid = str(uuid4()) compare_run = ChainRun( - id=None, + uuid=uuid, + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -284,16 +233,18 @@ def test_tracer_chain_run() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_chain_start(serialized={}, inputs={}) - tracer.on_chain_end(outputs={}) + tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid) + tracer.on_chain_end(outputs={}, run_id=uuid) assert tracer.runs == [compare_run] @freeze_time("2023-01-01") def test_tracer_tool_run() -> None: """Test tracer on a Tool run.""" + uuid = str(uuid4()) compare_run = ToolRun( - id=None, + uuid=uuid, + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -308,8 +259,8 @@ def test_tracer_tool_run() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_tool_start(serialized={}, input_str="test") - tracer.on_tool_end("test") + tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid) + tracer.on_tool_end("test", run_id=uuid) assert tracer.runs == [compare_run] @@ -318,17 +269,19 @@ def test_tracer_nested_run() -> None: """Test tracer on a nested run.""" tracer = FakeTracer() tracer.new_session() - _perform_nested_run(tracer) - assert tracer.runs == [_get_compare_run()] + [_perform_nested_run(tracer) for _ in range(10)] + assert tracer.runs == [_get_compare_run()] * 10 @freeze_time("2023-01-01") def test_tracer_llm_run_on_error() -> None: """Test tracer on an LLM run with an error.""" exception = Exception("test") + uuid = str(uuid4()) compare_run = LLMRun( - id=None, + uuid=uuid, + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -342,8 +295,8 @@ def test_tracer_llm_run_on_error() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_error(exception) + tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid) + tracer.on_llm_error(exception, run_id=uuid) assert tracer.runs == [compare_run] @@ -351,9 +304,11 @@ def test_tracer_llm_run_on_error() -> None: def test_tracer_chain_run_on_error() -> None: """Test tracer on a Chain run with an error.""" exception = Exception("test") + uuid = str(uuid4()) compare_run = ChainRun( - id=None, + uuid=uuid, + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -367,8 +322,8 @@ def test_tracer_chain_run_on_error() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_chain_start(serialized={}, inputs={}) - tracer.on_chain_error(exception) + tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid) + tracer.on_chain_error(exception, run_id=uuid) assert tracer.runs == [compare_run] @@ -376,9 +331,11 @@ def test_tracer_chain_run_on_error() -> None: def test_tracer_tool_run_on_error() -> None: """Test tracer on a Tool run with an error.""" exception = Exception("test") + uuid = str(uuid4()) compare_run = ToolRun( - id=None, + uuid=uuid, + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -393,8 +350,8 @@ def test_tracer_tool_run_on_error() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_tool_start(serialized={}, input_str="test") - tracer.on_tool_error(exception) + tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid) + tracer.on_tool_error(exception, run_id=uuid) assert tracer.runs == [compare_run] @@ -405,21 +362,34 @@ def test_tracer_nested_runs_on_error() -> None: tracer = FakeTracer() tracer.new_session() + chain_uuid = "chain_uuid" + tool_uuid = "tool_uuid" + llm_uuid1 = "llm_uuid1" + llm_uuid2 = "llm_uuid2" + llm_uuid3 = "llm_uuid3" for _ in range(3): - tracer.on_chain_start(serialized={}, inputs={}) - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) - tracer.on_tool_start(serialized={}, input_str="test") - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_error(exception) - tracer.on_tool_error(exception) - tracer.on_chain_error(exception) + tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=chain_uuid + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) + tracer.on_tool_start( + serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid + ) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid3, parent_run_id=tool_uuid + ) + tracer.on_llm_error(exception, run_id=llm_uuid3) + tracer.on_tool_error(exception, run_id=tool_uuid) + tracer.on_chain_error(exception, run_id=chain_uuid) compare_run = ChainRun( - id=None, + uuid=chain_uuid, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -429,9 +399,10 @@ def test_tracer_nested_runs_on_error() -> None: error=repr(exception), inputs={}, outputs=None, - child_runs=[ + child_llm_runs=[ LLMRun( - id=None, + uuid=llm_uuid1, + parent_uuid=chain_uuid, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -443,7 +414,8 @@ def test_tracer_nested_runs_on_error() -> None: response=LLMResult(generations=[[]], llm_output=None), ), LLMRun( - id=None, + uuid=llm_uuid2, + parent_uuid=chain_uuid, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -454,8 +426,12 @@ def test_tracer_nested_runs_on_error() -> None: prompts=[], response=LLMResult(generations=[[]], llm_output=None), ), + ], + child_chain_runs=[], + child_tool_runs=[ ToolRun( - id=None, + uuid=tool_uuid, + parent_uuid=chain_uuid, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -466,9 +442,10 @@ def test_tracer_nested_runs_on_error() -> None: tool_input="test", output=None, action="{}", - child_runs=[ + child_llm_runs=[ LLMRun( - id=None, + uuid=llm_uuid3, + parent_uuid=tool_uuid, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -480,43 +457,10 @@ def test_tracer_nested_runs_on_error() -> None: response=None, ) ], - child_llm_runs=[], child_chain_runs=[], child_tool_runs=[], ), ], - child_llm_runs=[], - child_chain_runs=[], - child_tool_runs=[], ) assert tracer.runs == [compare_run] * 3 - - -@freeze_time("2023-01-01") -def test_shared_tracer_nested_run() -> None: - """Test shared tracer on a nested run.""" - tracer = FakeSharedTracer() - tracer.new_session() - tracer.remove_runs() - _perform_nested_run(tracer) - assert tracer.runs == [_get_compare_run()] - - -@freeze_time("2023-01-01") -def test_shared_tracer_nested_run_multithreaded() -> None: - """Test shared tracer on a nested run.""" - tracer = FakeSharedTracer() - tracer.remove_runs() - tracer.new_session() - threads = [] - num_threads = 10 - for _ in range(num_threads): - thread = threading.Thread(target=_perform_nested_run, args=(tracer,)) - thread.start() - threads.append(thread) - - for thread in threads: - thread.join() - - assert tracer.runs == [_get_compare_run()] * num_threads From 7bcdc66b996c49448ce61fb2ffd9fb4d6b7429c3 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Wed, 26 Apr 2023 11:34:21 -0700 Subject: [PATCH 06/36] fix notebook and warnings --- docs/modules/callbacks/getting_started.ipynb | 310 +++---------------- langchain/callbacks/manager.py | 12 +- langchain/callbacks/tracers/langchain.py | 4 +- 3 files changed, 49 insertions(+), 277 deletions(-) diff --git a/docs/modules/callbacks/getting_started.ipynb b/docs/modules/callbacks/getting_started.ipynb index 6d4a99837ab53..a6ad10fe42a48 100644 --- a/docs/modules/callbacks/getting_started.ipynb +++ b/docs/modules/callbacks/getting_started.ipynb @@ -106,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "id": "80532dfc-d687-4147-a0c9-1f90cc3e868c", "metadata": { "tags": [] @@ -138,7 +138,7 @@ "'\\n\\n3'" ] }, - "execution_count": 5, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -174,7 +174,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 2, "id": "1b2e6588-0681-4cab-937a-7cc4790cea9a", "metadata": { "tags": [] @@ -192,9 +192,7 @@ "My custom handler, token: turn\n", "My custom handler, token: red\n", "My custom handler, token: ?\n", - "\n", - "\n", - "My custom handler, token: Because\n", + "My custom handler, token: Because\n", "My custom handler, token: it\n", "My custom handler, token: saw\n", "My custom handler, token: the\n", @@ -207,10 +205,10 @@ { "data": { "text/plain": [ - "AIMessage(content='Why did the tomato turn red?\\n\\nBecause it saw the salad dressing!', additional_kwargs={})" + "AIMessage(content='Why did the tomato turn red? Because it saw the salad dressing!', additional_kwargs={})" ] }, - "execution_count": 11, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -247,7 +245,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 3, "id": "c702e0c9-a961-4897-90c1-cdd13b6f16b2", "metadata": { "tags": [] @@ -258,20 +256,7 @@ "output_type": "stream", "text": [ "zzzz....\n", - "Hi! I just woke up. Your llm is starting\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 16:48:58,880 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/chat/completions processing_ms=210 request_id=0846181d992a4fbc954c80cf78e5bfb5 response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Hi! I just woke up. Your llm is starting\n", "Sync handler being called in a `thread_pool_executor`: token: \n", "Sync handler being called in a `thread_pool_executor`: token: Why\n", "Sync handler being called in a `thread_pool_executor`: token: don\n", @@ -280,7 +265,6 @@ "Sync handler being called in a `thread_pool_executor`: token: trust\n", "Sync handler being called in a `thread_pool_executor`: token: atoms\n", "Sync handler being called in a `thread_pool_executor`: token: ?\n", - "Sync handler being called in a `thread_pool_executor`: token: \n", "\n", "\n", "Sync handler being called in a `thread_pool_executor`: token: Because\n", @@ -297,10 +281,10 @@ { "data": { "text/plain": [ - "LLMResult(generations=[[ChatGeneration(text=\"Why don't scientists trust atoms? \\n\\nBecause they make up everything!\", generation_info=None, message=AIMessage(content=\"Why don't scientists trust atoms? \\n\\nBecause they make up everything!\", additional_kwargs={}))]], llm_output={'token_usage': {}, 'model_name': 'gpt-3.5-turbo'})" + "LLMResult(generations=[[ChatGeneration(text=\"Why don't scientists trust atoms?\\n\\nBecause they make up everything!\", generation_info=None, message=AIMessage(content=\"Why don't scientists trust atoms?\\n\\nBecause they make up everything!\", additional_kwargs={}))]], llm_output={'token_usage': {}, 'model_name': 'gpt-3.5-turbo'})" ] }, - "execution_count": 20, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -356,7 +340,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 4, "id": "8eec8756-1828-45cb-9699-38ac8543a150", "metadata": { "tags": [] @@ -461,7 +445,7 @@ "'1.1769067372187674'" ] }, - "execution_count": 25, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -554,14 +538,17 @@ "id": "d5a74b3f-3769-4a4f-99c7-b6a3b20a94e2", "metadata": {}, "source": [ - "There are two recommended ways to trace your LangChains. One is by setting the `LANGCHAIN_TRACING` environment variable to `\"true\"`. The other is to use a context manager `with tracing_enabled()` to trace a particular block of code.\n", + "There are two recommended ways to trace your LangChains:\n", + "\n", + "1. Setting the `LANGCHAIN_TRACING` environment variable to `\"true\"`. \n", + "2. Using a context manager `with tracing_enabled()` to trace a particular block of code.\n", "\n", "**Note** if the environment variable is set, all code will be traced, regardless of whether or not it's within the context manager." ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 5, "id": "f164dfd5-d987-4b6a-a7c8-019c651ce47f", "metadata": { "tags": [] @@ -592,19 +579,12 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 6, "id": "6be7777e-ec1d-438f-ae33-3a93c45f808e", "metadata": { "tags": [] }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:33:40,925 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -627,21 +607,7 @@ "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n", "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:34:05,653 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n", - "2023-04-25 17:34:05,673 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", @@ -676,20 +642,12 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 7, "id": "a6fd6026-dc1e-4d48-893d-3592539c7828", "metadata": { "tags": [] }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:37:45,895 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n", - "2023-04-25 17:37:45,982 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -729,7 +687,7 @@ "Action Input: 29^0.23\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", - "Final Answer: Harry Styles, Olivia Wilde's boyfriend, is 29 years old and his age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n", + "Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -737,15 +695,17 @@ { "data": { "text/plain": [ - "\"Harry Styles, Olivia Wilde's boyfriend, is 29 years old and his age raised to the 0.23 power is 2.169459462491557.\"" + "\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\"" ] }, - "execution_count": 34, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# Now, we unset the environment variable and use a context manager.\n", + "\n", "if \"LANGCHAIN_TRACING\" in os.environ:\n", " del os.environ[\"LANGCHAIN_TRACING\"]\n", "with tracing_enabled() as session:\n", @@ -757,21 +717,12 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 8, "id": "9383a351-4983-44e9-abd7-ef942e1c65c4", "metadata": { "tags": [] }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:07,944 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n", - "2023-04-25 17:39:08,038 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n", - "2023-04-25 17:39:08,039 [WARNING] - Failed to load default session, using empty session: 'LangChainTracer' object has no attribute '_endpoint'\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -783,186 +734,31 @@ "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\n", "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:10,123 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1795 request_id=9139c5a1b136a84603a4adc584bbdd9b response_code=200\n", - "2023-04-25 17:39:10,127 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1881 request_id=11729ae35c511f56238ab69a5856efcc response_code=200\n", - "2023-04-25 17:39:10,238 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1995 request_id=5c319aa337991381b80b4c4b858b7f75 response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", + "\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n", "Action: Search\n", - "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n", + "Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", "Action: Search\n", - "Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", + "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", "Action: Search\n", - "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:11,863 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1070 request_id=82b718523868a00c0d3f047ac8a9ecea response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", + "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\u001b[33;1m\u001b[1;3mLewis Hamilton has won 103 Grands Prix during his career. He won 21 races with McLaren and has won 82 with Mercedes. Lewis Hamilton holds the record for the ...\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", "Action: Search\n", - "Action Input: \"Harry Styles age\"\u001b[0m\u001b[33;1m\u001b[1;3m29 years\u001b[0m" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:12,611 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1829 request_id=fe0a82fe729ebc37b7983474d9418a84 response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[33;1m\u001b[1;3m36 years\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", "Action: Search\n", - "Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[33;1m\u001b[1;3m36 years\u001b[0m\u001b[33;1m\u001b[1;3mLewis Hamilton has won 103 Grands Prix during his career. He won 21 races with McLaren and has won 82 with Mercedes. Lewis Hamilton holds the record for the ...\u001b[0m" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:15,366 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2813 request_id=396f7e8180605345ad13693d91ebfdda response_code=200\n", - "2023-04-25 17:39:15,558 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2049 request_id=fa2004f0d6934e94f09632caacda8ca4 response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", + "Action Input: \"Harry Styles age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Lewis Hamilton's age\n", + "Action: Search\n", + "Action Input: \"Lewis Hamilton Age\"\u001b[0m\u001b[33;1m\u001b[1;3m29 years\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n", "Action: Calculator\n", - "Action Input: 29^0.23\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n", + "Action Input: 36^0.334\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", "Action: Calculator\n", - "Action Input: 36^0.334\u001b[0m" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:17,295 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1604 request_id=f9c40e7fb3d94d936b285c3a5a0eb55f response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:18,181 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2791 request_id=9f0c27a6e995895a518d5614d5e54c61 response_code=200\n", - "2023-04-25 17:39:18,185 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2207 request_id=b5281a29649bfbeaad532391eacf954d response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[32;1m\u001b[1;3m I need to find out Lewis Hamilton's age\n", - "Action: Search\n", - "Action Input: \"Lewis Hamilton Age\"\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\u001b[33;1m\u001b[1;3m38 years\u001b[0m" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:20,282 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2702 request_id=761627e38a5b6e580262357668dd635b response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:20,605 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2182 request_id=4cafeb94298befebe0da1e3f4a38ab27 response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Action Input: 29^0.23\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\u001b[33;1m\u001b[1;3m38 years\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:22,431 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2555 request_id=299ae3539ed6fb681ac2f8d16e73a6bc response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "\u001b[1m> Finished chain.\u001b[0m\n", "\u001b[32;1m\u001b[1;3m I now need to calculate 38 raised to the 0.23 power\n", "Action: Calculator\n", - "Action Input: 38^0.23\u001b[0m" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:24,802 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=2194 request_id=e09fe0fed313ba77c9a6444c41c12f1f response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[36;1m\u001b[1;3mAnswer: 2.3086081644669734\u001b[0m" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:39:26,912 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1963 request_id=8bedf74c7e4bfc5dfe014dcca47ce363 response_code=200\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", + "Action Input: 38^0.23\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 2.3086081644669734\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, @@ -972,7 +768,7 @@ "\"Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\"" ] }, - "execution_count": 36, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -1003,24 +799,12 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 9, "id": "5c3e0b89-2c5e-4036-bdf2-fb6b750e360c", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-04-25 17:43:22,369 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=492 request_id=a13b1f276947e6e8a2179ebf7c092878 response_code=200\n", - "2023-04-25 17:43:22,376 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=491 request_id=5bd41d073db19e0002eb3d862b9fde22 response_code=200\n", - "2023-04-25 17:43:22,441 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=526 request_id=155a0aa6a078db963fda3fe3b68c463e response_code=200\n", - "2023-04-25 17:43:23,072 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=475 request_id=231e6e20ff294f2b0a46b4844d739c09 response_code=200\n", - "2023-04-25 17:43:23,681 [INFO] - message='OpenAI API response' path=https://api.openai.com/v1/completions processing_ms=1088 request_id=90f4cf31a38f395d7ea98bd76d9bb36f response_code=200\n" - ] - } - ], + "outputs": [], "source": [ "from langchain.callbacks import get_openai_callback\n", "\n", @@ -1053,14 +837,6 @@ "await task\n", "assert cb.total_tokens == total_tokens" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f5d6521d-1901-473d-a2cd-a4e88db7f851", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 69628e0ad4d46..7ee31c296230a 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -24,12 +24,6 @@ from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.schema import AgentAction, AgentFinish, LLMResult -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] - %(message)s", - handlers=[logging.StreamHandler()], -) - Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( @@ -75,7 +69,8 @@ def _handle_event( ): getattr(handler, event_name)(*args, **kwargs) except Exception as e: - logging.error(f"Error in {event_name} callback: {e}") + # TODO: switch this to use logging + print(f"Error in {event_name} callback: {e}") async def _ahandle_event_for_handler( @@ -95,7 +90,8 @@ async def _ahandle_event_for_handler( None, functools.partial(event, *args, **kwargs) ) except Exception as e: - logging.error(f"Error in {event_name} callback: {e}") + # TODO: switch this to use logging + print(f"Error in {event_name} callback: {e}") async def _ahandle_event( diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 1cbf65e87a5f5..2c5118be91f22 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -20,14 +20,14 @@ class LangChainTracer(BaseTracer): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - def __init__(self, **kwargs: Any) -> None: + def __init__(self, session_name="default", **kwargs: Any) -> None: """Initialize the LangChain tracer.""" super().__init__(**kwargs) - self.session = self.load_default_session() self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") self._headers: Dict[str, Any] = {"Content-Type": "application/json"} if os.getenv("LANGCHAIN_API_KEY"): self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") + self.session = self.load_session(session_name) def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """Persist a run.""" From 6fec15b6fb5fef823dd70b6827465f8641451a7f Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Wed, 26 Apr 2023 11:37:36 -0700 Subject: [PATCH 07/36] write to different session --- docs/modules/callbacks/getting_started.ipynb | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/modules/callbacks/getting_started.ipynb b/docs/modules/callbacks/getting_started.ipynb index a6ad10fe42a48..acd4ee9676daa 100644 --- a/docs/modules/callbacks/getting_started.ipynb +++ b/docs/modules/callbacks/getting_started.ipynb @@ -642,7 +642,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "id": "a6fd6026-dc1e-4d48-893d-3592539c7828", "metadata": { "tags": [] @@ -663,7 +663,7 @@ "Action: Search\n", "Action Input: \"Rafael Nadal age\"\u001b[0m\n", "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate his age raised to the 0.334 power\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n", "Action: Calculator\n", "Action Input: 36^0.334\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n", @@ -698,7 +698,7 @@ "\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\"" ] }, - "execution_count": 7, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -708,7 +708,9 @@ "\n", "if \"LANGCHAIN_TRACING\" in os.environ:\n", " del os.environ[\"LANGCHAIN_TRACING\"]\n", - "with tracing_enabled() as session:\n", + "\n", + "# here, we are writing traces to \"my_test_session\"\n", + "with tracing_enabled(\"my_test_session\") as session:\n", " assert session\n", " agent.run(questions[0]) # this should be traced\n", "\n", From 50668693d7bb64998dc1528e04898324addb2448 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Wed, 26 Apr 2023 19:19:39 -0700 Subject: [PATCH 08/36] fix execution order issue --- langchain/callbacks/tracers/base.py | 48 ++++++++++++++++--- langchain/callbacks/tracers/langchain.py | 14 ------ langchain/callbacks/tracers/schemas.py | 1 + .../callbacks/tracers/test_tracer.py | 16 +++++++ 4 files changed, 58 insertions(+), 21 deletions(-) diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 763f710d9e5ae..fbe9e6e1de513 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -27,7 +27,6 @@ class BaseTracer(BaseCallbackHandler, ABC): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {} - self.execution_order: int = 1 self.session: Optional[TracerSession] = None @staticmethod @@ -70,8 +69,6 @@ def load_default_session(self) -> TracerSession: def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """Start a trace for a run.""" - self.execution_order += 1 - if run.parent_uuid: parent_run = self.run_map[run.parent_uuid] if parent_run: @@ -92,9 +89,32 @@ def _end_trace(self, run) -> None: """End a trace for a run.""" if not run.parent_uuid: self._persist_run(run) - self.execution_order = 1 + else: + parent_run = self.run_map.get(run.parent_uuid) + if parent_run is None: + raise TracerException( + f"Parent run with UUID {run.parent_uuid} not found." + ) + if isinstance(parent_run, LLMRun): + raise TracerException("LLM Runs are not allowed to have children. ") + if run.child_execution_order > parent_run.child_execution_order: + parent_run.child_execution_order = run.child_execution_order self.run_map.pop(run.uuid) + def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int: + """Get the execution order for a run.""" + if parent_run_id is None: + return 1 + + parent_run = self.run_map.get(parent_run_id) + if parent_run is None: + raise TracerException(f"Parent run with UUID {parent_run_id} not found.") + + if isinstance(parent_run, LLMRun): + raise TracerException("LLM Runs are not allowed to have children. ") + + return parent_run.child_execution_order + 1 + def on_llm_start( self, serialized: Dict[str, Any], @@ -110,6 +130,7 @@ def on_llm_start( if run_id is None: run_id = str(uuid4()) + execution_order = self._get_execution_order(parent_run_id) llm_run = LLMRun( uuid=run_id, parent_uuid=parent_run_id, @@ -117,7 +138,8 @@ def on_llm_start( prompts=prompts, extra=kwargs, start_time=datetime.utcnow(), - execution_order=self.execution_order, + execution_order=execution_order, + child_execution_order=execution_order, session_id=self.session.id, ) self._start_trace(llm_run) @@ -170,6 +192,7 @@ def on_chain_start( if run_id is None: run_id = str(uuid4()) + execution_order = self._get_execution_order(parent_run_id) chain_run = ChainRun( uuid=run_id, parent_uuid=parent_run_id, @@ -177,7 +200,8 @@ def on_chain_start( inputs=inputs, extra=kwargs, start_time=datetime.utcnow(), - execution_order=self.execution_order, + execution_order=execution_order, + child_execution_order=execution_order, child_runs=[], session_id=self.session.id, ) @@ -231,6 +255,7 @@ def on_tool_start( if run_id is None: run_id = str(uuid4()) + execution_order = self._get_execution_order(parent_run_id) tool_run = ToolRun( uuid=run_id, parent_uuid=parent_run_id, @@ -240,7 +265,8 @@ def on_tool_start( tool_input=input_str, extra=kwargs, start_time=datetime.utcnow(), - execution_order=self.execution_order, + execution_order=execution_order, + child_execution_order=execution_order, child_runs=[], session_id=self.session.id, ) @@ -276,3 +302,11 @@ def on_tool_error( tool_run.error = repr(error) tool_run.end_time = datetime.utcnow() self._end_trace(tool_run) + + def __deepcopy__(self, memo: dict) -> BaseTracer: + """Deepcopy the tracer.""" + return self + + def __copy__(self) -> BaseTracer: + """Copy the tracer.""" + return self diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 2c5118be91f22..b67223f6453c0 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -87,17 +87,3 @@ def load_session(self, session_name: str) -> TracerSession: def load_default_session(self) -> TracerSession: """Load the default tracing session and set it as the Tracer's session.""" return self._load_session("default") - - def __deepcopy__(self, memo): - """Deepcopy the tracer.""" - - # TODO: this is a hack to get tracing to work with the current backend - # we need to not use execution order, then remove this check - if self.execution_order == 1: - copy = LangChainTracer() - copy.session = self.session - copy.run_map = dict(self.run_map) - copy.execution_order = self.execution_order - return copy - else: - return self diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index d5987ec7b8892..fc96908f5c0a7 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -38,6 +38,7 @@ class BaseRun(BaseModel): end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) extra: Optional[Dict[str, Any]] = None execution_order: int + child_execution_order: int serialized: Dict[str, Any] session_id: int error: Optional[str] = None diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py index 32a74fbe54028..d0297a6f95212 100644 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -31,6 +31,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=4, serialized={}, inputs={}, outputs={}, @@ -44,6 +45,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: end_time=datetime.utcnow(), extra={}, execution_order=2, + child_execution_order=3, serialized={}, tool_input="test", output="test", @@ -61,6 +63,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: end_time=datetime.utcnow(), extra={}, execution_order=3, + child_execution_order=3, serialized={}, prompts=[], response=LLMResult(generations=[[]]), @@ -78,6 +81,7 @@ def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: end_time=datetime.utcnow(), extra={}, execution_order=4, + child_execution_order=4, serialized={}, prompts=[], response=LLMResult(generations=[[]]), @@ -161,6 +165,7 @@ def test_tracer_llm_run() -> None: end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, prompts=[], response=LLMResult(generations=[[]]), @@ -196,6 +201,7 @@ def test_tracer_multiple_llm_runs() -> None: end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, prompts=[], response=LLMResult(generations=[[]]), @@ -224,6 +230,7 @@ def test_tracer_chain_run() -> None: end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, inputs={}, outputs={}, @@ -249,6 +256,7 @@ def test_tracer_tool_run() -> None: end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, tool_input="test", output="test", @@ -286,6 +294,7 @@ def test_tracer_llm_run_on_error() -> None: end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, prompts=[], response=None, @@ -313,6 +322,7 @@ def test_tracer_chain_run_on_error() -> None: end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, inputs={}, outputs=None, @@ -340,6 +350,7 @@ def test_tracer_tool_run_on_error() -> None: end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, tool_input="test", output=None, @@ -394,6 +405,7 @@ def test_tracer_nested_runs_on_error() -> None: end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=5, serialized={}, session_id=TEST_SESSION_ID, error=repr(exception), @@ -407,6 +419,7 @@ def test_tracer_nested_runs_on_error() -> None: end_time=datetime.utcnow(), extra={}, execution_order=2, + child_execution_order=2, serialized={}, session_id=TEST_SESSION_ID, error=None, @@ -420,6 +433,7 @@ def test_tracer_nested_runs_on_error() -> None: end_time=datetime.utcnow(), extra={}, execution_order=3, + child_execution_order=3, serialized={}, session_id=TEST_SESSION_ID, error=None, @@ -436,6 +450,7 @@ def test_tracer_nested_runs_on_error() -> None: end_time=datetime.utcnow(), extra={}, execution_order=4, + child_execution_order=5, serialized={}, session_id=TEST_SESSION_ID, error=repr(exception), @@ -450,6 +465,7 @@ def test_tracer_nested_runs_on_error() -> None: end_time=datetime.utcnow(), extra={}, execution_order=5, + child_execution_order=5, serialized={}, session_id=TEST_SESSION_ID, error=repr(exception), From e953d2cf93cc9c53be60b19c73f6a520cc90b0a8 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Thu, 27 Apr 2023 12:26:58 -0700 Subject: [PATCH 09/36] mypy --- langchain/callbacks/base.py | 92 +++++++++++++++++++----------- langchain/callbacks/manager.py | 100 +++++++++++++++++++++++---------- 2 files changed, 130 insertions(+), 62 deletions(-) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index bd070104c7327..9389c769b6ec9 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -13,7 +13,8 @@ class LLMManagerMixin: def on_llm_new_token( self, token: str, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -22,7 +23,8 @@ def on_llm_new_token( def on_llm_end( self, response: LLMResult, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -31,7 +33,8 @@ def on_llm_end( def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -44,7 +47,8 @@ class ChainManagerMixin: def on_chain_end( self, outputs: Dict[str, Any], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -53,7 +57,8 @@ def on_chain_end( def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -62,7 +67,8 @@ def on_chain_error( def on_agent_action( self, action: AgentAction, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -71,7 +77,8 @@ def on_agent_action( def on_agent_finish( self, finish: AgentFinish, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -84,7 +91,8 @@ class ToolManagerMixin: def on_tool_end( self, output_str: str, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -93,7 +101,8 @@ def on_tool_end( def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -107,7 +116,8 @@ def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -117,7 +127,8 @@ def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -127,7 +138,8 @@ def on_tool_start( self, serialized: Dict[str, Any], input_str: str, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -140,7 +152,8 @@ class RunManagerMixin: def on_text( self, text: str, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -179,7 +192,8 @@ async def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -188,7 +202,8 @@ async def on_llm_start( async def on_llm_new_token( self, token: str, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -197,7 +212,8 @@ async def on_llm_new_token( async def on_llm_end( self, response: LLMResult, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -206,7 +222,8 @@ async def on_llm_end( async def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -216,7 +233,8 @@ async def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -225,7 +243,8 @@ async def on_chain_start( async def on_chain_end( self, outputs: Dict[str, Any], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -234,7 +253,8 @@ async def on_chain_end( async def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -244,7 +264,8 @@ async def on_tool_start( self, serialized: Dict[str, Any], input_str: str, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -252,8 +273,9 @@ async def on_tool_start( async def on_tool_end( self, - output: str, - run_id: Optional[str] = None, + output_str: str, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -262,7 +284,8 @@ async def on_tool_end( async def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -271,7 +294,8 @@ async def on_tool_error( async def on_text( self, text: str, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -280,7 +304,8 @@ async def on_text( async def on_agent_action( self, action: AgentAction, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -289,7 +314,8 @@ async def on_agent_action( async def on_agent_finish( self, finish: AgentFinish, - run_id: Optional[str] = None, + *, + run_id: str, parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: @@ -302,7 +328,7 @@ class BaseCallbackManager(CallbackManagerMixin): def __init__( self, handlers: List[BaseCallbackHandler], - inheritable_handlers: List[BaseCallbackHandler] = None, + inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, parent_run_id: Optional[str] = None, ) -> None: """Initialize callback manager.""" @@ -328,23 +354,25 @@ def remove_handler(self, handler: BaseCallbackHandler) -> None: self.handlers.remove(handler) self.inheritable_handlers.remove(handler) - def set_handlers(self, handlers: List[BaseCallbackHandler], inherit=True) -> None: + def set_handlers( + self, handlers: List[BaseCallbackHandler], inherit: bool = True + ) -> None: """Set handlers as the only handlers on the callback manager.""" self.handlers = [] self.inheritable_handlers = [] for handler in handlers: self.add_handler(handler, inherit=inherit) - def set_handler(self, handler: BaseCallbackHandler, inherit=True) -> None: + def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: """Set handler as the only handler on the callback manager.""" self.set_handlers([handler], inherit=inherit) - def __copy__(self): + def __copy__(self) -> "BaseCallbackManager": return self.__class__( self.handlers.copy(), self.inheritable_handlers.copy(), self.parent_run_id ) - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: dict) -> "BaseCallbackManager": return self.__class__( [copy.deepcopy(handler, memo) for handler in self.handlers], [copy.deepcopy(handler, memo) for handler in self.inheritable_handlers], diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 7ee31c296230a..b7adf77f3010c 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -3,7 +3,6 @@ import asyncio import copy import functools -import logging import os import uuid from contextlib import contextmanager @@ -120,7 +119,7 @@ def __init__( run_id: str, handlers: List[BaseCallbackHandler], inheritable_handlers: List[BaseCallbackHandler], - parent_run_id: str, + parent_run_id: Optional[str] = None, ) -> None: """Initialize run manager.""" self.run_id = run_id @@ -132,29 +131,60 @@ def __init__( class RunManager(BaseRunManager): """Sync Run Manager.""" - def on_text(self, text: str, **kwargs: Any) -> Any: + def on_text( + self, + text: str, + **kwargs: Any, + ) -> Any: """Run when text is received.""" - _handle_event(self.handlers, "on_text", None, text, **kwargs) + _handle_event( + self.handlers, + "on_text", + None, + text, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) class AsyncRunManager(BaseRunManager): """Async Run Manager.""" - async def on_text(self, text: str, **kwargs: Any) -> Any: + async def on_text( + self, + text: str, + **kwargs: Any, + ) -> Any: """Run when text is received.""" - await _ahandle_event(self.handlers, "on_text", None, text, **kwargs) + await _ahandle_event( + self.handlers, + "on_text", + None, + text, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): """Callback manager for LLM run.""" - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + def on_llm_new_token( + self, + token: str, + *, + run_id: str, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> None: """Run when LLM generates a new token.""" _handle_event( self.handlers, "on_llm_new_token", "ignore_llm", - token, + token=token, run_id=self.run_id, parent_run_id=self.parent_run_id, **kwargs, @@ -192,7 +222,14 @@ def on_llm_error( class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): """Async callback manager for LLM run.""" - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + async def on_llm_new_token( + self, + token: str, + *, + run_id: Optional[str] = None, + parent_run_id: Optional[str] = None, + **kwargs: Any, + ) -> None: """Run when LLM generates a new token.""" await _ahandle_event( self.handlers, @@ -366,13 +403,17 @@ def get_child(self) -> CallbackManager: manager.set_handlers(self.inheritable_handlers) return manager - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end( + self, + output_str: str, + **kwargs: Any, + ) -> None: """Run when tool ends running.""" _handle_event( self.handlers, "on_tool_end", "ignore_agent", - output, + output_str, run_id=self.run_id, parent_run_id=self.parent_run_id, **kwargs, @@ -404,13 +445,13 @@ def get_child(self) -> AsyncCallbackManager: manager.set_handlers(self.inheritable_handlers) return manager - async def on_tool_end(self, output: str, **kwargs: Any) -> None: + async def on_tool_end(self, output_str: str, **kwargs: Any) -> None: """Run when tool ends running.""" await _ahandle_event( self.handlers, "on_tool_end", "ignore_agent", - output, + output_str, run_id=self.run_id, parent_run_id=self.parent_run_id, **kwargs, @@ -637,48 +678,47 @@ def configure( def _configure( callback_manager_cls: Type[T], - inheritable_callbacks: Optional[Union[T, List[BaseCallbackHandler]]] = None, - local_callbacks: Optional[Union[T, List[BaseCallbackHandler]]] = None, + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, verbose: bool = False, ) -> T: """Configure the callback manager.""" - callback_manager: Optional[T] = None + callback_manager = callback_manager_cls([]) if inheritable_callbacks or local_callbacks: - if isinstance(inheritable_callbacks, list) or not inheritable_callbacks: + if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: + inheritable_callbacks_: List[BaseCallbackHandler] = inheritable_callbacks or [] callback_manager = callback_manager_cls( - handlers=inheritable_callbacks if inheritable_callbacks else [], - inheritable_handlers=inheritable_callbacks - if inheritable_callbacks - else [], + handlers=inheritable_callbacks_, + inheritable_handlers=inheritable_callbacks_, ) else: - callback_manager = inheritable_callbacks + callback_manager = callback_manager_cls( + handlers=inheritable_callbacks.handlers, + inheritable_handlers=inheritable_callbacks.inheritable_handlers, + parent_run_id=inheritable_callbacks.parent_run_id, + ) callback_manager = copy.deepcopy(callback_manager) local_handlers_ = ( local_callbacks if isinstance(local_callbacks, list) else (local_callbacks.handlers if local_callbacks else []) ) - [ + for handler in local_handlers_: callback_manager.add_handler(copy.deepcopy(handler), False) - for handler in local_handlers_ - ] - if not callback_manager: - callback_manager = callback_manager_cls([]) tracer = tracing_callback_var.get() open_ai = openai_callback_var.get() - tracing_enabled = ( + tracing_enabled_ = ( os.environ.get("LANGCHAIN_TRACING") is not None or tracer is not None ) - if verbose or tracing_enabled or open_ai is not None: + if verbose or tracing_enabled_ or open_ai is not None: if verbose and not any( isinstance(handler, StdOutCallbackHandler) for handler in callback_manager.handlers ): callback_manager.add_handler(StdOutCallbackHandler(), False) - if tracing_enabled and not any( + if tracing_enabled_ and not any( isinstance(handler, LangChainTracer) for handler in callback_manager.handlers ): From 6cd653deb4b6c7074bc75a0e769f579318d62bda Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Thu, 27 Apr 2023 14:16:31 -0700 Subject: [PATCH 10/36] cr --- langchain/callbacks/base.py | 4 +- langchain/callbacks/manager.py | 12 +-- langchain/callbacks/tracers/base.py | 76 ++++++++----------- langchain/callbacks/tracers/langchain.py | 2 +- .../callbacks/test_langchain_tracer.py | 29 +++---- .../callbacks/test_openai_callback.py | 3 +- 6 files changed, 55 insertions(+), 71 deletions(-) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 9389c769b6ec9..07c5913fcbb5c 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -90,7 +90,7 @@ class ToolManagerMixin: def on_tool_end( self, - output_str: str, + output: str, *, run_id: str, parent_run_id: Optional[str] = None, @@ -273,7 +273,7 @@ async def on_tool_start( async def on_tool_end( self, - output_str: str, + output: str, *, run_id: str, parent_run_id: Optional[str] = None, diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index b7adf77f3010c..1c1ec4dea4213 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -405,7 +405,7 @@ def get_child(self) -> CallbackManager: def on_tool_end( self, - output_str: str, + output: str, **kwargs: Any, ) -> None: """Run when tool ends running.""" @@ -413,7 +413,7 @@ def on_tool_end( self.handlers, "on_tool_end", "ignore_agent", - output_str, + output, run_id=self.run_id, parent_run_id=self.parent_run_id, **kwargs, @@ -445,13 +445,13 @@ def get_child(self) -> AsyncCallbackManager: manager.set_handlers(self.inheritable_handlers) return manager - async def on_tool_end(self, output_str: str, **kwargs: Any) -> None: + async def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" await _ahandle_event( self.handlers, "on_tool_end", "ignore_agent", - output_str, + output, run_id=self.run_id, parent_run_id=self.parent_run_id, **kwargs, @@ -686,7 +686,9 @@ def _configure( callback_manager = callback_manager_cls([]) if inheritable_callbacks or local_callbacks: if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: - inheritable_callbacks_: List[BaseCallbackHandler] = inheritable_callbacks or [] + inheritable_callbacks_: List[BaseCallbackHandler] = ( + inheritable_callbacks or [] + ) callback_manager = callback_manager_cls( handlers=inheritable_callbacks_, inheritable_handlers=inheritable_callbacks_, diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index fbe9e6e1de513..0b6b21934ac6c 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -85,7 +85,7 @@ def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: self.run_map[run.uuid] = run - def _end_trace(self, run) -> None: + def _end_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """End a trace for a run.""" if not run.parent_uuid: self._persist_run(run) @@ -119,8 +119,9 @@ def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], - run_id: str = None, - parent_run_id: str = None, + *, + run_id: str, + parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: """Start a trace for an LLM run.""" @@ -144,17 +145,15 @@ def on_llm_start( ) self._start_trace(llm_run) - def on_llm_end( - self, response: LLMResult, run_id: str = None, **kwargs: Any - ) -> None: + def on_llm_end(self, response: LLMResult, *, run_id: str, **kwargs: Any) -> None: """End a trace for an LLM run.""" if not run_id: raise TracerException("No run_id provided for on_llm_end callback.") - if not self.run_map or not isinstance(self.run_map[run_id], LLMRun): + llm_run = self.run_map.get(run_id) + if llm_run is None or not isinstance(llm_run, LLMRun): raise TracerException("No LLMRun found to be traced") - llm_run = self.run_map[run_id] llm_run.response = response llm_run.end_time = datetime.utcnow() self._end_trace(llm_run) @@ -162,17 +161,18 @@ def on_llm_end( def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], - run_id: str = None, + *, + run_id: str, **kwargs: Any, ) -> None: """Handle an error for an LLM run.""" if not run_id: raise TracerException("No run_id provided for on_llm_error callback.") - if not self.run_map or not isinstance(self.run_map[run_id], LLMRun): + llm_run = self.run_map.get(run_id) + if llm_run is None or not isinstance(llm_run, LLMRun): raise TracerException("No LLMRun found to be traced") - llm_run = self.run_map[run_id] llm_run.error = repr(error) llm_run.end_time = datetime.utcnow() self._end_trace(llm_run) @@ -181,17 +181,15 @@ def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], - run_id: str = None, - parent_run_id: str = None, + *, + run_id: str, + parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: """Start a trace for a chain run.""" if self.session is None: self.session = self.load_default_session() - if run_id is None: - run_id = str(uuid4()) - execution_order = self._get_execution_order(parent_run_id) chain_run = ChainRun( uuid=run_id, @@ -208,16 +206,13 @@ def on_chain_start( self._start_trace(chain_run) def on_chain_end( - self, outputs: Dict[str, Any], run_id: str = None, **kwargs: Any + self, outputs: Dict[str, Any], *, run_id: str, **kwargs: Any ) -> None: """End a trace for a chain run.""" - if not run_id: - raise TracerException("No run_id provided for on_chain_end callback.") - - if not self.run_map or not isinstance(self.run_map[run_id], ChainRun): + chain_run = self.run_map.get(run_id) + if chain_run is None or not isinstance(chain_run, ChainRun): raise TracerException("No ChainRun found to be traced") - chain_run = self.run_map[run_id] chain_run.outputs = outputs chain_run.end_time = datetime.utcnow() self._end_trace(chain_run) @@ -225,17 +220,15 @@ def on_chain_end( def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], - run_id: str = None, + *, + run_id: str, **kwargs: Any, ) -> None: """Handle an error for a chain run.""" - if not run_id: - raise TracerException("No run_id provided for on_chain_error callback.") - - if not self.run_map or not isinstance(self.run_map[run_id], ChainRun): + chain_run = self.run_map.get(run_id) + if chain_run is None or not isinstance(chain_run, ChainRun): raise TracerException("No ChainRun found to be traced") - chain_run = self.run_map[run_id] chain_run.error = repr(error) chain_run.end_time = datetime.utcnow() self._end_trace(chain_run) @@ -244,17 +237,15 @@ def on_tool_start( self, serialized: Dict[str, Any], input_str: str, - run_id: str = None, - parent_run_id: str = None, + *, + run_id: str, + parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: """Start a trace for a tool run.""" if self.session is None: self.session = self.load_default_session() - if run_id is None: - run_id = str(uuid4()) - execution_order = self._get_execution_order(parent_run_id) tool_run = ToolRun( uuid=run_id, @@ -272,15 +263,12 @@ def on_tool_start( ) self._start_trace(tool_run) - def on_tool_end(self, output: str, run_id: str = None, **kwargs: Any) -> None: + def on_tool_end(self, output: str, *, run_id: str, **kwargs: Any) -> None: """End a trace for a tool run.""" - if not run_id: - raise TracerException("No run_id provided for on_tool_end callback.") - - if not self.run_map or not isinstance(self.run_map[run_id], ToolRun): + tool_run = self.run_map.get(run_id) + if tool_run is None or not isinstance(tool_run, ToolRun): raise TracerException("No ToolRun found to be traced") - tool_run = self.run_map[run_id] tool_run.output = output tool_run.end_time = datetime.utcnow() self._end_trace(tool_run) @@ -288,17 +276,15 @@ def on_tool_end(self, output: str, run_id: str = None, **kwargs: Any) -> None: def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], - run_id: str = None, + *, + run_id: str, **kwargs: Any, ) -> None: """Handle an error for a tool run.""" - if not run_id: - raise TracerException("No run_id provided for on_tool_error callback.") - - if not self.run_map or not isinstance(self.run_map[run_id], ToolRun): + tool_run = self.run_map.get(run_id) + if tool_run is None or not isinstance(tool_run, ToolRun): raise TracerException("No ToolRun found to be traced") - tool_run = self.run_map[run_id] tool_run.error = repr(error) tool_run.end_time = datetime.utcnow() self._end_trace(tool_run) diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index b67223f6453c0..80e7d2d2d62d3 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -20,7 +20,7 @@ class LangChainTracer(BaseTracer): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - def __init__(self, session_name="default", **kwargs: Any) -> None: + def __init__(self, session_name: str = "default", **kwargs: Any) -> None: """Initialize the LangChain tracer.""" super().__init__(**kwargs) self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") diff --git a/tests/integration_tests/callbacks/test_langchain_tracer.py b/tests/integration_tests/callbacks/test_langchain_tracer.py index 2fccf6d0b21e5..c17ebeb77fcd8 100644 --- a/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -1,3 +1,4 @@ +"""Integration tests for the langchain tracer module.""" import asyncio import os import time @@ -18,26 +19,20 @@ ] -def test_tracing_sequential(): +def test_tracing_sequential() -> None: os.environ["LANGCHAIN_TRACING"] = "true" - def generate_serially(): - for q in questions[:3]: - llm = OpenAI(temperature=0) - tools = load_tools(["llm-math", "serpapi"], llm=llm) - agent = initialize_agent( - tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True - ) - agent.run(q) - - s = time.perf_counter() - generate_serially() - elapsed = time.perf_counter() - s - print(f"Serial executed in {elapsed:0.2f} seconds.") + for q in questions[:3]: + llm = OpenAI(temperature=0) + tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + agent.run(q) @pytest.mark.asyncio -async def test_tracing_concurrent(): +async def test_tracing_concurrent() -> None: os.environ["LANGCHAIN_TRACING"] = "true" aiosession = ClientSession() llm = OpenAI(temperature=0) @@ -50,7 +45,7 @@ async def test_tracing_concurrent(): await aiosession.close() -def test_tracing_context_manager(): +def test_tracing_context_manager() -> None: llm = OpenAI(temperature=0) tools = load_tools(["llm-math", "serpapi"], llm=llm) agent = initialize_agent( @@ -66,7 +61,7 @@ def test_tracing_context_manager(): @pytest.mark.asyncio -async def test_tracing_context_manager_async(): +async def test_tracing_context_manager_async() -> None: llm = OpenAI(temperature=0) async_tools = load_tools(["llm-math", "serpapi"], llm=llm) agent = initialize_agent( diff --git a/tests/integration_tests/callbacks/test_openai_callback.py b/tests/integration_tests/callbacks/test_openai_callback.py index 74edc700e06cf..91a4a30aa1d1f 100644 --- a/tests/integration_tests/callbacks/test_openai_callback.py +++ b/tests/integration_tests/callbacks/test_openai_callback.py @@ -1,3 +1,4 @@ +"""Integration tests for the langchain tracer module.""" import asyncio import pytest @@ -7,7 +8,7 @@ @pytest.mark.asyncio -async def test_openai_callback(): +async def test_openai_callback() -> None: llm = OpenAI(temperature=0) with get_openai_callback() as cb: llm("What is the square root of 4?") From 8ae809af67d3076ed1312f576b3de0efd3cfa884 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Thu, 27 Apr 2023 17:50:01 -0700 Subject: [PATCH 11/36] mypy --- langchain/agents/tools.py | 2 +- langchain/callbacks/manager.py | 36 ++++-------- .../callbacks/fake_callback_handler.py | 4 +- .../callbacks/test_callback_manager.py | 56 +++++++++---------- .../callbacks/tracers/test_tracer.py | 3 +- 5 files changed, 44 insertions(+), 57 deletions(-) diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index 86dcadb498597..927d803f6daab 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -56,8 +56,8 @@ async def _arun( **kwargs: Any, ) -> str: """Use the tool asynchronously.""" - new_argument_supported = signature(self.coroutine).parameters.get("callbacks") if self.coroutine: + new_argument_supported = signature(self.coroutine).parameters.get("callbacks") return ( await self.coroutine( *args, diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 1c1ec4dea4213..435773ca2e236 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -7,7 +7,7 @@ import uuid from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union +from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, Sequence from langchain.callbacks.base import ( BaseCallbackHandler, @@ -23,7 +23,7 @@ from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.schema import AgentAction, AgentFinish, LLMResult -Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] +Callbacks = Optional[Union[Sequence[BaseCallbackHandler], BaseCallbackManager]] openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( "openai_callback", default=None @@ -174,9 +174,6 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): def on_llm_new_token( self, token: str, - *, - run_id: str, - parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: """Run when LLM generates a new token.""" @@ -225,9 +222,6 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): async def on_llm_new_token( self, token: str, - *, - run_id: Optional[str] = None, - parent_run_id: Optional[str] = None, **kwargs: Any, ) -> None: """Run when LLM generates a new token.""" @@ -559,14 +553,10 @@ def on_tool_start( @classmethod def configure( cls, - inheritable_callbacks: Optional[ - Union[BaseCallbackManager, List[BaseCallbackHandler]] - ] = None, - local_callbacks: Optional[ - Union[BaseCallbackManager, List[BaseCallbackHandler]] - ] = None, + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, verbose: bool = False, - ) -> Optional[BaseCallbackManager]: + ) -> CallbackManager: """Configure the callback manager.""" return _configure(cls, inheritable_callbacks, local_callbacks, verbose) @@ -661,14 +651,10 @@ async def on_tool_start( @classmethod def configure( cls, - inheritable_callbacks: Optional[ - Union[BaseCallbackManager, List[BaseCallbackHandler]] - ] = None, - local_callbacks: Optional[ - Union[BaseCallbackManager, List[BaseCallbackHandler]] - ] = None, + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, verbose: bool = False, - ) -> Optional[BaseCallbackManager]: + ) -> AsyncCallbackManager: """Configure the callback manager.""" return _configure(cls, inheritable_callbacks, local_callbacks, verbose) @@ -685,8 +671,8 @@ def _configure( """Configure the callback manager.""" callback_manager = callback_manager_cls([]) if inheritable_callbacks or local_callbacks: - if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: - inheritable_callbacks_: List[BaseCallbackHandler] = ( + if isinstance(inheritable_callbacks, Sequence) or inheritable_callbacks is None: + inheritable_callbacks_ = ( inheritable_callbacks or [] ) callback_manager = callback_manager_cls( @@ -702,7 +688,7 @@ def _configure( callback_manager = copy.deepcopy(callback_manager) local_handlers_ = ( local_callbacks - if isinstance(local_callbacks, list) + if isinstance(local_callbacks, Sequence) else (local_callbacks.handlers if local_callbacks else []) ) for handler in local_handlers_: diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 6dd92a2621b19..ef2b8171c2c33 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -189,7 +189,7 @@ def on_text( ) -> Any: self.on_text_common() - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": return self @@ -302,5 +302,5 @@ async def on_text( ) -> None: self.on_text_common() - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": return self diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index df72c1ea8e6ec..b87bccc723c5e 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -23,17 +23,17 @@ def _test_callback_manager( run_manager.on_llm_new_token("foo") run_manager.on_text("foo") - run_manager = manager.on_chain_start({"name": "foo"}, {}) - run_manager.on_chain_end({}) - run_manager.on_chain_error(Exception()) - run_manager.on_agent_action(AgentAction(tool_input="foo", log="", tool="")) - run_manager.on_agent_finish(AgentFinish(log="", return_values={})) - run_manager.on_text("foo") - - run_manager = manager.on_tool_start({}, "") - run_manager.on_tool_end("") - run_manager.on_tool_error(Exception()) - run_manager.on_text("foo") + run_manager_chain = manager.on_chain_start({"name": "foo"}, {}) + run_manager_chain.on_chain_end({}) + run_manager_chain.on_chain_error(Exception()) + run_manager_chain.on_agent_action(AgentAction(tool_input="foo", log="", tool="")) + run_manager_chain.on_agent_finish(AgentFinish(log="", return_values={})) + run_manager_chain.on_text("foo") + + run_manager_tool = manager.on_tool_start({}, "") + run_manager_tool.on_tool_end("") + run_manager_tool.on_tool_error(Exception()) + run_manager_tool.on_text("foo") _check_num_calls(handlers) @@ -47,17 +47,17 @@ async def _test_callback_manager_async( await run_manager.on_llm_new_token("foo") await run_manager.on_text("foo") - run_manager = await manager.on_chain_start({"name": "foo"}, {}) - await run_manager.on_chain_end({}) - await run_manager.on_chain_error(Exception()) - await run_manager.on_agent_action(AgentAction(tool_input="foo", log="", tool="")) - await run_manager.on_agent_finish(AgentFinish(log="", return_values={})) - await run_manager.on_text("foo") - - run_manager = await manager.on_tool_start({}, "") - await run_manager.on_tool_end("") - await run_manager.on_tool_error(Exception()) - await run_manager.on_text("foo") + run_manager_chain = await manager.on_chain_start({"name": "foo"}, {}) + await run_manager_chain.on_chain_end({}) + await run_manager_chain.on_chain_error(Exception()) + await run_manager_chain.on_agent_action(AgentAction(tool_input="foo", log="", tool="")) + await run_manager_chain.on_agent_finish(AgentFinish(log="", return_values={})) + await run_manager_chain.on_text("foo") + + run_manager_tool = await manager.on_tool_start({}, "") + await run_manager_tool.on_tool_end("") + await run_manager_tool.on_tool_error(Exception()) + await run_manager_tool.on_text("foo") _check_num_calls(handlers) @@ -191,13 +191,13 @@ def test_callback_manager_inheritance() -> None: assert child_manager.handlers == [handler1] assert child_manager.inheritable_handlers == [handler1] - child_manager = child_manager.on_tool_start({}, "") - assert child_manager.handlers == [handler1] - assert child_manager.inheritable_handlers == [handler1] + run_manager_tool = child_manager.on_tool_start({}, "") + assert run_manager_tool.handlers == [handler1] + assert run_manager_tool.inheritable_handlers == [handler1] - child_manager = child_manager.get_child() - assert child_manager.handlers == [handler1] - assert child_manager.inheritable_handlers == [handler1] + child_manager2 = run_manager_tool.get_child() + assert child_manager2.handlers == [handler1] + assert child_manager2.inheritable_handlers == [handler1] def test_callback_manager_configure() -> None: diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py index d0297a6f95212..54a8e3527841b 100644 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -277,7 +277,8 @@ def test_tracer_nested_run() -> None: """Test tracer on a nested run.""" tracer = FakeTracer() tracer.new_session() - [_perform_nested_run(tracer) for _ in range(10)] + for _ in range(10): + _perform_nested_run(tracer) assert tracer.runs == [_get_compare_run()] * 10 From 1fc3941430bb95f9e553cbd53b498e9c591cdd5f Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Thu, 27 Apr 2023 18:05:44 -0700 Subject: [PATCH 12/36] mypy --- langchain/agents/tools.py | 4 ++- langchain/callbacks/manager.py | 12 ++++----- langchain/llms/base.py | 2 +- langchain/llms/gpt4all.py | 27 ++++++++++--------- .../callbacks/test_callback_manager.py | 11 +++++--- 5 files changed, 31 insertions(+), 25 deletions(-) diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index 927d803f6daab..fad11582ed46d 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -57,7 +57,9 @@ async def _arun( ) -> str: """Use the tool asynchronously.""" if self.coroutine: - new_argument_supported = signature(self.coroutine).parameters.get("callbacks") + new_argument_supported = signature(self.coroutine).parameters.get( + "callbacks" + ) return ( await self.coroutine( *args, diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 435773ca2e236..0eae27aba5aac 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -7,7 +7,7 @@ import uuid from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, Sequence +from typing import Any, Dict, Generator, List, Optional, Sequence, Type, TypeVar, Union from langchain.callbacks.base import ( BaseCallbackHandler, @@ -23,7 +23,7 @@ from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.schema import AgentAction, AgentFinish, LLMResult -Callbacks = Optional[Union[Sequence[BaseCallbackHandler], BaseCallbackManager]] +Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( "openai_callback", default=None @@ -671,10 +671,8 @@ def _configure( """Configure the callback manager.""" callback_manager = callback_manager_cls([]) if inheritable_callbacks or local_callbacks: - if isinstance(inheritable_callbacks, Sequence) or inheritable_callbacks is None: - inheritable_callbacks_ = ( - inheritable_callbacks or [] - ) + if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: + inheritable_callbacks_ = inheritable_callbacks or [] callback_manager = callback_manager_cls( handlers=inheritable_callbacks_, inheritable_handlers=inheritable_callbacks_, @@ -688,7 +686,7 @@ def _configure( callback_manager = copy.deepcopy(callback_manager) local_handlers_ = ( local_callbacks - if isinstance(local_callbacks, Sequence) + if isinstance(local_callbacks, list) else (local_callbacks.handlers if local_callbacks else []) ) for handler in local_handlers_: diff --git a/langchain/llms/base.py b/langchain/llms/base.py index cdfe366e560b8..dcab983d9b199 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -185,7 +185,7 @@ def generate( missing_prompts, ) = get_prompts(params, prompts) if len(missing_prompts) > 0: - run_manager = self.callback_manager.on_llm_start( + run_manager = callback_manager.on_llm_start( {"name": self.__class__.__name__}, missing_prompts ) try: diff --git a/langchain/llms/gpt4all.py b/langchain/llms/gpt4all.py index 780bddb9f4342..15f09266fcf1a 100644 --- a/langchain/llms/gpt4all.py +++ b/langchain/llms/gpt4all.py @@ -161,10 +161,10 @@ def _llm_type(self) -> str: return "gpt4all" def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: r"""Call out to GPT4All's generate method. @@ -181,14 +181,17 @@ def _call( prompt = "Once upon a time, " response = model(prompt, n_predict=55) """ - text_callback = partial( - self.callback_manager.on_llm_new_token, verbose=self.verbose - ) - text = self.client.generate( - prompt, - new_text_callback=text_callback, - **self._default_params, - ) + if run_manager: + text_callback = partial( + run_manager.on_llm_new_token, verbose=self.verbose + ) + text = self.client.generate( + prompt, + new_text_callback=text_callback, + **self._default_params, + ) + else: + text = self.client.generate(prompt, **self._default_params) if stop is not None: text = enforce_stop_tokens(text, stop) return text diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index b87bccc723c5e..59a2fae81cc93 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -1,8 +1,9 @@ """Test CallbackManager.""" -from typing import Tuple +from typing import Tuple, List import pytest +from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult @@ -50,7 +51,9 @@ async def _test_callback_manager_async( run_manager_chain = await manager.on_chain_start({"name": "foo"}, {}) await run_manager_chain.on_chain_end({}) await run_manager_chain.on_chain_error(Exception()) - await run_manager_chain.on_agent_action(AgentAction(tool_input="foo", log="", tool="")) + await run_manager_chain.on_agent_action( + AgentAction(tool_input="foo", log="", tool="") + ) await run_manager_chain.on_agent_finish(AgentFinish(log="", return_values={})) await run_manager_chain.on_text("foo") @@ -209,8 +212,8 @@ def test_callback_manager_configure() -> None: FakeCallbackHandler(), ) - inheritable_callbacks = [handler1, handler2] - local_callbacks = [handler3, handler4] + inheritable_callbacks: List[BaseCallbackHandler] = [handler1, handler2] + local_callbacks: List[BaseCallbackHandler] = [handler3, handler4] configured_manager = CallbackManager.configure( inheritable_callbacks=inheritable_callbacks, local_callbacks=local_callbacks, From 15c0fa5e7e52bb001584302b6010d7a2e7ad92e6 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Thu, 27 Apr 2023 18:14:51 -0700 Subject: [PATCH 13/36] cr --- langchain/llms/gpt4all.py | 12 +++++------- tests/unit_tests/callbacks/test_callback_manager.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/langchain/llms/gpt4all.py b/langchain/llms/gpt4all.py index 15f09266fcf1a..ff3a6a5dfa6b3 100644 --- a/langchain/llms/gpt4all.py +++ b/langchain/llms/gpt4all.py @@ -161,10 +161,10 @@ def _llm_type(self) -> str: return "gpt4all" def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: r"""Call out to GPT4All's generate method. @@ -182,9 +182,7 @@ def _call( response = model(prompt, n_predict=55) """ if run_manager: - text_callback = partial( - run_manager.on_llm_new_token, verbose=self.verbose - ) + text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) text = self.client.generate( prompt, new_text_callback=text_callback, diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 59a2fae81cc93..6a21598574147 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -1,5 +1,5 @@ """Test CallbackManager.""" -from typing import Tuple, List +from typing import List, Tuple import pytest From 5dcb44ee1d3cb0833c4859be84ab84f77195493a Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Thu, 27 Apr 2023 19:48:47 -0700 Subject: [PATCH 14/36] fix llm chain --- langchain/chains/llm.py | 62 ++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index f2b8652bfc2b9..db29cd93e677f 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -7,7 +7,9 @@ from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import ( + AsyncCallbackManager, AsyncCallbackManagerForChainRun, + CallbackManager, CallbackManagerForChainRun, Callbacks, ) @@ -64,7 +66,8 @@ def _call( inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: - return self.apply([inputs], run_manager=run_manager)[0] + response = self.generate([inputs], run_manager=run_manager) + return self.create_outputs(response)[0] def generate( self, @@ -137,22 +140,44 @@ async def aprep_prompts( return prompts, stop def apply( - self, - input_list: List[Dict[str, Any]], - run_manager: Optional[CallbackManagerForChainRun] = None, + self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None ) -> List[Dict[str, str]]: """Utilize the LLM generate method for speed gains.""" - response = self.generate(input_list, run_manager=run_manager) - return self.create_outputs(response) + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + run_manager = callback_manager.on_chain_start( + {"name": self.__class__.__name__}, + {"input_list": input_list}, + ) + try: + response = self.generate(input_list, run_manager=run_manager) + except (KeyboardInterrupt, Exception) as e: + run_manager.on_chain_error(e) + raise e + outputs = self.create_outputs(response) + run_manager.on_chain_end({"outputs": outputs}) + return outputs async def aapply( - self, - input_list: List[Dict[str, Any]], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None ) -> List[Dict[str, str]]: """Utilize the LLM generate method for speed gains.""" - response = await self.agenerate(input_list, run_manager=run_manager) - return self.create_outputs(response) + callback_manager = AsyncCallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + run_manager = await callback_manager.on_chain_start( + {"name": self.__class__.__name__}, + {"input_list": input_list}, + ) + try: + response = await self.agenerate(input_list, run_manager=run_manager) + except (KeyboardInterrupt, Exception) as e: + await run_manager.on_chain_error(e) + raise e + outputs = self.create_outputs(response) + await run_manager.on_chain_end({"outputs": outputs}) + return outputs def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]: """Create outputs from response.""" @@ -167,7 +192,8 @@ async def _acall( inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, str]: - return (await self.aapply([inputs], run_manager=run_manager))[0] + response = await self.agenerate([inputs], run_manager=run_manager) + return self.create_outputs(response)[0] def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str: """Format prompt with kwargs and pass to LLM. @@ -224,12 +250,10 @@ async def apredict_and_parse( return result def apply_and_parse( - self, - input_list: List[Dict[str, Any]], - run_manager: Optional[CallbackManagerForChainRun] = None, + self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None ) -> Sequence[Union[str, List[str], Dict[str, str]]]: """Call apply and then parse the results.""" - result = self.apply(input_list, run_manager=run_manager) + result = self.apply(input_list, callbacks=callbacks) return self._parse_result(result) def _parse_result( @@ -243,12 +267,10 @@ def _parse_result( return result async def aapply_and_parse( - self, - input_list: List[Dict[str, Any]], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None ) -> Sequence[Union[str, List[str], Dict[str, str]]]: """Call apply and then parse the results.""" - result = await self.aapply(input_list, run_manager=run_manager) + result = await self.aapply(input_list, callbacks=callbacks) return self._parse_result(result) @property From da27d8713da5a3ce2ad41dcaadc377b7fae19830 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Thu, 27 Apr 2023 23:19:09 -0700 Subject: [PATCH 15/36] fix most tests --- langchain/agents/tools.py | 10 +- langchain/tools/base.py | 15 ++- tests/unit_tests/agents/test_agent.py | 133 +++++---------------- tests/unit_tests/agents/test_react.py | 8 +- tests/unit_tests/chains/test_base.py | 8 +- tests/unit_tests/chains/test_hyde.py | 14 ++- tests/unit_tests/chains/test_natbot.py | 8 +- tests/unit_tests/chains/test_sequential.py | 9 +- tests/unit_tests/llms/fake_llm.py | 8 +- tests/unit_tests/llms/test_callbacks.py | 15 --- 10 files changed, 97 insertions(+), 131 deletions(-) diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index fad11582ed46d..0a4e3172163bb 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -87,11 +87,17 @@ class InvalidTool(BaseTool): name = "invalid_tool" description = "Called when tool name is invalid." - def _run(self, tool_name: str) -> str: + def _run( + self, tool_name: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Use the tool.""" return f"{tool_name} is not a valid tool, try another one." - async def _arun(self, tool_name: str) -> str: + async def _arun( + self, + tool_name: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" return f"{tool_name} is not a valid tool, try another one." diff --git a/langchain/tools/base.py b/langchain/tools/base.py index a370305f0b56e..16ca3eda86fb3 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -106,6 +106,8 @@ def run( callback_manager = CallbackManager.configure( callbacks, self.callbacks, verbose=verbose_ ) + # TODO: maybe also pass through run_manager is _run supports kwargs + new_arg_supported = inspect.signature(self._run).parameters.get("run_manager") run_manager = callback_manager.on_tool_start( {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), @@ -114,7 +116,11 @@ def run( ) try: tool_args, tool_kwargs = _to_args_and_kwargs(tool_input) - observation = self._run(*tool_args, run_manager=run_manager, **tool_kwargs) + observation = ( + self._run(*tool_args, run_manager=run_manager, **tool_kwargs) + if new_arg_supported + else self._run(*tool_args, **tool_kwargs) + ) except (Exception, KeyboardInterrupt) as e: run_manager.on_tool_error(e) raise e @@ -139,6 +145,7 @@ async def arun( callback_manager = AsyncCallbackManager.configure( callbacks, self.callbacks, verbose=verbose_ ) + new_arg_supported = inspect.signature(self._arun).parameters.get("run_manager") run_manager = await callback_manager.on_tool_start( {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), @@ -148,7 +155,11 @@ async def arun( try: # We then call the tool on the tool input to get an observation args, kwargs = _to_args_and_kwargs(tool_input) - observation = await self._arun(*args, run_manager=run_manager, **kwargs) + observation = ( + await self._arun(*args, run_manager=run_manager, **kwargs) + if new_arg_supported + else await self._arun(*args, **kwargs) + ) except (Exception, KeyboardInterrupt) as e: await run_manager.on_tool_error(e) raise e diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 4b0f736a4a43b..1fb94cf21a49b 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -4,7 +4,7 @@ from langchain.agents import AgentExecutor, AgentType, initialize_agent from langchain.agents.tools import Tool -from langchain.callbacks.manager import CallbackManager +from langchain.callbacks.manager import CallbackManager, CallbackManagerForLLMRun from langchain.llms.base import LLM from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -15,7 +15,12 @@ class FakeListLLM(LLM): responses: List[str] i: int = -1 - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Increment counter, and then return response in that index.""" self.i += 1 print(f"=== Mock Response #{self.i} ===") @@ -82,135 +87,57 @@ def test_agent_stopped_early() -> None: assert output == "Agent stopped due to iteration limit or time limit." -def test_agent_with_callbacks_global() -> None: +def test_agent_with_callbacks() -> None: """Test react chain with callbacks by setting verbose globally.""" - import langchain + handler1 = FakeCallbackHandler() + handler2 = FakeCallbackHandler() - langchain.verbose = True - handler = FakeCallbackHandler() - manager = CallbackManager(handlers=[handler]) tool = "Search" responses = [ f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", "Oh well\nFinal Answer: curses foiled again", ] - fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True) + # Only fake LLM gets callbacks for handler2 + fake_llm = FakeListLLM(responses=responses, callbacks=[handler2]) tools = [ Tool( name="Search", func=lambda x: x, description="Useful for searching", - callback_manager=manager, ), ] agent = initialize_agent( tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, - verbose=True, - callback_manager=manager, ) - output = agent.run("when was langchain made") + output = agent.run("when was langchain made", callbacks=[handler1]) assert output == "curses foiled again" # 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run - assert handler.chain_starts == handler.chain_ends == 3 - assert handler.llm_starts == handler.llm_ends == 2 - assert handler.tool_starts == 2 - assert handler.tool_ends == 1 - # 1 extra agent action - assert handler.starts == 7 - # 1 extra agent end - assert handler.ends == 7 - assert handler.errors == 0 - # during LLMChain - assert handler.text == 2 - - -def test_agent_with_callbacks_local() -> None: - """Test react chain with callbacks by setting verbose locally.""" - import langchain - - langchain.verbose = False - handler = FakeCallbackHandler() - manager = CallbackManager(handlers=[handler]) - tool = "Search" - responses = [ - f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", - "Oh well\nFinal Answer: curses foiled again", - ] - fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True) - tools = [ - Tool( - name="Search", - func=lambda x: x, - description="Useful for searching", - callback_manager=manager, - ), - ] - agent = initialize_agent( - tools, - fake_llm, - agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, - verbose=True, - callback_manager=manager, - ) - - agent.agent.llm_chain.verbose = True # type: ignore - - output = agent.run("when was langchain made") - assert output == "curses foiled again" - - # 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run - assert handler.chain_starts == handler.chain_ends == 3 - assert handler.llm_starts == handler.llm_ends == 2 - assert handler.tool_starts == 2 - assert handler.tool_ends == 1 + assert handler1.chain_starts == handler1.chain_ends == 3 + assert handler1.llm_starts == handler1.llm_ends == 2 + assert handler1.tool_starts == 1 + assert handler1.tool_ends == 1 # 1 extra agent action - assert handler.starts == 7 + assert handler1.starts == 7 # 1 extra agent end - assert handler.ends == 7 - assert handler.errors == 0 + assert handler1.ends == 7 + assert handler1.errors == 0 # during LLMChain - assert handler.text == 2 - - -def test_agent_with_callbacks_not_verbose() -> None: - """Test react chain with callbacks but not verbose.""" - import langchain - - langchain.verbose = False - handler = FakeCallbackHandler() - manager = CallbackManager(handlers=[handler]) - tool = "Search" - responses = [ - f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", - "Oh well\nFinal Answer: curses foiled again", - ] - fake_llm = FakeListLLM(responses=responses, callback_manager=manager) - tools = [ - Tool( - name="Search", - func=lambda x: x, - description="Useful for searching", - ), - ] - agent = initialize_agent( - tools, - fake_llm, - agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, - callback_manager=manager, + assert handler1.text == 2 + + assert handler2.llm_starts == 2 + assert handler2.llm_ends == 2 + assert ( + handler2.chain_starts + == handler2.tool_starts + == handler2.tool_ends + == handler2.chain_ends + == 0 ) - output = agent.run("when was langchain made") - assert output == "curses foiled again" - - # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run - assert handler.starts == 0 - assert handler.ends == 0 - assert handler.errors == 0 - def test_agent_tool_return_direct() -> None: """Test agent using tools that return directly.""" diff --git a/tests/unit_tests/agents/test_react.py b/tests/unit_tests/agents/test_react.py index 2689ea36eff05..8f2a3ff2fc50d 100644 --- a/tests/unit_tests/agents/test_react.py +++ b/tests/unit_tests/agents/test_react.py @@ -4,6 +4,7 @@ from langchain.agents.react.base import ReActChain, ReActDocstoreAgent from langchain.agents.tools import Tool +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.docstore.base import Docstore from langchain.docstore.document import Document from langchain.llms.base import LLM @@ -32,7 +33,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "fake_list" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Increment counter, and then return response in that index.""" self.i += 1 return self.responses[self.i] diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 8c3081c2f149c..84c8247c5f2e7 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -3,7 +3,7 @@ import pytest -from langchain.callbacks.manager import CallbackManager +from langchain.callbacks.manager import CallbackManager, CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.schema import BaseMemory from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -49,7 +49,11 @@ def output_keys(self) -> List[str]: """Output key of bar.""" return self.the_output_keys - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: if self.be_correct: return {"bar": "baz"} else: diff --git a/tests/unit_tests/chains/test_hyde.py b/tests/unit_tests/chains/test_hyde.py index cc3e6ae42f801..dd2ade83c1825 100644 --- a/tests/unit_tests/chains/test_hyde.py +++ b/tests/unit_tests/chains/test_hyde.py @@ -3,6 +3,10 @@ import numpy as np +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.embeddings.base import Embeddings @@ -28,12 +32,18 @@ class FakeLLM(BaseLLM): n: int = 1 def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) diff --git a/tests/unit_tests/chains/test_natbot.py b/tests/unit_tests/chains/test_natbot.py index fd30901af3ff5..b8c6538e2867e 100644 --- a/tests/unit_tests/chains/test_natbot.py +++ b/tests/unit_tests/chains/test_natbot.py @@ -2,6 +2,7 @@ from typing import Any, List, Mapping, Optional +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chains.natbot.base import NatBotChain from langchain.llms.base import LLM @@ -9,7 +10,12 @@ class FakeLLM(LLM): """Fake LLM wrapper for testing purposes.""" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Return `foo` if longer than 10000 words, else `bar`.""" if len(prompt) > 10000: return "foo" diff --git a/tests/unit_tests/chains/test_sequential.py b/tests/unit_tests/chains/test_sequential.py index 2ef0e7d4eeae0..19e7df106676e 100644 --- a/tests/unit_tests/chains/test_sequential.py +++ b/tests/unit_tests/chains/test_sequential.py @@ -1,8 +1,9 @@ """Test pipeline functionality.""" -from typing import Dict, List +from typing import Dict, List, Optional import pytest +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.memory.simple import SimpleMemory @@ -24,7 +25,11 @@ def output_keys(self) -> List[str]: """Input keys this chain returns.""" return self.output_variables - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: outputs = {} for var in self.output_variables: variables = [inputs[k] for k in self.input_variables] diff --git a/tests/unit_tests/llms/fake_llm.py b/tests/unit_tests/llms/fake_llm.py index cc12a7cab7b5e..8815cc0b82809 100644 --- a/tests/unit_tests/llms/fake_llm.py +++ b/tests/unit_tests/llms/fake_llm.py @@ -3,6 +3,7 @@ from pydantic import validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -28,7 +29,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "fake" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: if self.sequential_responses: return self._get_next_response_in_sequence diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py index 2802e58e38420..78480816edc81 100644 --- a/tests/unit_tests/llms/test_callbacks.py +++ b/tests/unit_tests/llms/test_callbacks.py @@ -13,18 +13,3 @@ def test_llm_with_callbacks() -> None: assert handler.starts == 1 assert handler.ends == 1 assert handler.errors == 0 - - -def test_llm_with_callbacks_not_verbose() -> None: - """Test LLM callbacks but not verbose.""" - import langchain - - langchain.verbose = False - - handler = FakeCallbackHandler() - llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler])) - output = llm("foo") - assert output == "foo" - assert handler.starts == 0 - assert handler.ends == 0 - assert handler.errors == 0 From 2ed4649e50e68f68e932c04e24659dbb9026d275 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Thu, 27 Apr 2023 23:23:58 -0700 Subject: [PATCH 16/36] fix baby agi --- .../experimental/autonomous_agents/baby_agi/baby_agi.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py b/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py index 3b7ce122f2831..ba87e5edede24 100644 --- a/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py +++ b/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py @@ -1,9 +1,11 @@ +"""BabyAGI agent.""" from collections import deque from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.experimental.autonomous_agents.baby_agi.task_creation import ( TaskCreationChain, @@ -112,7 +114,11 @@ def execute_task(self, objective: str, task: str, k: int = 5) -> str: objective=objective, context="\n".join(context), task=task ) - def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """Run the agent.""" objective = inputs["objective"] first_task = inputs.get("first_task", "Make a todo list") From 0e81e83466d9bed27029be5ab0692a4c6b9d1c24 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 28 Apr 2023 18:37:29 +0100 Subject: [PATCH 17/36] Nc/callbacks docs (#3717) --- docs/ecosystem/aim_tracking.ipynb | 12 +- docs/ecosystem/clearml_tracking.ipynb | 10 +- docs/ecosystem/comet_tracking.ipynb | 31 ++- docs/ecosystem/gpt4all.md | 9 +- docs/ecosystem/wandb_tracking.ipynb | 14 +- .../examples/async_agent.ipynb | 19 +- .../custom_agent_with_tool_retrieval.ipynb | 1 + docs/modules/callbacks/getting_started.ipynb | 40 +++- .../modules/chains/generic/custom_chain.ipynb | 188 ++++++++++++++++++ .../index_examples/chat_vector_db.ipynb | 3 +- .../models/chat/examples/streaming.ipynb | 3 +- .../modules/models/chat/getting_started.ipynb | 3 +- .../models/llms/examples/streaming_llm.ipynb | 17 +- .../models/llms/integrations/gpt4all.ipynb | 5 +- docs/tracing.md | 14 +- 15 files changed, 286 insertions(+), 83 deletions(-) create mode 100644 docs/modules/chains/generic/custom_chain.ipynb diff --git a/docs/ecosystem/aim_tracking.ipynb b/docs/ecosystem/aim_tracking.ipynb index fa9755f98692e..c7b1cc62ffec0 100644 --- a/docs/ecosystem/aim_tracking.ipynb +++ b/docs/ecosystem/aim_tracking.ipynb @@ -61,7 +61,6 @@ "from datetime import datetime\n", "\n", "from langchain.llms import OpenAI\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks import AimCallbackHandler, StdOutCallbackHandler" ] }, @@ -109,8 +108,8 @@ " experiment_name=\"scenario 1: OpenAI LLM\",\n", ")\n", "\n", - "manager = CallbackManager([StdOutCallbackHandler(), aim_callback])\n", - "llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)" + "callbacks = [StdOutCallbackHandler(), aim_callback]\n", + "llm = OpenAI(temperature=0, callbacks=callbacks)" ] }, { @@ -177,7 +176,7 @@ "Title: {title}\n", "Playwright: This is a synopsis for the above play:\"\"\"\n", "prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n", - "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n", "\n", "test_prompts = [\n", " {\"title\": \"documentary about good video games that push the boundary of game design\"},\n", @@ -249,13 +248,12 @@ ], "source": [ "# scenario 3 - Agent with Tools\n", - "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n", + "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n", "agent = initialize_agent(\n", " tools,\n", " llm,\n", " agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n", - " callback_manager=manager,\n", - " verbose=True,\n", + " callbacks=callbacks,\n", ")\n", "agent.run(\n", " \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n", diff --git a/docs/ecosystem/clearml_tracking.ipynb b/docs/ecosystem/clearml_tracking.ipynb index 20b118c6960ef..0fb33c2dd8ba9 100644 --- a/docs/ecosystem/clearml_tracking.ipynb +++ b/docs/ecosystem/clearml_tracking.ipynb @@ -79,7 +79,6 @@ "source": [ "from datetime import datetime\n", "from langchain.callbacks import ClearMLCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.llms import OpenAI\n", "\n", "# Setup and use the ClearML Callback\n", @@ -93,9 +92,9 @@ " complexity_metrics=True,\n", " stream_logs=True\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), clearml_callback])\n", + "callbacks = [StdOutCallbackHandler(), clearml_callback]\n", "# Get the OpenAI model ready to go\n", - "llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)" + "llm = OpenAI(temperature=0, callbacks=callbacks)" ] }, { @@ -523,13 +522,12 @@ "from langchain.agents import AgentType\n", "\n", "# SCENARIO 2 - Agent with Tools\n", - "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n", + "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n", "agent = initialize_agent(\n", " tools,\n", " llm,\n", " agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n", - " callback_manager=manager,\n", - " verbose=True,\n", + " callbacks=callbacks,\n", ")\n", "agent.run(\n", " \"Who is the wife of the person who sang summer of 69?\"\n", diff --git a/docs/ecosystem/comet_tracking.ipynb b/docs/ecosystem/comet_tracking.ipynb index 4d33bd00ab55b..9b49ff886b006 100644 --- a/docs/ecosystem/comet_tracking.ipynb +++ b/docs/ecosystem/comet_tracking.ipynb @@ -121,7 +121,6 @@ "from datetime import datetime\n", "\n", "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.llms import OpenAI\n", "\n", "comet_callback = CometCallbackHandler(\n", @@ -131,8 +130,8 @@ " tags=[\"llm\"],\n", " visualizations=[\"dep\"],\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", - "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "callbacks = [StdOutCallbackHandler(), comet_callback]\n", + "llm = OpenAI(temperature=0.9, callbacks=callbacks, verbose=True)\n", "\n", "llm_result = llm.generate([\"Tell me a joke\", \"Tell me a poem\", \"Tell me a fact\"] * 3)\n", "print(\"LLM result\", llm_result)\n", @@ -153,7 +152,6 @@ "outputs": [], "source": [ "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.chains import LLMChain\n", "from langchain.llms import OpenAI\n", "from langchain.prompts import PromptTemplate\n", @@ -164,15 +162,14 @@ " stream_logs=True,\n", " tags=[\"synopsis-chain\"],\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", - "\n", - "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "callbacks = [StdOutCallbackHandler(), comet_callback]\n", + "llm = OpenAI(temperature=0.9, callbacks=callbacks)\n", "\n", "template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n", "Title: {title}\n", "Playwright: This is a synopsis for the above play:\"\"\"\n", "prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n", - "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n", "\n", "test_prompts = [{\"title\": \"Documentary about Bigfoot in Paris\"}]\n", "print(synopsis_chain.apply(test_prompts))\n", @@ -194,7 +191,6 @@ "source": [ "from langchain.agents import initialize_agent, load_tools\n", "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.llms import OpenAI\n", "\n", "comet_callback = CometCallbackHandler(\n", @@ -203,15 +199,15 @@ " stream_logs=True,\n", " tags=[\"agent\"],\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", - "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "callbacks = [StdOutCallbackHandler(), comet_callback]\n", + "llm = OpenAI(temperature=0.9, callbacks=callbacks)\n", "\n", - "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n", + "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n", "agent = initialize_agent(\n", " tools,\n", " llm,\n", " agent=\"zero-shot-react-description\",\n", - " callback_manager=manager,\n", + " callbacks=callbacks,\n", " verbose=True,\n", ")\n", "agent.run(\n", @@ -255,7 +251,6 @@ "from rouge_score import rouge_scorer\n", "\n", "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.chains import LLMChain\n", "from langchain.llms import OpenAI\n", "from langchain.prompts import PromptTemplate\n", @@ -298,10 +293,10 @@ " tags=[\"custom_metrics\"],\n", " custom_metrics=rouge_score.compute_metric,\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", - "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "callbacks = [StdOutCallbackHandler(), comet_callback]\n", + "llm = OpenAI(temperature=0.9)\n", "\n", - "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)\n", "\n", "test_prompts = [\n", " {\n", @@ -323,7 +318,7 @@ " \"\"\"\n", " }\n", "]\n", - "print(synopsis_chain.apply(test_prompts))\n", + "print(synopsis_chain.apply(test_prompts, callbacks=callbacks))\n", "comet_callback.flush_tracker(synopsis_chain, finish=True)" ] } diff --git a/docs/ecosystem/gpt4all.md b/docs/ecosystem/gpt4all.md index ea4704d8f76f1..7dc5a0252becc 100644 --- a/docs/ecosystem/gpt4all.md +++ b/docs/ecosystem/gpt4all.md @@ -3,6 +3,7 @@ This page covers how to use the `GPT4All` wrapper within LangChain. The tutorial is divided into two parts: installation and setup, followed by usage with an example. ## Installation and Setup + - Install the Python package with `pip install pyllamacpp` - Download a [GPT4All model](https://github.com/nomic-ai/pyllamacpp#supported-model) and place it in your desired directory @@ -28,18 +29,16 @@ To stream the model's predictions, add in a CallbackManager. ```python from langchain.llms import GPT4All -from langchain.callbacks.manager import CallbackManager from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # There are many CallbackHandlers supported, such as # from langchain.callbacks.streamlit import StreamlitCallbackHandler -callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) -model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8, callback_handler=callback_handler, - verbose=True) +callbacks = [StreamingStdOutCallbackHandler()] +model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8) # Generate text. Tokens are streamed through the callback manager. -model("Once upon a time, ") +model("Once upon a time, ", callbacks=callbacks) ``` ## Model File diff --git a/docs/ecosystem/wandb_tracking.ipynb b/docs/ecosystem/wandb_tracking.ipynb index 9ead0230f2c37..78e4fb6a80bff 100644 --- a/docs/ecosystem/wandb_tracking.ipynb +++ b/docs/ecosystem/wandb_tracking.ipynb @@ -50,7 +50,6 @@ "source": [ "from datetime import datetime\n", "from langchain.callbacks import WandbCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.llms import OpenAI" ] }, @@ -196,8 +195,8 @@ " name=\"llm\",\n", " tags=[\"test\"],\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), wandb_callback])\n", - "llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)" + "callbacks = [StdOutCallbackHandler(), wandb_callback]\n", + "llm = OpenAI(temperature=0, callbacks=callbacks)" ] }, { @@ -484,7 +483,7 @@ "Title: {title}\n", "Playwright: This is a synopsis for the above play:\"\"\"\n", "prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n", - "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n", "\n", "test_prompts = [\n", " {\n", @@ -577,16 +576,15 @@ ], "source": [ "# SCENARIO 3 - Agent with Tools\n", - "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n", + "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n", "agent = initialize_agent(\n", " tools,\n", " llm,\n", " agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n", - " callback_manager=manager,\n", - " verbose=True,\n", ")\n", "agent.run(\n", - " \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n", + " \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\",\n", + " callbacks=callbacks,\n", ")\n", "wandb_callback.flush_tracker(agent, reset=False, finish=True)" ] diff --git a/docs/modules/agents/agent_executors/examples/async_agent.ipynb b/docs/modules/agents/agent_executors/examples/async_agent.ipynb index 3ef46cb4d69d0..925700ed33dfe 100644 --- a/docs/modules/agents/agent_executors/examples/async_agent.ipynb +++ b/docs/modules/agents/agent_executors/examples/async_agent.ipynb @@ -42,7 +42,6 @@ "from langchain.agents import AgentType\n", "from langchain.llms import OpenAI\n", "from langchain.callbacks.stdout import StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.tracers import LangChainTracer\n", "from aiohttp import ClientSession\n", "\n", @@ -307,14 +306,14 @@ " # To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n", " # but you must manually close the client session at the end of your program/event loop\n", " aiosession = ClientSession()\n", + " callbacks = [StdOutCallbackHandler()]\n", " for _ in questions:\n", - " manager = CallbackManager([StdOutCallbackHandler()])\n", - " llm = OpenAI(temperature=0, callback_manager=manager)\n", - " async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession, callback_manager=manager)\n", + " llm = OpenAI(temperature=0)\n", + " async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession)\n", " agents.append(\n", - " initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n", + " initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)\n", " )\n", - " tasks = [async_agent.arun(q) for async_agent, q in zip(agents, questions)]\n", + " tasks = [async_agent.arun(q, callbacks=callbacks) for async_agent, q in zip(agents, questions)]\n", " await asyncio.gather(*tasks)\n", " await aiosession.close()\n", "\n", @@ -376,14 +375,14 @@ "aiosession = ClientSession()\n", "tracer = LangChainTracer()\n", "tracer.load_default_session()\n", - "manager = CallbackManager([StdOutCallbackHandler(), tracer])\n", + "callbacks = [StdOutCallbackHandler(), tracer]\n", "\n", "# Pass the manager into the llm if you want llm calls traced.\n", - "llm = OpenAI(temperature=0, callback_manager=manager)\n", + "llm = OpenAI(temperature=0)\n", "\n", "async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession)\n", - "async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n", - "await async_agent.arun(questions[0])\n", + "async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)\n", + "await async_agent.arun(questions[0], callbacks=callbacks)\n", "await aiosession.close()" ] } diff --git a/docs/modules/agents/agents/custom_agent_with_tool_retrieval.ipynb b/docs/modules/agents/agents/custom_agent_with_tool_retrieval.ipynb index c81cc19e73045..6bbb4ad43a9c2 100644 --- a/docs/modules/agents/agents/custom_agent_with_tool_retrieval.ipynb +++ b/docs/modules/agents/agents/custom_agent_with_tool_retrieval.ipynb @@ -373,6 +373,7 @@ "metadata": {}, "outputs": [], "source": [ + "tools = get_tools(\"whats the weather?\")\n", "tool_names = [tool.name for tool in tools]\n", "agent = LLMSingleActionAgent(\n", " llm_chain=llm_chain, \n", diff --git a/docs/modules/callbacks/getting_started.ipynb b/docs/modules/callbacks/getting_started.ipynb index acd4ee9676daa..74e8883afb906 100644 --- a/docs/modules/callbacks/getting_started.ipynb +++ b/docs/modules/callbacks/getting_started.ipynb @@ -92,6 +92,27 @@ "```" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "cbccd7d1", + "metadata": {}, + "source": [ + "## How to use callbacks\n", + "\n", + "The `callbacks` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) in two different places:\n", + "\n", + "- **Constructor callbacks**: defined in the constructor, eg. `LLMChain(callbacks=[handler])`, which will be used for all calls made on that object, and will be scoped to that object only, eg. if you pass a handler to the `LLMChain` constructor, it will not be used by the Model attached to that chain.\n", + "- **Request callbacks**: defined in the `call()`/`run()`/`apply()` methods used for issuing a request, eg. `chain.call(inputs, callbacks=[handler])`, which will be used for that specific request only, and all sub-requests that it contains (eg. a call to an LLMChain triggers a call to a Model, which uses the same handler passed in the `call()` method).\n", + "\n", + "The `verbose` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) as a constructor argument, eg. `LLMChain(verbose=True)`, and it is equivalent to passing a `ConsoleCallbackHandler` to the `callbacks` argument of that object and all child objects. This is useful for debugging, as it will log all events to the console.\n", + "\n", + "### When do you want to use each of these?\n", + "\n", + "- Constructor callbacks are most useful for use cases such as logging, monitoring, etc., which are _not specific to a single request_, but rather to the entire chain. For example, if you want to log all the requests made to an LLMChain, you would pass a handler to the constructor.\n", + "- Request callbacks are most useful for use cases such as streaming, where you want to stream the output of a single request to a specific websocket connection, or other similar use cases. For example, if you want to stream the output of a single request to a websocket, you would pass a handler to the `call()` method" + ] + }, { "cell_type": "markdown", "id": "d3bf3304-43fb-47ad-ae50-0637a17018a2", @@ -106,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "80532dfc-d687-4147-a0c9-1f90cc3e868c", "metadata": { "tags": [] @@ -129,6 +150,13 @@ "Prompt after formatting:\n", "\u001b[32;1m\u001b[1;3m1 + 2 = \u001b[0m\n", "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3m1 + 2 = \u001b[0m\n", + "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, @@ -138,7 +166,7 @@ "'\\n\\n3'" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -159,7 +187,11 @@ "\n", "# Then, let's use the `verbose` flag to achieve the same result\n", "chain = LLMChain(llm=llm, prompt=prompt, verbose=True)\n", - "chain.run(number=2)" + "chain.run(number=2)\n", + "\n", + "# Finally, let's use the request `callbacks` to achieve the same result\n", + "chain = LLMChain(llm=llm, prompt=prompt)\n", + "chain.run(number=2, callbacks=[handler])" ] }, { @@ -857,7 +889,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.10" } }, "nbformat": 4, diff --git a/docs/modules/chains/generic/custom_chain.ipynb b/docs/modules/chains/generic/custom_chain.ipynb new file mode 100644 index 0000000000000..9f71fe3375e8a --- /dev/null +++ b/docs/modules/chains/generic/custom_chain.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "593f7553-7038-498e-96d4-8255e5ce34f0", + "metadata": {}, + "source": [ + "# Creating a custom Chain\n", + "\n", + "To implement your own custom chain you can subclass `BaseChain` and implement the following methods:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c19c736e-ca74-4726-bb77-0a849bcc2960", + "metadata": { + "tags": [], + "vscode": { + "languageId": "python" + } + }, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "from typing import Any, Dict, List, Optional\n", + "\n", + "from pydantic import Extra\n", + "\n", + "from langchain.base_language import BaseLanguageModel\n", + "from langchain.callbacks.manager import (\n", + " AsyncCallbackManagerForChainRun,\n", + " CallbackManagerForChainRun,\n", + ")\n", + "from langchain.chains.base import Chain\n", + "from langchain.prompts.base import BasePromptTemplate\n", + "\n", + "\n", + "class MyCustomChain(Chain):\n", + " \"\"\"\n", + " An example of a custom chain.\n", + " \"\"\"\n", + "\n", + " prompt: BasePromptTemplate\n", + " \"\"\"Prompt object to use.\"\"\"\n", + " llm: BaseLanguageModel\n", + " output_key: str = \"text\" #: :meta private:\n", + "\n", + " class Config:\n", + " \"\"\"Configuration for this pydantic object.\"\"\"\n", + "\n", + " extra = Extra.forbid\n", + " arbitrary_types_allowed = True\n", + "\n", + " @property\n", + " def input_keys(self) -> List[str]:\n", + " \"\"\"Will be whatever keys the prompt expects.\n", + "\n", + " :meta private:\n", + " \"\"\"\n", + " return self.prompt.input_variables\n", + "\n", + " @property\n", + " def output_keys(self) -> List[str]:\n", + " \"\"\"Will always return text key.\n", + "\n", + " :meta private:\n", + " \"\"\"\n", + " return [self.output_key]\n", + "\n", + " def _call(\n", + " self,\n", + " inputs: Dict[str, Any],\n", + " run_manager: Optional[CallbackManagerForChainRun] = None,\n", + " ) -> Dict[str, str]:\n", + " # Your custom chain logic goes here\n", + " # This is just an example that mimics LLMChain\n", + " prompt_value = self.prompt.format_prompt(**inputs)\n", + " \n", + " # Whenever you call a language model, or another chain, you should pass\n", + " # a callback manager to it. This allows the inner run to be tracked by\n", + " # any callbacks that are registered on the outer run.\n", + " # You can always obtain a callback manager for this by calling\n", + " # `run_manager.get_child()` as shown below.\n", + " response = self.llm.generate_prompt(\n", + " [prompt_value],\n", + " callbacks=run_manager.get_child() if run_manager else None\n", + " )\n", + "\n", + " # If you want to log something about this run, you can do so by calling\n", + " # methods on the `run_manager`, as shown below. This will trigger any\n", + " # callbacks that are registered for that event.\n", + " if run_manager:\n", + " run_manager.on_text(\"Log something about this run\")\n", + " \n", + " return {self.output_key: response.generations[0][0].text}\n", + "\n", + " async def _acall(\n", + " self,\n", + " inputs: Dict[str, Any],\n", + " run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n", + " ) -> Dict[str, str]:\n", + " # Your custom chain logic goes here\n", + " # This is just an example that mimics LLMChain\n", + " prompt_value = self.prompt.format_prompt(**inputs)\n", + " \n", + " # Whenever you call a language model, or another chain, you should pass\n", + " # a callback manager to it. This allows the inner run to be tracked by\n", + " # any callbacks that are registered on the outer run.\n", + " # You can always obtain a callback manager for this by calling\n", + " # `run_manager.get_child()` as shown below.\n", + " response = await self.llm.agenerate_prompt(\n", + " [prompt_value],\n", + " callbacks=run_manager.get_child() if run_manager else None\n", + " )\n", + "\n", + " # If you want to log something about this run, you can do so by calling\n", + " # methods on the `run_manager`, as shown below. This will trigger any\n", + " # callbacks that are registered for that event.\n", + " if run_manager:\n", + " await run_manager.on_text(\"Log something about this run\")\n", + " \n", + " return {self.output_key: response.generations[0][0].text}\n", + "\n", + " @property\n", + " def _chain_type(self) -> str:\n", + " return \"my_custom_chain\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "18361f89", + "metadata": { + "vscode": { + "languageId": "python" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new MyCustomChain chain...\u001b[0m\n", + "Log something about this run\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Why did the callback function feel lonely? Because it was always waiting for someone to call it back!'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.callbacks.stdout import StdOutCallbackHandler\n", + "from langchain.chat_models.openai import ChatOpenAI\n", + "from langchain.prompts.prompt import PromptTemplate\n", + "\n", + "\n", + "chain = MyCustomChain(\n", + " prompt=PromptTemplate.from_template('tell us a joke about {topic}'),\n", + " llm=ChatOpenAI()\n", + ")\n", + "\n", + "chain.run({'topic': 'callbacks'}, callbacks=[StdOutCallbackHandler()])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/modules/chains/index_examples/chat_vector_db.ipynb b/docs/modules/chains/index_examples/chat_vector_db.ipynb index b5e28b5191e8b..1cfb81c2c59ac 100644 --- a/docs/modules/chains/index_examples/chat_vector_db.ipynb +++ b/docs/modules/chains/index_examples/chat_vector_db.ipynb @@ -487,7 +487,6 @@ "outputs": [], "source": [ "from langchain.chains.llm import LLMChain\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", "from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n", "from langchain.chains.question_answering import load_qa_chain\n", @@ -495,7 +494,7 @@ "# Construct a ConversationalRetrievalChain with a streaming llm for combine docs\n", "# and a separate, non-streaming llm for question generation\n", "llm = OpenAI(temperature=0)\n", - "streaming_llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "streaming_llm = OpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "\n", "question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)\n", "doc_chain = load_qa_chain(streaming_llm, chain_type=\"stuff\", prompt=QA_PROMPT)\n", diff --git a/docs/modules/models/chat/examples/streaming.ipynb b/docs/modules/models/chat/examples/streaming.ipynb index 22b27e0cfe098..e7d0894e21080 100644 --- a/docs/modules/models/chat/examples/streaming.ipynb +++ b/docs/modules/models/chat/examples/streaming.ipynb @@ -80,9 +80,8 @@ } ], "source": [ - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", - "chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])" ] }, diff --git a/docs/modules/models/chat/getting_started.ipynb b/docs/modules/models/chat/getting_started.ipynb index 113d652e61f56..cee995ec72cd3 100644 --- a/docs/modules/models/chat/getting_started.ipynb +++ b/docs/modules/models/chat/getting_started.ipynb @@ -373,9 +373,8 @@ } ], "source": [ - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", - "chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])\n" ] }, diff --git a/docs/modules/models/llms/examples/streaming_llm.ipynb b/docs/modules/models/llms/examples/streaming_llm.ipynb index c48d1ee5b8de5..e10a79d791715 100644 --- a/docs/modules/models/llms/examples/streaming_llm.ipynb +++ b/docs/modules/models/llms/examples/streaming_llm.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07", "metadata": { "tags": [] @@ -21,14 +21,13 @@ "source": [ "from langchain.llms import OpenAI, Anthropic\n", "from langchain.chat_models import ChatOpenAI\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", "from langchain.schema import HumanMessage" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "77f60a4b-f786-41f2-972e-e5bb8a48dcd5", "metadata": { "tags": [] @@ -79,7 +78,7 @@ } ], "source": [ - "llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "llm = OpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "resp = llm(\"Write me a song about sparkling water.\")" ] }, @@ -95,7 +94,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "a35373f1-9ee6-4753-a343-5aee749b8527", "metadata": { "tags": [] @@ -136,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "22665f16-e05b-473c-a4bd-ad75744ea024", "metadata": { "tags": [] @@ -191,7 +190,7 @@ } ], "source": [ - "chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])" ] }, @@ -205,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "eadae4ba-9f21-4ec8-845d-dd43b0edc2dc", "metadata": { "tags": [] @@ -245,7 +244,7 @@ } ], "source": [ - "llm = Anthropic(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "llm = Anthropic(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "llm(\"Write me a song about sparkling water.\")" ] } diff --git a/docs/modules/models/llms/integrations/gpt4all.ipynb b/docs/modules/models/llms/integrations/gpt4all.ipynb index 81083afcf94aa..73bbd9b96bb9d 100644 --- a/docs/modules/models/llms/integrations/gpt4all.ipynb +++ b/docs/modules/models/llms/integrations/gpt4all.ipynb @@ -40,7 +40,6 @@ "source": [ "from langchain import PromptTemplate, LLMChain\n", "from langchain.llms import GPT4All\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler" ] }, @@ -124,9 +123,9 @@ "outputs": [], "source": [ "# Callbacks support token-wise streaming\n", - "callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])\n", + "callbacks = [StreamingStdOutCallbackHandler()]\n", "# Verbose is required to pass to the callback manager\n", - "llm = GPT4All(model=local_path, callback_manager=callback_manager, verbose=True)" + "llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)" ] }, { diff --git a/docs/tracing.md b/docs/tracing.md index 3214140e444e1..591282085f6e0 100644 --- a/docs/tracing.md +++ b/docs/tracing.md @@ -6,16 +6,15 @@ First, you should install tracing and set up your environment properly. You can use either a locally hosted version of this (uses Docker) or a cloud hosted version (in closed alpha). If you're interested in using the hosted platform, please fill out the form [here](https://forms.gle/tRCEMSeopZf6TE3b6). - - [Locally Hosted Setup](./tracing/local_installation.md) - [Cloud Hosted Setup](./tracing/hosted_installation.md) ## Tracing Walkthrough -When you first access the UI, you should see a page with your tracing sessions. -An initial one "default" should already be created for you. -A session is just a way to group traces together. -If you click on a session, it will take you to a page with no recorded traces that says "No Runs." +When you first access the UI, you should see a page with your tracing sessions. +An initial one "default" should already be created for you. +A session is just a way to group traces together. +If you click on a session, it will take you to a page with no recorded traces that says "No Runs." You can create a new session with the new session form. ![](tracing/homepage.png) @@ -35,7 +34,7 @@ We can keep on clicking further and further down to explore deeper and deeper. ![](tracing/explore.png) -We can also click on the "Explore" button of the top level run to dive even deeper. +We can also click on the "Explore" button of the top level run to dive even deeper. Here, we can see the inputs and outputs in full, as well as all the nested traces. ![](tracing/explore_trace.png) @@ -46,11 +45,12 @@ For example, here is the lowest level trace with the exact inputs/outputs to the ![](tracing/explore_llm.png) ## Changing Sessions + 1. To initially record traces to a session other than `"default"`, you can set the `LANGCHAIN_SESSION` environment variable to the name of the session you want to record to: ```python import os -os.environ["LANGCHAIN_HANDLER"] = "langchain" +os.environ["LANGCHAIN_TRACING"] = "true" os.environ["LANGCHAIN_SESSION"] = "my_session" # Make sure this session actually exists. You can create a new session in the UI. ``` From 1b48ea8d73af8c0dc9bf55b7019aedb04e6d7713 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Fri, 28 Apr 2023 13:45:14 -0700 Subject: [PATCH 18/36] cr --- langchain/tools/sql_database/tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain/tools/sql_database/tool.py b/langchain/tools/sql_database/tool.py index d9d6cf63e2112..6e85087b574be 100644 --- a/langchain/tools/sql_database/tool.py +++ b/langchain/tools/sql_database/tool.py @@ -6,7 +6,7 @@ from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate from langchain.sql_database import SQLDatabase -from langchain.schema import BaseLanguageModel +from langchain.base_language import BaseLanguageModel from langchain.tools.base import BaseTool from langchain.tools.sql_database.prompt import QUERY_CHECKER From 18138c6fc1488babde7d1b0f3f110c1113f8ca6a Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Fri, 28 Apr 2023 14:27:01 -0700 Subject: [PATCH 19/36] cr --- langchain/llms/pipelineai.py | 8 +++++++- langchain/llms/predictionguard.py | 8 +++++++- langchain/tools/base.py | 12 +++++++----- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/langchain/llms/pipelineai.py b/langchain/llms/pipelineai.py index 2a87962294e4e..3a29d64ec6a9b 100644 --- a/langchain/llms/pipelineai.py +++ b/langchain/llms/pipelineai.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -80,7 +81,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "pipeline_ai" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call to Pipeline Cloud endpoint.""" try: from pipeline import PipelineCloud diff --git a/langchain/llms/predictionguard.py b/langchain/llms/predictionguard.py index c5ba6165bbc43..4309cae556320 100644 --- a/langchain/llms/predictionguard.py +++ b/langchain/llms/predictionguard.py @@ -4,6 +4,7 @@ from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -73,7 +74,12 @@ def _llm_type(self) -> str: """Return type of llm.""" return "predictionguard" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to Prediction Guard's model proxy. Args: prompt: The prompt to pass into the model. diff --git a/langchain/tools/base.py b/langchain/tools/base.py index dda55e7a911d4..261014c59cd8f 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -1,7 +1,6 @@ """Base implementation for tools or skills.""" from __future__ import annotations -import inspect import warnings from abc import ABC, abstractmethod from inspect import signature @@ -14,7 +13,6 @@ create_model, root_validator, validate_arguments, - validator, ) from pydantic.main import ModelMetaclass @@ -173,7 +171,11 @@ def raise_deprecation(cls, values: Dict) -> Dict: return values @abstractmethod - def _run(self, *args: Any, **kwargs: Any) -> Any: + def _run( + self, + *args: Any, + **kwargs: Any, + ) -> Any: """Use the tool.""" @abstractmethod @@ -207,7 +209,7 @@ def run( callbacks, self.callbacks, verbose=verbose_ ) # TODO: maybe also pass through run_manager is _run supports kwargs - new_arg_supported = inspect.signature(self._run).parameters.get("run_manager") + new_arg_supported = signature(self._run).parameters.get("run_manager") run_manager = callback_manager.on_tool_start( {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), @@ -245,7 +247,7 @@ async def arun( callback_manager = AsyncCallbackManager.configure( callbacks, self.callbacks, verbose=verbose_ ) - new_arg_supported = inspect.signature(self._arun).parameters.get("run_manager") + new_arg_supported = signature(self._arun).parameters.get("run_manager") run_manager = await callback_manager.on_tool_start( {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), From 50f68959001ea9fbbf89a7981dcaafbde394b729 Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Fri, 28 Apr 2023 14:41:47 -0700 Subject: [PATCH 20/36] Chains callbacks refactor (#3683) Will keep adding chains to this branch, just pushing now for visibility --- .../agents/agent_toolkits/powerbi/toolkit.py | 2 +- langchain/callbacks/manager.py | 8 + langchain/chains/api/base.py | 38 +++-- langchain/chains/api/openapi/chain.py | 30 +++- .../chains/api/openapi/requests_chain.py | 9 +- .../chains/api/openapi/response_chain.py | 5 +- langchain/chains/base.py | 8 +- langchain/chains/combine_documents/base.py | 39 ++++- .../chains/combine_documents/map_reduce.py | 30 +++- .../chains/combine_documents/map_rerank.py | 13 +- langchain/chains/combine_documents/refine.py | 15 +- langchain/chains/combine_documents/stuff.py | 11 +- langchain/chains/constitutional_ai/base.py | 18 +- .../chains/conversational_retrieval/base.py | 32 +++- langchain/chains/graph_qa/base.py | 31 ++-- langchain/chains/hyde/base.py | 20 ++- langchain/chains/llm_bash/base.py | 110 ++++++------- langchain/chains/llm_bash/prompt.py | 40 ++++- langchain/chains/llm_checker/base.py | 155 +++++++++++++----- langchain/chains/llm_math/base.py | 81 +++++---- langchain/chains/llm_requests.py | 14 +- .../chains/llm_summarization_checker/base.py | 142 +++++++++++----- langchain/chains/mapreduce.py | 34 +++- langchain/chains/moderation.py | 7 +- langchain/chains/natbot/base.py | 43 +++-- langchain/chains/pal/base.py | 46 ++++-- langchain/chains/qa_generation/base.py | 14 +- langchain/chains/qa_with_sources/base.py | 26 ++- langchain/chains/query_constructor/base.py | 3 +- langchain/chains/retrieval_qa/base.py | 22 ++- langchain/chains/sequential.py | 60 +++++-- langchain/chains/sql_database/base.py | 98 +++++++---- langchain/chains/transform.py | 9 +- tests/integration_tests/chains/test_pal.py | 4 +- .../chains/test_sql_database.py | 10 +- tests/unit_tests/chains/test_llm_bash.py | 4 +- 36 files changed, 877 insertions(+), 354 deletions(-) diff --git a/langchain/agents/agent_toolkits/powerbi/toolkit.py b/langchain/agents/agent_toolkits/powerbi/toolkit.py index 00056563b1abc..812e3d01963f1 100644 --- a/langchain/agents/agent_toolkits/powerbi/toolkit.py +++ b/langchain/agents/agent_toolkits/powerbi/toolkit.py @@ -4,10 +4,10 @@ from pydantic import Field from langchain.agents.agent_toolkits.base import BaseToolkit +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate -from langchain.schema import BaseLanguageModel from langchain.tools import BaseTool from langchain.tools.powerbi.prompt import QUESTION_TO_QUERY from langchain.tools.powerbi.tool import ( diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 0eae27aba5aac..caf87fbe6e454 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -111,6 +111,9 @@ async def _ahandle_event( ) +BRM = TypeVar("BRM", bound="BaseRunManager") + + class BaseRunManager(RunManagerMixin): """Base class for run manager (a bound callback manager).""" @@ -127,6 +130,11 @@ def __init__( self.inheritable_handlers = inheritable_handlers self.parent_run_id = parent_run_id + @classmethod + def get_noop_manager(cls: Type[BRM]) -> BRM: + """Return a manager that doesn't perform any operations.""" + return cls("", [], []) + class RunManager(BaseRunManager): """Sync Run Manager.""" diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index da0bd362207c0..e5af03a0097fb 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -6,6 +6,10 @@ from pydantic import Field, root_validator from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -61,16 +65,21 @@ def validate_api_answer_prompt(cls, values: Dict) -> Dict: ) return values - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.question_key] api_url = self.api_request_chain.predict( - question=question, api_docs=self.api_docs - ) - self.callback_manager.on_text( - api_url, color="green", end="\n", verbose=self.verbose + question=question, + api_docs=self.api_docs, + callbacks=_run_manager.get_child(), ) + _run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose) api_response = self.requests_wrapper.get(api_url) - self.callback_manager.on_text( + _run_manager.on_text( api_response, color="yellow", end="\n", verbose=self.verbose ) answer = self.api_answer_chain.predict( @@ -78,19 +87,27 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: api_docs=self.api_docs, api_url=api_url, api_response=api_response, + callbacks=_run_manager.get_child(), ) return {self.output_key: answer} - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs[self.question_key] api_url = await self.api_request_chain.apredict( - question=question, api_docs=self.api_docs + question=question, + api_docs=self.api_docs, + callbacks=_run_manager.get_child(), ) - self.callback_manager.on_text( + await _run_manager.on_text( api_url, color="green", end="\n", verbose=self.verbose ) api_response = await self.requests_wrapper.aget(api_url) - self.callback_manager.on_text( + await _run_manager.on_text( api_response, color="yellow", end="\n", verbose=self.verbose ) answer = await self.api_answer_chain.apredict( @@ -98,6 +115,7 @@ async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: api_docs=self.api_docs, api_url=api_url, api_response=api_response, + callbacks=_run_manager.get_child(), ) return {self.output_key: answer} diff --git a/langchain/chains/api/openapi/chain.py b/langchain/chains/api/openapi/chain.py index 0f06276ae297c..8f19227133bc4 100644 --- a/langchain/chains/api/openapi/chain.py +++ b/langchain/chains/api/openapi/chain.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field from requests import Response +from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks from langchain.chains.api.openapi.requests_chain import APIRequesterChain from langchain.chains.api.openapi.response_chain import APIResponderChain from langchain.chains.base import Chain @@ -106,16 +107,21 @@ def _get_output(self, output: str, intermediate_steps: dict) -> dict: else: return {self.output_key: output} - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() intermediate_steps = {} instructions = inputs[self.instructions_key] instructions = instructions[: self.max_text_length] _api_arguments = self.api_request_chain.predict_and_parse( - instructions=instructions + instructions=instructions, callbacks=_run_manager.get_child() ) api_arguments = cast(str, _api_arguments) intermediate_steps["request_args"] = api_arguments - self.callback_manager.on_text( + _run_manager.on_text( api_arguments, color="green", end="\n", verbose=self.verbose ) if api_arguments.startswith("ERROR"): @@ -141,18 +147,17 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: response_text = f"Error with message {str(e)}" response_text = response_text[: self.max_text_length] intermediate_steps["response_text"] = response_text - self.callback_manager.on_text( + _run_manager.on_text( response_text, color="blue", end="\n", verbose=self.verbose ) if self.api_response_chain is not None: _answer = self.api_response_chain.predict_and_parse( response=response_text, instructions=instructions, + callbacks=_run_manager.get_child(), ) answer = cast(str, _answer) - self.callback_manager.on_text( - answer, color="yellow", end="\n", verbose=self.verbose - ) + _run_manager.on_text(answer, color="yellow", end="\n", verbose=self.verbose) return self._get_output(answer, intermediate_steps) else: return self._get_output(response_text, intermediate_steps) @@ -188,6 +193,7 @@ def from_api_operation( verbose: bool = False, return_intermediate_steps: bool = False, raw_response: bool = False, + callbacks: Callbacks = None, **kwargs: Any # TODO: Handle async ) -> "OpenAPIEndpointChain": @@ -198,12 +204,17 @@ def from_api_operation( path_params=operation.path_params, ) requests_chain = APIRequesterChain.from_llm_and_typescript( - llm, typescript_definition=operation.to_typescript(), verbose=verbose + llm, + typescript_definition=operation.to_typescript(), + verbose=verbose, + callbacks=callbacks, ) if raw_response: response_chain = None else: - response_chain = APIResponderChain.from_llm(llm, verbose=verbose) + response_chain = APIResponderChain.from_llm( + llm, verbose=verbose, callbacks=callbacks + ) _requests = requests or Requests() return cls( api_request_chain=requests_chain, @@ -213,5 +224,6 @@ def from_api_operation( param_mapping=param_mapping, verbose=verbose, return_intermediate_steps=return_intermediate_steps, + callbacks=callbacks, **kwargs, ) diff --git a/langchain/chains/api/openapi/requests_chain.py b/langchain/chains/api/openapi/requests_chain.py index acc1e4c36902c..4bc8bd83df5c5 100644 --- a/langchain/chains/api/openapi/requests_chain.py +++ b/langchain/chains/api/openapi/requests_chain.py @@ -2,6 +2,7 @@ import json import re +from typing import Any from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE from langchain.chains.llm import LLMChain @@ -36,7 +37,11 @@ class APIRequesterChain(LLMChain): @classmethod def from_llm_and_typescript( - cls, llm: BaseLLM, typescript_definition: str, verbose: bool = True + cls, + llm: BaseLLM, + typescript_definition: str, + verbose: bool = True, + **kwargs: Any, ) -> LLMChain: """Get the request parser.""" output_parser = APIRequesterOutputParser() @@ -46,4 +51,4 @@ def from_llm_and_typescript( partial_variables={"schema": typescript_definition}, input_variables=["instructions"], ) - return cls(prompt=prompt, llm=llm, verbose=verbose) + return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs) diff --git a/langchain/chains/api/openapi/response_chain.py b/langchain/chains/api/openapi/response_chain.py index 8cabbb0af682e..a1d7c5a1de00d 100644 --- a/langchain/chains/api/openapi/response_chain.py +++ b/langchain/chains/api/openapi/response_chain.py @@ -2,6 +2,7 @@ import json import re +from typing import Any from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE from langchain.chains.llm import LLMChain @@ -35,7 +36,7 @@ class APIResponderChain(LLMChain): """Get the response parser.""" @classmethod - def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain: + def from_llm(cls, llm: BaseLLM, verbose: bool = True, **kwargs: Any) -> LLMChain: """Get the response parser.""" output_parser = APIResponderOutputParser() prompt = PromptTemplate( @@ -43,4 +44,4 @@ def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain: output_parser=output_parser, input_variables=["response", "instructions"], ) - return cls(prompt=prompt, llm=llm, verbose=verbose) + return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs) diff --git a/langchain/chains/base.py b/langchain/chains/base.py index fc1cf9d69d8e0..c1e8e9b256657 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -92,16 +92,16 @@ def _validate_outputs(self, outputs: Dict[str, str]) -> None: @abstractmethod def _call( self, - inputs: Dict[str, str], + inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: + ) -> Dict[str, Any]: """Run the logic of this chain and return the output.""" async def _acall( self, - inputs: Dict[str, str], + inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> Dict[str, str]: + ) -> Dict[str, Any]: """Run the logic of this chain and return the output.""" raise NotImplementedError("Async call not supported for this chain type.") diff --git a/langchain/chains/combine_documents/base.py b/langchain/chains/combine_documents/base.py index 731a55283d1ad..338ea26a8431e 100644 --- a/langchain/chains/combine_documents/base.py +++ b/langchain/chains/combine_documents/base.py @@ -5,6 +5,10 @@ from pydantic import Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.docstore.document import Document from langchain.prompts.base import BasePromptTemplate @@ -68,19 +72,33 @@ async def acombine_docs( ) -> Tuple[str, dict]: """Combine documents into a single string asynchronously.""" - def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, List[Document]], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() docs = inputs[self.input_key] # Other keys are assumed to be needed for LLM prediction other_keys = {k: v for k, v in inputs.items() if k != self.input_key} - output, extra_return_dict = self.combine_docs(docs, **other_keys) + output, extra_return_dict = self.combine_docs( + docs, callbacks=_run_manager.get_child(), **other_keys + ) extra_return_dict[self.output_key] = output return extra_return_dict - async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, List[Document]], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() docs = inputs[self.input_key] # Other keys are assumed to be needed for LLM prediction other_keys = {k: v for k, v in inputs.items() if k != self.input_key} - output, extra_return_dict = await self.acombine_docs(docs, **other_keys) + output, extra_return_dict = await self.acombine_docs( + docs, callbacks=_run_manager.get_child(), **other_keys + ) extra_return_dict[self.output_key] = output return extra_return_dict @@ -108,10 +126,17 @@ def output_keys(self) -> List[str]: """ return self.combine_docs_chain.output_keys - def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() document = inputs[self.input_key] docs = self.text_splitter.create_documents([document]) # Other keys are assumed to be needed for LLM prediction - other_keys = {k: v for k, v in inputs.items() if k != self.input_key} + other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key} other_keys[self.combine_docs_chain.input_key] = docs - return self.combine_docs_chain(other_keys, return_only_outputs=True) + return self.combine_docs_chain( + other_keys, return_only_outputs=True, callbacks=_run_manager.get_child() + ) diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index b439870df7a8b..8b2925de975f9 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -6,6 +6,7 @@ from pydantic import Extra, root_validator +from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document @@ -129,7 +130,11 @@ def _collapse_chain(self) -> BaseCombineDocumentsChain: return self.combine_document_chain def combine_docs( - self, docs: List[Document], token_max: int = 3000, **kwargs: Any + self, + docs: List[Document], + token_max: int = 3000, + callbacks: Callbacks = None, + **kwargs: Any, ) -> Tuple[str, dict]: """Combine documents in a map reduce manner. @@ -138,12 +143,15 @@ def combine_docs( """ results = self.llm_chain.apply( # FYI - this is parallelized and so it is fast. - [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] + [{self.document_variable_name: d.page_content, **kwargs} for d in docs], + callbacks=callbacks, + ) + return self._process_results( + results, docs, token_max, callbacks=callbacks, **kwargs ) - return self._process_results(results, docs, token_max, **kwargs) async def acombine_docs( - self, docs: List[Document], **kwargs: Any + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: """Combine documents in a map reduce manner. @@ -152,15 +160,17 @@ async def acombine_docs( """ results = await self.llm_chain.aapply( # FYI - this is parallelized and so it is fast. - [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] + [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], + callbacks=callbacks, ) - return self._process_results(results, docs, **kwargs) + return self._process_results(results, docs, callbacks=callbacks, **kwargs) def _process_results( self, results: List[Dict], docs: List[Document], token_max: int = 3000, + callbacks: Callbacks = None, **kwargs: Any, ) -> Tuple[str, dict]: question_result_key = self.llm_chain.output_key @@ -173,7 +183,9 @@ def _process_results( num_tokens = length_func(result_docs, **kwargs) def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str: - return self._collapse_chain.run(input_documents=docs, **kwargs) + return self._collapse_chain.run( + input_documents=docs, callbacks=callbacks, **kwargs + ) while num_tokens is not None and num_tokens > token_max: new_result_doc_list = _split_list_of_docs( @@ -191,7 +203,9 @@ def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str: extra_return_dict = {"intermediate_steps": _results} else: extra_return_dict = {} - output = self.combine_document_chain.run(input_documents=result_docs, **kwargs) + output = self.combine_document_chain.run( + input_documents=result_docs, callbacks=callbacks, **kwargs + ) return output, extra_return_dict @property diff --git a/langchain/chains/combine_documents/map_rerank.py b/langchain/chains/combine_documents/map_rerank.py index 35f198a967ac1..ad8409c343481 100644 --- a/langchain/chains/combine_documents/map_rerank.py +++ b/langchain/chains/combine_documents/map_rerank.py @@ -6,6 +6,7 @@ from pydantic import Extra, root_validator +from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document @@ -89,19 +90,22 @@ def get_default_document_variable_name(cls, values: Dict) -> Dict: ) return values - def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: + def combine_docs( + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + ) -> Tuple[str, dict]: """Combine documents in a map rerank manner. Combine by mapping first chain over all documents, then reranking the results. """ results = self.llm_chain.apply_and_parse( # FYI - this is parallelized and so it is fast. - [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] + [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], + callbacks=callbacks, ) return self._process_results(docs, results) async def acombine_docs( - self, docs: List[Document], **kwargs: Any + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: """Combine documents in a map rerank manner. @@ -109,7 +113,8 @@ async def acombine_docs( """ results = await self.llm_chain.aapply_and_parse( # FYI - this is parallelized and so it is fast. - [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] + [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], + callbacks=callbacks, ) return self._process_results(docs, results) diff --git a/langchain/chains/combine_documents/refine.py b/langchain/chains/combine_documents/refine.py index 7d1ae7ff03855..4b480090589f4 100644 --- a/langchain/chains/combine_documents/refine.py +++ b/langchain/chains/combine_documents/refine.py @@ -6,6 +6,7 @@ from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, format_document, @@ -85,29 +86,31 @@ def get_default_document_variable_name(cls, values: Dict) -> Dict: ) return values - def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: + def combine_docs( + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + ) -> Tuple[str, dict]: """Combine by mapping first chain over all, then stuffing into final chain.""" inputs = self._construct_initial_inputs(docs, **kwargs) - res = self.initial_llm_chain.predict(**inputs) + res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs) refine_steps = [res] for doc in docs[1:]: base_inputs = self._construct_refine_inputs(doc, res) inputs = {**base_inputs, **kwargs} - res = self.refine_llm_chain.predict(**inputs) + res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs) refine_steps.append(res) return self._construct_result(refine_steps, res) async def acombine_docs( - self, docs: List[Document], **kwargs: Any + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: """Combine by mapping first chain over all, then stuffing into final chain.""" inputs = self._construct_initial_inputs(docs, **kwargs) - res = await self.initial_llm_chain.apredict(**inputs) + res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs) refine_steps = [res] for doc in docs[1:]: base_inputs = self._construct_refine_inputs(doc, res) inputs = {**base_inputs, **kwargs} - res = await self.refine_llm_chain.apredict(**inputs) + res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs) refine_steps.append(res) return self._construct_result(refine_steps, res) diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index 9d0a141c1bbcd..d1f051271127c 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -4,6 +4,7 @@ from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, format_document, @@ -75,19 +76,21 @@ def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: prompt = self.llm_chain.prompt.format(**inputs) return self.llm_chain.llm.get_num_tokens(prompt) - def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: + def combine_docs( + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + ) -> Tuple[str, dict]: """Stuff all documents into one prompt and pass to LLM.""" inputs = self._get_inputs(docs, **kwargs) # Call predict on the LLM. - return self.llm_chain.predict(**inputs), {} + return self.llm_chain.predict(callbacks=callbacks, **inputs), {} async def acombine_docs( - self, docs: List[Document], **kwargs: Any + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: """Stuff all documents into one prompt and pass to LLM.""" inputs = self._get_inputs(docs, **kwargs) # Call predict on the LLM. - return await self.llm_chain.apredict(**inputs), {} + return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {} @property def _chain_type(self) -> str: diff --git a/langchain/chains/constitutional_ai/base.py b/langchain/chains/constitutional_ai/base.py index ecd733598df3a..007b20925a6b1 100644 --- a/langchain/chains/constitutional_ai/base.py +++ b/langchain/chains/constitutional_ai/base.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.constitutional_ai.principles import PRINCIPLES @@ -86,11 +87,16 @@ def output_keys(self) -> List[str]: """Defines the output keys.""" return ["output"] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() response = self.chain.run(**inputs) input_prompt = self.chain.prompt.format(**inputs) - self.callback_manager.on_text( + _run_manager.on_text( text="Initial response: " + response + "\n\n", verbose=self.verbose, color="yellow", @@ -103,6 +109,7 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: input_prompt=input_prompt, output_from_model=response, critique_request=constitutional_principle.critique_request, + callbacks=_run_manager.get_child(), ) critique = self._parse_critique( output_string=raw_critique, @@ -116,22 +123,23 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: critique_request=constitutional_principle.critique_request, critique=critique, revision_request=constitutional_principle.revision_request, + callbacks=_run_manager.get_child(), ).strip() response = revision - self.callback_manager.on_text( + _run_manager.on_text( text=f"Applying {constitutional_principle.name}..." + "\n\n", verbose=self.verbose, color="green", ) - self.callback_manager.on_text( + _run_manager.on_text( text="Critique: " + critique + "\n\n", verbose=self.verbose, color="blue", ) - self.callback_manager.on_text( + _run_manager.on_text( text="Updated response: " + revision + "\n\n", verbose=self.verbose, color="yellow", diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index e4a9f268a0e22..ce7e7115babdc 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -9,6 +9,10 @@ from pydantic import Extra, Field, root_validator from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain @@ -82,14 +86,20 @@ def output_keys(self) -> List[str]: def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: """Get docs.""" - def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history chat_history_str = get_chat_history(inputs["chat_history"]) if chat_history_str: + callbacks = _run_manager.get_child() new_question = self.question_generator.run( - question=question, chat_history=chat_history_str + question=question, chat_history=chat_history_str, callbacks=callbacks ) else: new_question = question @@ -97,7 +107,9 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str - answer = self.combine_docs_chain.run(input_documents=docs, **new_inputs) + answer = self.combine_docs_chain.run( + input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs + ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: @@ -107,13 +119,19 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: """Get docs.""" - async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history chat_history_str = get_chat_history(inputs["chat_history"]) if chat_history_str: + callbacks = _run_manager.get_child() new_question = await self.question_generator.arun( - question=question, chat_history=chat_history_str + question=question, chat_history=chat_history_str, callbacks=callbacks ) else: new_question = question @@ -121,7 +139,9 @@ async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str - answer = await self.combine_docs_chain.arun(input_documents=docs, **new_inputs) + answer = await self.combine_docs_chain.arun( + input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs + ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: diff --git a/langchain/chains/graph_qa/base.py b/langchain/chains/graph_qa/base.py index addf72f821bf4..112338ae05199 100644 --- a/langchain/chains/graph_qa/base.py +++ b/langchain/chains/graph_qa/base.py @@ -1,10 +1,11 @@ """Question answering over a graph.""" from __future__ import annotations -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pydantic import Field +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT from langchain.chains.llm import LLMChain @@ -51,18 +52,25 @@ def from_llm( qa_chain = LLMChain(llm=llm, prompt=qa_prompt) entity_chain = LLMChain(llm=llm, prompt=entity_prompt) - return cls(qa_chain=qa_chain, entity_extraction_chain=entity_chain, **kwargs) + return cls( + qa_chain=qa_chain, + entity_extraction_chain=entity_chain, + **kwargs, + ) - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: """Extract entities, look up info and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] entity_string = self.entity_extraction_chain.run(question) - self.callback_manager.on_text( - "Entities Extracted:", end="\n", verbose=self.verbose - ) - self.callback_manager.on_text( + _run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose) + _run_manager.on_text( entity_string, color="green", end="\n", verbose=self.verbose ) entities = get_entities(entity_string) @@ -70,9 +78,10 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: for entity in entities: triplets = self.graph.get_entity_knowledge(entity) context += "\n".join(triplets) - self.callback_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - self.callback_manager.on_text( - context, color="green", end="\n", verbose=self.verbose + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text(context, color="green", end="\n", verbose=self.verbose) + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=_run_manager.get_child(), ) - result = self.qa_chain({"question": question, "context": context}) return {self.output_key: result[self.qa_chain.output_key]} diff --git a/langchain/chains/hyde/base.py b/langchain/chains/hyde/base.py index f2f9747032c31..3cd6170ee7313 100644 --- a/langchain/chains/hyde/base.py +++ b/langchain/chains/hyde/base.py @@ -4,11 +4,12 @@ """ from __future__ import annotations -from typing import Dict, List +from typing import Any, Dict, List, Optional import numpy as np from pydantic import Extra +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.chains.llm import LLMChain @@ -57,18 +58,27 @@ def embed_query(self, text: str) -> List[float]: embeddings = self.embed_documents(documents) return self.combine_embeddings(embeddings) - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: """Call the internal llm chain.""" - return self.llm_chain._call(inputs) + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + return self.llm_chain(inputs, callbacks=_run_manager.get_child()) @classmethod def from_llm( - cls, llm: BaseLLM, base_embeddings: Embeddings, prompt_key: str + cls, + llm: BaseLLM, + base_embeddings: Embeddings, + prompt_key: str, + **kwargs: Any, ) -> HypotheticalDocumentEmbedder: """Load and use LLMChain for a specific prompt key.""" prompt = PROMPT_MAP[prompt_key] llm_chain = LLMChain(llm=llm, prompt=prompt) - return cls(base_embeddings=base_embeddings, llm_chain=llm_chain) + return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs) @property def _chain_type(self) -> str: diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 197a4c573c3e8..61894bffe1468 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -1,48 +1,24 @@ """Chain that interprets a prompt and executes bash code to perform bash operations.""" +from __future__ import annotations + import logging -import re -from typing import Any, Dict, List +import warnings +from typing import Any, Dict, List, Optional -from pydantic import Extra, Field +from pydantic import Extra, Field, root_validator from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.prompt import PROMPT from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseOutputParser, OutputParserException +from langchain.schema import OutputParserException from langchain.utilities.bash import BashProcess logger = logging.getLogger(__name__) -class BashOutputParser(BaseOutputParser): - """Parser for bash output.""" - - def parse(self, text: str) -> List[str]: - if "```bash" in text: - return self.get_code_blocks(text) - else: - raise OutputParserException( - f"Failed to parse bash output. Got: {text}", - ) - - @staticmethod - def get_code_blocks(t: str) -> List[str]: - """Get multiple code blocks from the LLM result.""" - code_blocks: List[str] = [] - # Bash markdown code blocks - pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL) - for match in pattern.finditer(t): - matched = match.group(1).strip() - if matched: - code_blocks.extend( - [line for line in matched.split("\n") if line.strip()] - ) - - return code_blocks - - class LLMBashChain(Chain): """Chain that interprets a prompt and executes bash code to perform bash operations. @@ -50,15 +26,16 @@ class LLMBashChain(Chain): .. code-block:: python from langchain import LLMBashChain, OpenAI - llm_bash = LLMBashChain(llm=OpenAI()) + llm_bash = LLMBashChain.from_llm(OpenAI()) """ - llm: BaseLanguageModel - """LLM wrapper to use.""" + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated] LLM wrapper to use.""" input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: prompt: BasePromptTemplate = PROMPT - output_parser: BaseOutputParser = Field(default_factory=BashOutputParser) + """[Deprecated]""" bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private: class Config: @@ -67,6 +44,26 @@ class Config: extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMBashChain with an llm is deprecated. " + "Please instantiate with llm_chain or using the from_llm class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + prompt = values.get("prompt", PROMPT) + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) + return values + + @root_validator + def validate_prompt(cls, values: Dict) -> Dict: + if values["llm_chain"].prompt.output_parser is None: + raise ValueError( + "The prompt used by llm_chain is expected to have an output_parser." + ) + return values + @property def input_keys(self) -> List[str]: """Expect input key. @@ -83,30 +80,33 @@ def output_keys(self) -> List[str]: """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: - llm_executor = LLMChain(prompt=self.prompt, llm=self.llm) - - self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose) - - t = llm_executor.predict(question=inputs[self.input_key]) - self.callback_manager.on_text(t, color="green", verbose=self.verbose) + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + _run_manager.on_text(inputs[self.input_key], verbose=self.verbose) + + t = self.llm_chain.predict( + question=inputs[self.input_key], callbacks=_run_manager.get_child() + ) + _run_manager.on_text(t, color="green", verbose=self.verbose) t = t.strip() try: - command_list = self.output_parser.parse(t) + command_list = self.llm_chain.prompt.output_parser.parse(t) # type: ignore[union-attr] except OutputParserException as e: - self.callback_manager.on_chain_error(e, verbose=self.verbose) + _run_manager.on_chain_error(e, verbose=self.verbose) raise e if self.verbose: - self.callback_manager.on_text("\nCode: ", verbose=self.verbose) - self.callback_manager.on_text( + _run_manager.on_text("\nCode: ", verbose=self.verbose) + _run_manager.on_text( str(command_list), color="yellow", verbose=self.verbose ) - output = self.bash_process.run(command_list) - - self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) - self.callback_manager.on_text(output, color="yellow", verbose=self.verbose) + _run_manager.on_text("\nAnswer: ", verbose=self.verbose) + _run_manager.on_text(output, color="yellow", verbose=self.verbose) return {self.output_key: output} @property @@ -114,11 +114,11 @@ def _chain_type(self) -> str: return "llm_bash_chain" @classmethod - def from_bash_process( + def from_llm( cls, - bash_process: BashProcess, llm: BaseLanguageModel, + prompt: BasePromptTemplate = PROMPT, **kwargs: Any, - ) -> "LLMBashChain": - """Create a LLMBashChain from a BashProcess.""" - return cls(llm=llm, bash_process=bash_process, **kwargs) + ) -> LLMBashChain: + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, **kwargs) diff --git a/langchain/chains/llm_bash/prompt.py b/langchain/chains/llm_bash/prompt.py index 27dcbe57aae63..363b55058913e 100644 --- a/langchain/chains/llm_bash/prompt.py +++ b/langchain/chains/llm_bash/prompt.py @@ -1,5 +1,11 @@ # flake8: noqa +from __future__ import annotations + +import re +from typing import List + from langchain.prompts.prompt import PromptTemplate +from langchain.schema import BaseOutputParser, OutputParserException _PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format: @@ -19,4 +25,36 @@ Question: {question}""" -PROMPT = PromptTemplate(input_variables=["question"], template=_PROMPT_TEMPLATE) + +class BashOutputParser(BaseOutputParser): + """Parser for bash output.""" + + def parse(self, text: str) -> List[str]: + if "```bash" in text: + return self.get_code_blocks(text) + else: + raise OutputParserException( + f"Failed to parse bash output. Got: {text}", + ) + + @staticmethod + def get_code_blocks(t: str) -> List[str]: + """Get multiple code blocks from the LLM result.""" + code_blocks: List[str] = [] + # Bash markdown code blocks + pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL) + for match in pattern.finditer(t): + matched = match.group(1).strip() + if matched: + code_blocks.extend( + [line for line in matched.split("\n") if line.strip()] + ) + + return code_blocks + + +PROMPT = PromptTemplate( + input_variables=["question"], + template=_PROMPT_TEMPLATE, + output_parser=BashOutputParser(), +) diff --git a/langchain/chains/llm_checker/base.py b/langchain/chains/llm_checker/base.py index 0702818ae935e..ae2101e02b582 100644 --- a/langchain/chains/llm_checker/base.py +++ b/langchain/chains/llm_checker/base.py @@ -1,10 +1,12 @@ """Chain for question-answering with self-verification.""" +from __future__ import annotations +import warnings +from typing import Any, Dict, List, Optional -from typing import Dict, List - -from pydantic import Extra +from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_checker.prompt import ( @@ -18,6 +20,48 @@ from langchain.prompts import PromptTemplate +def _load_question_to_checked_assertions_chain( + llm: BaseLLM, + create_draft_answer_prompt: PromptTemplate, + list_assertions_prompt: PromptTemplate, + check_assertions_prompt: PromptTemplate, + revised_answer_prompt: PromptTemplate, +) -> SequentialChain: + create_draft_answer_chain = LLMChain( + llm=llm, + prompt=create_draft_answer_prompt, + output_key="statement", + ) + list_assertions_chain = LLMChain( + llm=llm, + prompt=list_assertions_prompt, + output_key="assertions", + ) + check_assertions_chain = LLMChain( + llm=llm, + prompt=check_assertions_prompt, + output_key="checked_assertions", + ) + revised_answer_chain = LLMChain( + llm=llm, + prompt=revised_answer_prompt, + output_key="revised_statement", + ) + chains = [ + create_draft_answer_chain, + list_assertions_chain, + check_assertions_chain, + revised_answer_chain, + ] + question_to_checked_assertions_chain = SequentialChain( + chains=chains, + input_variables=["question"], + output_variables=["revised_statement"], + verbose=True, + ) + return question_to_checked_assertions_chain + + class LLMCheckerChain(Chain): """Chain for question-answering with self-verification. @@ -26,16 +70,21 @@ class LLMCheckerChain(Chain): from langchain import OpenAI, LLMCheckerChain llm = OpenAI(temperature=0.7) - checker_chain = LLMCheckerChain(llm=llm) + checker_chain = LLMCheckerChain.from_llm(llm) """ - llm: BaseLLM - """LLM wrapper to use.""" + question_to_checked_assertions_chain: SequentialChain + + llm: Optional[BaseLLM] = None + """[Deprecated] LLM wrapper to use.""" create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT + """[Deprecated]""" list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT + """[Deprecated]""" check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT + """[Deprecated]""" revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT - """Prompt to use when questioning the documents.""" + """[Deprecated] Prompt to use when questioning the documents.""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @@ -45,6 +94,34 @@ class Config: extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMCheckerChain with an llm is deprecated. " + "Please instantiate with question_to_checked_assertions_chain " + "or using the from_llm class method." + ) + if ( + "question_to_checked_assertions_chain" not in values + and values["llm"] is not None + ): + question_to_checked_assertions_chain = ( + _load_question_to_checked_assertions_chain( + values["llm"], + values.get( + "create_draft_answer_prompt", CREATE_DRAFT_ANSWER_PROMPT + ), + values.get("list_assertions_prompt", LIST_ASSERTIONS_PROMPT), + values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT), + values.get("revised_answer_prompt", REVISED_ANSWER_PROMPT), + ) + ) + values[ + "question_to_checked_assertions_chain" + ] = question_to_checked_assertions_chain + return values + @property def input_keys(self) -> List[str]: """Return the singular input key. @@ -61,43 +138,43 @@ def output_keys(self) -> List[str]: """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] - create_draft_answer_chain = LLMChain( - llm=self.llm, prompt=self.create_draft_answer_prompt, output_key="statement" - ) - list_assertions_chain = LLMChain( - llm=self.llm, prompt=self.list_assertions_prompt, output_key="assertions" - ) - check_assertions_chain = LLMChain( - llm=self.llm, - prompt=self.check_assertions_prompt, - output_key="checked_assertions", - ) - - revised_answer_chain = LLMChain( - llm=self.llm, - prompt=self.revised_answer_prompt, - output_key="revised_statement", - ) - - chains = [ - create_draft_answer_chain, - list_assertions_chain, - check_assertions_chain, - revised_answer_chain, - ] - - question_to_checked_assertions_chain = SequentialChain( - chains=chains, - input_variables=["question"], - output_variables=["revised_statement"], - verbose=True, + output = self.question_to_checked_assertions_chain( + {"question": question}, callbacks=_run_manager.get_child() ) - output = question_to_checked_assertions_chain({"question": question}) return {self.output_key: output["revised_statement"]} @property def _chain_type(self) -> str: return "llm_checker_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLLM, + create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT, + list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT, + check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT, + revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT, + **kwargs: Any, + ) -> LLMCheckerChain: + question_to_checked_assertions_chain = ( + _load_question_to_checked_assertions_chain( + llm, + create_draft_answer_prompt, + list_assertions_prompt, + check_assertions_prompt, + revised_answer_prompt, + ) + ) + return cls( + question_to_checked_assertions_chain=question_to_checked_assertions_chain, + **kwargs, + ) diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index ff7c000dda3fd..1037658d60f4a 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -1,10 +1,13 @@ """Chain that interprets a prompt and executes python code to do math.""" +from __future__ import annotations + import math import re -from typing import Dict, List, Optional +import warnings +from typing import Any, Dict, List, Optional import numexpr -from pydantic import Extra +from pydantic import Extra, root_validator from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import ( @@ -24,13 +27,14 @@ class LLMMathChain(Chain): .. code-block:: python from langchain import LLMMathChain, OpenAI - llm_math = LLMMathChain(llm=OpenAI()) + llm_math = LLMMathChain.from_llm(OpenAI()) """ - llm: BaseLanguageModel - """LLM wrapper to use.""" + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated] LLM wrapper to use.""" prompt: BasePromptTemplate = PROMPT - """Prompt to use to translate to python if neccessary.""" + """[Deprecated] Prompt to use to translate to python if necessary.""" input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: @@ -40,6 +44,19 @@ class Config: extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMMathChain with an llm is deprecated. " + "Please instantiate with llm_chain argument or using the from_llm " + "class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + prompt = values.get("prompt", PROMPT) + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) + return values + @property def input_keys(self) -> List[str]: """Expect input key. @@ -73,18 +90,16 @@ def _evaluate_expression(self, expression: str) -> str: return re.sub(r"^\[|\]$", "", output) def _process_llm_result( - self, llm_output: str, run_manager: Optional[CallbackManagerForChainRun] = None + self, llm_output: str, run_manager: CallbackManagerForChainRun ) -> Dict[str, str]: - if run_manager: - run_manager.on_text(llm_output, color="green", verbose=self.verbose) + run_manager.on_text(llm_output, color="green", verbose=self.verbose) llm_output = llm_output.strip() text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) if text_match: expression = text_match.group(1) output = self._evaluate_expression(expression) - if run_manager: - run_manager.on_text("\nAnswer: ", verbose=self.verbose) - run_manager.on_text(output, color="yellow", verbose=self.verbose) + run_manager.on_text("\nAnswer: ", verbose=self.verbose) + run_manager.on_text(output, color="yellow", verbose=self.verbose) answer = "Answer: " + output elif llm_output.startswith("Answer:"): answer = llm_output @@ -97,18 +112,16 @@ def _process_llm_result( async def _aprocess_llm_result( self, llm_output: str, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + run_manager: AsyncCallbackManagerForChainRun, ) -> Dict[str, str]: - if run_manager: - await run_manager.on_text(llm_output, color="green", verbose=self.verbose) + await run_manager.on_text(llm_output, color="green", verbose=self.verbose) llm_output = llm_output.strip() text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) if text_match: expression = text_match.group(1) output = self._evaluate_expression(expression) - if run_manager: - await run_manager.on_text("\nAnswer: ", verbose=self.verbose) - await run_manager.on_text(output, color="yellow", verbose=self.verbose) + await run_manager.on_text("\nAnswer: ", verbose=self.verbose) + await run_manager.on_text(output, color="yellow", verbose=self.verbose) answer = "Answer: " + output elif llm_output.startswith("Answer:"): answer = llm_output @@ -123,31 +136,39 @@ def _call( inputs: Dict[str, str], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: - llm_executor = LLMChain(prompt=self.prompt, llm=self.llm) - if run_manager: - run_manager.on_text(inputs[self.input_key]) - llm_output = llm_executor.predict( + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + _run_manager.on_text(inputs[self.input_key]) + llm_output = self.llm_chain.predict( question=inputs[self.input_key], stop=["```output"], - callbacks=run_manager.get_child() if run_manager else None, + callbacks=_run_manager.get_child(), ) - return self._process_llm_result(llm_output, run_manager=run_manager) + return self._process_llm_result(llm_output, _run_manager) async def _acall( self, inputs: Dict[str, str], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, str]: - llm_executor = LLMChain(prompt=self.prompt, llm=self.llm) - if run_manager: - await run_manager.on_text(inputs[self.input_key]) - llm_output = await llm_executor.apredict( + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + await _run_manager.on_text(inputs[self.input_key]) + llm_output = await self.llm_chain.apredict( question=inputs[self.input_key], stop=["```output"], - callbacks=run_manager.get_child() if run_manager else None, + callbacks=_run_manager.get_child(), ) - return await self._aprocess_llm_result(llm_output, run_manager=run_manager) + return await self._aprocess_llm_result(llm_output, _run_manager) @property def _chain_type(self) -> str: return "llm_math_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + prompt: BasePromptTemplate = PROMPT, + **kwargs: Any, + ) -> LLMMathChain: + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, **kwargs) diff --git a/langchain/chains/llm_requests.py b/langchain/chains/llm_requests.py index 4abab04175340..d9c05744598fa 100644 --- a/langchain/chains/llm_requests.py +++ b/langchain/chains/llm_requests.py @@ -1,10 +1,11 @@ """Chain that hits a URL and then uses an LLM to parse results.""" from __future__ import annotations -from typing import Dict, List +from typing import Any, Dict, List, Optional from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains import LLMChain from langchain.chains.base import Chain from langchain.requests import TextRequestsWrapper @@ -61,9 +62,14 @@ def validate_environment(cls, values: Dict) -> Dict: ) return values - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: from bs4 import BeautifulSoup + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() # Other keys are assumed to be needed for LLM prediction other_keys = {k: v for k, v in inputs.items() if k != self.input_key} url = inputs[self.input_key] @@ -71,7 +77,9 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: # extract the text from the html soup = BeautifulSoup(res, "html.parser") other_keys[self.requests_key] = soup.get_text()[: self.text_length] - result = self.llm_chain.predict(**other_keys) + result = self.llm_chain.predict( + callbacks=_run_manager.get_child(), **other_keys + ) return {self.output_key: result} @property diff --git a/langchain/chains/llm_summarization_checker/base.py b/langchain/chains/llm_summarization_checker/base.py index d69eecb8ae537..e44a5cc76540e 100644 --- a/langchain/chains/llm_summarization_checker/base.py +++ b/langchain/chains/llm_summarization_checker/base.py @@ -1,10 +1,14 @@ """Chain for summarization with self-verification.""" +from __future__ import annotations + +import warnings from pathlib import Path -from typing import Dict, List +from typing import Any, Dict, List, Optional -from pydantic import Extra +from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sequential import SequentialChain @@ -27,6 +31,48 @@ ) +def _load_sequential_chain( + llm: BaseLLM, + create_assertions_prompt: PromptTemplate, + check_assertions_prompt: PromptTemplate, + revised_summary_prompt: PromptTemplate, + are_all_true_prompt: PromptTemplate, + verbose: bool = False, +) -> SequentialChain: + chain = SequentialChain( + chains=[ + LLMChain( + llm=llm, + prompt=create_assertions_prompt, + output_key="assertions", + verbose=verbose, + ), + LLMChain( + llm=llm, + prompt=check_assertions_prompt, + output_key="checked_assertions", + verbose=verbose, + ), + LLMChain( + llm=llm, + prompt=revised_summary_prompt, + output_key="revised_summary", + verbose=verbose, + ), + LLMChain( + llm=llm, + output_key="all_true", + prompt=are_all_true_prompt, + verbose=verbose, + ), + ], + input_variables=["summary"], + output_variables=["all_true", "revised_summary"], + verbose=verbose, + ) + return chain + + class LLMSummarizationCheckerChain(Chain): """Chain for question-answering with self-verification. @@ -35,16 +81,21 @@ class LLMSummarizationCheckerChain(Chain): from langchain import OpenAI, LLMSummarizationCheckerChain llm = OpenAI(temperature=0.0) - checker_chain = LLMSummarizationCheckerChain(llm=llm) + checker_chain = LLMSummarizationCheckerChain.from_llm(llm) """ - llm: BaseLLM - """LLM wrapper to use.""" + sequential_chain: SequentialChain + llm: Optional[BaseLLM] = None + """[Deprecated] LLM wrapper to use.""" create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT + """[Deprecated]""" check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT + """[Deprecated]""" revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT + """[Deprecated]""" are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT + """[Deprecated]""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @@ -57,6 +108,25 @@ class Config: extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMSummarizationCheckerChain with an llm is " + "deprecated. Please instantiate with" + " sequential_chain argument or using the from_llm class method." + ) + if "sequential_chain" not in values and values["llm"] is not None: + values["sequential_chain"] = _load_sequential_chain( + values["llm"], + values.get("create_assertions_prompt", CREATE_ASSERTIONS_PROMPT), + values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT), + values.get("revised_summary_prompt", REVISED_SUMMARY_PROMPT), + values.get("are_all_true_prompt", ARE_ALL_TRUE_PROMPT), + verbose=values.get("verbose", False), + ) + return values + @property def input_keys(self) -> List[str]: """Return the singular input key. @@ -73,46 +143,21 @@ def output_keys(self) -> List[str]: """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() all_true = False count = 0 output = None original_input = inputs[self.input_key] chain_input = original_input - while not all_true and count < self.max_checks: - chain = SequentialChain( - chains=[ - LLMChain( - llm=self.llm, - prompt=self.create_assertions_prompt, - output_key="assertions", - verbose=self.verbose, - ), - LLMChain( - llm=self.llm, - prompt=self.check_assertions_prompt, - output_key="checked_assertions", - verbose=self.verbose, - ), - LLMChain( - llm=self.llm, - prompt=self.revised_summary_prompt, - output_key="revised_summary", - verbose=self.verbose, - ), - LLMChain( - llm=self.llm, - output_key="all_true", - prompt=self.are_all_true_prompt, - verbose=self.verbose, - ), - ], - input_variables=["summary"], - output_variables=["all_true", "revised_summary"], - verbose=self.verbose, + output = self.sequential_chain( + {"summary": chain_input}, callbacks=_run_manager.get_child() ) - output = chain({"summary": chain_input}) count += 1 if output["all_true"].strip() == "True": @@ -131,3 +176,24 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: @property def _chain_type(self) -> str: return "llm_summarization_checker_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLLM, + create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT, + check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT, + revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT, + are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT, + verbose: bool = False, + **kwargs: Any, + ) -> LLMSummarizationCheckerChain: + chain = _load_sequential_chain( + llm, + create_assertions_prompt, + check_assertions_prompt, + revised_summary_prompt, + are_all_true_prompt, + verbose=verbose, + ) + return cls(sequential_chain=chain, verbose=verbose, **kwargs) diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 062a9431d652a..f1b66b49b8f80 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -5,10 +5,11 @@ """ from __future__ import annotations -from typing import Dict, List +from typing import Any, Dict, List, Optional from pydantic import Extra +from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -32,16 +33,26 @@ class MapReduceChain(Chain): @classmethod def from_params( - cls, llm: BaseLLM, prompt: BasePromptTemplate, text_splitter: TextSplitter + cls, + llm: BaseLLM, + prompt: BasePromptTemplate, + text_splitter: TextSplitter, + callbacks: Callbacks = None, + **kwargs: Any, ) -> MapReduceChain: """Construct a map-reduce chain that uses the chain for map and reduce.""" - llm_chain = LLMChain(llm=llm, prompt=prompt) - reduce_chain = StuffDocumentsChain(llm_chain=llm_chain) + llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks) + reduce_chain = StuffDocumentsChain(llm_chain=llm_chain, callbacks=callbacks) combine_documents_chain = MapReduceDocumentsChain( - llm_chain=llm_chain, combine_document_chain=reduce_chain + llm_chain=llm_chain, + combine_document_chain=reduce_chain, + callbacks=callbacks, ) return cls( - combine_documents_chain=combine_documents_chain, text_splitter=text_splitter + combine_documents_chain=combine_documents_chain, + text_splitter=text_splitter, + callbacks=callbacks, + **kwargs, ) class Config: @@ -66,9 +77,16 @@ def output_keys(self) -> List[str]: """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() # Split the larger text into smaller chunks. texts = self.text_splitter.split_text(inputs[self.input_key]) docs = [Document(page_content=text) for text in texts] - outputs = self.combine_documents_chain.run(input_documents=docs) + outputs = self.combine_documents_chain.run( + input_documents=docs, callbacks=_run_manager.get_child() + ) return {self.output_key: outputs} diff --git a/langchain/chains/moderation.py b/langchain/chains/moderation.py index 1e76c4360416b..96528a766d312 100644 --- a/langchain/chains/moderation.py +++ b/langchain/chains/moderation.py @@ -3,6 +3,7 @@ from pydantic import root_validator +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.utils import get_from_dict_or_env @@ -84,7 +85,11 @@ def _moderate(self, text: str, results: dict) -> str: return error_str return text - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: text = inputs[self.input_key] results = self.client.create(text) output = self._moderate(text, results["results"][0]) diff --git a/langchain/chains/natbot/base.py b/langchain/chains/natbot/base.py index 369f0f45bfb26..47c80616c292f 100644 --- a/langchain/chains/natbot/base.py +++ b/langchain/chains/natbot/base.py @@ -1,10 +1,12 @@ """Implement an LLM driven browser.""" from __future__ import annotations -from typing import Dict, List +import warnings +from typing import Dict, List, Optional -from pydantic import Extra +from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.natbot.prompt import PROMPT @@ -18,14 +20,15 @@ class NatBotChain(Chain): Example: .. code-block:: python - from langchain import NatBotChain, OpenAI - natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.") + from langchain import NatBotChain + natbot = NatBotChain.from_default("Buy me a new hat.") """ - llm: BaseLLM - """LLM wrapper to use.""" + llm_chain: LLMChain objective: str """Objective that NatBot is tasked with completing.""" + llm: Optional[BaseLLM] = None + """[Deprecated] LLM wrapper to use.""" input_url_key: str = "url" #: :meta private: input_browser_content_key: str = "browser_content" #: :meta private: previous_command: str = "" #: :meta private: @@ -37,11 +40,24 @@ class Config: extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an NatBotChain with an llm is deprecated. " + "Please instantiate with llm_chain argument or using the from_default " + "class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT) + return values + @classmethod def from_default(cls, objective: str) -> NatBotChain: - """Load with default LLM.""" + """Load with default LLMChain.""" llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50) - return cls(llm=llm, objective=objective) + llm_chain = LLMChain(llm=llm, prompt=PROMPT) + return cls(llm_chain=llm_chain, objective=objective) @property def input_keys(self) -> List[str]: @@ -59,15 +75,20 @@ def output_keys(self) -> List[str]: """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: - llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() url = inputs[self.input_url_key] browser_content = inputs[self.input_browser_content_key] - llm_cmd = llm_executor.predict( + llm_cmd = self.llm_chain.predict( objective=self.objective, url=url[:100], previous_command=self.previous_command, browser_content=browser_content[:4500], + callbacks=_run_manager.get_child(), ) llm_cmd = llm_cmd.strip() self.previous_command = llm_cmd diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 6c8cfa4706e4c..275680a8bb16c 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -4,11 +4,13 @@ """ from __future__ import annotations +import warnings from typing import Any, Dict, List, Optional -from pydantic import Extra +from pydantic import Extra, root_validator from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT @@ -20,8 +22,11 @@ class PALChain(Chain): """Implements Program-Aided Language Models.""" - llm: BaseLanguageModel - prompt: BasePromptTemplate + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated]""" + prompt: BasePromptTemplate = MATH_PROMPT + """[Deprecated]""" stop: str = "\n\n" get_answer_expr: str = "print(solution())" python_globals: Optional[Dict[str, Any]] = None @@ -35,6 +40,19 @@ class Config: extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an PALChain with an llm is deprecated. " + "Please instantiate with llm_chain argument or using the one of " + "the class method constructors from_math_prompt, " + "from_colored_object_prompt." + ) + if "llm_chain" not in values and values["llm"] is not None: + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=MATH_PROMPT) + return values + @property def input_keys(self) -> List[str]: """Return the singular input key. @@ -54,12 +72,16 @@ def output_keys(self) -> List[str]: else: return [self.output_key, "intermediate_steps"] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: - llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) - code = llm_chain.predict(stop=[self.stop], **inputs) - self.callback_manager.on_text( - code, color="green", end="\n", verbose=self.verbose + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + code = self.llm_chain.predict( + stop=[self.stop], callbacks=_run_manager.get_child(), **inputs ) + _run_manager.on_text(code, color="green", end="\n", verbose=self.verbose) repl = PythonREPL(_globals=self.python_globals, _locals=self.python_locals) res = repl.run(code + f"\n{self.get_answer_expr}") output = {self.output_key: res.strip()} @@ -70,9 +92,9 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: @classmethod def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain: """Load PAL from math prompt.""" + llm_chain = LLMChain(llm=llm, prompt=MATH_PROMPT) return cls( - llm=llm, - prompt=MATH_PROMPT, + llm_chain=llm_chain, stop="\n\n", get_answer_expr="print(solution())", **kwargs, @@ -83,9 +105,9 @@ def from_colored_object_prompt( cls, llm: BaseLanguageModel, **kwargs: Any ) -> PALChain: """Load PAL from colored object prompt.""" + llm_chain = LLMChain(llm=llm, prompt=COLORED_OBJECT_PROMPT) return cls( - llm=llm, - prompt=COLORED_OBJECT_PROMPT, + llm_chain=llm_chain, stop="\n\n\n", get_answer_expr="print(answer)", **kwargs, diff --git a/langchain/chains/qa_generation/base.py b/langchain/chains/qa_generation/base.py index 2dfb4b5d29818..1c0ae6b978478 100644 --- a/langchain/chains/qa_generation/base.py +++ b/langchain/chains/qa_generation/base.py @@ -6,6 +6,7 @@ from pydantic import Field from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR @@ -45,11 +46,14 @@ def input_keys(self) -> List[str]: def output_keys(self) -> List[str]: return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, List]: docs = self.text_splitter.create_documents([inputs[self.input_key]]) - results = self.llm_chain.generate([{"text": d.page_content} for d in docs]) + results = self.llm_chain.generate( + [{"text": d.page_content} for d in docs], run_manager=run_manager + ) qa = [json.loads(res[0].text) for res in results.generations] return {self.output_key: qa} - - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: - raise NotImplementedError diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index 10b5d96f9c858..96048a0a23384 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -9,6 +9,10 @@ from pydantic import Extra, root_validator from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -114,9 +118,16 @@ def validate_naming(cls, values: Dict) -> Dict: def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]: """Get docs to run questioning over.""" - def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() docs = self._get_docs(inputs) - answer = self.combine_documents_chain.run(input_documents=docs, **inputs) + answer = self.combine_documents_chain.run( + input_documents=docs, callbacks=_run_manager.get_child(), **inputs + ) if re.search(r"SOURCES:\s", answer): answer, sources = re.split(r"SOURCES:\s", answer) else: @@ -133,9 +144,16 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]: """Get docs to run questioning over.""" - async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() docs = await self._aget_docs(inputs) - answer = await self.combine_documents_chain.arun(input_documents=docs, **inputs) + answer = await self.combine_documents_chain.arun( + input_documents=docs, callbacks=_run_manager.get_child(), **inputs + ) if re.search(r"SOURCES:\s", answer): answer, sources = re.split(r"SOURCES:\s", answer) else: diff --git a/langchain/chains/query_constructor/base.py b/langchain/chains/query_constructor/base.py index 3fb80c4332f98..dd5062a9b5b32 100644 --- a/langchain/chains/query_constructor/base.py +++ b/langchain/chains/query_constructor/base.py @@ -5,6 +5,7 @@ from typing import Any, Callable, List, Optional, Sequence from langchain import BasePromptTemplate, FewShotPromptTemplate, LLMChain +from langchain.base_language import BaseLanguageModel from langchain.chains.query_constructor.ir import ( Comparator, Operator, @@ -20,7 +21,7 @@ ) from langchain.chains.query_constructor.schema import AttributeInfo from langchain.output_parsers.structured import parse_json_markdown -from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException +from langchain.schema import BaseOutputParser, OutputParserException class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): diff --git a/langchain/chains/retrieval_qa/base.py b/langchain/chains/retrieval_qa/base.py index 7e30f4720467a..2255f957123f9 100644 --- a/langchain/chains/retrieval_qa/base.py +++ b/langchain/chains/retrieval_qa/base.py @@ -8,6 +8,10 @@ from pydantic import Extra, Field, root_validator from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain @@ -93,7 +97,11 @@ def from_chain_type( def _get_docs(self, question: str) -> List[Document]: """Get documents to do question answering over.""" - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """Run get_relevant_text and llm on input query. If chain has 'return_source_documents' as 'True', returns @@ -105,11 +113,12 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] docs = self._get_docs(question) answer = self.combine_documents_chain.run( - input_documents=docs, question=question + input_documents=docs, question=question, callbacks=_run_manager.get_child() ) if self.return_source_documents: @@ -121,7 +130,11 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: async def _aget_docs(self, question: str) -> List[Document]: """Get documents to do question answering over.""" - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, Any]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """Run get_relevant_text and llm on input query. If chain has 'return_source_documents' as 'True', returns @@ -133,11 +146,12 @@ async def _acall(self, inputs: Dict[str, str]) -> Dict[str, Any]: res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] docs = await self._aget_docs(question) answer = await self.combine_documents_chain.arun( - input_documents=docs, question=question + input_documents=docs, question=question, callbacks=_run_manager.get_child() ) if self.return_source_documents: diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index b21dfac5083af..f94b5bc584bc2 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -1,8 +1,12 @@ """Chain pipeline where the outputs of one step feed directly into next.""" -from typing import Dict, List +from typing import Any, Dict, List, Optional from pydantic import Extra, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.input import get_color_mapping @@ -86,17 +90,31 @@ def validate_chains(cls, values: Dict) -> Dict: return values - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: known_values = inputs.copy() + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() for i, chain in enumerate(self.chains): - outputs = chain(known_values, return_only_outputs=True) + callbacks = _run_manager.get_child() + outputs = chain(known_values, return_only_outputs=True, callbacks=callbacks) known_values.update(outputs) return {k: known_values[k] for k in self.output_variables} - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: known_values = inputs.copy() + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() for i, chain in enumerate(self.chains): - outputs = await chain.acall(known_values, return_only_outputs=True) + outputs = await chain.acall( + known_values, return_only_outputs=True, callbacks=callbacks + ) known_values.update(outputs) return {k: known_values[k] for k in self.output_variables} @@ -147,31 +165,37 @@ def validate_chains(cls, values: Dict) -> Dict: ) return values - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _input = inputs[self.input_key] color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) for i, chain in enumerate(self.chains): - _input = chain.run(_input) + _input = chain.run(_input, callbacks=_run_manager.get_child()) if self.strip_outputs: _input = _input.strip() - self.callback_manager.on_text( + _run_manager.on_text( _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose ) return {self.output_key: _input} - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() _input = inputs[self.input_key] color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) for i, chain in enumerate(self.chains): - _input = await chain.arun(_input) + _input = await chain.arun(_input, callbacks=callbacks) if self.strip_outputs: _input = _input.strip() - if self.callback_manager.is_async: - await self.callback_manager.on_text( - _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose - ) - else: - self.callback_manager.on_text( - _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose - ) + await _run_manager.on_text( + _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose + ) return {self.output_key: _input} diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 3d8761b211dc6..d73d34a30afce 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -1,11 +1,13 @@ """Chain for interacting with SQL Database.""" from __future__ import annotations +import warnings from typing import Any, Dict, List, Optional -from pydantic import Extra, Field +from pydantic import Extra, Field, root_validator from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS @@ -21,15 +23,16 @@ class SQLDatabaseChain(Chain): from langchain import SQLDatabaseChain, OpenAI, SQLDatabase db = SQLDatabase(...) - db_chain = SQLDatabaseChain(llm=OpenAI(), database=db) + db_chain = SQLDatabaseChain.from_llm(OpenAI(), db) """ - llm: BaseLanguageModel - """LLM wrapper to use.""" + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated] LLM wrapper to use.""" database: SQLDatabase = Field(exclude=True) """SQL Database to connect to.""" prompt: Optional[BasePromptTemplate] = None - """Prompt to use to translate natural language to SQL.""" + """[Deprecated] Prompt to use to translate natural language to SQL.""" top_k: int = 5 """Number of results to return from the query""" input_key: str = "query" #: :meta private: @@ -45,6 +48,22 @@ class Config: extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an SQLDatabaseChain with an llm is deprecated. " + "Please instantiate with llm_chain argument or using the from_llm " + "class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + database = values["database"] + prompt = values.get("prompt") or SQL_PROMPTS.get( + database.dialect, PROMPT + ) + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) + return values + @property def input_keys(self) -> List[str]: """Return the singular input key. @@ -64,11 +83,14 @@ def output_keys(self) -> List[str]: else: return [self.output_key, "intermediate_steps"] - def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - prompt = self.prompt or SQL_PROMPTS.get(self.database.dialect, PROMPT) - llm_chain = LLMChain(llm=self.llm, prompt=prompt) + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() input_text = f"{inputs[self.input_key]}\nSQLQuery:" - self.callback_manager.on_text(input_text, verbose=self.verbose) + _run_manager.on_text(input_text, verbose=self.verbose) # If not present, then defaults to None which is all tables. table_names_to_use = inputs.get("table_names_to_use") table_info = self.database.get_table_info(table_names=table_names_to_use) @@ -80,24 +102,26 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: "stop": ["\nSQLResult:"], } intermediate_steps = [] - sql_cmd = llm_chain.predict(**llm_inputs) + sql_cmd = self.llm_chain.predict( + callbacks=_run_manager.get_child(), **llm_inputs + ) intermediate_steps.append(sql_cmd) - self.callback_manager.on_text(sql_cmd, color="green", verbose=self.verbose) + _run_manager.on_text(sql_cmd, color="green", verbose=self.verbose) result = self.database.run(sql_cmd) intermediate_steps.append(result) - self.callback_manager.on_text("\nSQLResult: ", verbose=self.verbose) - self.callback_manager.on_text(result, color="yellow", verbose=self.verbose) + _run_manager.on_text("\nSQLResult: ", verbose=self.verbose) + _run_manager.on_text(result, color="yellow", verbose=self.verbose) # If return direct, we just set the final result equal to the sql query if self.return_direct: final_result = result else: - self.callback_manager.on_text("\nAnswer:", verbose=self.verbose) + _run_manager.on_text("\nAnswer:", verbose=self.verbose) input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:" llm_inputs["input"] = input_text - final_result = llm_chain.predict(**llm_inputs) - self.callback_manager.on_text( - final_result, color="green", verbose=self.verbose + final_result = self.llm_chain.predict( + callbacks=_run_manager.get_child(), **llm_inputs ) + _run_manager.on_text(final_result, color="green", verbose=self.verbose) chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: chain_result["intermediate_steps"] = intermediate_steps @@ -107,6 +131,18 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def _chain_type(self) -> str: return "sql_database_chain" + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + db: SQLDatabase, + prompt: Optional[BasePromptTemplate] = None, + **kwargs: Any, + ) -> SQLDatabaseChain: + prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT) + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, database=db, **kwargs) + class SQLDatabaseSequentialChain(Chain): """Chain for querying SQL database that is a sequential chain. @@ -118,6 +154,10 @@ class SQLDatabaseSequentialChain(Chain): This is useful in cases where the number of tables in the database is large. """ + decider_chain: LLMChain + sql_chain: SQLDatabaseChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: return_intermediate_steps: bool = False @classmethod @@ -138,11 +178,6 @@ def from_llm( ) return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs) - decider_chain: LLMChain - sql_chain: SQLDatabaseChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - @property def input_keys(self) -> List[str]: """Return the singular input key. @@ -162,25 +197,32 @@ def output_keys(self) -> List[str]: else: return [self.output_key, "intermediate_steps"] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _table_names = self.sql_chain.database.get_usable_table_names() table_names = ", ".join(_table_names) llm_inputs = { "query": inputs[self.input_key], "table_names": table_names, } - table_names_to_use = self.decider_chain.predict_and_parse(**llm_inputs) - self.callback_manager.on_text( - "Table names to use:", end="\n", verbose=self.verbose + table_names_to_use = self.decider_chain.predict_and_parse( + callbacks=_run_manager.get_child(), **llm_inputs ) - self.callback_manager.on_text( + _run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose) + _run_manager.on_text( str(table_names_to_use), color="yellow", verbose=self.verbose ) new_inputs = { self.sql_chain.input_key: inputs[self.input_key], "table_names_to_use": table_names_to_use, } - return self.sql_chain(new_inputs, return_only_outputs=True) + return self.sql_chain( + new_inputs, callbacks=_run_manager.get_child(), return_only_outputs=True + ) @property def _chain_type(self) -> str: diff --git a/langchain/chains/transform.py b/langchain/chains/transform.py index eb5cb314a8960..90947b2b698d6 100644 --- a/langchain/chains/transform.py +++ b/langchain/chains/transform.py @@ -1,6 +1,7 @@ """Chain that runs an arbitrary python function.""" -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain @@ -35,5 +36,9 @@ def output_keys(self) -> List[str]: """ return self.output_variables - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: return self.transform(inputs) diff --git a/tests/integration_tests/chains/test_pal.py b/tests/integration_tests/chains/test_pal.py index 9bbf6f8d88622..cb03d80cae1e1 100644 --- a/tests/integration_tests/chains/test_pal.py +++ b/tests/integration_tests/chains/test_pal.py @@ -6,7 +6,7 @@ def test_math_prompt() -> None: """Test math prompt.""" - llm = OpenAI(model_name="code-davinci-002", temperature=0, max_tokens=512) + llm = OpenAI(temperature=0, max_tokens=512) pal_chain = PALChain.from_math_prompt(llm) question = ( "Jan has three times the number of pets as Marcia. " @@ -19,7 +19,7 @@ def test_math_prompt() -> None: def test_colored_object_prompt() -> None: """Test colored object prompt.""" - llm = OpenAI(model_name="code-davinci-002", temperature=0, max_tokens=512) + llm = OpenAI(temperature=0, max_tokens=512) pal_chain = PALChain.from_colored_object_prompt(llm) question = ( "On the desk, you see two blue booklets, " diff --git a/tests/integration_tests/chains/test_sql_database.py b/tests/integration_tests/chains/test_sql_database.py index 3518866c2e6c8..f19ec02594e71 100644 --- a/tests/integration_tests/chains/test_sql_database.py +++ b/tests/integration_tests/chains/test_sql_database.py @@ -27,7 +27,7 @@ def test_sql_database_run() -> None: with engine.connect() as conn: conn.execute(stmt) db = SQLDatabase(engine) - db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db) + db_chain = SQLDatabaseChain.from_llm(OpenAI(temperature=0), db) output = db_chain.run("What company does Harrison work at?") expected_output = " Harrison works at Foo." assert output == expected_output @@ -41,7 +41,7 @@ def test_sql_database_run_update() -> None: with engine.connect() as conn: conn.execute(stmt) db = SQLDatabase(engine) - db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db) + db_chain = SQLDatabaseChain.from_llm(OpenAI(temperature=0), db) output = db_chain.run("Update Harrison's workplace to Bar") expected_output = " Harrison's workplace has been updated to Bar." assert output == expected_output @@ -59,9 +59,7 @@ def test_sql_database_sequential_chain_run() -> None: with engine.connect() as conn: conn.execute(stmt) db = SQLDatabase(engine) - db_chain = SQLDatabaseSequentialChain.from_llm( - llm=OpenAI(temperature=0), database=db - ) + db_chain = SQLDatabaseSequentialChain.from_llm(OpenAI(temperature=0), db) output = db_chain.run("What company does Harrison work at?") expected_output = " Harrison works at Foo." assert output == expected_output @@ -77,7 +75,7 @@ def test_sql_database_sequential_chain_intermediate_steps() -> None: conn.execute(stmt) db = SQLDatabase(engine) db_chain = SQLDatabaseSequentialChain.from_llm( - llm=OpenAI(temperature=0), database=db, return_intermediate_steps=True + OpenAI(temperature=0), db, return_intermediate_steps=True ) output = db_chain("What company does Harrison work at?") expected_output = " Harrison works at Foo." diff --git a/tests/unit_tests/chains/test_llm_bash.py b/tests/unit_tests/chains/test_llm_bash.py index 3e20e3567c233..ecbfb3a8a53c4 100644 --- a/tests/unit_tests/chains/test_llm_bash.py +++ b/tests/unit_tests/chains/test_llm_bash.py @@ -3,8 +3,8 @@ import pytest -from langchain.chains.llm_bash.base import BashOutputParser, LLMBashChain -from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE +from langchain.chains.llm_bash.base import LLMBashChain +from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser from langchain.schema import OutputParserException from tests.unit_tests.llms.fake_llm import FakeLLM From 83cda5e83a174f9cad72dd94a0cede5d60f18271 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Fri, 28 Apr 2023 15:25:19 -0700 Subject: [PATCH 21/36] lint --- langchain/agents/tools.py | 1 - langchain/callbacks/__init__.py | 8 ------- langchain/callbacks/manager.py | 2 +- langchain/callbacks/tracers/schemas.py | 2 +- .../agents/trajectory_eval_chain.py | 7 +++++- .../generative_agents/generative_agent.py | 2 +- .../experimental/generative_agents/memory.py | 3 ++- .../document_compressors/chain_extract.py | 3 ++- .../document_compressors/chain_filter.py | 3 ++- langchain/retrievers/self_query/base.py | 5 ++-- langchain/tools/base.py | 2 -- .../callbacks/test_langchain_tracer.py | 23 ++++++++++++++----- tests/integration_tests/llms/test_llamacpp.py | 4 +--- tests/unit_tests/agents/test_agent.py | 2 +- .../callbacks/tracers/test_tracer.py | 2 +- 15 files changed, 38 insertions(+), 31 deletions(-) diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index cbe19a01c8ef7..0f94313861210 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -8,7 +8,6 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, - Callbacks, ) from langchain.tools.base import BaseTool, StructuredTool diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py index 28a00e0b5c7f3..a85c375e29a02 100644 --- a/langchain/callbacks/__init__.py +++ b/langchain/callbacks/__init__.py @@ -1,23 +1,15 @@ """Callback handlers that allow listening to events in LangChain.""" -from contextlib import contextmanager -from typing import Generator from langchain.callbacks.aim_callback import AimCallbackHandler -from langchain.callbacks.base import ( - BaseCallbackHandler, - BaseCallbackManager, -) from langchain.callbacks.clearml_callback import ClearMLCallbackHandler from langchain.callbacks.comet_ml_callback import CometCallbackHandler from langchain.callbacks.manager import ( - CallbackManager, get_openai_callback, tracing_enabled, ) from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler -from langchain.callbacks.tracers import LangChainTracer from langchain.callbacks.wandb_callback import WandbCallbackHandler __all__ = [ diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index caf87fbe6e454..60024ef7d410a 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -7,7 +7,7 @@ import uuid from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Dict, Generator, List, Optional, Sequence, Type, TypeVar, Union +from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union from langchain.callbacks.base import ( BaseCallbackHandler, diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index fc96908f5c0a7..ce6368ff9636a 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -2,7 +2,7 @@ from __future__ import annotations import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field diff --git a/langchain/evaluation/agents/trajectory_eval_chain.py b/langchain/evaluation/agents/trajectory_eval_chain.py index f6f9cf088b368..d79171bb7bae4 100644 --- a/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/langchain/evaluation/agents/trajectory_eval_chain.py @@ -1,6 +1,7 @@ """A chain for evaluating ReAct style agents.""" from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chat_models import ChatOpenAI @@ -94,7 +95,11 @@ def output_keys(self) -> List[str]: return ["score", "reasoning"] return ["score"] - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: raw_output = self.eval_chain.run( {"tool_descriptions": self._tools_description, **inputs} ) diff --git a/langchain/experimental/generative_agents/generative_agent.py b/langchain/experimental/generative_agents/generative_agent.py index ac5d951ae668f..64780da81a6f4 100644 --- a/langchain/experimental/generative_agents/generative_agent.py +++ b/langchain/experimental/generative_agents/generative_agent.py @@ -5,9 +5,9 @@ from pydantic import BaseModel, Field from langchain import LLMChain +from langchain.base_language import BaseLanguageModel from langchain.experimental.generative_agents.memory import GenerativeAgentMemory from langchain.prompts import PromptTemplate -from langchain.schema import BaseLanguageModel class GenerativeAgent(BaseModel): diff --git a/langchain/experimental/generative_agents/memory.py b/langchain/experimental/generative_agents/memory.py index 8719d1bf3f680..5f1d65f423024 100644 --- a/langchain/experimental/generative_agents/memory.py +++ b/langchain/experimental/generative_agents/memory.py @@ -3,9 +3,10 @@ from typing import Any, Dict, List, Optional from langchain import LLMChain +from langchain.base_language import BaseLanguageModel from langchain.prompts import PromptTemplate from langchain.retrievers import TimeWeightedVectorStoreRetriever -from langchain.schema import BaseLanguageModel, BaseMemory, Document +from langchain.schema import BaseMemory, Document logger = logging.getLogger(__name__) diff --git a/langchain/retrievers/document_compressors/chain_extract.py b/langchain/retrievers/document_compressors/chain_extract.py index 6f638559443f1..ea1b2c3fbe5fc 100644 --- a/langchain/retrievers/document_compressors/chain_extract.py +++ b/langchain/retrievers/document_compressors/chain_extract.py @@ -2,13 +2,14 @@ from typing import Any, Callable, Dict, Optional, Sequence from langchain import LLMChain, PromptTemplate +from langchain.base_language import BaseLanguageModel from langchain.retrievers.document_compressors.base import ( BaseDocumentCompressor, ) from langchain.retrievers.document_compressors.chain_extract_prompt import ( prompt_template, ) -from langchain.schema import BaseLanguageModel, BaseOutputParser, Document +from langchain.schema import BaseOutputParser, Document def default_get_input(query: str, doc: Document) -> Dict[str, Any]: diff --git a/langchain/retrievers/document_compressors/chain_filter.py b/langchain/retrievers/document_compressors/chain_filter.py index f5e33e6bf65ab..6cb0d2fe8456c 100644 --- a/langchain/retrievers/document_compressors/chain_filter.py +++ b/langchain/retrievers/document_compressors/chain_filter.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, Optional, Sequence from langchain import BasePromptTemplate, LLMChain, PromptTemplate +from langchain.base_language import BaseLanguageModel from langchain.output_parsers.boolean import BooleanOutputParser from langchain.retrievers.document_compressors.base import ( BaseDocumentCompressor, @@ -9,7 +10,7 @@ from langchain.retrievers.document_compressors.chain_filter_prompt import ( prompt_template, ) -from langchain.schema import BaseLanguageModel, Document +from langchain.schema import Document def _get_default_chain_prompt() -> PromptTemplate: diff --git a/langchain/retrievers/self_query/base.py b/langchain/retrievers/self_query/base.py index a9d7cad4412e0..7c9ced716626c 100644 --- a/langchain/retrievers/self_query/base.py +++ b/langchain/retrievers/self_query/base.py @@ -4,13 +4,14 @@ from pydantic import BaseModel, Field, root_validator from langchain import LLMChain +from langchain.base_language import BaseLanguageModel from langchain.chains.query_constructor.base import ( load_query_constructor_chain, ) from langchain.chains.query_constructor.ir import StructuredQuery, Visitor from langchain.chains.query_constructor.schema import AttributeInfo from langchain.retrievers.self_query.pinecone import PineconeTranslator -from langchain.schema import BaseLanguageModel, BaseRetriever, Document +from langchain.schema import BaseRetriever, Document from langchain.vectorstores import Pinecone, VectorStore @@ -69,7 +70,7 @@ def get_relevant_documents(self, query: str) -> List[Document]: """ inputs = self.llm_chain.prep_inputs(query) structured_query = cast( - StructuredQuery, self.llm_chain.predict_and_parse(**inputs) + StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs) ) if self.verbose: print(structured_query) diff --git a/langchain/tools/base.py b/langchain/tools/base.py index 261014c59cd8f..3388910df74df 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -19,9 +19,7 @@ from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import ( AsyncCallbackManager, - AsyncCallbackManagerForToolRun, CallbackManager, - CallbackManagerForToolRun, Callbacks, ) diff --git a/tests/integration_tests/callbacks/test_langchain_tracer.py b/tests/integration_tests/callbacks/test_langchain_tracer.py index c17ebeb77fcd8..ffcb8c467446a 100644 --- a/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -1,7 +1,6 @@ """Integration tests for the langchain tracer module.""" import asyncio import os -import time import pytest from aiohttp import ClientSession @@ -11,11 +10,23 @@ from langchain.llms import OpenAI questions = [ - "Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?", - "Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?", - "Who won the most recent formula 1 grand prix? What is their age raised to the 0.23 power?", - "Who won the US Open women's final in 2019? What is her age raised to the 0.34 power?", - "Who is Beyonce's husband? What is his age raised to the 0.19 power?", + ( + "Who won the US Open men's final in 2019? " + "What is his age raised to the 0.334 power?" + ), + ( + "Who is Olivia Wilde's boyfriend? " + "What is his current age raised to the 0.23 power?" + ), + ( + "Who won the most recent formula 1 grand prix? " + "What is their age raised to the 0.23 power?" + ), + ( + "Who won the US Open women's final in 2019? " + "What is her age raised to the 0.34 power?" + ), + ("Who is Beyonce's husband? " "What is his age raised to the 0.19 power?"), ] diff --git a/tests/integration_tests/llms/test_llamacpp.py b/tests/integration_tests/llms/test_llamacpp.py index 7ea2881f2ec53..e1a28594a118f 100644 --- a/tests/integration_tests/llms/test_llamacpp.py +++ b/tests/integration_tests/llms/test_llamacpp.py @@ -5,7 +5,6 @@ from urllib.request import urlretrieve from langchain.llms import LlamaCpp -from langchain.callbacks.base import CallbackManager from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -61,10 +60,9 @@ def test_llamacpp_streaming_callback() -> None: OFF_BY_ONE = 1 # There may be an off by one error in the upstream code! callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) llm = LlamaCpp( model_path=get_model(), - callback_manager=callback_manager, + callbacks=[callback_handler], verbose=True, max_tokens=MAX_TOKENS, ) diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 1fb94cf21a49b..3a03f03f15b50 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -4,7 +4,7 @@ from langchain.agents import AgentExecutor, AgentType, initialize_agent from langchain.agents.tools import Tool -from langchain.callbacks.manager import CallbackManager, CallbackManagerForLLMRun +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py index 54a8e3527841b..c60373d63042d 100644 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -2,7 +2,7 @@ from __future__ import annotations from datetime import datetime -from typing import List, Optional, Union +from typing import List, Union from uuid import uuid4 import pytest From 9c876bdb7314c2e2fb3406101025752d09652c11 Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Fri, 28 Apr 2023 15:44:13 -0700 Subject: [PATCH 22/36] update chain notebooks (#3740) --- docs/modules/chains/examples/llm_bash.ipynb | 30 ++--- .../modules/chains/examples/llm_checker.ipynb | 22 +--- docs/modules/chains/examples/llm_math.ipynb | 116 ++---------------- .../examples/llm_summarization_checker.ipynb | 61 ++++----- docs/modules/chains/examples/pal.ipynb | 32 ++--- docs/modules/chains/examples/sqlite.ipynb | 14 +-- 6 files changed, 87 insertions(+), 188 deletions(-) diff --git a/docs/modules/chains/examples/llm_bash.ipynb b/docs/modules/chains/examples/llm_bash.ipynb index c2cb0fe6362d2..dab1f6e45b05c 100644 --- a/docs/modules/chains/examples/llm_bash.ipynb +++ b/docs/modules/chains/examples/llm_bash.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -37,7 +37,7 @@ "'Hello World\\n'" ] }, - "execution_count": 1, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -50,7 +50,7 @@ "\n", "text = \"Please write a bash script that prints 'Hello World' to the console.\"\n", "\n", - "bash_chain = LLMBashChain(llm=llm, verbose=True)\n", + "bash_chain = LLMBashChain.from_llm(llm, verbose=True)\n", "\n", "bash_chain.run(text)" ] @@ -65,11 +65,12 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from langchain.prompts.prompt import PromptTemplate\n", + "from langchain.chains.llm_bash.prompt import BashOutputParser\n", "\n", "_PROMPT_TEMPLATE = \"\"\"If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put \"#!/bin/bash\" in your answer. Make sure to reason step by step, using this format:\n", "Question: \"copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'\"\n", @@ -88,12 +89,12 @@ "That is the format. Begin!\n", "Question: {question}\"\"\"\n", "\n", - "PROMPT = PromptTemplate(input_variables=[\"question\"], template=_PROMPT_TEMPLATE)" + "PROMPT = PromptTemplate(input_variables=[\"question\"], template=_PROMPT_TEMPLATE, output_parser=BashOutputParser())" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -120,13 +121,13 @@ "'Hello World\\n'" ] }, - "execution_count": 3, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "bash_chain = LLMBashChain(llm=llm, prompt=PROMPT, verbose=True)\n", + "bash_chain = LLMBashChain.from_llm(llm, prompt=PROMPT, verbose=True)\n", "\n", "text = \"Please write a bash script that prints 'Hello World' to the console.\"\n", "\n", @@ -134,7 +135,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -145,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -177,7 +177,7 @@ "'api.ipynb\\t\\t\\tllm_summarization_checker.ipynb\\r\\nconstitutional_chain.ipynb\\tmoderation.ipynb\\r\\nllm_bash.ipynb\\t\\t\\topenai_openapi.yaml\\r\\nllm_checker.ipynb\\t\\topenapi.ipynb\\r\\nllm_math.ipynb\\t\\t\\tpal.ipynb\\r\\nllm_requests.ipynb\\t\\tsqlite.ipynb'" ] }, - "execution_count": 4, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -187,7 +187,7 @@ "\n", "\n", "persistent_process = BashProcess(persistent=True)\n", - "bash_chain = LLMBashChain.from_bash_process(llm=llm, bash_process=persistent_process, verbose=True)\n", + "bash_chain = LLMBashChain.from_llm(llm, bash_process=persistent_process, verbose=True)\n", "\n", "text = \"List the current directory then move up a level.\"\n", "\n", @@ -196,7 +196,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -224,7 +224,7 @@ "'examples\\t\\tgetting_started.ipynb\\tindex_examples\\r\\ngeneric\\t\\t\\thow_to_guides.rst'" ] }, - "execution_count": 5, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -258,7 +258,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/modules/chains/examples/llm_checker.ipynb b/docs/modules/chains/examples/llm_checker.ipynb index a6bc0b73e3455..38ed1b64a4f1e 100644 --- a/docs/modules/chains/examples/llm_checker.ipynb +++ b/docs/modules/chains/examples/llm_checker.ipynb @@ -23,28 +23,16 @@ "\n", "\n", "\u001b[1m> Entering new SequentialChain chain...\u001b[0m\n", - "\u001b[1mChain 0\u001b[0m:\n", - "{'statement': '\\nNone. Mammals do not lay eggs.'}\n", "\n", - "\u001b[1mChain 1\u001b[0m:\n", - "{'assertions': '\\n• Mammals reproduce using live birth\\n• Mammals do not lay eggs\\n• Animals that lay eggs are not mammals'}\n", + "\u001b[1m> Finished chain.\u001b[0m\n", "\n", - "\u001b[1mChain 2\u001b[0m:\n", - "{'checked_assertions': '\\n1. True\\n\\n2. True\\n\\n3. False - Mammals are a class of animals that includes animals that lay eggs, such as monotremes (platypus and echidna).'}\n", - "\n", - "\u001b[1mChain 3\u001b[0m:\n", - "{'revised_statement': ' Monotremes, such as the platypus and echidna, lay the biggest eggs of any mammal.'}\n", - "\n", - "\n", - "\u001b[1m> Finished SequentialChain chain.\u001b[0m\n", - "\n", - "\u001b[1m> Finished LLMCheckerChain chain.\u001b[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "' Monotremes, such as the platypus and echidna, lay the biggest eggs of any mammal.'" + "' No mammal lays the biggest eggs. The Elephant Bird, which was a species of giant bird, laid the largest eggs of any bird.'" ] }, "execution_count": 1, @@ -60,7 +48,7 @@ "\n", "text = \"What type of mammal lays the biggest eggs?\"\n", "\n", - "checker_chain = LLMCheckerChain(llm=llm, verbose=True)\n", + "checker_chain = LLMCheckerChain.from_llm(llm, verbose=True)\n", "\n", "checker_chain.run(text)" ] @@ -89,7 +77,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/modules/chains/examples/llm_math.ipynb b/docs/modules/chains/examples/llm_math.ipynb index 29eaaea1c2fb1..c46f825e8f9df 100644 --- a/docs/modules/chains/examples/llm_math.ipynb +++ b/docs/modules/chains/examples/llm_math.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "id": "44e9ba31", "metadata": {}, "outputs": [ @@ -24,23 +24,22 @@ "\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", "What is 13 raised to the .3432 power?\u001b[32;1m\u001b[1;3m\n", - "```python\n", - "import math\n", - "print(math.pow(13, .3432))\n", + "```text\n", + "13 ** .3432\n", "```\n", + "...numexpr.evaluate(\"13 ** .3432\")...\n", "\u001b[0m\n", - "Answer: \u001b[33;1m\u001b[1;3m2.4116004626599237\n", - "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m2.4116004626599237\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "'Answer: 2.4116004626599237\\n'" + "'Answer: 2.4116004626599237'" ] }, - "execution_count": 1, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -49,102 +48,7 @@ "from langchain import OpenAI, LLMMathChain\n", "\n", "llm = OpenAI(temperature=0)\n", - "llm_math = LLMMathChain(llm=llm, verbose=True)\n", - "\n", - "llm_math.run(\"What is 13 raised to the .3432 power?\")" - ] - }, - { - "cell_type": "markdown", - "id": "2bdd5fc6", - "metadata": {}, - "source": [ - "## Customize Prompt\n", - "You can also customize the prompt that is used. Here is an example prompting it to use numpy" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "76be17b0", - "metadata": {}, - "outputs": [], - "source": [ - "from langchain.prompts.prompt import PromptTemplate\n", - "\n", - "_PROMPT_TEMPLATE = \"\"\"You are GPT-3, and you can't do math.\n", - "\n", - "You can do basic math, and your memorization abilities are impressive, but you can't do any complex calculations that a human could not do in their head. You also have an annoying tendency to just make up highly specific, but wrong, answers.\n", - "\n", - "So we hooked you up to a Python 3 kernel, and now you can execute code. If you execute code, you must print out the final answer using the print function. You MUST use the python package numpy to answer your question. You must import numpy as np.\n", - "\n", - "\n", - "Question: ${{Question with hard calculation.}}\n", - "```python\n", - "${{Code that prints what you need to know}}\n", - "print(${{code}})\n", - "```\n", - "```output\n", - "${{Output of your code}}\n", - "```\n", - "Answer: ${{Answer}}\n", - "\n", - "Begin.\n", - "\n", - "Question: What is 37593 * 67?\n", - "\n", - "```python\n", - "import numpy as np\n", - "print(np.multiply(37593, 67))\n", - "```\n", - "```output\n", - "2518731\n", - "```\n", - "Answer: 2518731\n", - "\n", - "Question: {question}\"\"\"\n", - "\n", - "PROMPT = PromptTemplate(input_variables=[\"question\"], template=_PROMPT_TEMPLATE)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "0c42faa0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", - "What is 13 raised to the .3432 power?\u001b[32;1m\u001b[1;3m\n", - "\n", - "```python\n", - "import numpy as np\n", - "print(np.power(13, .3432))\n", - "```\n", - "\u001b[0m\n", - "Answer: \u001b[33;1m\u001b[1;3m2.4116004626599237\n", - "\u001b[0m\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "'Answer: 2.4116004626599237\\n'" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "llm_math = LLMMathChain(llm=llm, prompt=PROMPT, verbose=True)\n", + "llm_math = LLMMathChain.from_llm(llm, verbose=True)\n", "\n", "llm_math.run(\"What is 13 raised to the .3432 power?\")" ] @@ -152,7 +56,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0c62951b", + "id": "e978bb8e", "metadata": {}, "outputs": [], "source": [] @@ -174,7 +78,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/modules/chains/examples/llm_summarization_checker.ipynb b/docs/modules/chains/examples/llm_summarization_checker.ipynb index 7448f84f5f435..7436616e4cbac 100644 --- a/docs/modules/chains/examples/llm_summarization_checker.ipynb +++ b/docs/modules/chains/examples/llm_summarization_checker.ipynb @@ -221,11 +221,11 @@ "\n", "• The light from these galaxies has been traveling for over 13 billion years to reach us. - True \n", "\n", - "• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 1995. \n", + "• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 2004. \n", "\n", "• Exoplanets were first discovered in 1992. - True \n", "\n", - "• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. It is too early to tell as the JWST has not been launched yet.\n", + "• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. The JWST has not yet been launched, so it is not yet known how much detail it will be able to provide.\n", "\"\"\"\n", "\n", "Original Summary:\n", @@ -296,11 +296,11 @@ "\n", "• The light from these galaxies has been traveling for over 13 billion years to reach us. - True \n", "\n", - "• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 1995. \n", + "• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 2004. \n", "\n", "• Exoplanets were first discovered in 1992. - True \n", "\n", - "• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. It is too early to tell as the JWST has not been launched yet.\n", + "• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. The JWST has not yet been launched, so it is not yet known how much detail it will be able to provide.\n", "\"\"\"\n", "Result:\u001b[0m\n", "\n", @@ -312,7 +312,7 @@ "Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\n", "• In 2023, The JWST will spot a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\n", "• The telescope will capture images of galaxies that are over 13 billion years old. This means that the light from these galaxies has been traveling for over 13 billion years to reach us.\n", - "• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail than ever before.\n", + "• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail when it is launched in 2023.\n", "These discoveries can spark a child's imagination about the infinite wonders of the universe.\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" @@ -321,7 +321,7 @@ { "data": { "text/plain": [ - "'Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\\n• In 2023, The JWST will spot a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\\n• The telescope will capture images of galaxies that are over 13 billion years old. This means that the light from these galaxies has been traveling for over 13 billion years to reach us.\\n• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail than ever before.\\nThese discoveries can spark a child\\'s imagination about the infinite wonders of the universe.'" + "'Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\\n• In 2023, The JWST will spot a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\\n• The telescope will capture images of galaxies that are over 13 billion years old. This means that the light from these galaxies has been traveling for over 13 billion years to reach us.\\n• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail when it is launched in 2023.\\nThese discoveries can spark a child\\'s imagination about the infinite wonders of the universe.'" ] }, "execution_count": 1, @@ -334,7 +334,7 @@ "from langchain.llms import OpenAI\n", "\n", "llm = OpenAI(temperature=0)\n", - "checker_chain = LLMSummarizationCheckerChain(llm=llm, verbose=True, max_checks=2)\n", + "checker_chain = LLMSummarizationCheckerChain.from_llm(llm, verbose=True, max_checks=2)\n", "text = \"\"\"\n", "Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\n", "• In 2023, The JWST spotted a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\n", @@ -407,7 +407,8 @@ "Prompt after formatting:\n", "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.\n", "\n", - "Checked Assertions:\"\"\"\n", + "Checked Assertions:\n", + "\"\"\"\n", "\n", "- The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. True\n", "\n", @@ -428,7 +429,8 @@ "- It is considered the northern branch of the Norwegian Sea. True\n", "\"\"\"\n", "\n", - "Original Summary:\"\"\"\n", + "Original Summary:\n", + "\"\"\"\n", "The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is one of five oceans in the world, alongside the Pacific Ocean, Atlantic Ocean, Indian Ocean, and the Southern Ocean. It is the smallest of the five oceans and is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the island of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Norwegian Sea.\n", "\"\"\"\n", "\n", @@ -443,7 +445,7 @@ "\n", "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", "Prompt after formatting:\n", - "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false.\n", + "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true or false.\n", "\n", "If all of the assertions are true, return \"True\". If any of the assertions are false, return \"False\".\n", "\n", @@ -555,7 +557,8 @@ "Prompt after formatting:\n", "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.\n", "\n", - "Checked Assertions:\"\"\"\n", + "Checked Assertions:\n", + "\"\"\"\n", "\n", "- The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. True\n", "\n", @@ -574,7 +577,8 @@ "- It is considered the northern branch of the Norwegian Sea. False - It is considered the northern branch of the Atlantic Ocean.\n", "\"\"\"\n", "\n", - "Original Summary:\"\"\"\n", + "Original Summary:\n", + "\"\"\"\n", "\n", "The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is an arm of the Arctic Ocean. It is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the island of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Norwegian Sea.\n", "\"\"\"\n", @@ -583,14 +587,20 @@ "\n", "The output should have the same structure and formatting as the original summary.\n", "\n", - "Summary:\u001b[0m\n", + "Summary:\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", "Prompt after formatting:\n", - "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false.\n", + "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true or false.\n", "\n", "If all of the assertions are true, return \"True\". If any of the assertions are false, return \"False\".\n", "\n", @@ -701,7 +711,8 @@ "Prompt after formatting:\n", "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.\n", "\n", - "Checked Assertions:\"\"\"\n", + "Checked Assertions:\n", + "\"\"\"\n", "\n", "- The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. True\n", "\n", @@ -718,7 +729,8 @@ "- It is considered the northern branch of the Atlantic Ocean. False - The Greenland Sea is considered part of the Arctic Ocean, not the Atlantic Ocean.\n", "\"\"\"\n", "\n", - "Original Summary:\"\"\"\n", + "Original Summary:\n", + "\"\"\"\n", "\n", "\n", "The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is an arm of the Arctic Ocean. It is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the country of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Atlantic Ocean.\n", @@ -735,7 +747,7 @@ "\n", "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", "Prompt after formatting:\n", - "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false.\n", + "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true or false.\n", "\n", "If all of the assertions are true, return \"True\". If any of the assertions are false, return \"False\".\n", "\n", @@ -813,14 +825,14 @@ "from langchain.llms import OpenAI\n", "\n", "llm = OpenAI(temperature=0)\n", - "checker_chain = LLMSummarizationCheckerChain(llm=llm, verbose=True, max_checks=3)\n", + "checker_chain = LLMSummarizationCheckerChain.from_llm(llm, verbose=True, max_checks=3)\n", "text = \"The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is one of five oceans in the world, alongside the Pacific Ocean, Atlantic Ocean, Indian Ocean, and the Southern Ocean. It is the smallest of the five oceans and is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the island of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Norwegian Sea.\"\n", "checker_chain.run(text)" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -1077,7 +1089,7 @@ "'Birds are not mammals, but they are a class of their own. They lay eggs, unlike mammals which give birth to live young.'" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -1087,17 +1099,10 @@ "from langchain.llms import OpenAI\n", "\n", "llm = OpenAI(temperature=0)\n", - "checker_chain = LLMSummarizationCheckerChain(llm=llm, max_checks=3, verbose=True)\n", + "checker_chain = LLMSummarizationCheckerChain.from_llm(llm, max_checks=3, verbose=True)\n", "text = \"Mammals can lay eggs, birds can lay eggs, therefore birds are mammals.\"\n", "checker_chain.run(text)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/modules/chains/examples/pal.ipynb b/docs/modules/chains/examples/pal.ipynb index 36b580729cdad..94942ccbeecb5 100644 --- a/docs/modules/chains/examples/pal.ipynb +++ b/docs/modules/chains/examples/pal.ipynb @@ -28,7 +28,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm = OpenAI(model_name='code-davinci-002', temperature=0, max_tokens=512)" + "llm = OpenAI(temperature=0, max_tokens=512)" ] }, { @@ -63,7 +63,9 @@ "cell_type": "code", "execution_count": 4, "id": "3ef64b27", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", @@ -71,17 +73,17 @@ "text": [ "\n", "\n", - "\u001B[1m> Entering new PALChain chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3mdef solution():\n", + "\u001b[1m> Entering new PALChain chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mdef solution():\n", " \"\"\"Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have?\"\"\"\n", " cindy_pets = 4\n", " marcia_pets = cindy_pets + 2\n", " jan_pets = marcia_pets * 3\n", " total_pets = cindy_pets + marcia_pets + jan_pets\n", " result = total_pets\n", - " return result\u001B[0m\n", + " return result\u001b[0m\n", "\n", - "\u001B[1m> Finished chain.\u001B[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { @@ -139,8 +141,8 @@ "text": [ "\n", "\n", - "\u001B[1m> Entering new PALChain chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3m# Put objects into a list to record ordering\n", + "\u001b[1m> Entering new PALChain chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m# Put objects into a list to record ordering\n", "objects = []\n", "objects += [('booklet', 'blue')] * 2\n", "objects += [('booklet', 'purple')] * 2\n", @@ -151,9 +153,9 @@ "\n", "# Count number of purple objects\n", "num_purple = len([object for object in objects if object[1] == 'purple'])\n", - "answer = num_purple\u001B[0m\n", + "answer = num_purple\u001b[0m\n", "\n", - "\u001B[1m> Finished PALChain chain.\u001B[0m\n" + "\u001b[1m> Finished PALChain chain.\u001b[0m\n" ] }, { @@ -212,8 +214,8 @@ "text": [ "\n", "\n", - "\u001B[1m> Entering new PALChain chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3m# Put objects into a list to record ordering\n", + "\u001b[1m> Entering new PALChain chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m# Put objects into a list to record ordering\n", "objects = []\n", "objects += [('booklet', 'blue')] * 2\n", "objects += [('booklet', 'purple')] * 2\n", @@ -224,9 +226,9 @@ "\n", "# Count number of purple objects\n", "num_purple = len([object for object in objects if object[1] == 'purple'])\n", - "answer = num_purple\u001B[0m\n", + "answer = num_purple\u001b[0m\n", "\n", - "\u001B[1m> Finished chain.\u001B[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] } ], @@ -280,7 +282,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/modules/chains/examples/sqlite.ipynb b/docs/modules/chains/examples/sqlite.ipynb index b3b23eb43d410..472ac99e960ed 100644 --- a/docs/modules/chains/examples/sqlite.ipynb +++ b/docs/modules/chains/examples/sqlite.ipynb @@ -73,7 +73,7 @@ "metadata": {}, "outputs": [], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)" + "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)" ] }, { @@ -175,7 +175,7 @@ "metadata": {}, "outputs": [], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True)" + "db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True)" ] }, { @@ -230,7 +230,7 @@ "metadata": {}, "outputs": [], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True, return_intermediate_steps=True)" + "db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True, return_intermediate_steps=True)" ] }, { @@ -285,7 +285,7 @@ "metadata": {}, "outputs": [], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True, top_k=3)" + "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, top_k=3)" ] }, { @@ -407,7 +407,7 @@ "metadata": {}, "outputs": [], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)" + "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)" ] }, { @@ -569,7 +569,7 @@ } ], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n", + "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)\n", "db_chain.run(\"What are some example tracks by Bach?\")" ] }, @@ -681,7 +681,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.9.1" } }, "nbformat": 4, From 43410e4904749ca842d9d890f0157382751049d0 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Fri, 28 Apr 2023 15:47:44 -0700 Subject: [PATCH 23/36] fix test --- tests/unit_tests/chains/test_base.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 84c8247c5f2e7..e1e33a4bd3c24 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -147,25 +147,10 @@ def test_run_with_callback() -> None: """Test run method works when callback manager is passed.""" handler = FakeCallbackHandler() chain = FakeChain( - callback_manager=CallbackManager(handlers=[handler]), verbose=True + callbacks=[handler], ) output = chain.run("bar") assert output == "baz" assert handler.starts == 1 assert handler.ends == 1 assert handler.errors == 0 - - -def test_run_with_callback_not_verbose() -> None: - """Test run method works when callback manager is passed and not verbose.""" - import langchain - - langchain.verbose = False - - handler = FakeCallbackHandler() - chain = FakeChain(callback_manager=CallbackManager(handlers=[handler])) - output = chain.run("bar") - assert output == "baz" - assert handler.starts == 0 - assert handler.ends == 0 - assert handler.errors == 0 From 9c988ae39e74390023dc74c5fcc20cf558e73415 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Fri, 28 Apr 2023 15:56:22 -0700 Subject: [PATCH 24/36] cr --- tests/unit_tests/llms/test_callbacks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py index 78480816edc81..ce0cf77f495f1 100644 --- a/tests/unit_tests/llms/test_callbacks.py +++ b/tests/unit_tests/llms/test_callbacks.py @@ -1,5 +1,4 @@ """Test LLM callbacks.""" -from langchain.callbacks.manager import CallbackManager from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.llms.fake_llm import FakeLLM @@ -7,7 +6,7 @@ def test_llm_with_callbacks() -> None: """Test LLM callbacks.""" handler = FakeCallbackHandler() - llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True) + llm = FakeLLM(callbacks=[handler], verbose=True) output = llm("foo") assert output == "foo" assert handler.starts == 1 From bd9ac67afb79ed1c12cebbb06dddb3a5f0447b14 Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Fri, 28 Apr 2023 16:06:52 -0700 Subject: [PATCH 25/36] nb nit (#3744) --- .../models/llms/examples/custom_llm.ipynb | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/docs/modules/models/llms/examples/custom_llm.ipynb b/docs/modules/models/llms/examples/custom_llm.ipynb index 1375d63929192..4db92f0477d49 100644 --- a/docs/modules/models/llms/examples/custom_llm.ipynb +++ b/docs/modules/models/llms/examples/custom_llm.ipynb @@ -22,18 +22,20 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 6, "id": "a65696a0", "metadata": {}, "outputs": [], "source": [ - "from langchain.llms.base import LLM\n", - "from typing import Optional, List, Mapping, Any" + "from typing import Any, List, Mapping, Optional\n", + "\n", + "from langchain.callbacks.manager import CallbackManagerForLLMRun\n", + "from langchain.llms.base import LLM" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "id": "d5ceff02", "metadata": {}, "outputs": [], @@ -46,7 +48,12 @@ " def _llm_type(self) -> str:\n", " return \"custom\"\n", " \n", - " def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:\n", + " def _call(\n", + " self,\n", + " prompt: str,\n", + " stop: Optional[List[str]] = None,\n", + " run_manager: Optional[CallbackManagerForLLMRun] = None,\n", + " ) -> str:\n", " if stop is not None:\n", " raise ValueError(\"stop kwargs are not permitted.\")\n", " return prompt[:self.n]\n", @@ -67,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "id": "10e5ece6", "metadata": {}, "outputs": [], @@ -77,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "id": "8cd49199", "metadata": {}, "outputs": [ @@ -87,7 +94,7 @@ "'This is a '" ] }, - "execution_count": 4, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -106,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 10, "id": "9c33fa19", "metadata": {}, "outputs": [ From e60489e9392860da35179ae05d8e536c94bf9fbe Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Fri, 28 Apr 2023 16:52:33 -0700 Subject: [PATCH 26/36] fix lint --- langchain/chains/llm_bash/base.py | 2 +- tests/unit_tests/chains/test_base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 61894bffe1468..4c8cb2ae56972 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -94,7 +94,7 @@ def _call( _run_manager.on_text(t, color="green", verbose=self.verbose) t = t.strip() try: - command_list = self.llm_chain.prompt.output_parser.parse(t) # type: ignore[union-attr] + command_list = self.llm_chain.prompt.output_parser.parse(t) # type: ignore[union-attr] # noqa: E501 except OutputParserException as e: _run_manager.on_chain_error(e, verbose=self.verbose) raise e diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index e1e33a4bd3c24..b852510c5b1fe 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -3,7 +3,7 @@ import pytest -from langchain.callbacks.manager import CallbackManager, CallbackManagerForChainRun +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.schema import BaseMemory from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler From 9dad0514e721cf9abe5d43db7dd4250861a23b00 Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Fri, 28 Apr 2023 17:14:20 -0700 Subject: [PATCH 27/36] fix test warnings (#3753) --- langchain/chains/natbot/base.py | 13 +++++++++---- tests/unit_tests/chains/test_llm_bash.py | 4 ++-- tests/unit_tests/chains/test_llm_checker.py | 2 +- tests/unit_tests/chains/test_llm_math.py | 2 +- .../chains/test_llm_summarization_checker.py | 4 +++- tests/unit_tests/chains/test_natbot.py | 6 +++--- 6 files changed, 19 insertions(+), 12 deletions(-) diff --git a/langchain/chains/natbot/base.py b/langchain/chains/natbot/base.py index 47c80616c292f..452f78600acb5 100644 --- a/langchain/chains/natbot/base.py +++ b/langchain/chains/natbot/base.py @@ -2,7 +2,7 @@ from __future__ import annotations import warnings -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from pydantic import Extra, root_validator @@ -45,7 +45,7 @@ def raise_deprecation(cls, values: Dict) -> Dict: if "llm" in values: warnings.warn( "Directly instantiating an NatBotChain with an llm is deprecated. " - "Please instantiate with llm_chain argument or using the from_default " + "Please instantiate with llm_chain argument or using the from_llm " "class method." ) if "llm_chain" not in values and values["llm"] is not None: @@ -53,11 +53,16 @@ def raise_deprecation(cls, values: Dict) -> Dict: return values @classmethod - def from_default(cls, objective: str) -> NatBotChain: + def from_default(cls, objective: str, **kwargs: Any) -> NatBotChain: """Load with default LLMChain.""" llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50) + return cls.from_llm(llm, objective, **kwargs) + + @classmethod + def from_llm(cls, llm: BaseLLM, objective: str, **kwargs: Any) -> NatBotChain: + """Load from LLM.""" llm_chain = LLMChain(llm=llm, prompt=PROMPT) - return cls(llm_chain=llm_chain, objective=objective) + return cls(llm_chain=llm_chain, objective=objective, **kwargs) @property def input_keys(self) -> List[str]: diff --git a/tests/unit_tests/chains/test_llm_bash.py b/tests/unit_tests/chains/test_llm_bash.py index ecbfb3a8a53c4..e6ee11d09f3d6 100644 --- a/tests/unit_tests/chains/test_llm_bash.py +++ b/tests/unit_tests/chains/test_llm_bash.py @@ -43,7 +43,7 @@ def test_simple_question() -> None: prompt = _PROMPT_TEMPLATE.format(question=question) queries = {prompt: "```bash\nexpr 1 + 1\n```"} fake_llm = FakeLLM(queries=queries) - fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a") + fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a") output = fake_llm_bash_chain.run(question) assert output == "2\n" @@ -71,7 +71,7 @@ def test_parsing_error() -> None: """ } fake_llm = FakeLLM(queries=queries) - fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a") + fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a") with pytest.raises(OutputParserException): fake_llm_bash_chain.run(question) diff --git a/tests/unit_tests/chains/test_llm_checker.py b/tests/unit_tests/chains/test_llm_checker.py index 0c9b9343550a9..cc2ceb9909c2b 100644 --- a/tests/unit_tests/chains/test_llm_checker.py +++ b/tests/unit_tests/chains/test_llm_checker.py @@ -33,7 +33,7 @@ def fake_llm_checker_chain() -> LLMCheckerChain: ): "I still don't know.", } fake_llm = FakeLLM(queries=queries) - return LLMCheckerChain(llm=fake_llm, input_key="q", output_key="a") + return LLMCheckerChain.from_llm(fake_llm, input_key="q", output_key="a") def test_simple_question(fake_llm_checker_chain: LLMCheckerChain) -> None: diff --git a/tests/unit_tests/chains/test_llm_math.py b/tests/unit_tests/chains/test_llm_math.py index c412436c665cb..4e3887ab9b09a 100644 --- a/tests/unit_tests/chains/test_llm_math.py +++ b/tests/unit_tests/chains/test_llm_math.py @@ -17,7 +17,7 @@ def fake_llm_math_chain() -> LLMMathChain: _PROMPT_TEMPLATE.format(question="foo"): "foo", } fake_llm = FakeLLM(queries=queries) - return LLMMathChain(llm=fake_llm, input_key="q", output_key="a") + return LLMMathChain.from_llm(fake_llm, input_key="q", output_key="a") def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None: diff --git a/tests/unit_tests/chains/test_llm_summarization_checker.py b/tests/unit_tests/chains/test_llm_summarization_checker.py index 81e4a8fa1c645..aa82cead6bee4 100644 --- a/tests/unit_tests/chains/test_llm_summarization_checker.py +++ b/tests/unit_tests/chains/test_llm_summarization_checker.py @@ -32,7 +32,9 @@ def fake_llm_summarization_checker_chain() -> LLMSummarizationCheckerChain: ): "True", } fake_llm = FakeLLM(queries=queries) - return LLMSummarizationCheckerChain(llm=fake_llm, input_key="q", output_key="a") + return LLMSummarizationCheckerChain.from_llm( + fake_llm, input_key="q", output_key="a" + ) def test_simple_text( diff --git a/tests/unit_tests/chains/test_natbot.py b/tests/unit_tests/chains/test_natbot.py index b8c6538e2867e..77c29808433a2 100644 --- a/tests/unit_tests/chains/test_natbot.py +++ b/tests/unit_tests/chains/test_natbot.py @@ -34,7 +34,7 @@ def _identifying_params(self) -> Mapping[str, Any]: def test_proper_inputs() -> None: """Test that natbot shortens inputs correctly.""" - nat_bot_chain = NatBotChain(llm=FakeLLM(), objective="testing") + nat_bot_chain = NatBotChain.from_llm(FakeLLM(), objective="testing") url = "foo" * 10000 browser_content = "foo" * 10000 output = nat_bot_chain.execute(url, browser_content) @@ -43,8 +43,8 @@ def test_proper_inputs() -> None: def test_variable_key_naming() -> None: """Test that natbot handles variable key naming correctly.""" - nat_bot_chain = NatBotChain( - llm=FakeLLM(), + nat_bot_chain = NatBotChain.from_llm( + FakeLLM(), objective="testing", input_url_key="u", input_browser_content_key="b", From 5f78219de69db14e83db3a4e188e07a95cac2dbf Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Fri, 28 Apr 2023 18:30:33 -0700 Subject: [PATCH 28/36] fix some docs, add session variable --- .../examples/async_agent.ipynb | 277 ++++++------------ docs/modules/callbacks/getting_started.ipynb | 3 +- docs/tracing/agent_with_tracing.ipynb | 257 ++++++++++++---- langchain/agents/load_tools.py | 4 +- langchain/callbacks/manager.py | 7 +- 5 files changed, 299 insertions(+), 249 deletions(-) diff --git a/docs/modules/agents/agent_executors/examples/async_agent.ipynb b/docs/modules/agents/agent_executors/examples/async_agent.ipynb index 925700ed33dfe..cc9a92b3f9e7f 100644 --- a/docs/modules/agents/agent_executors/examples/async_agent.ipynb +++ b/docs/modules/agents/agent_executors/examples/async_agent.ipynb @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 7, "id": "da5df06c-af6f-4572-b9f5-0ab971c16487", "metadata": { "tags": [] @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "id": "fd4c294e-b1d6-44b8-b32e-2765c017e503", "metadata": { "tags": [] @@ -72,16 +72,15 @@ "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", "Action: Search\n", "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Rafael Nadal's age\n", + "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", "Action: Search\n", "Action Input: \"Rafael Nadal age\"\u001b[0m\n", "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 36 raised to the 0.334 power\n", + "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate his age raised to the 0.334 power\n", "Action: Calculator\n", "Action Input: 36^0.334\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\n", - "\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n", "\n", @@ -92,18 +91,17 @@ "\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", "Action: Search\n", "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mJason Sudeikis\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Jason Sudeikis' age\n", + "Observation: \u001b[33;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", "Action: Search\n", - "Action Input: \"Jason Sudeikis age\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m47 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 47 raised to the 0.23 power\n", + "Action Input: \"Harry Styles age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", "Action: Calculator\n", - "Action Input: 47^0.23\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.4242784855673896\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Jason Sudeikis, Olivia Wilde's boyfriend, is 47 years old and his age raised to the 0.23 power is 2.4242784855673896.\u001b[0m\n", + "Action Input: 29^0.23\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", + "Final Answer: Harry Styles, Olivia Wilde's boyfriend, is 29 years old and his age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", @@ -112,17 +110,17 @@ "\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n", "Action: Search\n", "Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mMax Verstappen\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Max Verstappen's age\n", + "Observation: \u001b[33;1m\u001b[1;3mMichael Schumacher (top left) and Lewis Hamilton (top right) have each won the championship a record seven times during their careers, while Sebastian Vettel ( ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", "Action: Search\n", - "Action Input: \"Max Verstappen Age\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m25 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 25 raised to the 0.23 power\n", + "Action Input: \"Michael Schumacher age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m54 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.23 power\n", "Action: Calculator\n", - "Action Input: 25^0.23\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.84599359907945\u001b[0m\n", + "Action Input: 54^0.23\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.502940725307012\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Max Verstappen, 25 years old, raised to the 0.23 power is 1.84599359907945.\u001b[0m\n", + "Final Answer: Michael Schumacher, aged 54, raised to the 0.23 power is 2.502940725307012.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", @@ -131,18 +129,17 @@ "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open women's final in 2019 and then calculate her age raised to the 0.34 power.\n", "Action: Search\n", "Action Input: \"US Open women's final 2019 winner\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu defeated Serena Williams in the final, 6–3, 7–5 to win the women's singles tennis title at the 2019 US Open. It was her first major title, and she became the first Canadian, as well as the first player born in the 2000s, to win a major singles title.\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Bianca Andreescu's age.\n", + "Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out her age\n", "Action: Search\n", "Action Input: \"Bianca Andreescu age\"\u001b[0m\n", "Observation: \u001b[33;1m\u001b[1;3m22 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the age of Bianca Andreescu and can calculate her age raised to the 0.34 power.\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate her age raised to the 0.34 power\n", "Action: Calculator\n", "Action Input: 22^0.34\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", - "Final Answer: Bianca Andreescu won the US Open women's final in 2019 and her age raised to the 0.34 power is 2.8603798598506933.\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: Bianca Andreescu, aged 22, won the US Open women's final in 2019 and her age raised to the 0.34 power is 2.86.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", @@ -159,35 +156,32 @@ "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 53 raised to the 0.19 power\n", "Action: Calculator\n", "Action Input: 53^0.19\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\n", - "\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: Jay-Z is Beyonce's husband and his age raised to the 0.19 power is 2.12624064206896.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "Serial executed in 65.11 seconds.\n" + "Serial executed in 52.47 seconds.\n" ] } ], "source": [ - "def generate_serially():\n", - " for q in questions:\n", - " llm = OpenAI(temperature=0)\n", - " tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n", - " agent = initialize_agent(\n", - " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n", - " )\n", - " agent.run(q)\n", + "llm = OpenAI(temperature=0)\n", + "tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n", + "agent = initialize_agent(\n", + " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n", + ")\n", "\n", "s = time.perf_counter()\n", - "generate_serially()\n", + "for q in questions:\n", + " agent.run(q)\n", "elapsed = time.perf_counter() - s\n", "print(f\"Serial executed in {elapsed:0.2f} seconds.\")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "id": "076d7b85-45ec-465d-8b31-c2ad119c3438", "metadata": { "tags": [] @@ -201,10 +195,10 @@ "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\n", - "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\n", "\n", + "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\n", "\n", @@ -212,179 +206,94 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", + "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open women's final in 2019 and then calculate her age raised to the 0.34 power.\n", "Action: Search\n", - "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Beyonce's husband is and then calculate his age raised to the 0.19 power.\n", + "Action Input: \"US Open women's final 2019 winner\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", "Action: Search\n", - "Action Input: \"Who is Beyonce's husband?\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mJay-Z\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n", + "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", "Action: Search\n", - "Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the US Open women's final in 2019 and then calculate her age raised to the 0.34 power.\n", + "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n", "Action: Search\n", - "Action Input: \"US Open women's final 2019 winner\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mJason Sudeikis\u001b[0m\n", + "Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Beyonce's husband is and then calculate his age raised to the 0.19 power.\n", + "Action: Search\n", + "Action Input: \"Who is Beyonce's husband?\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n", "Thought:\n", - "Observation: \u001b[33;1m\u001b[1;3mMax Verstappen\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mJay-Z\u001b[0m\n", "Thought:\n", - "Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu defeated Serena Williams in the final, 6–3, 7–5 to win the women's singles tennis title at the 2019 US Open. It was her first major title, and she became the first Canadian, as well as the first player born in the 2000s, to win a major singles title.\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Jason Sudeikis' age\n", + "Observation: \u001b[33;1m\u001b[1;3mMichael Schumacher (top left) and Lewis Hamilton (top right) have each won the championship a record seven times during their careers, while Sebastian Vettel ( ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out her age\n", "Action: Search\n", - "Action Input: \"Jason Sudeikis age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Jay-Z's age\n", + "Action Input: \"Bianca Andreescu age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Jay-Z's age\n", "Action: Search\n", "Action Input: \"How old is Jay-Z?\"\u001b[0m\n", "Observation: \u001b[33;1m\u001b[1;3m53 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", - "Action: Search\n", - "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n", "Thought:\n", - "Observation: \u001b[33;1m\u001b[1;3m47 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Max Verstappen's age\n", - "Action: Search\n", - "Action Input: \"Max Verstappen Age\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m25 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Bianca Andreescu's age.\n", - "Action: Search\n", - "Action Input: \"Bianca Andreescu age\"\u001b[0m\n", "Observation: \u001b[33;1m\u001b[1;3m22 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 53 raised to the 0.19 power\n", - "Action: Calculator\n", - "Action Input: 53^0.19\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", "Action: Search\n", - "Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 47 raised to the 0.23 power\n", + "Action Input: \"Harry Styles age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate her age raised to the 0.34 power\n", "Action: Calculator\n", - "Action Input: 47^0.23\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 25 raised to the 0.23 power\n", + "Action Input: 22^0.34\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 53 raised to the 0.19 power\n", "Action: Calculator\n", - "Action Input: 25^0.23\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the age of Bianca Andreescu and can calculate her age raised to the 0.34 power.\n", + "Action Input: 53^0.19\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", "Action: Calculator\n", - "Action Input: 22^0.34\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.84599359907945\u001b[0m\n", + "Action Input: 29^0.23\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Action: Search\n", + "Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Action: Search\n", + "Action Input: \"Michael Schumacher age\"\u001b[0m\n", + "Observation: \n", + "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", + "Thought:\u001b[33;1m\u001b[1;3m54 years\u001b[0m\n", "Thought:\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.4242784855673896\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate his age raised to the 0.334 power\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\u001b[0m\n", + "Thought:\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n", + "Thought:\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n", "Action: Calculator\n", - "Action Input: 36^0.334\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Jay-Z is Beyonce's husband and his age raised to the 0.19 power is 2.12624064206896.\u001b[0m\n", - "\n", + "Action Input: 36^0.334\u001b[0m\u001b[32;1m\u001b[1;3m I now need to calculate the age raised to the 0.23 power\n", + "Action: Calculator\n", + "Action Input: 54^0.23\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Max Verstappen, 25 years old, raised to the 0.23 power is 1.84599359907945.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Jason Sudeikis, Olivia Wilde's boyfriend, is 47 years old and his age raised to the 0.23 power is 2.4242784855673896.\u001b[0m\n", - "\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I now know the final answer.\n", - "Final Answer: Bianca Andreescu won the US Open women's final in 2019 and her age raised to the 0.34 power is 2.8603798598506933.\u001b[0m\n", "\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n", + "Thought:\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.502940725307012\u001b[0m\n", + "Thought:\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "Concurrent executed in 12.38 seconds.\n" + "Concurrent executed in 14.49 seconds.\n" ] } ], "source": [ - "async def generate_concurrently():\n", - " agents = []\n", - " # To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n", - " # but you must manually close the client session at the end of your program/event loop\n", - " aiosession = ClientSession()\n", - " callbacks = [StdOutCallbackHandler()]\n", - " for _ in questions:\n", - " llm = OpenAI(temperature=0)\n", - " async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession)\n", - " agents.append(\n", - " initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)\n", - " )\n", - " tasks = [async_agent.arun(q, callbacks=callbacks) for async_agent, q in zip(agents, questions)]\n", - " await asyncio.gather(*tasks)\n", - " await aiosession.close()\n", + "llm = OpenAI(temperature=0)\n", + "tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n", + "agent = initialize_agent(\n", + " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n", + ")\n", "\n", "s = time.perf_counter()\n", - "# If running this outside of Jupyter, use asyncio.run(generate_concurrently())\n", - "await generate_concurrently()\n", + "# If running this outside of Jupyter, use asyncio.run or loop.run_until_complete\n", + "tasks = [agent.arun(q) for q in questions]\n", + "await asyncio.gather(*tasks)\n", "elapsed = time.perf_counter() - s\n", "print(f\"Concurrent executed in {elapsed:0.2f} seconds.\")" ] - }, - { - "cell_type": "markdown", - "id": "97ef285c-4a43-4a4e-9698-cd52a1bc56c9", - "metadata": {}, - "source": [ - "## Using Tracing with Asynchronous Agents\n", - "\n", - "To use tracing with async agents, you must pass in a custom `CallbackManager` with `LangChainTracer` to each agent running asynchronously. This way, you avoid collisions while the trace is being collected." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "44bda05a-d33e-4e91-9a71-a0f3f96aae95", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", - "Action: Search\n", - "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Rafael Nadal's age\n", - "Action: Search\n", - "Action Input: \"Rafael Nadal age\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 36 raised to the 0.334 power\n", - "Action: Calculator\n", - "Action Input: 36^0.334\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - } - ], - "source": [ - "# To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n", - "# but you must manually close the client session at the end of your program/event loop\n", - "aiosession = ClientSession()\n", - "tracer = LangChainTracer()\n", - "tracer.load_default_session()\n", - "callbacks = [StdOutCallbackHandler(), tracer]\n", - "\n", - "# Pass the manager into the llm if you want llm calls traced.\n", - "llm = OpenAI(temperature=0)\n", - "\n", - "async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession)\n", - "async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)\n", - "await async_agent.arun(questions[0], callbacks=callbacks)\n", - "await aiosession.close()" - ] } ], "metadata": { @@ -403,7 +312,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/docs/modules/callbacks/getting_started.ipynb b/docs/modules/callbacks/getting_started.ipynb index 74e8883afb906..109907bd7c425 100644 --- a/docs/modules/callbacks/getting_started.ipynb +++ b/docs/modules/callbacks/getting_started.ipynb @@ -93,7 +93,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "cbccd7d1", "metadata": {}, @@ -889,7 +888,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/docs/tracing/agent_with_tracing.ipynb b/docs/tracing/agent_with_tracing.ipynb index 26b2b9d0c3027..7facae9553a0e 100644 --- a/docs/tracing/agent_with_tracing.ipynb +++ b/docs/tracing/agent_with_tracing.ipynb @@ -5,7 +5,14 @@ "id": "5371a9bb", "metadata": {}, "source": [ - "# Tracing Walkthrough" + "# Tracing Walkthrough\n", + "\n", + "There are two recommended ways to trace your LangChains:\n", + "\n", + "1. Setting the `LANGCHAIN_TRACING` environment variable to \"true\".\n", + "1. Using a context manager with tracing_enabled() to trace a particular block of code.\n", + "\n", + "**Note** if the environment variable is set, all code will be traced, regardless of whether or not it's within the context manager." ] }, { @@ -18,24 +25,22 @@ "outputs": [], "source": [ "import os\n", - "os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"\n", - "\n", - "## Uncomment this if using hosted setup.\n", + "os.environ[\"LANGCHAIN_TRACING\"] = \"true\"\n", "\n", + "## Uncomment below if using hosted setup.\n", "# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchain-api-gateway-57eoxz8z.uc.gateway.dev\" \n", "\n", - "## Uncomment this if you want traces to be recorded to \"my_session\" instead of default.\n", - "\n", + "## Uncomment below if you want traces to be recorded to \"my_session\" instead of \"default\".\n", "# os.environ[\"LANGCHAIN_SESSION\"] = \"my_session\" \n", "\n", "## Better to set this environment variable in the terminal\n", - "## Uncomment this if using hosted version. Replace \"my_api_key\" with your actual API Key.\n", - "\n", + "## Uncomment below if using hosted version. Replace \"my_api_key\" with your actual API Key.\n", "# os.environ[\"LANGCHAIN_API_KEY\"] = \"my_api_key\" \n", "\n", "import langchain\n", "from langchain.agents import Tool, initialize_agent, load_tools\n", "from langchain.agents import AgentType\n", + "from langchain.callbacks import tracing_enabled\n", "from langchain.chat_models import ChatOpenAI\n", "from langchain.llms import OpenAI" ] @@ -73,8 +78,7 @@ "\u001b[32;1m\u001b[1;3m I need to use a calculator to solve this.\n", "Action: Calculator\n", "Action Input: 2^.123243\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\n", - "\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", "Final Answer: 1.0891804557407723\u001b[0m\n", "\n", @@ -104,7 +108,9 @@ "cell_type": "code", "execution_count": 4, "id": "4829eb1d", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", @@ -113,52 +119,11 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3mQuestion: What is 2 raised to .123243 power?\n", - "Thought: I need a calculator to solve this problem.\n", - "Action:\n", - "```\n", - "{\n", - " \"action\": \"calculator\",\n", - " \"action_input\": \"2^0.123243\"\n", - "}\n", - "```\n", - "\u001b[0m\n", - "Observation: calculator is not a valid tool, try another one.\n", - "\u001b[32;1m\u001b[1;3mI made a mistake, I need to use the correct tool for this question.\n", - "Action:\n", - "```\n", - "{\n", - " \"action\": \"calculator\",\n", - " \"action_input\": \"2^0.123243\"\n", - "}\n", - "```\n", - "\n", - "\u001b[0m\n", - "Observation: calculator is not a valid tool, try another one.\n", - "\u001b[32;1m\u001b[1;3mI made a mistake, the tool name is actually \"calc\" instead of \"calculator\".\n", - "Action:\n", - "```\n", - "{\n", - " \"action\": \"calc\",\n", - " \"action_input\": \"2^0.123243\"\n", - "}\n", - "```\n", - "\n", - "\u001b[0m\n", - "Observation: calc is not a valid tool, try another one.\n", - "\u001b[32;1m\u001b[1;3mI made another mistake, the tool name is actually \"Calculator\" instead of \"calc\".\n", - "Action:\n", - "```\n", - "{\n", - " \"action\": \"Calculator\",\n", - " \"action_input\": \"2^0.123243\"\n", - "}\n", - "```\n", - "\n", - "\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3mThe final answer is 1.0891804557407723.\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 2 ^ .123243\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n", "Final Answer: 1.0891804557407723\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" @@ -186,8 +151,182 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "76abfd82", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 2 ^ .123243\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n", + "Final Answer: 1.0891804557407723\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 5 ^ .123243\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2193914912400514\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n", + "Final Answer: 1.2193914912400514\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + } + ], + "source": [ + "# Both of the agent runs will be traced because the environment variable is set\n", + "agent.run(\"What is 2 raised to .123243 power?\")\n", + "with tracing_enabled() as session:\n", + " agent.run(\"What is 5 raised to .123243 power?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fe833c33-033f-4806-be0c-cc3d147db13d", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 5 ^ .123243\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2193914912400514\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n", + "Final Answer: 1.2193914912400514\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 2 ^ .123243\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n", + "Final Answer: 1.0891804557407723\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'1.0891804557407723'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now, we unset the environment variable and use a context manager.\n", + "if \"LANGCHAIN_TRACING\" in os.environ:\n", + " del os.environ[\"LANGCHAIN_TRACING\"]\n", + "\n", + "# here, we are writing traces to \"my_test_session\"\n", + "with tracing_enabled(\"my_session\") as session:\n", + " assert session\n", + " agent.run(\"What is 5 raised to .123243 power?\") # this should be traced\n", + "\n", + "agent.run(\"What is 2 raised to .123243 power?\") # this should not be traced" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b34105a4-be8e-46e4-8abe-01adba3ba727", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 3^0.123\u001b[0m\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 2^0.123\u001b[0m\u001b[32;1m\u001b[1;3mAny number raised to the power of 0 is 1, but I'm not sure about a decimal power.\n", + "Action: Calculator\n", + "Action Input: 1^.123\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.1446847956963533\u001b[0m\n", + "Thought:\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0889970153361064\u001b[0m\n", + "Thought:\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0\u001b[0m\n", + "Thought:\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'1.0'" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The context manager is concurrency safe:\n", + "import asyncio \n", + "if \"LANGCHAIN_TRACING\" in os.environ:\n", + " del os.environ[\"LANGCHAIN_TRACING\"]\n", + " \n", + "questions = [f\"What is {i} raised to .123 power?\" for i in range(1,4)]\n", + "\n", + "# start a background task\n", + "task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced\n", + "with tracing_enabled() as session:\n", + " assert session\n", + " tasks = [agent.arun(q) for q in questions[1:3]] # these should be traced\n", + " await asyncio.gather(*tasks)\n", + "\n", + "await task" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e46c85b-2ac0-4661-abed-9c2bf3036820", "metadata": {}, "outputs": [], "source": [] @@ -209,7 +348,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index 1d44852a86057..e11440e8b6768 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -103,8 +103,8 @@ def _get_llm_math(llm: BaseLLM) -> BaseTool: return Tool( name="Calculator", description="Useful for when you need to answer questions about math.", - func=LLMMathChain(llm=llm).run, - coroutine=LLMMathChain(llm=llm).arun, + func=LLMMathChain.from_llm(llm=llm).run, + coroutine=LLMMathChain.from_llm(llm=llm).arun, ) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 60024ef7d410a..0d0a8bd993af5 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -705,6 +705,9 @@ def _configure( tracing_enabled_ = ( os.environ.get("LANGCHAIN_TRACING") is not None or tracer is not None ) + tracer_session = os.environ.get("LANGCHAIN_SESSION") + if tracer_session is None: + tracer_session = "default" if verbose or tracing_enabled_ or open_ai is not None: if verbose and not any( isinstance(handler, StdOutCallbackHandler) @@ -717,10 +720,10 @@ def _configure( for handler in callback_manager.handlers ): if tracer: - callback_manager.add_handler(copy.deepcopy(tracer), True) + callback_manager.add_handler(tracer, True) else: handler = LangChainTracer() - handler.load_default_session() + handler.load_session(tracer_session) callback_manager.add_handler(handler, True) if open_ai is not None and not any( isinstance(handler, OpenAICallbackHandler) From 290fe752f1948e9638a49c63272e5ab196edf7aa Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Fri, 28 Apr 2023 18:32:37 -0700 Subject: [PATCH 29/36] Add RunManager to Tools Arguments (#3746) Co-authored-by: Ankush Gola <9536492+agola11@users.noreply.github.com> --- .../agent_toolkits/file_management/toolkit.py | 2 +- langchain/chains/llm_bash/base.py | 3 +- .../document_compressors/chain_extract.py | 4 +- .../document_compressors/chain_filter.py | 4 +- langchain/retrievers/self_query/base.py | 4 +- langchain/tools/arxiv/tool.py | 18 ++++- langchain/tools/base.py | 65 ++++++++++++++++--- langchain/tools/bing_search/tool.py | 30 +++++++-- langchain/tools/ddg_search/tool.py | 30 +++++++-- langchain/tools/file_management/copy.py | 25 +++++-- langchain/tools/file_management/delete.py | 23 +++++-- .../tools/file_management/file_search.py | 25 +++++-- langchain/tools/file_management/list_dir.py | 23 +++++-- langchain/tools/file_management/move.py | 25 +++++-- langchain/tools/file_management/read.py | 23 +++++-- langchain/tools/file_management/utils.py | 14 +--- langchain/tools/file_management/write.py | 27 ++++++-- langchain/tools/google_places/tool.py | 18 ++++- langchain/tools/google_search/tool.py | 30 +++++++-- langchain/tools/human/tool.py | 18 ++++- langchain/tools/ifttt.py | 18 ++++- langchain/tools/interaction/tool.py | 14 +++- langchain/tools/jira/tool.py | 18 ++++- langchain/tools/json/tool.py | 30 +++++++-- langchain/tools/playwright/base.py | 12 +++- langchain/tools/playwright/click.py | 9 ++- langchain/tools/playwright/current_page.py | 8 ++- .../tools/playwright/extract_hyperlinks.py | 9 ++- langchain/tools/playwright/extract_text.py | 7 +- langchain/tools/playwright/get_elements.py | 6 +- langchain/tools/playwright/navigate.py | 9 ++- langchain/tools/playwright/navigate_back.py | 8 ++- langchain/tools/plugin.py | 16 ++++- langchain/tools/powerbi/tool.py | 54 ++++++++++++--- langchain/tools/python/tool.py | 32 +++++++-- langchain/tools/requests/tool.py | 58 +++++++++++++---- langchain/tools/searx_search/tool.py | 30 +++++++-- langchain/tools/shell/tool.py | 18 ++++- langchain/tools/sql_database/tool.py | 59 +++++++++++++---- langchain/tools/vectorstore/tool.py | 30 +++++++-- langchain/tools/wikipedia/tool.py | 18 ++++- langchain/tools/wolfram_alpha/tool.py | 18 ++++- langchain/tools/zapier/tool.py | 26 ++++++-- tests/unit_tests/agents/test_tools.py | 5 +- tests/unit_tests/chains/test_base.py | 2 - tests/unit_tests/tools/test_signatures.py | 41 ++++++++++++ 46 files changed, 792 insertions(+), 174 deletions(-) create mode 100644 tests/unit_tests/tools/test_signatures.py diff --git a/langchain/agents/agent_toolkits/file_management/toolkit.py b/langchain/agents/agent_toolkits/file_management/toolkit.py index 17ae4f3a768c0..cc7d77f72a92b 100644 --- a/langchain/agents/agent_toolkits/file_management/toolkit.py +++ b/langchain/agents/agent_toolkits/file_management/toolkit.py @@ -54,7 +54,7 @@ def get_tools(self) -> List[BaseTool]: tools: List[BaseTool] = [] for tool in allowed_tools: tool_cls = _FILE_TOOLS[tool] - tools.append(tool_cls(root_dir=self.root_dir)) + tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore return tools diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 4c8cb2ae56972..468c0ba75075b 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -94,7 +94,8 @@ def _call( _run_manager.on_text(t, color="green", verbose=self.verbose) t = t.strip() try: - command_list = self.llm_chain.prompt.output_parser.parse(t) # type: ignore[union-attr] # noqa: E501 + parser = self.llm_chain.prompt.output_parser + command_list = parser.parse(t) # type: ignore[union-attr] except OutputParserException as e: _run_manager.on_chain_error(e, verbose=self.verbose) raise e diff --git a/langchain/retrievers/document_compressors/chain_extract.py b/langchain/retrievers/document_compressors/chain_extract.py index ea1b2c3fbe5fc..bf8c366607379 100644 --- a/langchain/retrievers/document_compressors/chain_extract.py +++ b/langchain/retrievers/document_compressors/chain_extract.py @@ -3,9 +3,7 @@ from langchain import LLMChain, PromptTemplate from langchain.base_language import BaseLanguageModel -from langchain.retrievers.document_compressors.base import ( - BaseDocumentCompressor, -) +from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.retrievers.document_compressors.chain_extract_prompt import ( prompt_template, ) diff --git a/langchain/retrievers/document_compressors/chain_filter.py b/langchain/retrievers/document_compressors/chain_filter.py index 6cb0d2fe8456c..245e005108f5c 100644 --- a/langchain/retrievers/document_compressors/chain_filter.py +++ b/langchain/retrievers/document_compressors/chain_filter.py @@ -4,9 +4,7 @@ from langchain import BasePromptTemplate, LLMChain, PromptTemplate from langchain.base_language import BaseLanguageModel from langchain.output_parsers.boolean import BooleanOutputParser -from langchain.retrievers.document_compressors.base import ( - BaseDocumentCompressor, -) +from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.retrievers.document_compressors.chain_filter_prompt import ( prompt_template, ) diff --git a/langchain/retrievers/self_query/base.py b/langchain/retrievers/self_query/base.py index 7c9ced716626c..b74dfacadc229 100644 --- a/langchain/retrievers/self_query/base.py +++ b/langchain/retrievers/self_query/base.py @@ -5,9 +5,7 @@ from langchain import LLMChain from langchain.base_language import BaseLanguageModel -from langchain.chains.query_constructor.base import ( - load_query_constructor_chain, -) +from langchain.chains.query_constructor.base import load_query_constructor_chain from langchain.chains.query_constructor.ir import StructuredQuery, Visitor from langchain.chains.query_constructor.schema import AttributeInfo from langchain.retrievers.self_query.pinecone import PineconeTranslator diff --git a/langchain/tools/arxiv/tool.py b/langchain/tools/arxiv/tool.py index 83c211311e3ab..76513e27a187c 100644 --- a/langchain/tools/arxiv/tool.py +++ b/langchain/tools/arxiv/tool.py @@ -1,5 +1,11 @@ """Tool for the Arxiv API.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.arxiv import ArxivAPIWrapper @@ -18,10 +24,18 @@ class ArxivQueryRun(BaseTool): ) api_wrapper: ArxivAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the Arxiv tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Arxiv tool asynchronously.""" raise NotImplementedError("ArxivAPIWrapper does not support async") diff --git a/langchain/tools/base.py b/langchain/tools/base.py index 3388910df74df..fded77c4d4dad 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -19,7 +19,9 @@ from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import ( AsyncCallbackManager, + AsyncCallbackManagerForToolRun, CallbackManager, + CallbackManagerForToolRun, Callbacks, ) @@ -118,7 +120,8 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): that after the tool is called, the AgentExecutor will stop looping. """ verbose: bool = False - """Whether to print the tool's output to the console.""" + """Whether to log the tool's progress.""" + callbacks: Callbacks = None """Callbacks to be called during tool execution.""" callback_manager: Optional[BaseCallbackManager] = None @@ -174,11 +177,23 @@ def _run( *args: Any, **kwargs: Any, ) -> Any: - """Use the tool.""" + """Use the tool. + + Add run_manager: Optional[CallbackManagerForToolRun] = None + to child implementations to enable tracing, + """ @abstractmethod - async def _arun(self, *args: Any, **kwargs: Any) -> Any: - """Use the tool asynchronously.""" + async def _arun( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Use the tool asynchronously. + + Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None + to child implementations to enable tracing, + """ def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: # For backwards compatibility, if run_input is a string, @@ -279,9 +294,9 @@ class StructuredTool(BaseTool): description: str = "" args_schema: Type[BaseModel] = Field(..., description="The tool schema.") """The input arguments' schema.""" - func: Callable[..., str] + func: Callable[..., Any] """The function to run when the tool is called.""" - coroutine: Optional[Callable[..., Awaitable[str]]] = None + coroutine: Optional[Callable[..., Awaitable[Any]]] = None """The asynchronous version of the function.""" @property @@ -289,14 +304,44 @@ def args(self) -> dict: """The tool's input arguments.""" return self.args_schema.schema()["properties"] - def _run(self, *args: Any, **kwargs: Any) -> Any: + def _run( + self, + *args: Any, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: """Use the tool.""" - return self.func(*args, **kwargs) + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) + ) - async def _arun(self, *args: Any, **kwargs: Any) -> Any: + async def _arun( + self, + *args: Any, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: """Use the tool asynchronously.""" if self.coroutine: - return await self.coroutine(*args, **kwargs) + new_argument_supported = signature(self.coroutine).parameters.get( + "callbacks" + ) + return ( + await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else await self.coroutine(*args, **kwargs) + ) raise NotImplementedError("Tool does not support async") @classmethod diff --git a/langchain/tools/bing_search/tool.py b/langchain/tools/bing_search/tool.py index dd57295c7178d..3340a55af689b 100644 --- a/langchain/tools/bing_search/tool.py +++ b/langchain/tools/bing_search/tool.py @@ -1,5 +1,11 @@ """Tool for the Bing search API.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.bing_search import BingSearchAPIWrapper @@ -15,11 +21,19 @@ class BingSearchRun(BaseTool): ) api_wrapper: BingSearchAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("BingSearchRun does not support async") @@ -36,10 +50,18 @@ class BingSearchResults(BaseTool): num_results: int = 4 api_wrapper: BingSearchAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return str(self.api_wrapper.results(query, self.num_results)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("BingSearchResults does not support async") diff --git a/langchain/tools/ddg_search/tool.py b/langchain/tools/ddg_search/tool.py index 5948756fdfbde..109431e20c490 100644 --- a/langchain/tools/ddg_search/tool.py +++ b/langchain/tools/ddg_search/tool.py @@ -1,10 +1,14 @@ """Tool for the DuckDuckGo search API.""" import warnings -from typing import Any +from typing import Any, Optional from pydantic import Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper @@ -22,11 +26,19 @@ class DuckDuckGoSearchRun(BaseTool): default_factory=DuckDuckGoSearchAPIWrapper ) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("DuckDuckGoSearch does not support async") @@ -45,11 +57,19 @@ class DuckDuckGoSearchResults(BaseTool): default_factory=DuckDuckGoSearchAPIWrapper ) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return str(self.api_wrapper.results(query, self.num_results)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("DuckDuckGoSearchResults does not support async") diff --git a/langchain/tools/file_management/copy.py b/langchain/tools/file_management/copy.py index 0fb23c7ec329e..5231c7d4157c0 100644 --- a/langchain/tools/file_management/copy.py +++ b/langchain/tools/file_management/copy.py @@ -1,11 +1,16 @@ import shutil -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -17,12 +22,17 @@ class FileCopyInput(BaseModel): destination_path: str = Field(..., description="Path to save the copied file") -class CopyFileTool(BaseFileTool): +class CopyFileTool(BaseFileToolMixin, BaseTool): name: str = "copy_file" args_schema: Type[BaseModel] = FileCopyInput description: str = "Create a copy of a file in a specified location" - def _run(self, source_path: str, destination_path: str) -> str: + def _run( + self, + source_path: str, + destination_path: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: source_path_ = self.get_relative_path(source_path) except FileValidationError: @@ -41,6 +51,11 @@ def _run(self, source_path: str, destination_path: str) -> str: except Exception as e: return "Error: " + str(e) - async def _arun(self, source_path: str, destination_path: str) -> str: + async def _arun( + self, + source_path: str, + destination_path: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/delete.py b/langchain/tools/file_management/delete.py index 218cf606ef81e..bf00e707c1103 100644 --- a/langchain/tools/file_management/delete.py +++ b/langchain/tools/file_management/delete.py @@ -1,11 +1,16 @@ import os -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -16,12 +21,16 @@ class FileDeleteInput(BaseModel): file_path: str = Field(..., description="Path of the file to delete") -class DeleteFileTool(BaseFileTool): +class DeleteFileTool(BaseFileToolMixin, BaseTool): name: str = "file_delete" args_schema: Type[BaseModel] = FileDeleteInput description: str = "Delete a file" - def _run(self, file_path: str) -> str: + def _run( + self, + file_path: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: file_path_ = self.get_relative_path(file_path) except FileValidationError: @@ -34,6 +43,10 @@ def _run(self, file_path: str) -> str: except Exception as e: return "Error: " + str(e) - async def _arun(self, file_path: str) -> str: + async def _arun( + self, + file_path: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/file_search.py b/langchain/tools/file_management/file_search.py index 7e2f1d9302234..ce67f59d9c9ec 100644 --- a/langchain/tools/file_management/file_search.py +++ b/langchain/tools/file_management/file_search.py @@ -1,12 +1,17 @@ import fnmatch import os -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -24,14 +29,19 @@ class FileSearchInput(BaseModel): ) -class FileSearchTool(BaseFileTool): +class FileSearchTool(BaseFileToolMixin, BaseTool): name: str = "file_search" args_schema: Type[BaseModel] = FileSearchInput description: str = ( "Recursively search for files in a subdirectory that match the regex pattern" ) - def _run(self, pattern: str, dir_path: str = ".") -> str: + def _run( + self, + pattern: str, + dir_path: str = ".", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: dir_path_ = self.get_relative_path(dir_path) except FileValidationError: @@ -50,6 +60,11 @@ def _run(self, pattern: str, dir_path: str = ".") -> str: except Exception as e: return "Error: " + str(e) - async def _arun(self, dir_path: str, pattern: str) -> str: + async def _arun( + self, + dir_path: str, + pattern: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/list_dir.py b/langchain/tools/file_management/list_dir.py index ff5cb8a143f26..f013257da18e2 100644 --- a/langchain/tools/file_management/list_dir.py +++ b/langchain/tools/file_management/list_dir.py @@ -1,11 +1,16 @@ import os -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -16,12 +21,16 @@ class DirectoryListingInput(BaseModel): dir_path: str = Field(default=".", description="Subdirectory to list.") -class ListDirectoryTool(BaseFileTool): +class ListDirectoryTool(BaseFileToolMixin, BaseTool): name: str = "list_directory" args_schema: Type[BaseModel] = DirectoryListingInput description: str = "List files and directories in a specified folder" - def _run(self, dir_path: str = ".") -> str: + def _run( + self, + dir_path: str = ".", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: dir_path_ = self.get_relative_path(dir_path) except FileValidationError: @@ -35,6 +44,10 @@ def _run(self, dir_path: str = ".") -> str: except Exception as e: return "Error: " + str(e) - async def _arun(self, dir_path: str) -> str: + async def _arun( + self, + dir_path: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/move.py b/langchain/tools/file_management/move.py index ccf8879620a6f..b4cc1a9454357 100644 --- a/langchain/tools/file_management/move.py +++ b/langchain/tools/file_management/move.py @@ -1,11 +1,16 @@ import shutil -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -17,12 +22,17 @@ class FileMoveInput(BaseModel): destination_path: str = Field(..., description="New path for the moved file") -class MoveFileTool(BaseFileTool): +class MoveFileTool(BaseFileToolMixin, BaseTool): name: str = "move_file" args_schema: Type[BaseModel] = FileMoveInput description: str = "Move or rename a file from one location to another" - def _run(self, source_path: str, destination_path: str) -> str: + def _run( + self, + source_path: str, + destination_path: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: source_path_ = self.get_relative_path(source_path) except FileValidationError: @@ -44,6 +54,11 @@ def _run(self, source_path: str, destination_path: str) -> str: except Exception as e: return "Error: " + str(e) - async def _arun(self, source_path: str, destination_path: str) -> str: + async def _arun( + self, + source_path: str, + destination_path: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/read.py b/langchain/tools/file_management/read.py index d243a9e3a34ba..86d6191d6f621 100644 --- a/langchain/tools/file_management/read.py +++ b/langchain/tools/file_management/read.py @@ -1,10 +1,15 @@ -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -15,12 +20,16 @@ class ReadFileInput(BaseModel): file_path: str = Field(..., description="name of file") -class ReadFileTool(BaseFileTool): +class ReadFileTool(BaseFileToolMixin, BaseTool): name: str = "read_file" args_schema: Type[BaseModel] = ReadFileInput description: str = "Read file from disk" - def _run(self, file_path: str) -> str: + def _run( + self, + file_path: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: read_path = self.get_relative_path(file_path) except FileValidationError: @@ -34,6 +43,10 @@ def _run(self, file_path: str) -> str: except Exception as e: return "Error: " + str(e) - async def _arun(self, file_path: str) -> str: + async def _arun( + self, + file_path: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/utils.py b/langchain/tools/file_management/utils.py index c8efefb4fc009..788823fecd739 100644 --- a/langchain/tools/file_management/utils.py +++ b/langchain/tools/file_management/utils.py @@ -1,11 +1,9 @@ import sys from pathlib import Path -from typing import Any, Optional +from typing import Optional from pydantic import BaseModel -from langchain.tools.base import BaseTool - def is_relative_to(path: Path, root: Path) -> bool: """Check if path is relative to root.""" @@ -29,8 +27,8 @@ class FileValidationError(ValueError): """Error for paths outside the root directory.""" -class BaseFileTool(BaseTool, BaseModel): - """Input for ReadFileTool.""" +class BaseFileToolMixin(BaseModel): + """Mixin for file system tools.""" root_dir: Optional[str] = None """The final path will be chosen relative to root_dir if specified.""" @@ -41,12 +39,6 @@ def get_relative_path(self, file_path: str) -> Path: return Path(file_path) return get_validated_relative_path(Path(self.root_dir), file_path) - def _run(self, *args: Any, **kwargs: Any) -> str: - raise NotImplementedError - - async def _arun(self, *args: Any, **kwargs: Any) -> str: - raise NotImplementedError - def get_validated_relative_path(root: Path, user_path: str) -> Path: """Resolve a relative path, raising an error if not within the root directory.""" diff --git a/langchain/tools/file_management/write.py b/langchain/tools/file_management/write.py index 865bcbe7da742..fcebe1c7a1a96 100644 --- a/langchain/tools/file_management/write.py +++ b/langchain/tools/file_management/write.py @@ -1,10 +1,15 @@ -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -19,12 +24,18 @@ class WriteFileInput(BaseModel): ) -class WriteFileTool(BaseFileTool): +class WriteFileTool(BaseFileToolMixin, BaseTool): name: str = "write_file" args_schema: Type[BaseModel] = WriteFileInput description: str = "Write file to disk" - def _run(self, file_path: str, text: str, append: bool = False) -> str: + def _run( + self, + file_path: str, + text: str, + append: bool = False, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: write_path = self.get_relative_path(file_path) except FileValidationError: @@ -38,6 +49,12 @@ def _run(self, file_path: str, text: str, append: bool = False) -> str: except Exception as e: return "Error: " + str(e) - async def _arun(self, file_path: str, text: str, append: bool = False) -> str: + async def _arun( + self, + file_path: str, + text: str, + append: bool = False, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/google_places/tool.py b/langchain/tools/google_places/tool.py index 31ae39dae80b5..cce83642826d3 100644 --- a/langchain/tools/google_places/tool.py +++ b/langchain/tools/google_places/tool.py @@ -1,7 +1,13 @@ """Tool for the Google search API.""" +from typing import Optional + from pydantic import Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.google_places_api import GooglePlacesAPIWrapper @@ -18,10 +24,18 @@ class GooglePlacesTool(BaseTool): ) api_wrapper: GooglePlacesAPIWrapper = Field(default_factory=GooglePlacesAPIWrapper) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("GooglePlacesRun does not support async") diff --git a/langchain/tools/google_search/tool.py b/langchain/tools/google_search/tool.py index 1945a3df8b009..71288e19c8191 100644 --- a/langchain/tools/google_search/tool.py +++ b/langchain/tools/google_search/tool.py @@ -1,5 +1,11 @@ """Tool for the Google search API.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.google_search import GoogleSearchAPIWrapper @@ -15,11 +21,19 @@ class GoogleSearchRun(BaseTool): ) api_wrapper: GoogleSearchAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("GoogleSearchRun does not support async") @@ -36,10 +50,18 @@ class GoogleSearchResults(BaseTool): num_results: int = 4 api_wrapper: GoogleSearchAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return str(self.api_wrapper.results(query, self.num_results)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("GoogleSearchRun does not support async") diff --git a/langchain/tools/human/tool.py b/langchain/tools/human/tool.py index de2cce81cfa0a..a207c6b179b0d 100644 --- a/langchain/tools/human/tool.py +++ b/langchain/tools/human/tool.py @@ -1,9 +1,13 @@ """Tool for asking human input.""" -from typing import Callable +from typing import Callable, Optional from pydantic import Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool @@ -24,11 +28,19 @@ class HumanInputRun(BaseTool): prompt_func: Callable[[str], None] = Field(default_factory=lambda: _print_func) input_func: Callable = Field(default_factory=lambda: input) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the Human input tool.""" self.prompt_func(query) return self.input_func() - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Human tool asynchronously.""" raise NotImplementedError("Human tool does not support async") diff --git a/langchain/tools/ifttt.py b/langchain/tools/ifttt.py index 8d3d943af0c59..e42c232f46342 100644 --- a/langchain/tools/ifttt.py +++ b/langchain/tools/ifttt.py @@ -32,8 +32,14 @@ - Copy the IFTTT key value from there. The URL is of the form https://maker.ifttt.com/use/YOUR_IFTTT_KEY. Grab the YOUR_IFTTT_KEY value. """ +from typing import Optional + import requests +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool @@ -48,10 +54,18 @@ class IFTTTWebhook(BaseTool): url: str - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: body = {"this": tool_input} response = requests.post(self.url, data=body) return response.text - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: raise NotImplementedError("Not implemented.") diff --git a/langchain/tools/interaction/tool.py b/langchain/tools/interaction/tool.py index ee2b51ca4cac7..096c885db2f21 100644 --- a/langchain/tools/interaction/tool.py +++ b/langchain/tools/interaction/tool.py @@ -1,6 +1,12 @@ """Tools for interacting with the user.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + RunManager, +) from langchain.tools.base import BaseTool @@ -14,9 +20,13 @@ class StdInInquireTool(BaseTool): " question (to disambiguate) or a request for more context." ) - def _run(self, prompt: str) -> str: + def _run(self, prompt: str, run_manager: Optional[RunManager] = None) -> str: """Prompt the user for more input.""" return input(f"\n{prompt}") - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support async") diff --git a/langchain/tools/jira/tool.py b/langchain/tools/jira/tool.py index 86861759d2bd8..6c75ca9155adc 100644 --- a/langchain/tools/jira/tool.py +++ b/langchain/tools/jira/tool.py @@ -28,8 +28,14 @@ ) ``` """ +from typing import Optional + from pydantic import Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.jira import JiraAPIWrapper @@ -40,10 +46,18 @@ class JiraAction(BaseTool): name = "" description = "" - def _run(self, instructions: str) -> str: + def _run( + self, + instructions: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the Atlassian Jira API to run an operation.""" return self.api_wrapper.run(self.mode, instructions) - async def _arun(self, _: str) -> str: + async def _arun( + self, + _: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Atlassian Jira API to run an operation.""" raise NotImplementedError("JiraAction does not support async") diff --git a/langchain/tools/json/tool.py b/langchain/tools/json/tool.py index 9f1bdac737cb0..6f6473d51e6b4 100644 --- a/langchain/tools/json/tool.py +++ b/langchain/tools/json/tool.py @@ -5,10 +5,14 @@ import json import re from pathlib import Path -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool @@ -88,10 +92,18 @@ class JsonListKeysTool(BaseTool): """ spec: JsonSpec - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: return self.spec.keys(tool_input) - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: return self._run(tool_input) @@ -106,8 +118,16 @@ class JsonGetValueTool(BaseTool): """ spec: JsonSpec - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: return self.spec.value(tool_input) - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: return self._run(tool_input) diff --git a/langchain/tools/playwright/base.py b/langchain/tools/playwright/base.py index 95db7f92d897a..79857b45b4e42 100644 --- a/langchain/tools/playwright/base.py +++ b/langchain/tools/playwright/base.py @@ -1,9 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from pydantic import Field, root_validator +from langchain.callbacks.manager import ( + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.tools.playwright.utils import create_playwright_browser, run_async @@ -28,7 +31,12 @@ def check_args(cls, values: dict) -> dict: ) return values - def _run(self, *args: Any, **kwargs: Any) -> str: + def _run( + self, + *args: Any, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: """Use the tool.""" return run_async(self._arun(*args, **kwargs)) diff --git a/langchain/tools/playwright/click.py b/langchain/tools/playwright/click.py index 0d963d35aec05..d11cdde46793b 100644 --- a/langchain/tools/playwright/click.py +++ b/langchain/tools/playwright/click.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import AsyncCallbackManagerForToolRun from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( get_current_page, @@ -21,7 +22,11 @@ class ClickTool(BaseBrowserTool): description: str = "Click on an element with the given CSS selector" args_schema: Type[BaseModel] = ClickToolInput - async def _arun(self, selector: str) -> str: + async def _arun( + self, + selector: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" page = await get_current_page(self.browser) # Navigate to the desired webpage before using this tool diff --git a/langchain/tools/playwright/current_page.py b/langchain/tools/playwright/current_page.py index bde0ff8acbcf8..fe6593fa3674a 100644 --- a/langchain/tools/playwright/current_page.py +++ b/langchain/tools/playwright/current_page.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Type +from typing import Optional, Type from pydantic import BaseModel +from langchain.callbacks.manager import AsyncCallbackManagerForToolRun from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( get_current_page, @@ -15,7 +16,10 @@ class CurrentWebPageTool(BaseBrowserTool): description: str = "Returns the URL of the current page" args_schema: Type[BaseModel] = BaseModel - async def _arun(self) -> str: + async def _arun( + self, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" page = await get_current_page(self.browser) return str(page.url) diff --git a/langchain/tools/playwright/extract_hyperlinks.py b/langchain/tools/playwright/extract_hyperlinks.py index 9e792f198c0a8..6d903e0101c97 100644 --- a/langchain/tools/playwright/extract_hyperlinks.py +++ b/langchain/tools/playwright/extract_hyperlinks.py @@ -1,10 +1,11 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Type +from typing import TYPE_CHECKING, Optional, Type from pydantic import BaseModel, Field, root_validator +from langchain.callbacks.manager import AsyncCallbackManagerForToolRun from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import get_current_page @@ -40,7 +41,11 @@ def check_args(cls, values: dict) -> dict: ) return values - async def _arun(self, absolute_urls: bool = False) -> str: + async def _arun( + self, + absolute_urls: bool = False, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" from urllib.parse import urljoin diff --git a/langchain/tools/playwright/extract_text.py b/langchain/tools/playwright/extract_text.py index 0ced6d35d3949..5b20b7c38ab5b 100644 --- a/langchain/tools/playwright/extract_text.py +++ b/langchain/tools/playwright/extract_text.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, root_validator +from langchain.callbacks.manager import AsyncCallbackManagerForToolRun from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import get_current_page @@ -25,7 +26,9 @@ def check_args(cls, values: dict) -> dict: ) return values - async def _arun(self) -> str: + async def _arun( + self, run_manager: Optional[AsyncCallbackManagerForToolRun] = None + ) -> str: """Use the tool.""" # Use Beautiful Soup since it's faster than looping through the elements from bs4 import BeautifulSoup diff --git a/langchain/tools/playwright/get_elements.py b/langchain/tools/playwright/get_elements.py index 2a90112d40c27..910c022cf7fb9 100644 --- a/langchain/tools/playwright/get_elements.py +++ b/langchain/tools/playwright/get_elements.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field +from langchain.callbacks.manager import AsyncCallbackManagerForToolRun from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import get_current_page @@ -53,7 +54,10 @@ class GetElementsTool(BaseBrowserTool): args_schema: Type[BaseModel] = GetElementsToolInput async def _arun( - self, selector: str, attributes: Sequence[str] = ["innerText"] + self, + selector: str, + attributes: Sequence[str] = ["innerText"], + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" page = await get_current_page(self.browser) diff --git a/langchain/tools/playwright/navigate.py b/langchain/tools/playwright/navigate.py index cac357195b122..cd1e4dbc1028d 100644 --- a/langchain/tools/playwright/navigate.py +++ b/langchain/tools/playwright/navigate.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import AsyncCallbackManagerForToolRun from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( get_current_page, @@ -21,7 +22,11 @@ class NavigateTool(BaseBrowserTool): description: str = "Navigate a browser to the specified URL" args_schema: Type[BaseModel] = NavigateToolInput - async def _arun(self, url: str) -> str: + async def _arun( + self, + url: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" page = await get_current_page(self.browser) response = await page.goto(url) diff --git a/langchain/tools/playwright/navigate_back.py b/langchain/tools/playwright/navigate_back.py index 114fc81c77a84..1a93e2636b918 100644 --- a/langchain/tools/playwright/navigate_back.py +++ b/langchain/tools/playwright/navigate_back.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Type +from typing import Optional, Type from pydantic import BaseModel +from langchain.callbacks.manager import AsyncCallbackManagerForToolRun from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( get_current_page, @@ -17,7 +18,10 @@ class NavigateBackTool(BaseBrowserTool): description: str = "Navigate back to the previous page in the browser history" args_schema: Type[BaseModel] = BaseModel - async def _arun(self) -> str: + async def _arun( + self, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" page = await get_current_page(self.browser) response = await page.go_back() diff --git a/langchain/tools/plugin.py b/langchain/tools/plugin.py index 8f5fadd8e27d8..0510eaaa2ce28 100644 --- a/langchain/tools/plugin.py +++ b/langchain/tools/plugin.py @@ -7,6 +7,10 @@ import yaml from pydantic import BaseModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool @@ -72,10 +76,18 @@ def from_plugin_url(cls, url: str) -> AIPluginTool: api_spec=api_spec, ) - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.api_spec - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" return self.api_spec diff --git a/langchain/tools/powerbi/tool.py b/langchain/tools/powerbi/tool.py index 67efe42339888..633f99d355f49 100644 --- a/langchain/tools/powerbi/tool.py +++ b/langchain/tools/powerbi/tool.py @@ -3,6 +3,10 @@ from pydantic import Field, validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.chains.llm import LLMChain from langchain.tools.base import BaseTool from langchain.tools.powerbi.prompt import ( @@ -45,7 +49,11 @@ def _check_cache(self, tool_input: str) -> Optional[str]: self.session_cache[tool_input] = BAD_REQUEST_RESPONSE_ESCALATED return self.session_cache[tool_input] - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Execute the query, return the results or an error message.""" if cache := self._check_cache(tool_input): return cache @@ -67,7 +75,11 @@ def _run(self, tool_input: str) -> str: ) return self.session_cache[tool_input] - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Execute the query, return the results or an error message.""" if cache := self._check_cache(tool_input): return cache @@ -107,11 +119,19 @@ class Config: arbitrary_types_allowed = True - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Get the schema for tables in a comma-separated list.""" return self.powerbi.get_table_info(tool_input.split(", ")) - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: return await self.powerbi.aget_table_info(tool_input.split(", ")) @@ -127,11 +147,21 @@ class Config: arbitrary_types_allowed = True - def _run(self, *args: Any, **kwargs: Any) -> str: + def _run( + self, + *args: Any, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: """Get the names of the tables.""" return ", ".join(self.powerbi.get_table_names()) - async def _arun(self, *args: Any, **kwargs: Any) -> str: + async def _arun( + self, + *args: Any, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: """Get the names of the tables.""" return ", ".join(self.powerbi.get_table_names()) @@ -171,7 +201,11 @@ def validate_llm_chain_input_variables( # pylint: disable=E0213 ) return llm_chain - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the LLM to check the query.""" return self.llm_chain.predict( tool_input=tool_input, @@ -180,7 +214,11 @@ def _run(self, tool_input: str) -> str: examples=self.examples, ) - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: return await self.llm_chain.apredict( tool_input=tool_input, tables=self.powerbi.get_table_names(), diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index 607fec222d568..2e67d6701c74b 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -3,10 +3,14 @@ import ast import sys from io import StringIO -from typing import Dict, Optional +from typing import Any, Dict, Optional from pydantic import Field, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities import PythonREPL @@ -28,13 +32,23 @@ class PythonREPLTool(BaseTool): python_repl: PythonREPL = Field(default_factory=_get_default_python_repl) sanitize_input: bool = True - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: """Use the tool.""" if self.sanitize_input: query = query.strip().strip("```") return self.python_repl.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: """Use the tool asynchronously.""" raise NotImplementedError("PythonReplTool does not support async") @@ -64,7 +78,11 @@ def validate_python_version(cls, values: Dict) -> Dict: ) return values - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" try: if self.sanitize_input: @@ -91,6 +109,10 @@ def _run(self, query: str) -> str: except Exception as e: return "{}: {}".format(type(e).__name__, str(e)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("PythonReplTool does not support async") diff --git a/langchain/tools/requests/tool.py b/langchain/tools/requests/tool.py index 1bfc8bc7c1543..64b25303c5862 100644 --- a/langchain/tools/requests/tool.py +++ b/langchain/tools/requests/tool.py @@ -1,9 +1,13 @@ # flake8: noqa """Tools for making requests to an API endpoint.""" import json -from typing import Any, Dict +from typing import Any, Dict, Optional from pydantic import BaseModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.requests import TextRequestsWrapper from langchain.tools.base import BaseTool @@ -31,11 +35,17 @@ class RequestsGetTool(BaseRequestsTool, BaseTool): name = "requests_get" description = "A portal to the internet. Use this when you need to get specific content from a website. Input should be a url (i.e. https://www.google.com). The output will be the text response of the GET request." - def _run(self, url: str) -> str: + def _run( + self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Run the tool.""" return self.requests_wrapper.get(_clean_url(url)) - async def _arun(self, url: str) -> str: + async def _arun( + self, + url: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run the tool asynchronously.""" return await self.requests_wrapper.aget(_clean_url(url)) @@ -52,7 +62,9 @@ class RequestsPostTool(BaseRequestsTool, BaseTool): The output will be the text response of the POST request. """ - def _run(self, text: str) -> str: + def _run( + self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Run the tool.""" try: data = _parse_input(text) @@ -60,7 +72,11 @@ def _run(self, text: str) -> str: except Exception as e: return repr(e) - async def _arun(self, text: str) -> str: + async def _arun( + self, + text: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run the tool asynchronously.""" try: data = _parse_input(text) @@ -83,7 +99,9 @@ class RequestsPatchTool(BaseRequestsTool, BaseTool): The output will be the text response of the PATCH request. """ - def _run(self, text: str) -> str: + def _run( + self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Run the tool.""" try: data = _parse_input(text) @@ -91,7 +109,11 @@ def _run(self, text: str) -> str: except Exception as e: return repr(e) - async def _arun(self, text: str) -> str: + async def _arun( + self, + text: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run the tool asynchronously.""" try: data = _parse_input(text) @@ -114,7 +136,9 @@ class RequestsPutTool(BaseRequestsTool, BaseTool): The output will be the text response of the PUT request. """ - def _run(self, text: str) -> str: + def _run( + self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Run the tool.""" try: data = _parse_input(text) @@ -122,7 +146,11 @@ def _run(self, text: str) -> str: except Exception as e: return repr(e) - async def _arun(self, text: str) -> str: + async def _arun( + self, + text: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run the tool asynchronously.""" try: data = _parse_input(text) @@ -139,10 +167,18 @@ class RequestsDeleteTool(BaseRequestsTool, BaseTool): name = "requests_delete" description = "A portal to the internet. Use this when you need to make a DELETE request to a URL. Input should be a specific url, and the output will be the text response of the DELETE request." - def _run(self, url: str) -> str: + def _run( + self, + url: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Run the tool.""" return self.requests_wrapper.delete(_clean_url(url)) - async def _arun(self, url: str) -> str: + async def _arun( + self, + url: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run the tool asynchronously.""" return await self.requests_wrapper.adelete(_clean_url(url)) diff --git a/langchain/tools/searx_search/tool.py b/langchain/tools/searx_search/tool.py index a91f7e279c508..e3ea04b5b496f 100644 --- a/langchain/tools/searx_search/tool.py +++ b/langchain/tools/searx_search/tool.py @@ -1,6 +1,12 @@ """Tool for the SearxNG search API.""" +from typing import Optional + from pydantic import Extra +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.searx_search import SearxSearchWrapper @@ -16,11 +22,19 @@ class SearxSearchRun(BaseTool): ) wrapper: SearxSearchWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" return await self.wrapper.arun(query) @@ -42,10 +56,18 @@ class Config: extra = Extra.allow - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return str(self.wrapper.results(query, self.num_results)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" return (await self.wrapper.aresults(query, self.num_results)).__str__() diff --git a/langchain/tools/shell/tool.py b/langchain/tools/shell/tool.py index 8f9ecaefb031f..42e19038158b2 100644 --- a/langchain/tools/shell/tool.py +++ b/langchain/tools/shell/tool.py @@ -1,10 +1,14 @@ import asyncio import platform import warnings -from typing import List, Type +from typing import List, Optional, Type from pydantic import BaseModel, Field, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.bash import BashProcess @@ -60,11 +64,19 @@ class ShellTool(BaseTool): args_schema: Type[BaseModel] = ShellInput """Schema for input arguments.""" - def _run(self, commands: List[str]) -> str: + def _run( + self, + commands: List[str], + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Run commands and return final output.""" return self.process.run(commands) - async def _arun(self, commands: List[str]) -> str: + async def _arun( + self, + commands: List[str], + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run commands asynchronously and return final output.""" return await asyncio.get_event_loop().run_in_executor( None, self.process.run, commands diff --git a/langchain/tools/sql_database/tool.py b/langchain/tools/sql_database/tool.py index 6e85087b574be..2e677c6c814e9 100644 --- a/langchain/tools/sql_database/tool.py +++ b/langchain/tools/sql_database/tool.py @@ -1,12 +1,17 @@ # flake8: noqa """Tools for interacting with a SQL database.""" -from pydantic import BaseModel, Extra, Field, validator, root_validator -from typing import Any, Dict +from typing import Any, Dict, Optional +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate from langchain.sql_database import SQLDatabase -from langchain.base_language import BaseLanguageModel from langchain.tools.base import BaseTool from langchain.tools.sql_database.prompt import QUERY_CHECKER @@ -35,11 +40,19 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): If an error is returned, rewrite the query, check the query, and try again. """ - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Execute the query, return the results or an error message.""" return self.db.run_no_throw(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: raise NotImplementedError("QuerySqlDbTool does not support async") @@ -54,11 +67,19 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): Example Input: "table1, table2, table3" """ - def _run(self, table_names: str) -> str: + def _run( + self, + table_names: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Get the schema for tables in a comma-separated list.""" return self.db.get_table_info_no_throw(table_names.split(", ")) - async def _arun(self, table_name: str) -> str: + async def _arun( + self, + table_name: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: raise NotImplementedError("SchemaSqlDbTool does not support async") @@ -68,11 +89,19 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): name = "list_tables_sql_db" description = "Input is an empty string, output is a comma separated list of tables in the database." - def _run(self, tool_input: str = "") -> str: + def _run( + self, + tool_input: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Get the schema for a specific table.""" return ", ".join(self.db.get_usable_table_names()) - async def _arun(self, tool_input: str = "") -> str: + async def _arun( + self, + tool_input: str = "", + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: raise NotImplementedError("ListTablesSqlDbTool does not support async") @@ -106,9 +135,17 @@ def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the LLM to check the query.""" return self.llm_chain.predict(query=query, dialect=self.db.dialect) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: return await self.llm_chain.apredict(query=query, dialect=self.db.dialect) diff --git a/langchain/tools/vectorstore/tool.py b/langchain/tools/vectorstore/tool.py index 1dd18fd2023e8..983224b4a911d 100644 --- a/langchain/tools/vectorstore/tool.py +++ b/langchain/tools/vectorstore/tool.py @@ -1,10 +1,14 @@ """Tools for interacting with vectorstores.""" import json -from typing import Any, Dict +from typing import Any, Dict, Optional from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain from langchain.llms.base import BaseLLM from langchain.llms.openai import OpenAI @@ -42,14 +46,22 @@ def get_description(name: str, description: str) -> str: ) return template.format(name=name, description=description) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" chain = RetrievalQA.from_chain_type( self.llm, retriever=self.vectorstore.as_retriever() ) return chain.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("VectorStoreQATool does not support async") @@ -70,13 +82,21 @@ def get_description(name: str, description: str) -> str: ) return template.format(name=name, description=description) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" chain = RetrievalQAWithSourcesChain.from_chain_type( self.llm, retriever=self.vectorstore.as_retriever() ) return json.dumps(chain({chain.question_key: query}, return_only_outputs=True)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("VectorStoreQAWithSourcesTool does not support async") diff --git a/langchain/tools/wikipedia/tool.py b/langchain/tools/wikipedia/tool.py index 5bede75b21689..af398d7f93fad 100644 --- a/langchain/tools/wikipedia/tool.py +++ b/langchain/tools/wikipedia/tool.py @@ -1,5 +1,11 @@ """Tool for the Wikipedia API.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.wikipedia import WikipediaAPIWrapper @@ -16,10 +22,18 @@ class WikipediaQueryRun(BaseTool): ) api_wrapper: WikipediaAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the Wikipedia tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Wikipedia tool asynchronously.""" raise NotImplementedError("WikipediaQueryRun does not support async") diff --git a/langchain/tools/wolfram_alpha/tool.py b/langchain/tools/wolfram_alpha/tool.py index ecac7b8f46394..a243d22f6ec99 100644 --- a/langchain/tools/wolfram_alpha/tool.py +++ b/langchain/tools/wolfram_alpha/tool.py @@ -1,5 +1,11 @@ """Tool for the Wolfram Alpha API.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper @@ -16,10 +22,18 @@ class WolframAlphaQueryRun(BaseTool): ) api_wrapper: WolframAlphaAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the WolframAlpha tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the WolframAlpha tool asynchronously.""" raise NotImplementedError("WolframAlphaQueryRun does not support async") diff --git a/langchain/tools/zapier/tool.py b/langchain/tools/zapier/tool.py index f6ec020e2f2e1..f68a3562f7a70 100644 --- a/langchain/tools/zapier/tool.py +++ b/langchain/tools/zapier/tool.py @@ -81,6 +81,10 @@ from pydantic import Field, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT from langchain.utilities.zapier import ZapierNLAWrapper @@ -119,11 +123,17 @@ def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]: ) return values - def _run(self, instructions: str) -> str: + def _run( + self, instructions: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Use the Zapier NLA tool to return a list of all exposed user actions.""" return self.api_wrapper.run_as_str(self.action_id, instructions, self.params) - async def _arun(self, _: str) -> str: + async def _arun( + self, + _: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Zapier NLA tool to return a list of all exposed user actions.""" raise NotImplementedError("ZapierNLAListActions does not support async") @@ -148,11 +158,19 @@ class ZapierNLAListActions(BaseTool): ) api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper) - def _run(self, _: str) -> str: + def _run( + self, + _: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the Zapier NLA tool to return a list of all exposed user actions.""" return self.api_wrapper.list_as_str() - async def _arun(self, _: str) -> str: + async def _arun( + self, + _: str = "", + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Zapier NLA tool to return a list of all exposed user actions.""" raise NotImplementedError("ZapierNLAListActions does not support async") diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index 965d44db110e9..055720d8689a2 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -171,7 +171,7 @@ def structured_tool_input( def test_structured_args_decorator_no_infer_schema() -> None: """Test functionality with structured arguments parsed as a decorator.""" - @tool + @tool(infer_schema=False) def structured_tool_input( arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None ) -> str: @@ -182,7 +182,8 @@ def structured_tool_input( assert structured_tool_input.name == "structured_tool_input" args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}} expected_result = "1, 0.001, {'foo': 'bar'}" - assert structured_tool_input.run(args) == expected_result + with pytest.raises(ValueError): + assert structured_tool_input.run(args) == expected_result def test_structured_single_str_decorator_no_infer_schema() -> None: diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index b852510c5b1fe..1e5022b89c40e 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -25,11 +25,9 @@ def load_memory_variables( def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Pass.""" - pass def clear(self) -> None: """Pass.""" - pass class FakeChain(Chain): diff --git a/tests/unit_tests/tools/test_signatures.py b/tests/unit_tests/tools/test_signatures.py new file mode 100644 index 0000000000000..6a7a912e050f6 --- /dev/null +++ b/tests/unit_tests/tools/test_signatures.py @@ -0,0 +1,41 @@ +"""Test base tool child implementations.""" + + +import inspect +import re +from typing import List, Type + +import pytest + +from langchain.tools.base import BaseTool + + +def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]: + subclasses = [] + for subclass in cls.__subclasses__(): + if not getattr( + subclass, "__abstract__", None + ) and not subclass.__name__.startswith("_"): + subclasses.append(subclass) + subclasses.extend(get_non_abstract_subclasses(subclass)) + return subclasses + + +@pytest.mark.parametrize("cls", get_non_abstract_subclasses(BaseTool)) # type: ignore +def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None: + """Test that tools defined in this repo accept a run manager argument.""" + # This wouldn't be necessary if the BaseTool had a strict API. + if cls._run is not BaseTool._arun: + run_func = cls._run + params = inspect.signature(run_func).parameters + assert "run_manager" in params + pattern = re.compile(r"(?!Async)CallbackManagerForToolRun") + assert bool(re.search(pattern, str(params["run_manager"].annotation))) + assert params["run_manager"].default is None + + if cls._arun is not BaseTool._arun: + run_func = cls._arun + params = inspect.signature(run_func).parameters + assert "run_manager" in params + assert "AsyncCallbackManagerForToolRun" in str(params["run_manager"].annotation) + assert params["run_manager"].default is None From 20ba88816dd75c10926f317bf27496c967b5cced Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Fri, 28 Apr 2023 18:36:22 -0700 Subject: [PATCH 30/36] Call Manager for New Tools (#3755) Couple additional tools landed today --- .../agents/toolkits/examples/playwright.ipynb | 89 +++++++++----- .../agents/tools/examples/sceneXplain.ipynb | 116 ++++++++++++++++++ .../agent_toolkits/playwright/toolkit.py | 55 ++++++--- langchain/tools/__init__.py | 5 +- langchain/tools/base.py | 18 ++- langchain/tools/google_places/tool.py | 6 +- langchain/tools/playwright/__init__.py | 2 - langchain/tools/playwright/base.py | 65 +++++----- langchain/tools/playwright/click.py | 23 +++- langchain/tools/playwright/current_page.py | 23 +++- .../tools/playwright/extract_hyperlinks.py | 48 +++++--- langchain/tools/playwright/extract_text.py | 29 ++++- langchain/tools/playwright/get_elements.py | 50 +++++++- langchain/tools/playwright/navigate.py | 23 +++- langchain/tools/playwright/navigate_back.py | 27 +++- langchain/tools/playwright/utils.py | 24 +++- langchain/tools/plugin.py | 13 +- langchain/tools/scenexplain/__init__.py | 1 + langchain/tools/scenexplain/tool.py | 42 +++++++ langchain/utilities/scenexplain.py | 68 ++++++++++ tests/unit_tests/tools/test_signatures.py | 10 +- 21 files changed, 603 insertions(+), 134 deletions(-) create mode 100644 docs/modules/agents/tools/examples/sceneXplain.ipynb create mode 100644 langchain/tools/scenexplain/__init__.py create mode 100644 langchain/tools/scenexplain/tool.py create mode 100644 langchain/utilities/scenexplain.py diff --git a/docs/modules/agents/toolkits/examples/playwright.ipynb b/docs/modules/agents/toolkits/examples/playwright.ipynb index 0d628c07acab0..e4025d85ffe86 100644 --- a/docs/modules/agents/toolkits/examples/playwright.ipynb +++ b/docs/modules/agents/toolkits/examples/playwright.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -20,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -34,48 +33,71 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 2, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.agents.agent_toolkits import PlayWrightBrowserToolkit\n", + "from langchain.tools.playwright.utils import (\n", + " create_async_playwright_browser,\n", + " create_sync_playwright_browser,# A synchronous browser is available, though it isn't compatible with jupyter.\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ - "from langchain.agents.agent_toolkits import PlayWrightBrowserToolkit" + "# This import is required only for jupyter notebooks, since they have their own eventloop\n", + "import nest_asyncio\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiating a Browser Toolkit\n", + "\n", + "It's always recommended to instantiate using the `from_browser` method so that the " ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[ClickTool(name='click_element', description='Click on an element with the given CSS selector', args_schema=, return_direct=False, verbose=False, callback_manager=, browser= version=112.0.5615.29>),\n", - " NavigateTool(name='navigate_browser', description='Navigate a browser to the specified URL', args_schema=, return_direct=False, verbose=False, callback_manager=, browser= version=112.0.5615.29>),\n", - " NavigateBackTool(name='previous_webpage', description='Navigate back to the previous page in the browser history', args_schema=, return_direct=False, verbose=False, callback_manager=, browser= version=112.0.5615.29>),\n", - " ExtractTextTool(name='extract_text', description='Extract all the text on the current webpage', args_schema=, return_direct=False, verbose=False, callback_manager=, browser= version=112.0.5615.29>),\n", - " ExtractHyperlinksTool(name='extract_hyperlinks', description='Extract all hyperlinks on the current webpage', args_schema=, return_direct=False, verbose=False, callback_manager=, browser= version=112.0.5615.29>),\n", - " GetElementsTool(name='get_elements', description='Retrieve elements in the current web page matching the given CSS selector', args_schema=, return_direct=False, verbose=False, callback_manager=, browser= version=112.0.5615.29>),\n", - " CurrentWebPageTool(name='current_webpage', description='Returns the URL of the current page', args_schema=, return_direct=False, verbose=False, callback_manager=, browser= version=112.0.5615.29>)]" + "[ClickTool(sync_browser=None, async_browser= version=112.0.5615.29>, name='click_element', description='Click on an element with the given CSS selector', args_schema=, return_direct=False, verbose=False, callback_manager=),\n", + " NavigateTool(sync_browser=None, async_browser= version=112.0.5615.29>, name='navigate_browser', description='Navigate a browser to the specified URL', args_schema=, return_direct=False, verbose=False, callback_manager=),\n", + " NavigateBackTool(sync_browser=None, async_browser= version=112.0.5615.29>, name='previous_webpage', description='Navigate back to the previous page in the browser history', args_schema=, return_direct=False, verbose=False, callback_manager=),\n", + " ExtractTextTool(sync_browser=None, async_browser= version=112.0.5615.29>, name='extract_text', description='Extract all the text on the current webpage', args_schema=, return_direct=False, verbose=False, callback_manager=),\n", + " ExtractHyperlinksTool(sync_browser=None, async_browser= version=112.0.5615.29>, name='extract_hyperlinks', description='Extract all hyperlinks on the current webpage', args_schema=, return_direct=False, verbose=False, callback_manager=),\n", + " GetElementsTool(sync_browser=None, async_browser= version=112.0.5615.29>, name='get_elements', description='Retrieve elements in the current web page matching the given CSS selector', args_schema=, return_direct=False, verbose=False, callback_manager=),\n", + " CurrentWebPageTool(sync_browser=None, async_browser= version=112.0.5615.29>, name='current_webpage', description='Returns the URL of the current page', args_schema=, return_direct=False, verbose=False, callback_manager=)]" ] }, - "execution_count": 22, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# This import is required only for jupyter notebooks, since they have their own eventloop\n", - "import nest_asyncio\n", - "nest_asyncio.apply()\n", - "\n", - "toolkit = PlayWrightBrowserToolkit()\n", + "async_browser = create_async_playwright_browser()\n", + "toolkit = PlayWrightBrowserToolkit.from_browser(async_browser=async_browser)\n", "tools = toolkit.get_tools()\n", "tools" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -86,7 +108,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -95,55 +117,55 @@ "'Navigating to https://web.archive.org/web/20230428131116/https://www.cnn.com/world returned status code 200'" ] }, - "execution_count": 24, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "navigate_tool.run({\"url\": \"https://web.archive.org/web/20230428131116/https://www.cnn.com/world\"})" + "await navigate_tool.arun({\"url\": \"https://web.archive.org/web/20230428131116/https://www.cnn.com/world\"})" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'[{\"innerText\": \"As US and Philippine defense ties grow, China warns over Taiwan tensions\"}, {\"innerText\": \"Almost two-thirds of elephant habitat lost across Asia, study finds\"}, {\"innerText\": \"\\\\u2018We don\\\\u2019t sleep \\\\u2026 I would call it fainting\\\\u2019: Working as a doctor in Sudan\\\\u2019s crisis\"}, {\"innerText\": \"Kenya arrests second pastor to face criminal charges \\\\u2018related to mass killing of his followers\\\\u2019\"}, {\"innerText\": \"Ocean census aims to discover 100,000 previously unknown marine species\"}, {\"innerText\": \"Iran\\\\u2019s Navy seizes Marshall Islands-flagged ship\"}, {\"innerText\": \"German landlord wins right to sunbathe naked despite complaints from tenants\"}, {\"innerText\": \"Single people should be \\\\u2018valued\\\\u2019 as Jesus was single, Church of England says\"}, {\"innerText\": \"Turkey\\\\u2019s Erdogan cancels public appearances after falling ill as election nears\"}, {\"innerText\": \"Drought-stricken Spain braces for exceptionally high temperatures expected to break April records\"}, {\"innerText\": \"With Zelensky call, Xi Jinping steps up bid to broker peace \\\\u2013 but does he have a plan?\"}, {\"innerText\": \"Indian and Chinese defense ministers to meet face to face\"}, {\"innerText\": \"Pope to allow women to vote at global bishops meeting\"}, {\"innerText\": \"Catastrophic drought that\\\\u2019s pushed millions into crisis made 100 times more likely by climate change, analysis finds\"}, {\"innerText\": \"\\\\u2018Bring Ya Ya home\\\\u2019: How a panda in the US turbocharged Chinese nationalist sentiment\"}, {\"innerText\": \"\\\\u2018Often they shoot at each other\\\\u2019: Ukrainian drone operator details chaos in Russian ranks\"}, {\"innerText\": \"U.S. talk show host Jerry Springer dies at 79\"}, {\"innerText\": \"Girl to get life-saving treatment for rare immune disease\"}, {\"innerText\": \"Wall Street Journal editor discusses reporter\\\\u2019s arrest in Moscow\"}, {\"innerText\": \"Belgium destroys shipment of American beer after taking issue with \\\\u2018Champagne of Beer\\\\u2019 slogan\"}, {\"innerText\": \"UK Prime Minister Rishi Sunak rocked by resignation of top ally Raab over bullying allegations\"}, {\"innerText\": \"Coronation mishaps King Charles III will want to avoid\"}, {\"innerText\": \"Russian jet accidentally drops bomb on Russian city of Belgorod, state media says\"}, {\"innerText\": \"Queen Camilla\\\\u2019s son, Tom Parker Bowles, says his mother \\\\u2018married the person she loved\\\\u2019\"}, {\"innerText\": \"These Iranian activists fled for freedom. The regime still managed to find them\"}, {\"innerText\": \"A divided Israel stands at a perilous crossroads on its 75th birthday\"}, {\"innerText\": \"Palestinian reporter breaks barriers by reporting in Hebrew on Israeli TV\"}, {\"innerText\": \"One-fifth of water pollution comes from textile dyes. But a shellfish-inspired solution could clean it up\"}, {\"innerText\": \"\\\\u2018People sacrificed their lives for just\\\\u00a010 dollars\\\\u2019: At least 78 killed in Yemen crowd surge\"}, {\"innerText\": \"Israeli police say two men shot near Jewish tomb in Jerusalem in suspected \\\\u2018terror attack\\\\u2019\"}, {\"innerText\": \"Houthis try to reassure skeptics they won\\\\u2019t seek full control of Yemen, as Saudis eye exit\"}, {\"innerText\": \"The week in 33 photos\"}, {\"innerText\": \"Hong Kong\\\\u2019s endangered turtles\"}, {\"innerText\": \"In pictures: Britain\\\\u2019s Queen Camilla\"}, {\"innerText\": \"In pictures: Charles and Camilla\"}, {\"innerText\": \"For years, a UK mining giant was untouchable in Zambia for pollution until a former miner\\\\u2019s son took them on\"}, {\"innerText\": \"Former Sudanese minister Ahmed Haroun wanted on war crimes charges freed from Khartoum prison\"}, {\"innerText\": \"WHO warns of \\\\u2018biological risk\\\\u2019 after Sudan fighters seize lab, as violence mars US-brokered ceasefire\"}, {\"innerText\": \"Rival generals are battling for control in Sudan. Here\\\\u2019s a simple guide to the fighting\"}, {\"innerText\": \"How Colombia\\\\u2019s Petro, a former leftwing guerrilla, found his opening in Washington\"}, {\"innerText\": \"Bolsonaro accidentally created Facebook post questioning Brazil election results, say his attorneys\"}, {\"innerText\": \"Crowd kills over a dozen suspected gang members in Haiti\"}, {\"innerText\": \"Thousands of tequila bottles containing liquid meth seized\"}, {\"innerText\": \"Why send a US stealth submarine to South Korea \\\\u2013 and tell the world about it?\"}, {\"innerText\": \"Fukushima\\\\u2019s fishing industry survived a nuclear disaster. 12 years on, it fears Tokyo\\\\u2019s next move may finish it off\"}, {\"innerText\": \"Singapore executes man for trafficking two pounds of cannabis\"}, {\"innerText\": \"Conservative Thai party looks to woo voters with promise to legalize sex toys\"}, {\"innerText\": \"Watch planes take off in Japan \\\\u2014 from an onsen\"}, {\"innerText\": \"Bilt\\\\u2019s May Rent Day promotion: Fly to Europe for as few as 6,000 Bilt points\"}, {\"innerText\": \"Cabeau just dropped the Evolution Earth, a new eco-minded travel pillow\"}, {\"innerText\": \"Nemo\\\\u2019s Garden: The future of farming could be under the sea\"}, {\"innerText\": \"Cadence\\\\u2019s cult-favorite travel capsules are now available in more sizes\"}, {\"innerText\": \"Judy Blume\\\\u2019s books were formative for generations of readers. Here\\\\u2019s why they endure\"}, {\"innerText\": \"Craft, salvage and sustainability take center stage at Milan Design Week\"}, {\"innerText\": \"Life-sized chocolate King Charles III sculpture unveiled to celebrate coronation\"}, {\"innerText\": \"Rock legend Freddie Mercury\\\\u2019s personal possessions are going up for auction\"}, {\"innerText\": \"John Travolta\\\\u2019s white \\\\u2018Saturday Night Fever\\\\u2019 suit fetches $260K at auction\"}, {\"innerText\": \"The South is in the crosshairs of severe weather again, as the multi-day threat of large hail and tornadoes continues\"}, {\"innerText\": \"Spring snowmelt has cities along the Mississippi bracing for flooding in homes and businesses\"}, {\"innerText\": \"Know the difference between a tornado watch, a tornado warning and a tornado emergency\"}, {\"innerText\": \"Large hail drops on parts of Texas and Florida as South remains at risk of severe storms\"}, {\"innerText\": \"House Republicans adopt bill raising U.S. debt limit and cutting spending\"}, {\"innerText\": \"Judge puts hold on Missouri rule limiting gender-affirming care\"}, {\"innerText\": \"Eleven people killed in suspected Maoist militant attack in central India\"}, {\"innerText\": \"Prosecutors tell judge intel the Air National Guardsman took \\\\u2018far exceeds\\\\u2019 what has been reported\"}, {\"innerText\": \"The son of a Sudanese doctor killed in a mortar attack speaks with Rosemary Church\"}, {\"innerText\": \"Melting snow worsens flooding along the Mississippi River\"}, {\"innerText\": \"Writer E. Jean Carroll testifies in civil suit against Donald Trump\"}, {\"innerText\": \"Nepalese authorities issue record number of Everest permits\"}, {\"innerText\": \"Cruise passenger disappears overboard during trip from Australia to Hawaii\"}, {\"innerText\": \"Watch South Korean president sing \\\\u2018American Pie\\\\u2019 for Biden\"}, {\"innerText\": \"See Russian fighter jet on fire after blowing up mid-flight\"}, {\"innerText\": \"Disney Sues Florida Governor Ron DeSantis\"}, {\"innerText\": \"Yasmeen Lari, \\\\u2018starchitect\\\\u2019 turned social engineer, wins one of architecture\\\\u2019s most coveted prizes\"}, {\"innerText\": \"A massive, newly restored Frank Lloyd Wright mansion is up for sale\"}, {\"innerText\": \"Are these the most sustainable architectural projects in the world?\"}, {\"innerText\": \"Step inside a $72 million London townhouse in a converted army barracks\"}, {\"innerText\": \"A 3D-printing company is preparing to build on the lunar surface. But first, a moonshot at home\"}, {\"innerText\": \"Carolina Panthers select QB Bryce Young with first pick of NFL Draft\"}, {\"innerText\": \"Brittney Griner says she\\\\u2019ll \\\\u2018never go overseas again\\\\u2019 to play unless it\\\\u2019s for the Olympics after being detained in Russia\"}, {\"innerText\": \"Pel\\\\u00e9 added to Portuguese dictionary as an adjective for \\\\u2018out of the ordinary\\\\u2019\"}, {\"innerText\": \"Players reimbursing fans and the interim manager getting sacked: How Tottenham Hotspur fell into disrepair\"}, {\"innerText\": \"This CNN Hero is recruiting recreational divers to help rebuild reefs in Florida one coral at a time\"}, {\"innerText\": \"This CNN Hero offers judgment-free veterinary care for the pets of those experiencing homelessness\"}, {\"innerText\": \"Don\\\\u2019t give up on milestones: A CNN Hero\\\\u2019s message for Autism Awareness Month\"}, {\"innerText\": \"CNN Hero of the Year Nelly Cheboi returned to Kenya with plans to lift more students out of poverty\"}]'" + "'[{\"innerText\": \"These Ukrainian veterinarians are risking their lives to care for dogs and cats in the war zone\"}, {\"innerText\": \"Life in the ocean\\\\u2019s \\\\u2018twilight zone\\\\u2019 could disappear due to the climate crisis\"}, {\"innerText\": \"Clashes renew in West Darfur as food and water shortages worsen in Sudan violence\"}, {\"innerText\": \"Thai policeman\\\\u2019s wife investigated over alleged murder and a dozen other poison cases\"}, {\"innerText\": \"American teacher escaped Sudan on French evacuation plane, with no help offered back home\"}, {\"innerText\": \"Dubai\\\\u2019s emerging hip-hop scene is finding its voice\"}, {\"innerText\": \"How an underwater film inspired a marine protected area off Kenya\\\\u2019s coast\"}, {\"innerText\": \"The Iranian drones deployed by Russia in Ukraine are powered by stolen Western technology, research reveals\"}, {\"innerText\": \"India says border violations erode \\\\u2018entire basis\\\\u2019 of ties with China\"}, {\"innerText\": \"Australian police sift through 3,000 tons of trash for missing woman\\\\u2019s remains\"}, {\"innerText\": \"As US and Philippine defense ties grow, China warns over Taiwan tensions\"}, {\"innerText\": \"Don McLean offers duet with South Korean president who sang \\\\u2018American Pie\\\\u2019 to Biden\"}, {\"innerText\": \"Almost two-thirds of elephant habitat lost across Asia, study finds\"}, {\"innerText\": \"\\\\u2018We don\\\\u2019t sleep \\\\u2026 I would call it fainting\\\\u2019: Working as a doctor in Sudan\\\\u2019s crisis\"}, {\"innerText\": \"Kenya arrests second pastor to face criminal charges \\\\u2018related to mass killing of his followers\\\\u2019\"}, {\"innerText\": \"Russia launches deadly wave of strikes across Ukraine\"}, {\"innerText\": \"Woman forced to leave her forever home or \\\\u2018walk to your death\\\\u2019 she says\"}, {\"innerText\": \"U.S. House Speaker Kevin McCarthy weighs in on Disney-DeSantis feud\"}, {\"innerText\": \"Two sides agree to extend Sudan ceasefire\"}, {\"innerText\": \"Spanish Leopard 2 tanks are on their way to Ukraine, defense minister confirms\"}, {\"innerText\": \"Flamb\\\\u00e9ed pizza thought to have sparked deadly Madrid restaurant fire\"}, {\"innerText\": \"Another bomb found in Belgorod just days after Russia accidentally struck the city\"}, {\"innerText\": \"A Black teen\\\\u2019s murder sparked a crisis over racism in British policing. Thirty years on, little has changed\"}, {\"innerText\": \"Belgium destroys shipment of American beer after taking issue with \\\\u2018Champagne of Beer\\\\u2019 slogan\"}, {\"innerText\": \"UK Prime Minister Rishi Sunak rocked by resignation of top ally Raab over bullying allegations\"}, {\"innerText\": \"Iran\\\\u2019s Navy seizes Marshall Islands-flagged ship\"}, {\"innerText\": \"A divided Israel stands at a perilous crossroads on its 75th birthday\"}, {\"innerText\": \"Palestinian reporter breaks barriers by reporting in Hebrew on Israeli TV\"}, {\"innerText\": \"One-fifth of water pollution comes from textile dyes. But a shellfish-inspired solution could clean it up\"}, {\"innerText\": \"\\\\u2018People sacrificed their lives for just\\\\u00a010 dollars\\\\u2019: At least 78 killed in Yemen crowd surge\"}, {\"innerText\": \"Israeli police say two men shot near Jewish tomb in Jerusalem in suspected \\\\u2018terror attack\\\\u2019\"}, {\"innerText\": \"King Charles III\\\\u2019s coronation: Who\\\\u2019s performing at the ceremony\"}, {\"innerText\": \"The week in 33 photos\"}, {\"innerText\": \"Hong Kong\\\\u2019s endangered turtles\"}, {\"innerText\": \"In pictures: Britain\\\\u2019s Queen Camilla\"}, {\"innerText\": \"Catastrophic drought that\\\\u2019s pushed millions into crisis made 100 times more likely by climate change, analysis finds\"}, {\"innerText\": \"For years, a UK mining giant was untouchable in Zambia for pollution until a former miner\\\\u2019s son took them on\"}, {\"innerText\": \"Former Sudanese minister Ahmed Haroun wanted on war crimes charges freed from Khartoum prison\"}, {\"innerText\": \"WHO warns of \\\\u2018biological risk\\\\u2019 after Sudan fighters seize lab, as violence mars US-brokered ceasefire\"}, {\"innerText\": \"How Colombia\\\\u2019s Petro, a former leftwing guerrilla, found his opening in Washington\"}, {\"innerText\": \"Bolsonaro accidentally created Facebook post questioning Brazil election results, say his attorneys\"}, {\"innerText\": \"Crowd kills over a dozen suspected gang members in Haiti\"}, {\"innerText\": \"Thousands of tequila bottles containing liquid meth seized\"}, {\"innerText\": \"Why send a US stealth submarine to South Korea \\\\u2013 and tell the world about it?\"}, {\"innerText\": \"Fukushima\\\\u2019s fishing industry survived a nuclear disaster. 12 years on, it fears Tokyo\\\\u2019s next move may finish it off\"}, {\"innerText\": \"Singapore executes man for trafficking two pounds of cannabis\"}, {\"innerText\": \"Conservative Thai party looks to woo voters with promise to legalize sex toys\"}, {\"innerText\": \"Inside the Italian village being repopulated by Americans\"}, {\"innerText\": \"Strikes, soaring airfares and yo-yoing hotel fees: A traveler\\\\u2019s guide to the coronation\"}, {\"innerText\": \"A year in Azerbaijan: From spring\\\\u2019s Grand Prix to winter ski adventures\"}, {\"innerText\": \"The bicycle mayor peddling a two-wheeled revolution in Cape Town\"}, {\"innerText\": \"Tokyo ramen shop bans customers from using their phones while eating\"}, {\"innerText\": \"South African opera star will perform at coronation of King Charles III\"}, {\"innerText\": \"Luxury loot under the hammer: France auctions goods seized from drug dealers\"}, {\"innerText\": \"Judy Blume\\\\u2019s books were formative for generations of readers. Here\\\\u2019s why they endure\"}, {\"innerText\": \"Craft, salvage and sustainability take center stage at Milan Design Week\"}, {\"innerText\": \"Life-sized chocolate King Charles III sculpture unveiled to celebrate coronation\"}, {\"innerText\": \"Severe storms to strike the South again as millions in Texas could see damaging winds and hail\"}, {\"innerText\": \"The South is in the crosshairs of severe weather again, as the multi-day threat of large hail and tornadoes continues\"}, {\"innerText\": \"Spring snowmelt has cities along the Mississippi bracing for flooding in homes and businesses\"}, {\"innerText\": \"Know the difference between a tornado watch, a tornado warning and a tornado emergency\"}, {\"innerText\": \"Reporter spotted familiar face covering Sudan evacuation. See what happened next\"}, {\"innerText\": \"This country will soon become the world\\\\u2019s most populated\"}, {\"innerText\": \"April 27, 2023 - Russia-Ukraine news\"}, {\"innerText\": \"\\\\u2018Often they shoot at each other\\\\u2019: Ukrainian drone operator details chaos in Russian ranks\"}, {\"innerText\": \"Hear from family members of Americans stuck in Sudan frustrated with US response\"}, {\"innerText\": \"U.S. talk show host Jerry Springer dies at 79\"}, {\"innerText\": \"Bureaucracy stalling at least one family\\\\u2019s evacuation from Sudan\"}, {\"innerText\": \"Girl to get life-saving treatment for rare immune disease\"}, {\"innerText\": \"Haiti\\\\u2019s crime rate more than doubles in a year\"}, {\"innerText\": \"Ocean census aims to discover 100,000 previously unknown marine species\"}, {\"innerText\": \"Wall Street Journal editor discusses reporter\\\\u2019s arrest in Moscow\"}, {\"innerText\": \"Can Tunisia\\\\u2019s democracy be saved?\"}, {\"innerText\": \"Yasmeen Lari, \\\\u2018starchitect\\\\u2019 turned social engineer, wins one of architecture\\\\u2019s most coveted prizes\"}, {\"innerText\": \"A massive, newly restored Frank Lloyd Wright mansion is up for sale\"}, {\"innerText\": \"Are these the most sustainable architectural projects in the world?\"}, {\"innerText\": \"Step inside a $72 million London townhouse in a converted army barracks\"}, {\"innerText\": \"A 3D-printing company is preparing to build on the lunar surface. But first, a moonshot at home\"}, {\"innerText\": \"Simona Halep says \\\\u2018the stress is huge\\\\u2019 as she battles to return to tennis following positive drug test\"}, {\"innerText\": \"Barcelona reaches third straight Women\\\\u2019s Champions League final with draw against Chelsea\"}, {\"innerText\": \"Wrexham: An intoxicating tale of Hollywood glamor and sporting romance\"}, {\"innerText\": \"Shohei Ohtani comes within inches of making yet more MLB history in Angels win\"}, {\"innerText\": \"This CNN Hero is recruiting recreational divers to help rebuild reefs in Florida one coral at a time\"}, {\"innerText\": \"This CNN Hero offers judgment-free veterinary care for the pets of those experiencing homelessness\"}, {\"innerText\": \"Don\\\\u2019t give up on milestones: A CNN Hero\\\\u2019s message for Autism Awareness Month\"}, {\"innerText\": \"CNN Hero of the Year Nelly Cheboi returned to Kenya with plans to lift more students out of poverty\"}]'" ] }, - "execution_count": 25, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The browser is shared across tools, so the agent can interact in a stateful manner\n", - "get_elements_tool.run({\"selector\": \".container__headline\", \"attributes\": [\"innerText\"]})" + "await get_elements_tool.arun({\"selector\": \".container__headline\", \"attributes\": [\"innerText\"]})" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'https://web.archive.org/web/20230428033754/https://www.cnn.com/world'" + "'https://web.archive.org/web/20230428133211/https://cnn.com/world'" ] }, - "execution_count": 26, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# If the agent wants to remember the current webpage, it can use the `current_webpage` tool\n", - "tools_by_name['current_webpage'].run({})" + "await tools_by_name['current_webpage'].arun({})" ] }, { @@ -156,7 +178,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -171,9 +193,8 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/docs/modules/agents/tools/examples/sceneXplain.ipynb b/docs/modules/agents/tools/examples/sceneXplain.ipynb new file mode 100644 index 0000000000000..48ec640226fd3 --- /dev/null +++ b/docs/modules/agents/tools/examples/sceneXplain.ipynb @@ -0,0 +1,116 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SceneXplain\n", + "\n", + "\n", + "[SceneXplain](https://scenex.jina.ai/) is an ImageCaptioning service accessible through the SceneXplain Tool.\n", + "\n", + "To use this tool, you'll need to make an account and fetch your API Token [from the website](https://scenex.jina.ai/api). Then you can instantiate the tool." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from langchain.tools import SceneXplainTool\n", + "\n", + "\n", + "os.environ[\"SCENEX_API_KEY\"] = \"\"\n", + "tool = SceneXplainTool()\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage in an Agent\n", + "\n", + "The tool can be used in any LangChain agent as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Thought: Do I need to use a tool? Yes\n", + "Action: Image Explainer\n", + "Action Input: https://storage.googleapis.com/causal-diffusion.appspot.com/imagePrompts%2F0rw369i5h9t%2Foriginal.png\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mIn a charmingly whimsical scene, a young girl is seen braving the rain alongside her furry companion, the lovable Totoro. The two are depicted standing on a bustling street corner, where they are sheltered from the rain by a bright yellow umbrella. The girl, dressed in a cheerful yellow frock, holds onto the umbrella with both hands while gazing up at Totoro with an expression of wonder and delight.\n", + "\n", + "Totoro, meanwhile, stands tall and proud beside his young friend, holding his own umbrella aloft to protect them both from the downpour. His furry body is rendered in rich shades of grey and white, while his large ears and wide eyes lend him an endearing charm.\n", + "\n", + "In the background of the scene, a street sign can be seen jutting out from the pavement amidst a flurry of raindrops. A sign with Chinese characters adorns its surface, adding to the sense of cultural diversity and intrigue. Despite the dreary weather, there is an undeniable sense of joy and camaraderie in this heartwarming image.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m Do I need to use a tool? No\n", + "AI: This image appears to be a still from the 1988 Japanese animated fantasy film My Neighbor Totoro. The film follows two young girls, Satsuki and Mei, as they explore the countryside and befriend the magical forest spirits, including the titular character Totoro.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "This image appears to be a still from the 1988 Japanese animated fantasy film My Neighbor Totoro. The film follows two young girls, Satsuki and Mei, as they explore the countryside and befriend the magical forest spirits, including the titular character Totoro.\n" + ] + } + ], + "source": [ + "from langchain.llms import OpenAI\n", + "from langchain.agents import initialize_agent\n", + "from langchain.memory import ConversationBufferMemory\n", + "\n", + "llm = OpenAI(temperature=0)\n", + "memory = ConversationBufferMemory(memory_key=\"chat_history\")\n", + "tools = [\n", + " tool\n", + "]\n", + "\n", + "agent = initialize_agent(\n", + " tools, llm, memory=memory, agent=\"conversational-react-description\", verbose=True\n", + ")\n", + "output = agent.run(\n", + " input=(\n", + " \"What is in this image https://storage.googleapis.com/causal-diffusion.appspot.com/imagePrompts%2F0rw369i5h9t%2Foriginal.png. \"\n", + " \"Is it movie or a game? If it is a movie, what is the name of the movie?\"\n", + " )\n", + ")\n", + "\n", + "print(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/agents/agent_toolkits/playwright/toolkit.py b/langchain/agents/agent_toolkits/playwright/toolkit.py index cd3722752e317..9b5a6fd804f50 100644 --- a/langchain/agents/agent_toolkits/playwright/toolkit.py +++ b/langchain/agents/agent_toolkits/playwright/toolkit.py @@ -1,13 +1,16 @@ """Playwright web browser toolkit.""" from __future__ import annotations -from typing import TYPE_CHECKING, List, Type +from typing import TYPE_CHECKING, List, Optional, Type, cast -from pydantic import Extra, Field, root_validator +from pydantic import Extra, root_validator from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.tools.base import BaseTool -from langchain.tools.playwright.base import BaseBrowserTool +from langchain.tools.playwright.base import ( + BaseBrowserTool, + lazy_import_playwright_browsers, +) from langchain.tools.playwright.click import ClickTool from langchain.tools.playwright.current_page import CurrentWebPageTool from langchain.tools.playwright.extract_hyperlinks import ExtractHyperlinksTool @@ -15,16 +18,24 @@ from langchain.tools.playwright.get_elements import GetElementsTool from langchain.tools.playwright.navigate import NavigateTool from langchain.tools.playwright.navigate_back import NavigateBackTool -from langchain.tools.playwright.utils import create_playwright_browser if TYPE_CHECKING: from playwright.async_api import Browser as AsyncBrowser + from playwright.sync_api import Browser as SyncBrowser +else: + try: + # We do this so pydantic can resolve the types when instantiating + from playwright.async_api import Browser as AsyncBrowser + from playwright.sync_api import Browser as SyncBrowser + except ImportError: + pass class PlayWrightBrowserToolkit(BaseToolkit): """Toolkit for web browser tools.""" - browser: AsyncBrowser = Field(default_factory=create_playwright_browser) + sync_browser: Optional["SyncBrowser"] = None + async_browser: Optional["AsyncBrowser"] = None class Config: """Configuration for this pydantic object.""" @@ -33,15 +44,11 @@ class Config: arbitrary_types_allowed = True @root_validator - def check_args(cls, values: dict) -> dict: + def validate_imports_and_browser_provided(cls, values: dict) -> dict: """Check that the arguments are valid.""" - try: - from playwright.async_api import Browser as AsyncBrowser # noqa: F401 - except ImportError: - raise ValueError( - "The 'playwright' package is required to use this tool." - " Please install it with 'pip install playwright'." - ) + lazy_import_playwright_browsers() + if values.get("async_browser") is None and values.get("sync_browser") is None: + raise ValueError("Either async_browser or sync_browser must be specified.") return values def get_tools(self) -> List[BaseTool]: @@ -56,11 +63,21 @@ def get_tools(self) -> List[BaseTool]: CurrentWebPageTool, ] - return [tool_cls.from_browser(self.browser) for tool_cls in tool_classes] + tools = [ + tool_cls.from_browser( + sync_browser=self.sync_browser, async_browser=self.async_browser + ) + for tool_cls in tool_classes + ] + return cast(List[BaseTool], tools) @classmethod - def from_browser(cls, browser: AsyncBrowser) -> PlayWrightBrowserToolkit: - from playwright.async_api import Browser as AsyncBrowser - - cls.update_forward_refs(AsyncBrowser=AsyncBrowser) - return cls(browser=browser) + def from_browser( + cls, + sync_browser: Optional[SyncBrowser] = None, + async_browser: Optional[AsyncBrowser] = None, + ) -> PlayWrightBrowserToolkit: + """Instantiate the toolkit.""" + # This is to raise a better error than the forward ref ones Pydantic would have + lazy_import_playwright_browsers() + return cls(sync_browser=sync_browser, async_browser=async_browser) diff --git a/langchain/tools/__init__.py b/langchain/tools/__init__.py index f8080cd5dbb79..a03519ac9cc26 100644 --- a/langchain/tools/__init__.py +++ b/langchain/tools/__init__.py @@ -16,7 +16,6 @@ from langchain.tools.openapi.utils.api_models import APIOperation from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec from langchain.tools.playwright import ( - BaseBrowserTool, ClickTool, CurrentWebPageTool, ExtractHyperlinksTool, @@ -26,12 +25,12 @@ NavigateTool, ) from langchain.tools.plugin import AIPluginTool +from langchain.tools.scenexplain.tool import SceneXplainTool from langchain.tools.shell.tool import ShellTool __all__ = [ "AIPluginTool", "APIOperation", - "BaseBrowserTool", "BaseTool", "BaseTool", "BingSearchResults", @@ -59,4 +58,6 @@ "ReadFileTool", "ShellTool", "WriteFileTool", + "BaseTool", + "SceneXplainTool", ] diff --git a/langchain/tools/base.py b/langchain/tools/base.py index fded77c4d4dad..6799409e75c20 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -86,7 +86,14 @@ def get_filtered_args( """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - return {k: schema[k] for k in valid_keys} + return {k: schema[k] for k in valid_keys if k != "run_manager"} + + +class _SchemaConfig: + """Configuration for the pydantic model.""" + + extra = Extra.forbid + arbitrary_types_allowed = True def create_schema_from_function( @@ -94,7 +101,10 @@ def create_schema_from_function( func: Callable, ) -> Type[BaseModel]: """Create a pydantic schema from a function's signature.""" - inferred_model = validate_arguments(func).model # type: ignore + validated = validate_arguments(func, config=_SchemaConfig) # type: ignore + inferred_model = validated.model # type: ignore + if "run_manager" in inferred_model.__fields__: + del inferred_model.__fields__["run_manager"] # Pydantic adds placeholder virtual fields we need to strip filtered_args = get_filtered_args(inferred_model, func) return _create_subset_model( @@ -143,8 +153,8 @@ def args(self) -> dict: if self.args_schema is not None: return self.args_schema.schema()["properties"] else: - inferred_model = validate_arguments(self._run).model # type: ignore - return get_filtered_args(inferred_model, self._run) + schema = create_schema_from_function(self.name, self._run) + return schema.schema()["properties"] def _parse_input( self, diff --git a/langchain/tools/google_places/tool.py b/langchain/tools/google_places/tool.py index cce83642826d3..27b0b56132f8d 100644 --- a/langchain/tools/google_places/tool.py +++ b/langchain/tools/google_places/tool.py @@ -2,7 +2,7 @@ from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, @@ -12,6 +12,10 @@ from langchain.utilities.google_places_api import GooglePlacesAPIWrapper +class GooglePlacesSchema(BaseModel): + query: str = Field(..., description="Query for goole maps") + + class GooglePlacesTool(BaseTool): """Tool that adds the capability to query the Google places API.""" diff --git a/langchain/tools/playwright/__init__.py b/langchain/tools/playwright/__init__.py index 8f7e6153eafff..2b58e50867cf3 100644 --- a/langchain/tools/playwright/__init__.py +++ b/langchain/tools/playwright/__init__.py @@ -1,6 +1,5 @@ """Browser tools and toolkit.""" -from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.click import ClickTool from langchain.tools.playwright.current_page import CurrentWebPageTool from langchain.tools.playwright.extract_hyperlinks import ExtractHyperlinksTool @@ -15,7 +14,6 @@ "ExtractTextTool", "ExtractHyperlinksTool", "GetElementsTool", - "BaseBrowserTool", "ClickTool", "CurrentWebPageTool", ] diff --git a/langchain/tools/playwright/base.py b/langchain/tools/playwright/base.py index 79857b45b4e42..1220cbe803d22 100644 --- a/langchain/tools/playwright/base.py +++ b/langchain/tools/playwright/base.py @@ -1,48 +1,55 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional, Tuple, Type -from pydantic import Field, root_validator +from pydantic import root_validator -from langchain.callbacks.manager import ( - CallbackManagerForToolRun, -) from langchain.tools.base import BaseTool -from langchain.tools.playwright.utils import create_playwright_browser, run_async if TYPE_CHECKING: from playwright.async_api import Browser as AsyncBrowser + from playwright.sync_api import Browser as SyncBrowser +else: + try: + # We do this so pydantic can resolve the types when instantiating + from playwright.async_api import Browser as AsyncBrowser + from playwright.sync_api import Browser as SyncBrowser + except ImportError: + pass + + +def lazy_import_playwright_browsers() -> Tuple[Type[AsyncBrowser], Type[SyncBrowser]]: + try: + from playwright.async_api import Browser as AsyncBrowser # noqa: F401 + from playwright.sync_api import Browser as SyncBrowser # noqa: F401 + except ImportError: + raise ValueError( + "The 'playwright' package is required to use the playwright tools." + " Please install it with 'pip install playwright'." + ) + return AsyncBrowser, SyncBrowser class BaseBrowserTool(BaseTool): """Base class for browser tools.""" - browser: AsyncBrowser = Field(default_factory=create_playwright_browser) + sync_browser: Optional["SyncBrowser"] = None + async_browser: Optional["AsyncBrowser"] = None @root_validator - def check_args(cls, values: dict) -> dict: + def validate_browser_provided(cls, values: dict) -> dict: """Check that the arguments are valid.""" - try: - from playwright.async_api import Browser as AsyncBrowser # noqa: F401 - except ImportError: - raise ValueError( - "The 'playwright' package is required to use this tool." - " Please install it with 'pip install playwright'." - ) + lazy_import_playwright_browsers() + if values.get("async_browser") is None and values.get("sync_browser") is None: + raise ValueError("Either async_browser or sync_browser must be specified.") return values - def _run( - self, - *args: Any, - run_manager: Optional[CallbackManagerForToolRun] = None, - **kwargs: Any, - ) -> str: - """Use the tool.""" - return run_async(self._arun(*args, **kwargs)) - @classmethod - def from_browser(cls, browser: AsyncBrowser) -> BaseBrowserTool: - from playwright.async_api import Browser as AsyncBrowser - - cls.update_forward_refs(AsyncBrowser=AsyncBrowser) - return cls(browser=browser) + def from_browser( + cls, + sync_browser: Optional[SyncBrowser] = None, + async_browser: Optional[AsyncBrowser] = None, + ) -> BaseBrowserTool: + """Instantiate the tool.""" + lazy_import_playwright_browsers() + return cls(sync_browser=sync_browser, async_browser=async_browser) diff --git a/langchain/tools/playwright/click.py b/langchain/tools/playwright/click.py index d11cdde46793b..671faf433eddb 100644 --- a/langchain/tools/playwright/click.py +++ b/langchain/tools/playwright/click.py @@ -4,9 +4,13 @@ from pydantic import BaseModel, Field -from langchain.callbacks.manager import AsyncCallbackManagerForToolRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( + aget_current_page, get_current_page, ) @@ -22,13 +26,28 @@ class ClickTool(BaseBrowserTool): description: str = "Click on an element with the given CSS selector" args_schema: Type[BaseModel] = ClickToolInput + def _run( + self, + selector: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the tool.""" + if self.sync_browser is None: + raise ValueError(f"Synchronous browser not provided to {self.name}") + page = get_current_page(self.sync_browser) + # Navigate to the desired webpage before using this tool + page.click(selector) + return f"Clicked element '{selector}'" + async def _arun( self, selector: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" - page = await get_current_page(self.browser) + if self.async_browser is None: + raise ValueError(f"Asynchronous browser not provided to {self.name}") + page = await aget_current_page(self.async_browser) # Navigate to the desired webpage before using this tool await page.click(selector) return f"Clicked element '{selector}'" diff --git a/langchain/tools/playwright/current_page.py b/langchain/tools/playwright/current_page.py index fe6593fa3674a..b0e51c2586801 100644 --- a/langchain/tools/playwright/current_page.py +++ b/langchain/tools/playwright/current_page.py @@ -4,11 +4,12 @@ from pydantic import BaseModel -from langchain.callbacks.manager import AsyncCallbackManagerForToolRun -from langchain.tools.playwright.base import BaseBrowserTool -from langchain.tools.playwright.utils import ( - get_current_page, +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, ) +from langchain.tools.playwright.base import BaseBrowserTool +from langchain.tools.playwright.utils import aget_current_page, get_current_page class CurrentWebPageTool(BaseBrowserTool): @@ -16,10 +17,22 @@ class CurrentWebPageTool(BaseBrowserTool): description: str = "Returns the URL of the current page" args_schema: Type[BaseModel] = BaseModel + def _run( + self, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the tool.""" + if self.sync_browser is None: + raise ValueError(f"Synchronous browser not provided to {self.name}") + page = get_current_page(self.sync_browser) + return str(page.url) + async def _arun( self, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" - page = await get_current_page(self.browser) + if self.async_browser is None: + raise ValueError(f"Asynchronous browser not provided to {self.name}") + page = await aget_current_page(self.async_browser) return str(page.url) diff --git a/langchain/tools/playwright/extract_hyperlinks.py b/langchain/tools/playwright/extract_hyperlinks.py index 6d903e0101c97..781de348f4446 100644 --- a/langchain/tools/playwright/extract_hyperlinks.py +++ b/langchain/tools/playwright/extract_hyperlinks.py @@ -1,13 +1,16 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING, Any, Optional, Type from pydantic import BaseModel, Field, root_validator -from langchain.callbacks.manager import AsyncCallbackManagerForToolRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool -from langchain.tools.playwright.utils import get_current_page +from langchain.tools.playwright.utils import aget_current_page, get_current_page if TYPE_CHECKING: pass @@ -30,7 +33,7 @@ class ExtractHyperlinksTool(BaseBrowserTool): args_schema: Type[BaseModel] = ExtractHyperlinksToolInput @root_validator - def check_args(cls, values: dict) -> dict: + def check_bs_import(cls, values: dict) -> dict: """Check that the arguments are valid.""" try: from bs4 import BeautifulSoup # noqa: F401 @@ -41,19 +44,12 @@ def check_args(cls, values: dict) -> dict: ) return values - async def _arun( - self, - absolute_urls: bool = False, - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, - ) -> str: - """Use the tool.""" + @staticmethod + def scrape_page(page: Any, html_content: str, absolute_urls: bool) -> str: from urllib.parse import urljoin from bs4 import BeautifulSoup - page = await get_current_page(self.browser) - html_content = await page.content() - # Parse the HTML content with BeautifulSoup soup = BeautifulSoup(html_content, "lxml") @@ -64,6 +60,30 @@ async def _arun( links = [urljoin(base_url, anchor.get("href", "")) for anchor in anchors] else: links = [anchor.get("href", "") for anchor in anchors] - # Return the list of links as a JSON string return json.dumps(links) + + def _run( + self, + absolute_urls: bool = False, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the tool.""" + if self.sync_browser is None: + raise ValueError(f"Synchronous browser not provided to {self.name}") + page = get_current_page(self.sync_browser) + html_content = page.content() + return self.scrape_page(page, html_content, absolute_urls) + + async def _arun( + self, + absolute_urls: bool = False, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + """Use the tool.""" + """Use the tool asynchronously.""" + if self.async_browser is None: + raise ValueError(f"Asynchronous browser not provided to {self.name}") + page = await aget_current_page(self.async_browser) + html_content = await page.content() + return self.scrape_page(page, html_content, absolute_urls) diff --git a/langchain/tools/playwright/extract_text.py b/langchain/tools/playwright/extract_text.py index 5b20b7c38ab5b..5b228786c86fa 100644 --- a/langchain/tools/playwright/extract_text.py +++ b/langchain/tools/playwright/extract_text.py @@ -4,9 +4,12 @@ from pydantic import BaseModel, root_validator -from langchain.callbacks.manager import AsyncCallbackManagerForToolRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool -from langchain.tools.playwright.utils import get_current_page +from langchain.tools.playwright.utils import aget_current_page, get_current_page class ExtractTextTool(BaseBrowserTool): @@ -15,7 +18,7 @@ class ExtractTextTool(BaseBrowserTool): args_schema: Type[BaseModel] = BaseModel @root_validator - def check_args(cls, values: dict) -> dict: + def check_acheck_bs_importrgs(cls, values: dict) -> dict: """Check that the arguments are valid.""" try: from bs4 import BeautifulSoup # noqa: F401 @@ -26,14 +29,32 @@ def check_args(cls, values: dict) -> dict: ) return values + def _run(self, run_manager: Optional[CallbackManagerForToolRun] = None) -> str: + """Use the tool.""" + # Use Beautiful Soup since it's faster than looping through the elements + from bs4 import BeautifulSoup + + if self.sync_browser is None: + raise ValueError(f"Synchronous browser not provided to {self.name}") + + page = get_current_page(self.sync_browser) + html_content = page.content() + + # Parse the HTML content with BeautifulSoup + soup = BeautifulSoup(html_content, "lxml") + + return " ".join(text for text in soup.stripped_strings) + async def _arun( self, run_manager: Optional[AsyncCallbackManagerForToolRun] = None ) -> str: """Use the tool.""" + if self.async_browser is None: + raise ValueError(f"Asynchronous browser not provided to {self.name}") # Use Beautiful Soup since it's faster than looping through the elements from bs4 import BeautifulSoup - page = await get_current_page(self.browser) + page = await aget_current_page(self.async_browser) html_content = await page.content() # Parse the HTML content with BeautifulSoup diff --git a/langchain/tools/playwright/get_elements.py b/langchain/tools/playwright/get_elements.py index 910c022cf7fb9..a5ad232f2498b 100644 --- a/langchain/tools/playwright/get_elements.py +++ b/langchain/tools/playwright/get_elements.py @@ -5,12 +5,16 @@ from pydantic import BaseModel, Field -from langchain.callbacks.manager import AsyncCallbackManagerForToolRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool -from langchain.tools.playwright.utils import get_current_page +from langchain.tools.playwright.utils import aget_current_page, get_current_page if TYPE_CHECKING: from playwright.async_api import Page as AsyncPage + from playwright.sync_api import Page as SyncPage class GetElementsToolInput(BaseModel): @@ -26,7 +30,7 @@ class GetElementsToolInput(BaseModel): ) -async def _get_elements( +async def _aget_elements( page: AsyncPage, selector: str, attributes: Sequence[str] ) -> List[dict]: """Get elements matching the given CSS selector.""" @@ -46,6 +50,26 @@ async def _get_elements( return results +def _get_elements( + page: SyncPage, selector: str, attributes: Sequence[str] +) -> List[dict]: + """Get elements matching the given CSS selector.""" + elements = page.query_selector_all(selector) + results = [] + for element in elements: + result = {} + for attribute in attributes: + if attribute == "innerText": + val: Optional[str] = element.inner_text() + else: + val = element.get_attribute(attribute) + if val is not None and val.strip() != "": + result[attribute] = val + if result: + results.append(result) + return results + + class GetElementsTool(BaseBrowserTool): name: str = "get_elements" description: str = ( @@ -53,6 +77,20 @@ class GetElementsTool(BaseBrowserTool): ) args_schema: Type[BaseModel] = GetElementsToolInput + def _run( + self, + selector: str, + attributes: Sequence[str] = ["innerText"], + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the tool.""" + if self.sync_browser is None: + raise ValueError(f"Synchronous browser not provided to {self.name}") + page = get_current_page(self.sync_browser) + # Navigate to the desired webpage before using this tool + results = _get_elements(page, selector, attributes) + return json.dumps(results) + async def _arun( self, selector: str, @@ -60,7 +98,9 @@ async def _arun( run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" - page = await get_current_page(self.browser) + if self.async_browser is None: + raise ValueError(f"Asynchronous browser not provided to {self.name}") + page = await aget_current_page(self.async_browser) # Navigate to the desired webpage before using this tool - results = await _get_elements(page, selector, attributes) + results = await _aget_elements(page, selector, attributes) return json.dumps(results) diff --git a/langchain/tools/playwright/navigate.py b/langchain/tools/playwright/navigate.py index cd1e4dbc1028d..f9af3fc8396bc 100644 --- a/langchain/tools/playwright/navigate.py +++ b/langchain/tools/playwright/navigate.py @@ -4,9 +4,13 @@ from pydantic import BaseModel, Field -from langchain.callbacks.manager import AsyncCallbackManagerForToolRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( + aget_current_page, get_current_page, ) @@ -22,13 +26,28 @@ class NavigateTool(BaseBrowserTool): description: str = "Navigate a browser to the specified URL" args_schema: Type[BaseModel] = NavigateToolInput + def _run( + self, + url: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the tool.""" + if self.sync_browser is None: + raise ValueError(f"Synchronous browser not provided to {self.name}") + page = get_current_page(self.sync_browser) + response = page.goto(url) + status = response.status if response else "unknown" + return f"Navigating to {url} returned status code {status}" + async def _arun( self, url: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" - page = await get_current_page(self.browser) + if self.async_browser is None: + raise ValueError(f"Asynchronous browser not provided to {self.name}") + page = await aget_current_page(self.async_browser) response = await page.goto(url) status = response.status if response else "unknown" return f"Navigating to {url} returned status code {status}" diff --git a/langchain/tools/playwright/navigate_back.py b/langchain/tools/playwright/navigate_back.py index 1a93e2636b918..da4d35775498b 100644 --- a/langchain/tools/playwright/navigate_back.py +++ b/langchain/tools/playwright/navigate_back.py @@ -4,9 +4,13 @@ from pydantic import BaseModel -from langchain.callbacks.manager import AsyncCallbackManagerForToolRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( + aget_current_page, get_current_page, ) @@ -18,18 +22,35 @@ class NavigateBackTool(BaseBrowserTool): description: str = "Navigate back to the previous page in the browser history" args_schema: Type[BaseModel] = BaseModel + def _run(self, run_manager: Optional[CallbackManagerForToolRun] = None) -> str: + """Use the tool.""" + if self.sync_browser is None: + raise ValueError(f"Synchronous browser not provided to {self.name}") + page = get_current_page(self.sync_browser) + response = page.go_back() + + if response: + return ( + f"Navigated back to the previous page with URL '{response.url}'." + f" Status code {response.status}" + ) + else: + return "Unable to navigate back; no previous page in the history" + async def _arun( self, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" - page = await get_current_page(self.browser) + if self.async_browser is None: + raise ValueError(f"Asynchronous browser not provided to {self.name}") + page = await aget_current_page(self.async_browser) response = await page.go_back() if response: return ( f"Navigated back to the previous page with URL '{response.url}'." - " Status code {response.status}" + f" Status code {response.status}" ) else: return "Unable to navigate back; no previous page in the history" diff --git a/langchain/tools/playwright/utils.py b/langchain/tools/playwright/utils.py index 4903836ae7a52..b5e7f13677ecf 100644 --- a/langchain/tools/playwright/utils.py +++ b/langchain/tools/playwright/utils.py @@ -7,9 +7,11 @@ if TYPE_CHECKING: from playwright.async_api import Browser as AsyncBrowser from playwright.async_api import Page as AsyncPage + from playwright.sync_api import Browser as SyncBrowser + from playwright.sync_api import Page as SyncPage -async def get_current_page(browser: AsyncBrowser) -> AsyncPage: +async def aget_current_page(browser: AsyncBrowser) -> AsyncPage: if not browser.contexts: context = await browser.new_context() return await context.new_page() @@ -20,13 +22,31 @@ async def get_current_page(browser: AsyncBrowser) -> AsyncPage: return context.pages[-1] -def create_playwright_browser() -> AsyncBrowser: +def get_current_page(browser: SyncBrowser) -> SyncPage: + if not browser.contexts: + context = browser.new_context() + return context.new_page() + context = browser.contexts[0] # Assuming you're using the default browser context + if not context.pages: + return context.new_page() + # Assuming the last page in the list is the active one + return context.pages[-1] + + +def create_async_playwright_browser() -> AsyncBrowser: from playwright.async_api import async_playwright browser = run_async(async_playwright().start()) return run_async(browser.chromium.launch(headless=True)) +def create_sync_playwright_browser() -> SyncBrowser: + from playwright.sync_api import sync_playwright + + browser = sync_playwright().start() + return browser.chromium.launch(headless=True) + + T = TypeVar("T") diff --git a/langchain/tools/plugin.py b/langchain/tools/plugin.py index 0510eaaa2ce28..3d38895b0d43d 100644 --- a/langchain/tools/plugin.py +++ b/langchain/tools/plugin.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Optional +from typing import Optional, Type import requests import yaml @@ -49,9 +49,16 @@ def marshal_spec(txt: str) -> dict: return yaml.safe_load(txt) +class AIPLuginToolSchema(BaseModel): + """AIPLuginToolSchema.""" + + tool_input: Optional[str] = "" + + class AIPluginTool(BaseTool): plugin: AIPlugin api_spec: str + args_schema: Type[AIPLuginToolSchema] = AIPLuginToolSchema @classmethod def from_plugin_url(cls, url: str) -> AIPluginTool: @@ -78,7 +85,7 @@ def from_plugin_url(cls, url: str) -> AIPluginTool: def _run( self, - tool_input: str, + tool_input: Optional[str] = "", run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" @@ -86,7 +93,7 @@ def _run( async def _arun( self, - tool_input: str, + tool_input: Optional[str] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Use the tool asynchronously.""" diff --git a/langchain/tools/scenexplain/__init__.py b/langchain/tools/scenexplain/__init__.py new file mode 100644 index 0000000000000..2e6553b73567d --- /dev/null +++ b/langchain/tools/scenexplain/__init__.py @@ -0,0 +1 @@ +"""SceneXplain API toolkit.""" diff --git a/langchain/tools/scenexplain/tool.py b/langchain/tools/scenexplain/tool.py new file mode 100644 index 0000000000000..94ce73c6a400a --- /dev/null +++ b/langchain/tools/scenexplain/tool.py @@ -0,0 +1,42 @@ +"""Tool for the SceneXplain API.""" + +from typing import Optional + +from pydantic import BaseModel, Field + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool +from langchain.utilities.scenexplain import SceneXplainAPIWrapper + + +class SceneXplainInput(BaseModel): + """Input for SceneXplain.""" + + query: str = Field(..., description="The link to the image to explain") + + +class SceneXplainTool(BaseTool): + """Tool that adds the capability to explain images.""" + + name = "Image Explainer" + description = ( + "An Image Captioning Tool: Use this tool to generate a detailed caption " + "for an image. The input can be an image file of any format, and " + "the output will be a text description that covers every detail of the image." + ) + api_wrapper: SceneXplainAPIWrapper = Field(default_factory=SceneXplainAPIWrapper) + + def _run( + self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: + """Use the tool.""" + return self.api_wrapper.run(query) + + async def _arun( + self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None + ) -> str: + """Use the tool asynchronously.""" + raise NotImplementedError("SceneXplainTool does not support async") diff --git a/langchain/utilities/scenexplain.py b/langchain/utilities/scenexplain.py new file mode 100644 index 0000000000000..7b3342c16491b --- /dev/null +++ b/langchain/utilities/scenexplain.py @@ -0,0 +1,68 @@ +"""Util that calls SceneXplain. + +In order to set this up, you need API key for the SceneXplain API. +You can obtain a key by following the steps below. +- Sign up for a free account at https://scenex.jina.ai/. +- Navigate to the API Access page (https://scenex.jina.ai/api) and create a new API key. +""" +from typing import Dict + +import requests +from pydantic import BaseModel, root_validator + +from langchain.utils import get_from_dict_or_env + + +class SceneXplainAPIWrapper(BaseModel): + """Wrapper for SceneXplain API. + + In order to set this up, you need API key for the SceneXplain API. + You can obtain a key by following the steps below. + - Sign up for a free account at https://scenex.jina.ai/. + - Navigate to the API Access page (https://scenex.jina.ai/api) + and create a new API key. + """ + + scenex_api_key: str + scenex_api_url: str = ( + "https://us-central1-causal-diffusion.cloudfunctions.net/describe" + ) + + def _describe_image(self, image: str) -> str: + headers = { + "x-api-key": f"token {self.scenex_api_key}", + "content-type": "application/json", + } + payload = { + "data": [ + { + "image": image, + "algorithm": "Ember", + "languages": ["en"], + } + ] + } + response = requests.post(self.scenex_api_url, headers=headers, json=payload) + response.raise_for_status() + result = response.json().get("result", []) + img = result[0] if result else {} + + return img.get("text", "") + + @root_validator(pre=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key exists in environment.""" + scenex_api_key = get_from_dict_or_env( + values, "scenex_api_key", "SCENEX_API_KEY" + ) + values["scenex_api_key"] = scenex_api_key + + return values + + def run(self, image: str) -> str: + """Run SceneXplain image explainer.""" + description = self._describe_image(image) + if not description: + return "No description found." + + return description diff --git a/tests/unit_tests/tools/test_signatures.py b/tests/unit_tests/tools/test_signatures.py index 6a7a912e050f6..b1634dfcb1fb0 100644 --- a/tests/unit_tests/tools/test_signatures.py +++ b/tests/unit_tests/tools/test_signatures.py @@ -8,14 +8,18 @@ import pytest from langchain.tools.base import BaseTool +from langchain.tools.playwright.base import BaseBrowserTool def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]: + to_skip = {BaseBrowserTool} # Abstract but not recognized subclasses = [] for subclass in cls.__subclasses__(): - if not getattr( - subclass, "__abstract__", None - ) and not subclass.__name__.startswith("_"): + if ( + not getattr(subclass, "__abstract__", None) + and not subclass.__name__.startswith("_") + and subclass not in to_skip + ): subclasses.append(subclass) subclasses.extend(get_non_abstract_subclasses(subclass)) return subclasses From 9192abcaf1a368b7da1c6148fcaa384c82c67379 Mon Sep 17 00:00:00 2001 From: vowelparrot <130414180+vowelparrot@users.noreply.github.com> Date: Fri, 28 Apr 2023 19:00:18 -0700 Subject: [PATCH 31/36] Notebook Nits --- docs/modules/agents/tools/examples/arxiv.ipynb | 2 ++ docs/modules/agents/tools/examples/gradio_tools.ipynb | 5 +++-- docs/modules/agents/tools/examples/python.ipynb | 10 +++++++++- docs/modules/agents/tools/examples/serpapi.ipynb | 10 +++++++++- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/docs/modules/agents/tools/examples/arxiv.ipynb b/docs/modules/agents/tools/examples/arxiv.ipynb index 04b1cc478823e..38027d3c319b5 100644 --- a/docs/modules/agents/tools/examples/arxiv.ipynb +++ b/docs/modules/agents/tools/examples/arxiv.ipynb @@ -75,6 +75,8 @@ } ], "source": [ + "\n", + "arxiv = ArxivAPIWrapper()\n", "docs = arxiv.run(\"1605.08386\")\n", "docs" ] diff --git a/docs/modules/agents/tools/examples/gradio_tools.ipynb b/docs/modules/agents/tools/examples/gradio_tools.ipynb index a1a8c2ca3d1d7..d4a1891878710 100644 --- a/docs/modules/agents/tools/examples/gradio_tools.ipynb +++ b/docs/modules/agents/tools/examples/gradio_tools.ipynb @@ -69,7 +69,8 @@ } ], "source": [ - "StableDiffusionTool().langchain.run(\"Please create a photo of a dog riding a skateboard\")" + "local_file_path = StableDiffusionTool().langchain.run(\"Please create a photo of a dog riding a skateboard\")\n", + "local_file_path" ] }, { @@ -89,7 +90,7 @@ "metadata": {}, "outputs": [], "source": [ - "im = Image.open(\"/Users/harrisonchase/workplace/langchain/docs/modules/agents/tools/examples/b61c1dd9-47e2-46f1-a47c-20d27640993d/tmp4ap48vnm.jpg\")" + "im = Image.open(local_file_path)" ] }, { diff --git a/docs/modules/agents/tools/examples/python.ipynb b/docs/modules/agents/tools/examples/python.ipynb index db2824f2db522..a1428545379e0 100644 --- a/docs/modules/agents/tools/examples/python.ipynb +++ b/docs/modules/agents/tools/examples/python.ipynb @@ -19,6 +19,7 @@ "metadata": {}, "outputs": [], "source": [ + "from langchain.agents import Tool\n", "from langchain.utilities import PythonREPL" ] }, @@ -59,7 +60,14 @@ "id": "54fc1f03", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# You can create the tool to pass to an agent\n", + "repl_tool = Tool(\n", + " name=\"python_repl\",\n", + " description=\"A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.\",\n", + " func=python_repl\n", + ")" + ] } ], "metadata": { diff --git a/docs/modules/agents/tools/examples/serpapi.ipynb b/docs/modules/agents/tools/examples/serpapi.ipynb index c77821ca91a84..c4ad0a6bb7fb6 100644 --- a/docs/modules/agents/tools/examples/serpapi.ipynb +++ b/docs/modules/agents/tools/examples/serpapi.ipynb @@ -102,7 +102,15 @@ "id": "e0a1dc1c", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from langchain.agents import Tool\n", + "# You can create the tool to pass to an agent\n", + "repl_tool = Tool(\n", + " name=\"python_repl\",\n", + " description=\"A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.\",\n", + " func=search.run,\n", + ")" + ] } ], "metadata": { From ebc6242fc887779dcceff00c0c916aa3b8b463cc Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Sat, 29 Apr 2023 11:13:32 -0700 Subject: [PATCH 32/36] fix docs --- docs/modules/chains/generic/custom_chain.ipynb | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/modules/chains/generic/custom_chain.ipynb b/docs/modules/chains/generic/custom_chain.ipynb index 9f71fe3375e8a..4916b14c00a6e 100644 --- a/docs/modules/chains/generic/custom_chain.ipynb +++ b/docs/modules/chains/generic/custom_chain.ipynb @@ -1,14 +1,13 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "id": "593f7553-7038-498e-96d4-8255e5ce34f0", "metadata": {}, "source": [ "# Creating a custom Chain\n", "\n", - "To implement your own custom chain you can subclass `BaseChain` and implement the following methods:" + "To implement your own custom chain you can subclass `Chain` and implement the following methods:" ] }, { @@ -181,6 +180,18 @@ "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" } }, "nbformat": 4, From 737467a73180ce364052fdff44d3ce7565036362 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Sat, 29 Apr 2023 13:41:51 -0700 Subject: [PATCH 33/36] use UUID --- langchain/callbacks/base.py | 109 ++++---- langchain/callbacks/manager.py | 36 +-- langchain/callbacks/openai_info.py | 10 - langchain/callbacks/stdout.py | 2 +- langchain/callbacks/tracers/base.py | 76 +++--- .../callbacks/tracers/test_tracer.py | 238 +++++++++--------- 6 files changed, 235 insertions(+), 236 deletions(-) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 07c5913fcbb5c..69ba0f1a7cdfe 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -3,6 +3,7 @@ import copy from typing import Any, Dict, List, Optional, Union +from uuid import UUID from langchain.schema import AgentAction, AgentFinish, LLMResult @@ -14,8 +15,8 @@ def on_llm_new_token( self, token: str, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run on new LLM token. Only available when streaming is enabled.""" @@ -24,8 +25,8 @@ def on_llm_end( self, response: LLMResult, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run when LLM ends running.""" @@ -34,8 +35,8 @@ def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run when LLM errors.""" @@ -48,8 +49,8 @@ def on_chain_end( self, outputs: Dict[str, Any], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run when chain ends running.""" @@ -58,8 +59,8 @@ def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run when chain errors.""" @@ -68,8 +69,8 @@ def on_agent_action( self, action: AgentAction, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run on agent action.""" @@ -78,8 +79,8 @@ def on_agent_finish( self, finish: AgentFinish, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run on agent end.""" @@ -92,8 +93,8 @@ def on_tool_end( self, output: str, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run when tool ends running.""" @@ -102,8 +103,8 @@ def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run when tool errors.""" @@ -117,8 +118,8 @@ def on_llm_start( serialized: Dict[str, Any], prompts: List[str], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run when LLM starts running.""" @@ -128,8 +129,8 @@ def on_chain_start( serialized: Dict[str, Any], inputs: Dict[str, Any], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run when chain starts running.""" @@ -139,8 +140,8 @@ def on_tool_start( serialized: Dict[str, Any], input_str: str, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run when tool starts running.""" @@ -153,8 +154,8 @@ def on_text( self, text: str, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run on arbitrary text.""" @@ -193,8 +194,8 @@ async def on_llm_start( serialized: Dict[str, Any], prompts: List[str], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run when LLM starts running.""" @@ -203,8 +204,8 @@ async def on_llm_new_token( self, token: str, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run on new LLM token. Only available when streaming is enabled.""" @@ -213,8 +214,8 @@ async def on_llm_end( self, response: LLMResult, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run when LLM ends running.""" @@ -223,8 +224,8 @@ async def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run when LLM errors.""" @@ -234,8 +235,8 @@ async def on_chain_start( serialized: Dict[str, Any], inputs: Dict[str, Any], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run when chain starts running.""" @@ -244,8 +245,8 @@ async def on_chain_end( self, outputs: Dict[str, Any], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run when chain ends running.""" @@ -254,8 +255,8 @@ async def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run when chain errors.""" @@ -265,8 +266,8 @@ async def on_tool_start( serialized: Dict[str, Any], input_str: str, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run when tool starts running.""" @@ -275,8 +276,8 @@ async def on_tool_end( self, output: str, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run when tool ends running.""" @@ -285,8 +286,8 @@ async def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run when tool errors.""" @@ -295,8 +296,8 @@ async def on_text( self, text: str, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run on arbitrary text.""" @@ -305,8 +306,8 @@ async def on_agent_action( self, action: AgentAction, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run on agent action.""" @@ -315,8 +316,8 @@ async def on_agent_finish( self, finish: AgentFinish, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Run on agent end.""" @@ -329,14 +330,14 @@ def __init__( self, handlers: List[BaseCallbackHandler], inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, - parent_run_id: Optional[str] = None, + parent_run_id: Optional[UUID] = None, ) -> None: """Initialize callback manager.""" self.handlers: List[BaseCallbackHandler] = handlers self.inheritable_handlers: List[BaseCallbackHandler] = ( inheritable_handlers or [] ) - self.parent_run_id: Optional[str] = parent_run_id + self.parent_run_id: Optional[UUID] = parent_run_id @property def is_async(self) -> bool: diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 0d0a8bd993af5..e001e6e1cd61d 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -4,10 +4,10 @@ import copy import functools import os -import uuid from contextlib import contextmanager from contextvars import ContextVar from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union +from uuid import UUID, uuid4 from langchain.callbacks.base import ( BaseCallbackHandler, @@ -119,10 +119,10 @@ class BaseRunManager(RunManagerMixin): def __init__( self, - run_id: str, + run_id: UUID, handlers: List[BaseCallbackHandler], inheritable_handlers: List[BaseCallbackHandler], - parent_run_id: Optional[str] = None, + parent_run_id: Optional[UUID] = None, ) -> None: """Initialize run manager.""" self.run_id = run_id @@ -133,7 +133,7 @@ def __init__( @classmethod def get_noop_manager(cls: Type[BRM]) -> BRM: """Return a manager that doesn't perform any operations.""" - return cls("", [], []) + return cls(uuid4(), [], []) class RunManager(BaseRunManager): @@ -483,12 +483,12 @@ def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], - run_id: Optional[str] = None, + run_id: Optional[UUID] = None, **kwargs: Any, ) -> CallbackManagerForLLMRun: """Run when LLM starts running.""" if run_id is None: - run_id = str(uuid.uuid4()) + run_id = uuid4() _handle_event( self.handlers, @@ -509,12 +509,12 @@ def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], - run_id: Optional[str] = None, + run_id: Optional[UUID] = None, **kwargs: Any, ) -> CallbackManagerForChainRun: """Run when chain starts running.""" if run_id is None: - run_id = str(uuid.uuid4()) + run_id = uuid4() _handle_event( self.handlers, @@ -535,13 +535,13 @@ def on_tool_start( self, serialized: Dict[str, Any], input_str: str, - run_id: Optional[str] = None, - parent_run_id: Optional[str] = None, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> CallbackManagerForToolRun: """Run when tool starts running.""" if run_id is None: - run_id = str(uuid.uuid4()) + run_id = uuid4() _handle_event( self.handlers, @@ -581,12 +581,12 @@ async def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], - run_id: Optional[str] = None, + run_id: Optional[UUID] = None, **kwargs: Any, ) -> AsyncCallbackManagerForLLMRun: """Run when LLM starts running.""" if run_id is None: - run_id = str(uuid.uuid4()) + run_id = uuid4() await _ahandle_event( self.handlers, @@ -607,12 +607,12 @@ async def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], - run_id: Optional[str] = None, + run_id: Optional[UUID] = None, **kwargs: Any, ) -> AsyncCallbackManagerForChainRun: """Run when chain starts running.""" if run_id is None: - run_id = str(uuid.uuid4()) + run_id = uuid4() await _ahandle_event( self.handlers, @@ -633,13 +633,13 @@ async def on_tool_start( self, serialized: Dict[str, Any], input_str: str, - run_id: Optional[str] = None, - parent_run_id: Optional[str] = None, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> AsyncCallbackManagerForToolRun: """Run when tool starts running.""" if run_id is None: - run_id = str(uuid.uuid4()) + run_id = uuid4() await _ahandle_event( self.handlers, diff --git a/langchain/callbacks/openai_info.py b/langchain/callbacks/openai_info.py index 42005acbcdde4..9181ee897a186 100644 --- a/langchain/callbacks/openai_info.py +++ b/langchain/callbacks/openai_info.py @@ -148,16 +148,6 @@ def on_tool_error( """Do nothing.""" pass - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Optional[str], - ) -> None: - """Run when agent ends.""" - pass - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run on agent action.""" pass diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py index 18eb0d2199994..90b0a83e1ac90 100644 --- a/langchain/callbacks/stdout.py +++ b/langchain/callbacks/stdout.py @@ -91,7 +91,7 @@ def on_text( text: str, color: Optional[str] = None, end: str = "", - **kwargs: Optional[str], + **kwargs: Any, ) -> None: """Run when agent ends.""" print_text(text, color=color if color else self.color, end=end) diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 0b6b21934ac6c..a7d3b32252381 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, List, Optional, Union -from uuid import uuid4 +from uuid import UUID from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.tracers.schemas import ( @@ -120,21 +120,21 @@ def on_llm_start( serialized: Dict[str, Any], prompts: List[str], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Start a trace for an LLM run.""" if self.session is None: self.session = self.load_default_session() - if run_id is None: - run_id = str(uuid4()) + run_id_ = str(run_id) + parent_run_id_ = str(parent_run_id) if parent_run_id else None - execution_order = self._get_execution_order(parent_run_id) + execution_order = self._get_execution_order(parent_run_id_) llm_run = LLMRun( - uuid=run_id, - parent_uuid=parent_run_id, + uuid=run_id_, + parent_uuid=parent_run_id_, serialized=serialized, prompts=prompts, extra=kwargs, @@ -145,12 +145,13 @@ def on_llm_start( ) self._start_trace(llm_run) - def on_llm_end(self, response: LLMResult, *, run_id: str, **kwargs: Any) -> None: + def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None: """End a trace for an LLM run.""" if not run_id: raise TracerException("No run_id provided for on_llm_end callback.") - llm_run = self.run_map.get(run_id) + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) if llm_run is None or not isinstance(llm_run, LLMRun): raise TracerException("No LLMRun found to be traced") @@ -162,14 +163,15 @@ def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], *, - run_id: str, + run_id: UUID, **kwargs: Any, ) -> None: """Handle an error for an LLM run.""" if not run_id: raise TracerException("No run_id provided for on_llm_error callback.") - llm_run = self.run_map.get(run_id) + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) if llm_run is None or not isinstance(llm_run, LLMRun): raise TracerException("No LLMRun found to be traced") @@ -182,18 +184,21 @@ def on_chain_start( serialized: Dict[str, Any], inputs: Dict[str, Any], *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Start a trace for a chain run.""" if self.session is None: self.session = self.load_default_session() - execution_order = self._get_execution_order(parent_run_id) + run_id_ = str(run_id) + parent_run_id_ = str(parent_run_id) if parent_run_id else None + + execution_order = self._get_execution_order(parent_run_id_) chain_run = ChainRun( - uuid=run_id, - parent_uuid=parent_run_id, + uuid=run_id_, + parent_uuid=parent_run_id_, serialized=serialized, inputs=inputs, extra=kwargs, @@ -206,10 +211,12 @@ def on_chain_start( self._start_trace(chain_run) def on_chain_end( - self, outputs: Dict[str, Any], *, run_id: str, **kwargs: Any + self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any ) -> None: """End a trace for a chain run.""" - chain_run = self.run_map.get(run_id) + run_id_ = str(run_id) + + chain_run = self.run_map.get(run_id_) if chain_run is None or not isinstance(chain_run, ChainRun): raise TracerException("No ChainRun found to be traced") @@ -221,11 +228,13 @@ def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], *, - run_id: str, + run_id: UUID, **kwargs: Any, ) -> None: """Handle an error for a chain run.""" - chain_run = self.run_map.get(run_id) + run_id_ = str(run_id) + + chain_run = self.run_map.get(run_id_) if chain_run is None or not isinstance(chain_run, ChainRun): raise TracerException("No ChainRun found to be traced") @@ -238,18 +247,21 @@ def on_tool_start( serialized: Dict[str, Any], input_str: str, *, - run_id: str, - parent_run_id: Optional[str] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> None: """Start a trace for a tool run.""" if self.session is None: self.session = self.load_default_session() - execution_order = self._get_execution_order(parent_run_id) + run_id_ = str(run_id) + parent_run_id_ = str(parent_run_id) if parent_run_id else None + + execution_order = self._get_execution_order(parent_run_id_) tool_run = ToolRun( - uuid=run_id, - parent_uuid=parent_run_id, + uuid=run_id_, + parent_uuid=parent_run_id_, serialized=serialized, # TODO: this is duplicate info as above, not needed. action=str(serialized), @@ -263,9 +275,11 @@ def on_tool_start( ) self._start_trace(tool_run) - def on_tool_end(self, output: str, *, run_id: str, **kwargs: Any) -> None: + def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None: """End a trace for a tool run.""" - tool_run = self.run_map.get(run_id) + run_id_ = str(run_id) + + tool_run = self.run_map.get(run_id_) if tool_run is None or not isinstance(tool_run, ToolRun): raise TracerException("No ToolRun found to be traced") @@ -277,11 +291,13 @@ def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], *, - run_id: str, + run_id: UUID, **kwargs: Any, ) -> None: """Handle an error for a tool run.""" - tool_run = self.run_map.get(run_id) + run_id_ = str(run_id) + + tool_run = self.run_map.get(run_id_) if tool_run is None or not isinstance(tool_run, ToolRun): raise TracerException("No ToolRun found to be traced") diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py index c60373d63042d..66f0c38727ca7 100644 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -22,98 +22,6 @@ TEST_SESSION_ID = 2023 -@freeze_time("2023-01-01") -def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: - return ChainRun( - uuid="chain_uuid", - error=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - child_execution_order=4, - serialized={}, - inputs={}, - outputs={}, - session_id=TEST_SESSION_ID, - child_chain_runs=[], - child_tool_runs=[ - ToolRun( - uuid="tool_uuid", - parent_uuid="chain_uuid", - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=2, - child_execution_order=3, - serialized={}, - tool_input="test", - output="test", - action="{}", - session_id=TEST_SESSION_ID, - error=None, - child_chain_runs=[], - child_tool_runs=[], - child_llm_runs=[ - LLMRun( - uuid="llm_uuid1", - parent_uuid="tool_uuid", - error=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=3, - child_execution_order=3, - serialized={}, - prompts=[], - response=LLMResult(generations=[[]]), - session_id=TEST_SESSION_ID, - ) - ], - ), - ], - child_llm_runs=[ - LLMRun( - uuid="llm_uuid2", - parent_uuid="chain_uuid", - error=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=4, - child_execution_order=4, - serialized={}, - prompts=[], - response=LLMResult(generations=[[]]), - session_id=TEST_SESSION_ID, - ), - ], - ) - - -def _perform_nested_run(tracer: BaseTracer) -> None: - """Perform a nested run.""" - chain_uuid = "chain_uuid" - tool_uuid = "tool_uuid" - llm_uuid1 = "llm_uuid1" - llm_uuid2 = "llm_uuid2" - - tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid) - tracer.on_tool_start( - serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid - ) - tracer.on_llm_start( - serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid - ) - tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) - tracer.on_tool_end("test", run_id=tool_uuid) - tracer.on_llm_start( - serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid - ) - tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) - tracer.on_chain_end(outputs={}, run_id=chain_uuid) - - def load_session(session_name: str) -> TracerSession: """Load a tracing session.""" return TracerSession(id=1, name=session_name, start_time=datetime.utcnow()) @@ -157,9 +65,9 @@ def load_default_session(self) -> TracerSession: @freeze_time("2023-01-01") def test_tracer_llm_run() -> None: """Test tracer on an LLM run.""" - uuid = str(uuid4()) + uuid = uuid4() compare_run = LLMRun( - uuid=uuid, + uuid=str(uuid), parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -187,15 +95,15 @@ def test_tracer_llm_run_errors_no_start() -> None: tracer.new_session() with pytest.raises(TracerException): - tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=str(uuid4())) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4()) @freeze_time("2023-01-01") def test_tracer_multiple_llm_runs() -> None: """Test the tracer with multiple runs.""" - uuid = str(uuid4()) + uuid = uuid4() compare_run = LLMRun( - uuid=uuid, + uuid=str(uuid), parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -222,9 +130,9 @@ def test_tracer_multiple_llm_runs() -> None: @freeze_time("2023-01-01") def test_tracer_chain_run() -> None: """Test tracer on a Chain run.""" - uuid = str(uuid4()) + uuid = uuid4() compare_run = ChainRun( - uuid=uuid, + uuid=str(uuid), parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -248,9 +156,9 @@ def test_tracer_chain_run() -> None: @freeze_time("2023-01-01") def test_tracer_tool_run() -> None: """Test tracer on a Tool run.""" - uuid = str(uuid4()) + uuid = uuid4() compare_run = ToolRun( - uuid=uuid, + uuid=str(uuid), parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -277,19 +185,103 @@ def test_tracer_nested_run() -> None: """Test tracer on a nested run.""" tracer = FakeTracer() tracer.new_session() + + chain_uuid = uuid4() + tool_uuid = uuid4() + llm_uuid1 = uuid4() + llm_uuid2 = uuid4() for _ in range(10): - _perform_nested_run(tracer) - assert tracer.runs == [_get_compare_run()] * 10 + tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid) + tracer.on_tool_start( + serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid + ) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) + tracer.on_tool_end("test", run_id=tool_uuid) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) + tracer.on_chain_end(outputs={}, run_id=chain_uuid) + + compare_run = ChainRun( + uuid=str(chain_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=4, + serialized={}, + inputs={}, + outputs={}, + session_id=TEST_SESSION_ID, + child_chain_runs=[], + child_tool_runs=[ + ToolRun( + uuid=str(tool_uuid), + parent_uuid=str(chain_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=2, + child_execution_order=3, + serialized={}, + tool_input="test", + output="test", + action="{}", + session_id=TEST_SESSION_ID, + error=None, + child_chain_runs=[], + child_tool_runs=[], + child_llm_runs=[ + LLMRun( + uuid=str(llm_uuid1), + parent_uuid=str(tool_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=3, + child_execution_order=3, + serialized={}, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + ) + ], + ), + ], + child_llm_runs=[ + LLMRun( + uuid=str(llm_uuid2), + parent_uuid=str(chain_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=4, + child_execution_order=4, + serialized={}, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + ), + ], + ) + assert tracer.runs == [compare_run] * 10 @freeze_time("2023-01-01") def test_tracer_llm_run_on_error() -> None: """Test tracer on an LLM run with an error.""" exception = Exception("test") - uuid = str(uuid4()) + uuid = uuid4() compare_run = LLMRun( - uuid=uuid, + uuid=str(uuid), parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -314,10 +306,10 @@ def test_tracer_llm_run_on_error() -> None: def test_tracer_chain_run_on_error() -> None: """Test tracer on a Chain run with an error.""" exception = Exception("test") - uuid = str(uuid4()) + uuid = uuid4() compare_run = ChainRun( - uuid=uuid, + uuid=str(uuid), parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -342,10 +334,10 @@ def test_tracer_chain_run_on_error() -> None: def test_tracer_tool_run_on_error() -> None: """Test tracer on a Tool run with an error.""" exception = Exception("test") - uuid = str(uuid4()) + uuid = uuid4() compare_run = ToolRun( - uuid=uuid, + uuid=str(uuid), parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -374,11 +366,11 @@ def test_tracer_nested_runs_on_error() -> None: tracer = FakeTracer() tracer.new_session() - chain_uuid = "chain_uuid" - tool_uuid = "tool_uuid" - llm_uuid1 = "llm_uuid1" - llm_uuid2 = "llm_uuid2" - llm_uuid3 = "llm_uuid3" + chain_uuid = uuid4() + tool_uuid = uuid4() + llm_uuid1 = uuid4() + llm_uuid2 = uuid4() + llm_uuid3 = uuid4() for _ in range(3): tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid) @@ -401,7 +393,7 @@ def test_tracer_nested_runs_on_error() -> None: tracer.on_chain_error(exception, run_id=chain_uuid) compare_run = ChainRun( - uuid=chain_uuid, + uuid=str(chain_uuid), start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -414,8 +406,8 @@ def test_tracer_nested_runs_on_error() -> None: outputs=None, child_llm_runs=[ LLMRun( - uuid=llm_uuid1, - parent_uuid=chain_uuid, + uuid=str(llm_uuid1), + parent_uuid=str(chain_uuid), start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -428,8 +420,8 @@ def test_tracer_nested_runs_on_error() -> None: response=LLMResult(generations=[[]], llm_output=None), ), LLMRun( - uuid=llm_uuid2, - parent_uuid=chain_uuid, + uuid=str(llm_uuid2), + parent_uuid=str(chain_uuid), start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -445,8 +437,8 @@ def test_tracer_nested_runs_on_error() -> None: child_chain_runs=[], child_tool_runs=[ ToolRun( - uuid=tool_uuid, - parent_uuid=chain_uuid, + uuid=str(tool_uuid), + parent_uuid=str(chain_uuid), start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, @@ -460,8 +452,8 @@ def test_tracer_nested_runs_on_error() -> None: action="{}", child_llm_runs=[ LLMRun( - uuid=llm_uuid3, - parent_uuid=tool_uuid, + uuid=str(llm_uuid3), + parent_uuid=str(tool_uuid), start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, From 3839703fe876bc4e785660ac31533b8ed39e68a8 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Sat, 29 Apr 2023 14:16:35 -0700 Subject: [PATCH 34/36] bw compat environ variable --- langchain/callbacks/manager.py | 4 ++- .../callbacks/test_langchain_tracer.py | 32 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index e001e6e1cd61d..a21847a524f46 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -703,7 +703,9 @@ def _configure( tracer = tracing_callback_var.get() open_ai = openai_callback_var.get() tracing_enabled_ = ( - os.environ.get("LANGCHAIN_TRACING") is not None or tracer is not None + os.environ.get("LANGCHAIN_TRACING") is not None + or tracer is not None + or os.environ.get("LANGCHAIN_HANDLER") is not None ) tracer_session = os.environ.get("LANGCHAIN_SESSION") if tracer_session is None: diff --git a/tests/integration_tests/callbacks/test_langchain_tracer.py b/tests/integration_tests/callbacks/test_langchain_tracer.py index ffcb8c467446a..cd830ea28ef79 100644 --- a/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -42,6 +42,20 @@ def test_tracing_sequential() -> None: agent.run(q) +def test_tracing_session_env_var() -> None: + os.environ["LANGCHAIN_TRACING"] = "true" + os.environ["LANGCHAIN_SESSION"] = "my_session" + + llm = OpenAI(temperature=0) + tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + agent.run(questions[0]) + if "LANGCHAIN_SESSION" in os.environ: + del os.environ["LANGCHAIN_SESSION"] + + @pytest.mark.asyncio async def test_tracing_concurrent() -> None: os.environ["LANGCHAIN_TRACING"] = "true" @@ -56,6 +70,24 @@ async def test_tracing_concurrent() -> None: await aiosession.close() +@pytest.mark.asyncio +async def test_tracing_concurrent_bw_compat_environ() -> None: + os.environ["LANGCHAIN_HANDLER"] = "langchain" + if "LANGCHAIN_TRACING" in os.environ: + del os.environ["LANGCHAIN_TRACING"] + aiosession = ClientSession() + llm = OpenAI(temperature=0) + async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession) + agent = initialize_agent( + async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + tasks = [agent.arun(q) for q in questions[:3]] + await asyncio.gather(*tasks) + await aiosession.close() + if "LANGCHAIN_HANDLER" in os.environ: + del os.environ["LANGCHAIN_HANDLER"] + + def test_tracing_context_manager() -> None: llm = OpenAI(temperature=0) tools = load_tools(["llm-math", "serpapi"], llm=llm) From fa1742ce69e22bcaccff853cbc5efad84aea6cbc Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Sat, 29 Apr 2023 15:04:13 -0700 Subject: [PATCH 35/36] fix openai callback --- langchain/callbacks/openai_info.py | 8 ++++++++ .../callbacks/test_openai_callback.py | 20 ++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/langchain/callbacks/openai_info.py b/langchain/callbacks/openai_info.py index 9181ee897a186..3c77f1f2218d2 100644 --- a/langchain/callbacks/openai_info.py +++ b/langchain/callbacks/openai_info.py @@ -157,3 +157,11 @@ def on_agent_finish( ) -> None: """Run on agent end.""" pass + + def __copy__(self) -> "OpenAICallbackHandler": + """Return a copy of the callback handler.""" + return self + + def __deepcopy__(self, memo: Any) -> "OpenAICallbackHandler": + """Return a deep copy of the callback handler.""" + return self diff --git a/tests/integration_tests/callbacks/test_openai_callback.py b/tests/integration_tests/callbacks/test_openai_callback.py index 91a4a30aa1d1f..9704cb5612f86 100644 --- a/tests/integration_tests/callbacks/test_openai_callback.py +++ b/tests/integration_tests/callbacks/test_openai_callback.py @@ -3,8 +3,9 @@ import pytest -from langchain import OpenAI +from langchain.agents import AgentType, initialize_agent, load_tools from langchain.callbacks import get_openai_callback +from langchain.llms import OpenAI @pytest.mark.asyncio @@ -35,3 +36,20 @@ async def test_openai_callback() -> None: await task assert cb.total_tokens == total_tokens + + +def test_openai_callback_agent() -> None: + llm = OpenAI(temperature=0) + tools = load_tools(["serpapi", "llm-math"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + with get_openai_callback() as cb: + agent.run( + "Who is Olivia Wilde's boyfriend? " + "What is his current age raised to the 0.23 power?" + ) + print(f"Total Tokens: {cb.total_tokens}") + print(f"Prompt Tokens: {cb.prompt_tokens}") + print(f"Completion Tokens: {cb.completion_tokens}") + print(f"Total Cost (USD): ${cb.total_cost}") From fb78f69166e4dc1f6b6b774af0c3dc645074ce44 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 30 Apr 2023 10:12:41 -0700 Subject: [PATCH 36/36] cr --- docs/index.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index 04e3abcb9df2f..0533d78c4e824 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -44,6 +44,8 @@ These modules are, in increasing order of complexity: - `Agents <./modules/agents.html>`_: Agents involve an LLM making decisions about which Actions to take, taking that Action, seeing an Observation, and repeating that until done. LangChain provides a standard interface for agents, a selection of agents to choose from, and examples of end to end agents. +- `Callbacks <./modules/callbacks/getting_started.html>`_: It can be difficult to track all that occurs inside a chain or agent - callbacks help add a level of observability and introspection. + .. toctree:: :maxdepth: 1 @@ -57,6 +59,7 @@ These modules are, in increasing order of complexity: ./modules/memory.md ./modules/chains.md ./modules/agents.md + ./modules/callbacks/getting_started.ipynb Use Cases ----------