Skip to content

Commit

Permalink
make langchain CoT more readable (#518)
Browse files Browse the repository at this point in the history
* make langchain CoT more readable

* fix type error

* enhance ignore logic

* make only retriever and llm annotable

* enhance readability

* support functions in pp

* fix open ai chat without functions
  • Loading branch information
willydouhard authored Nov 6, 2023
1 parent 3fa5973 commit 058cab1
Show file tree
Hide file tree
Showing 17 changed files with 368 additions and 57 deletions.
160 changes: 130 additions & 30 deletions backend/chainlit/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def _build_completion_prompt(self, serialized: Dict, inputs: Dict):
template_format = kwargs.get("template_format")
stringified_inputs = {k: str(v) for (k, v) in inputs.items()}

if not template:
return

self.prompt_sequence.append(
Prompt(
template=template,
Expand Down Expand Up @@ -258,6 +261,9 @@ def build_template_messages() -> List[PromptMessage]:

template_messages = build_template_messages()

if not template_messages:
return

stringified_inputs = {k: str(v) for (k, v) in inputs.items()}
self.prompt_sequence.append(
Prompt(messages=template_messages, inputs=stringified_inputs)
Expand Down Expand Up @@ -344,14 +350,27 @@ def _build_llm_settings(
return provider, settings


DEFAULT_TO_IGNORE = ["RunnableSequence", "RunnableParallel", "<lambda>"]
DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"]


class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
llm_stream_message: Dict[str, Message]
parent_id_map: Dict[str, str]
ignored_runs: set

def __init__(
self,
# Token sequence that prefixes the answer
answer_prefix_tokens: Optional[List[str]] = None,
# Should we stream the final answer?
stream_final_answer: bool = False,
# Should force stream the first response?
force_stream_final_answer: bool = False,
# Runs to ignore to enhance readability
to_ignore: Optional[List[str]] = None,
# Runs to keep within ignored runs
to_keep: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
BaseTracer.__init__(self, **kwargs)
Expand All @@ -364,13 +383,84 @@ def __init__(
)
self.context = context_var.get()
self.llm_stream_message = {}
self.parent_id_map = {}
self.ignored_runs = set()
self.root_parent_id = (
self.context.session.root_message.id
if self.context.session.root_message
else None
)

if to_ignore is None:
self.to_ignore = DEFAULT_TO_IGNORE
else:
self.to_ignore = to_ignore

if to_keep is None:
self.to_keep = DEFAULT_TO_KEEP
else:
self.to_keep = to_keep

def _run_sync(self, co):
asyncio.run_coroutine_threadsafe(co, loop=self.context.loop)

def _persist_run(self, run: Run) -> None:
pass

def _get_run_parent_id(self, run: Run):
parent_id = str(run.parent_run_id) if run.parent_run_id else self.root_parent_id

return parent_id

def _get_non_ignored_parent_id(self, current_parent_id: Optional[str] = None):
if not current_parent_id:
return self.root_parent_id

if current_parent_id not in self.parent_id_map:
return current_parent_id

while current_parent_id in self.parent_id_map:
current_parent_id = self.parent_id_map[current_parent_id]

return current_parent_id

def _should_ignore_run(self, run: Run):
parent_id = self._get_run_parent_id(run)

ignore_by_name = run.name in self.to_ignore
ignore_by_parent = parent_id in self.ignored_runs

ignore = ignore_by_name or ignore_by_parent

if ignore:
if parent_id:
# Add the parent id of the ignored run in the mapping
# so we can re-attach a kept child to the right parent id
self.parent_id_map[str(run.id)] = parent_id
# Tag the run as ignored
self.ignored_runs.add(str(run.id))

# If the ignore cause is the parent being ignored, check if we should nonetheless keep the child
if ignore_by_parent and not ignore_by_name and run.run_type in self.to_keep:
return False, self._get_non_ignored_parent_id(str(run.id))
else:
return ignore, parent_id

def _is_annotable(self, run: Run):
return run.run_type in ["retriever", "llm"]

def _get_completion(self, generation: Dict):
if message := generation.get("message"):
kwargs = message.get("kwargs", {})
if function_call := kwargs.get("additional_kwargs", {}).get(
"function_call"
):
return json.dumps(function_call), "json"
else:
return kwargs.get("content", ""), None
else:
return generation.get("text", "")

def on_chat_model_start(
self,
serialized: Dict[str, Any],
Expand Down Expand Up @@ -436,51 +526,51 @@ def on_llm_new_token(

def _start_trace(self, run: Run) -> None:
super()._start_trace(run)

context_var.set(self.context)
root_message_id = (
self.context.session.root_message.id
if self.context.session.root_message
else None
)
parent_id = str(run.parent_run_id) if run.parent_run_id else root_message_id

if run.run_type in ["chain", "prompt"]:
# Prompt templates are contained in chains or prompts (lcel)
self._build_prompt(run.serialized or {}, run.inputs)

ignore, parent_id = self._should_ignore_run(run)

if ignore:
return

disable_human_feedback = not self._is_annotable(run)

if run.run_type == "llm":
msg = Message(
id=run.id,
content="",
author=run.name,
parent_id=parent_id,
disable_human_feedback=disable_human_feedback,
)
self.llm_stream_message[str(run.id)] = msg
self._run_sync(msg.send())
return

content = run.inputs

self._run_sync(
Message(
id=run.id,
content=content,
language="json",
content="",
author=run.name,
parent_id=parent_id,
disable_human_feedback=disable_human_feedback,
).send()
)

def _on_run_update(self, run: Run) -> None:
"""Process a run upon update."""
context_var.set(self.context)

root_message_id = (
self.context.session.root_message.id
if self.context.session.root_message
else None
)
parent_id = str(run.parent_run_id) if run.parent_run_id else root_message_id
ignore, parent_id = self._should_ignore_run(run)

if ignore:
return

disable_human_feedback = not self._is_annotable(run)

if run.run_type in ["chain"]:
if self.prompt_sequence:
Expand All @@ -491,8 +581,7 @@ def _on_run_update(self, run: Run) -> None:
(run.serialized or {}), (run.extra or {}).get("invocation_params")
)
generations = (run.outputs or {}).get("generations", [])
completion = generations[0][0]["text"]
generation_type = generations[0][0]["type"]
completion, language = self._get_completion(generations[0][0])

current_prompt = (
self.prompt_sequence.pop() if self.prompt_sequence else None
Expand All @@ -503,34 +592,45 @@ def _on_run_update(self, run: Run) -> None:
current_prompt.settings = llm_settings
current_prompt.completion = completion
else:
generation_type = generations[0][0].get("type", "")
current_prompt = self._build_default_prompt(
run, generation_type, provider, llm_settings, completion
)

msg = self.llm_stream_message.get(str(run.id), None)
if msg:
msg.content = completion
msg.language = language
msg.prompt = current_prompt
self._run_sync(msg.update())
return

outputs = run.outputs or {}
output_keys = list(outputs.keys())
if output_keys:
content = outputs.get(output_keys[0], "")
else:
return

if run.run_type in ["agent", "chain", "tool"]:
# Add the response of the chain/tool
self._run_sync(
Message(
content=run.outputs or {},
language="json",
author=run.name,
parent_id=parent_id,
).send()
)
pass
# # Add the response of the chain/tool
# self._run_sync(
# Message(
# content=content,
# author=run.name,
# parent_id=parent_id,
# disable_human_feedback=disable_human_feedback,
# ).send()
# )
else:
self._run_sync(
Message(
id=run.id,
content=run.outputs or {},
language="json",
content=content,
author=run.name,
parent_id=parent_id,
disable_human_feedback=disable_human_feedback,
).update()
)

Expand Down
34 changes: 25 additions & 9 deletions backend/chainlit/playground/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from contextlib import contextmanager

from chainlit.input_widget import Select, Slider, Tags
Expand Down Expand Up @@ -113,7 +114,6 @@ async def create_completion(self, request):
import openai

env_settings = self.validate_env(request=request)

deployment_id = self.get_var(request, "OPENAI_API_DEPLOYMENT_ID")

if deployment_id:
Expand All @@ -134,7 +134,11 @@ async def create_completion(self, request):

llm_settings["stop"] = stop

llm_settings["stream"] = True
if request.prompt.functions:
llm_settings["functions"] = request.prompt.functions
llm_settings["stream"] = False
else:
llm_settings["stream"] = True

with handle_openai_error():
response = await openai.ChatCompletion.acreate(
Expand All @@ -143,13 +147,25 @@ async def create_completion(self, request):
**llm_settings,
)

async def create_event_stream():
async for stream_resp in response:
if hasattr(stream_resp, 'choices') and len(stream_resp.choices) > 0:
token = stream_resp.choices[0]["delta"].get("content", "")
yield token
else:
continue
if llm_settings["stream"]:

async def create_event_stream():
async for stream_resp in response:
if hasattr(stream_resp, "choices") and len(stream_resp.choices) > 0:
delta = stream_resp.choices[0]["delta"]
token = delta.get("content", "")
if token:
yield token
else:
continue

else:

async def create_event_stream():
function_call = json.dumps(
response.choices[0]["message"]["function_call"]
)
yield function_call

return StreamingResponse(create_event_stream())

Expand Down
1 change: 1 addition & 0 deletions backend/chainlit/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ class Prompt(BaseTemplate):
completion: Optional[str] = None
settings: Optional[Dict[str, Any]] = None
messages: Optional[List[PromptMessage]] = None
functions: Optional[List[Dict]] = None
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "chainlit"
version = "0.7.400"
version = "0.7.401"
keywords = ['LLM', 'Agents', 'gen ai', 'chat ui', 'chatbot ui', 'langchain']
description = "A faster way to build chatbot UIs."
authors = ["Chainlit"]
Expand Down
17 changes: 16 additions & 1 deletion frontend/src/components/organisms/chat/Messages/container.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
IAction,
IAsk,
IAvatarElement,
IFunction,
IMessage,
IMessageElement
} from '@chainlit/components';
Expand Down Expand Up @@ -58,8 +59,22 @@ const MessageContainer = memo(
(message: IMessage) => {
setPlayground((old) => ({
...old,
prompt: message.prompt,
prompt: message.prompt
? {
...message.prompt,
functions:
(message.prompt.settings
?.functions as unknown as IFunction[]) || []
}
: undefined,
originalPrompt: message.prompt
? {
...message.prompt,
functions:
(message.prompt.settings
?.functions as unknown as IFunction[]) || []
}
: undefined
}));
},
[setPlayground]
Expand Down
Loading

0 comments on commit 058cab1

Please sign in to comment.