Skip to content

Commit

Permalink
feat(autofix): Stream output in Autofix agent (#1617)
Browse files Browse the repository at this point in the history
Replaces Autofix agent LLM calls with streaming. Saves the output as it
comes in the step object in the run state. All other functionality is
identical.

I added a function for streaming with OpenAI as well, and tested that it
works, but we don't use it at the moment. Just future proofing.

---------

Co-authored-by: Jenn Mueng <[email protected]>
  • Loading branch information
roaga and jennmueng authored Dec 13, 2024
1 parent 5d644ca commit 9296c53
Show file tree
Hide file tree
Showing 17 changed files with 2,409 additions and 806 deletions.
253 changes: 252 additions & 1 deletion src/seer/automation/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import re
from dataclasses import dataclass
from typing import Any, ClassVar, Iterable, Type, Union, cast
from typing import Any, ClassVar, Iterable, Iterator, Type, Union, cast

import anthropic
from anthropic import NOT_GIVEN
Expand Down Expand Up @@ -270,6 +270,94 @@ def _prep_message_and_tools(

return message_dicts, tool_dicts

@observe(as_type="generation", name="OpenAI Stream")
def generate_text_stream(
self,
*,
prompt: str | None = None,
messages: list[Message] | None = None,
system_prompt: str | None = None,
tools: list[FunctionTool] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
) -> Iterator[str | ToolCall | Usage]:
message_dicts, tool_dicts = self._prep_message_and_tools(
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
tools=tools,
)

openai_client = self.get_client()

stream = openai_client.chat.completions.create(
model=self.model_name,
messages=cast(Iterable[ChatCompletionMessageParam], message_dicts),
temperature=temperature,
tools=(
cast(Iterable[ChatCompletionToolParam], tool_dicts)
if tool_dicts
else openai.NotGiven()
),
max_tokens=max_tokens or openai.NotGiven(),
timeout=timeout or openai.NotGiven(),
stream=True,
stream_options={"include_usage": True},
)

try:
current_tool_call: dict[str, Any] | None = None
current_tool_call_index = 0

for chunk in stream:
if not chunk.choices and chunk.usage:
usage = Usage(
completion_tokens=chunk.usage.completion_tokens,
prompt_tokens=chunk.usage.prompt_tokens,
total_tokens=chunk.usage.total_tokens,
)
yield usage
langfuse_context.update_current_observation(model=self.model_name, usage=usage)
break

delta = chunk.choices[0].delta
if delta.tool_calls:
tool_call = delta.tool_calls[0]

if (
not current_tool_call or current_tool_call_index != tool_call.index
): # Start of new tool call
current_tool_call_index = tool_call.index
if current_tool_call:
yield ToolCall(**current_tool_call)
current_tool_call = None
current_tool_call = {
"id": tool_call.id,
"function": tool_call.function.name if tool_call.function.name else "",
"args": (
tool_call.function.arguments if tool_call.function.arguments else ""
),
}
else:
if tool_call.function.arguments:
current_tool_call["args"] += tool_call.function.arguments
if chunk.choices[0].finish_reason == "tool_calls" and current_tool_call:
yield ToolCall(**current_tool_call)
if delta.content:
yield delta.content
finally:
stream.response.close()

def construct_message_from_stream(
self, content_chunks: list[str], tool_calls: list[ToolCall]
) -> Message:
return Message(
role="assistant",
content="".join(content_chunks) if content_chunks else None,
tool_calls=tool_calls if tool_calls else None,
)


@dataclass
class AnthropicProvider:
Expand Down Expand Up @@ -455,6 +543,96 @@ def _prep_message_and_tools(

return message_dicts, tool_dicts, system_prompt

@observe(as_type="generation", name="Anthropic Stream")
def generate_text_stream(
self,
*,
prompt: str | None = None,
messages: list[Message] | None = None,
system_prompt: str | None = None,
tools: list[FunctionTool] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
) -> Iterator[str | ToolCall | Usage]:
message_dicts, tool_dicts, system_prompt = self._prep_message_and_tools(
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
tools=tools,
)

anthropic_client = self.get_client()

stream = anthropic_client.messages.create(
system=system_prompt or NOT_GIVEN,
model=self.model_name,
tools=cast(Iterable[ToolParam], tool_dicts) if tool_dicts else NOT_GIVEN,
messages=cast(Iterable[MessageParam], message_dicts),
max_tokens=max_tokens or 8192,
temperature=temperature or NOT_GIVEN,
timeout=timeout or NOT_GIVEN,
stream=True,
)

try:
current_tool_call: dict[str, Any] | None = None
current_input_json = []
total_input_tokens = 0
total_output_tokens = 0

for chunk in stream:
if chunk.type == "message_start" and chunk.message.usage:
total_input_tokens += chunk.message.usage.input_tokens
total_output_tokens += chunk.message.usage.output_tokens
elif chunk.type == "message_delta" and chunk.usage:
total_output_tokens += chunk.usage.output_tokens

if chunk.type == "message_stop":
break
elif chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
yield chunk.delta.text
elif chunk.type == "content_block_start" and chunk.content_block.type == "tool_use":
# Start accumulating a new tool call
current_tool_call = {
"id": chunk.content_block.id,
"function": chunk.content_block.name,
"args": "",
}
elif chunk.type == "content_block_delta" and chunk.delta.type == "input_json_delta":
# Accumulate the input JSON
if current_tool_call:
current_input_json.append(chunk.delta.partial_json)
elif chunk.type == "content_block_stop" and current_tool_call:
# Tool call is complete, yield it
current_tool_call["args"] = "".join(current_input_json)
yield ToolCall(**current_tool_call)
current_tool_call = None
current_input_json = []
finally:
usage = Usage(
completion_tokens=total_output_tokens,
prompt_tokens=total_input_tokens,
total_tokens=total_input_tokens + total_output_tokens,
)
yield usage
langfuse_context.update_current_observation(model=self.model_name, usage=usage)
stream.response.close()

def construct_message_from_stream(
self, content_chunks: list[str], tool_calls: list[ToolCall]
) -> Message:
message = Message(
role="tool_use" if tool_calls else "assistant",
content="".join(content_chunks) if content_chunks else None,
)

if tool_calls:
message.tool_calls = tool_calls
message.tool_call_id = tool_calls[0].id

return message


LlmProvider = Union[OpenAiProvider, AnthropicProvider]

Expand Down Expand Up @@ -558,6 +736,64 @@ def generate_structured(
logger.exception(f"Text generation failed with provider {model.provider_name}: {e}")
raise e

@observe(name="Generate Text Stream")
def generate_text_stream(
self,
*,
prompt: str | None = None,
messages: list[Message] | None = None,
model: LlmProvider,
system_prompt: str | None = None,
tools: list[FunctionTool] | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
run_name: str | None = None,
timeout: float | None = None,
) -> Iterator[str | ToolCall | Usage]:
try:
if run_name:
langfuse_context.update_current_observation(
name=run_name + " - Generate Text Stream"
)
langfuse_context.flush()

defaults = model.defaults
default_temperature = defaults.temperature if defaults else None

messages = LlmClient.clean_message_content(messages if messages else [])
if not tools:
messages = LlmClient.clean_tool_call_assistant_messages(messages)

if model.provider_name == LlmProviderType.OPENAI:
model = cast(OpenAiProvider, model)
yield from model.generate_text_stream(
max_tokens=max_tokens,
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
temperature=temperature or default_temperature,
tools=tools,
timeout=timeout,
)
elif model.provider_name == LlmProviderType.ANTHROPIC:
model = cast(AnthropicProvider, model)
yield from model.generate_text_stream(
max_tokens=max_tokens,
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
temperature=temperature or default_temperature,
tools=tools,
timeout=timeout,
)
else:
raise ValueError(f"Invalid provider: {model.provider_name}")
except Exception as e:
logger.exception(
f"Text stream generation failed with provider {model.provider_name}: {e}"
)
raise e

@staticmethod
def clean_tool_call_assistant_messages(messages: list[Message]) -> list[Message]:
new_messages = []
Expand Down Expand Up @@ -585,6 +821,21 @@ def clean_message_content(messages: list[Message]) -> list[Message]:
new_messages.append(message)
return new_messages

def construct_message_from_stream(
self,
content_chunks: list[str],
tool_calls: list[ToolCall],
model: LlmProvider,
) -> Message:
if model.provider_name == LlmProviderType.OPENAI:
model = cast(OpenAiProvider, model)
return model.construct_message_from_stream(content_chunks, tool_calls)
elif model.provider_name == LlmProviderType.ANTHROPIC:
model = cast(AnthropicProvider, model)
return model.construct_message_from_stream(content_chunks, tool_calls)
else:
raise ValueError(f"Invalid provider: {model.provider_name}")


@module.provider
def provide_llm_client() -> LlmClient:
Expand Down
54 changes: 53 additions & 1 deletion src/seer/automation/autofix/autofix_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from typing import Optional

from seer.automation.agent.agent import AgentConfig, LlmAgent, RunConfig
from seer.automation.agent.models import Message
from seer.automation.agent.models import (
LlmGenerateTextResponse,
LlmResponseMetadata,
Message,
ToolCall,
Usage,
)
from seer.automation.agent.tools import FunctionTool
from seer.automation.autofix.autofix_context import AutofixContext
from seer.automation.autofix.components.insight_sharing.component import create_insight_output
Expand Down Expand Up @@ -63,6 +69,52 @@ def _check_prompt_for_help(self, run_config: RunConfig):
"You're taking a while. If you need help, ask me a concrete question using the tool provided."
)

def get_completion(self, run_config: RunConfig):
"""
Streams the preliminary output to the current step and only returns when output is complete
"""
content_chunks = []
tool_calls = []
usage = Usage()

stream = self.client.generate_text_stream(
messages=self.memory,
model=run_config.model,
system_prompt=run_config.system_prompt if run_config.system_prompt else None,
tools=(self.tools if len(self.tools) > 0 else None),
temperature=run_config.temperature or 0.0,
)

cleared = False
for chunk in stream:
if isinstance(chunk, str):
with self.context.state.update() as cur:
cur_step = cur.steps[-1]
if not cleared:
cur_step.clear_output_stream()
cleared = True
cur_step.receive_output_stream(chunk)
content_chunks.append(chunk)
elif isinstance(chunk, ToolCall):
tool_calls.append(chunk)
elif isinstance(chunk, Usage):
usage += chunk

message = self.client.construct_message_from_stream(
content_chunks=content_chunks,
tool_calls=tool_calls,
model=run_config.model,
)

return LlmGenerateTextResponse(
message=message,
metadata=LlmResponseMetadata(
model=run_config.model.model_name,
provider_name=run_config.model.provider_name,
usage=usage,
),
)

def run_iteration(self, run_config: RunConfig):
logger.debug(f"----[{self.name}] Running Iteration {self.iterations}----")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def invoke(
for i, item in enumerate(cause_model.code_context):
item.id = i
# Find line range for the snippet
if item.snippet.file_path and item.snippet.snippet:
if item.snippet and item.snippet.file_path and item.snippet.snippet:
try:
file_contents = self.context.get_file_contents(
item.snippet.file_path,
Expand Down
1 change: 1 addition & 0 deletions src/seer/automation/autofix/event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def restart_step(self, step: Step):
cur_step = cur.find_or_add(step)
cur_step.status = AutofixStatus.PROCESSING
cur_step.progress = []
cur_step.clear_output_stream()
cur_step.completedMessage = None # type: ignore[assignment]
cur.status = AutofixStatus.PROCESSING
cur.mark_triggered()
Expand Down
9 changes: 9 additions & 0 deletions src/seer/automation/autofix/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,19 @@ class BaseStep(BaseModel):
completedMessage: Optional[str] = None

queued_user_messages: list[str] = []
output_stream: str | None = None

def receive_user_message(self, message: str):
self.queued_user_messages.append(message)

def receive_output_stream(self, stream_chunk: str):
if self.output_stream is None:
self.output_stream = ""
self.output_stream += stream_chunk

def clear_output_stream(self):
self.output_stream = None

def find_child(self, *, id: str) -> "Step | None":
for step in self.progress:
if isinstance(step, (DefaultStep, RootCauseStep, ChangesStep)) and step.id == id:
Expand Down
Loading

0 comments on commit 9296c53

Please sign in to comment.