Skip to content

Commit

Permalink
core, community: propagate context between threads (#15171)
Browse files Browse the repository at this point in the history
While using `chain.batch`, the default implementation uses a
`ThreadPoolExecutor` and run the chains in separate threads. An issue
with this approach is that that [the token counting
callback](https://python.langchain.com/docs/modules/callbacks/token_counting)
fails to work as a consequence of the context not being propagated
between threads. This PR adds context propagation to the new threads and
adds some thread synchronization in the OpenAI callback. With this
change, the token counting callback works as intended.

Having the context propagation change would be highly beneficial for
those implementing custom callbacks for similar functionalities as well.

---------

Co-authored-by: Nuno Campos <[email protected]>
  • Loading branch information
joshy-deshaw and nfcampos authored Dec 28, 2023
1 parent f74151b commit bf53855
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 26 deletions.
24 changes: 20 additions & 4 deletions libs/community/langchain_community/callbacks/openai_info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Callback Handler that prints to std out."""
import threading
from typing import Any, Dict, List

from langchain_core.callbacks import BaseCallbackHandler
Expand Down Expand Up @@ -154,6 +155,10 @@ class OpenAICallbackHandler(BaseCallbackHandler):
successful_requests: int = 0
total_cost: float = 0.0

def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()

def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
Expand Down Expand Up @@ -182,9 +187,13 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
return None
self.successful_requests += 1

if "token_usage" not in response.llm_output:
with self._lock:
self.successful_requests += 1
return None

# compute tokens and cost for this request
token_usage = response.llm_output["token_usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
Expand All @@ -194,10 +203,17 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
model_name, completion_tokens, is_completion=True
)
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
else:
completion_cost = 0
prompt_cost = 0

# update shared state behind lock
with self._lock:
self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.successful_requests += 1

def __copy__(self) -> "OpenAICallbackHandler":
"""Return a copy of the callback handler."""
Expand Down
48 changes: 35 additions & 13 deletions libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager
from contextvars import Context, copy_context
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -271,12 +272,25 @@ def handle_event(
# we end up in a deadlock, as we'd have gotten here from a
# running coroutine, which we cannot interrupt to run this one.
# The solution is to create a new loop in a new thread.
with ThreadPoolExecutor(1) as executor:
with _executor_w_context(1) as executor:
executor.submit(_run_coros, coros).result()
else:
_run_coros(coros)


def _set_context(context: Context) -> None:
for var, value in context.items():
var.set(value)


def _executor_w_context(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
return ThreadPoolExecutor(
max_workers=max_workers,
initializer=_set_context,
initargs=(copy_context(),),
)


def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
if hasattr(asyncio, "Runner"):
# Python 3.11+
Expand All @@ -301,6 +315,7 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:


async def _ahandle_event_for_handler(
executor: ThreadPoolExecutor,
handler: BaseCallbackHandler,
event_name: str,
ignore_condition_name: Optional[str],
Expand All @@ -317,12 +332,13 @@ async def _ahandle_event_for_handler(
event(*args, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(event, *args, **kwargs)
executor, functools.partial(event, *args, **kwargs)
)
except NotImplementedError as e:
if event_name == "on_chat_model_start":
message_strings = [get_buffer_string(m) for m in args[1]]
await _ahandle_event_for_handler(
executor,
handler,
"on_llm_start",
"ignore_llm",
Expand Down Expand Up @@ -364,19 +380,25 @@ async def ahandle_event(
*args: The arguments to pass to the event handler
**kwargs: The keyword arguments to pass to the event handler
"""
for handler in [h for h in handlers if h.run_inline]:
await _ahandle_event_for_handler(
handler, event_name, ignore_condition_name, *args, **kwargs
)
await asyncio.gather(
*(
_ahandle_event_for_handler(
handler, event_name, ignore_condition_name, *args, **kwargs
with _executor_w_context() as executor:
for handler in [h for h in handlers if h.run_inline]:
await _ahandle_event_for_handler(
executor, handler, event_name, ignore_condition_name, *args, **kwargs
)
await asyncio.gather(
*(
_ahandle_event_for_handler(
executor,
handler,
event_name,
ignore_condition_name,
*args,
**kwargs,
)
for handler in handlers
if not handler.run_inline
)
for handler in handlers
if not handler.run_inline
)
)


BRM = TypeVar("BRM", bound="BaseRunManager")
Expand Down
3 changes: 2 additions & 1 deletion libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ async def astream(
if type(self)._astream == BaseChatModel._astream:
# model doesn't implement streaming, so use default implementation
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
BaseMessageChunk,
await self.ainvoke(input, config=config, stop=stop, **kwargs),
)
else:
config = config or {}
Expand Down
14 changes: 8 additions & 6 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,10 @@ async def ainvoke(
Subclasses should override this method if they can run asynchronously.
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, **kwargs), input, config
)
with get_executor_for_config(config) as executor:
return await asyncio.get_running_loop().run_in_executor(
executor, partial(self.invoke, **kwargs), input, config
)

def batch(
self,
Expand Down Expand Up @@ -2882,9 +2883,10 @@ async def _ainvoke(

@wraps(self.func)
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.func, **kwargs), *args
)
with get_executor_for_config(config) as executor:
return await asyncio.get_running_loop().run_in_executor(
executor, partial(self.func, **kwargs), *args
)

afunc = f

Expand Down
17 changes: 15 additions & 2 deletions libs/core/langchain_core/runnables/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import Context, copy_context
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -387,8 +388,15 @@ def get_async_callback_manager_for_config(
)


def _set_context(context: Context) -> None:
for var, value in context.items():
var.set(value)


@contextmanager
def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None, None]:
def get_executor_for_config(
config: Optional[RunnableConfig]
) -> Generator[Executor, None, None]:
"""Get an executor for a config.
Args:
Expand All @@ -397,5 +405,10 @@ def get_executor_for_config(config: RunnableConfig) -> Generator[Executor, None,
Yields:
Generator[Executor, None, None]: The executor.
"""
with ThreadPoolExecutor(max_workers=config.get("max_concurrency")) as executor:
config = config or {}
with ThreadPoolExecutor(
max_workers=config.get("max_concurrency"),
initializer=_set_context,
initargs=(copy_context(),),
) as executor:
yield executor

0 comments on commit bf53855

Please sign in to comment.