-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add openai router ; Chat completion route can be static ('model'-based) ; Update tests * Refactoring: move all pydantic models in one file * Avoid running the real pipeline in tests * Add streaming support * Pass streming_callback as additional input to pipeline component which supports it * Add tests for streaming (single request and concurrent requests) * Add examples ; update test files * Rename run_chat into run_chat_completion * Add return type to run_chat_completion also on base wrapper class
- Loading branch information
1 parent
aec8faf
commit 897632f
Showing
23 changed files
with
765 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
components: | ||
converter: | ||
type: haystack.components.converters.html.HTMLToDocument | ||
init_parameters: | ||
extraction_kwargs: null | ||
|
||
fetcher: | ||
init_parameters: | ||
raise_on_failure: true | ||
retry_attempts: 2 | ||
timeout: 3 | ||
user_agents: | ||
- haystack/LinkContentFetcher/2.0.0b8 | ||
type: haystack.components.fetchers.link_content.LinkContentFetcher | ||
|
||
llm: | ||
init_parameters: | ||
api_base_url: null | ||
api_key: | ||
env_vars: | ||
- OPENAI_API_KEY | ||
strict: true | ||
type: env_var | ||
generation_kwargs: {} | ||
model: gpt-4o-mini | ||
streaming_callback: null | ||
system_prompt: null | ||
type: haystack.components.generators.openai.OpenAIGenerator | ||
|
||
prompt: | ||
init_parameters: | ||
template: | | ||
"According to the contents of this website: | ||
{% for document in documents %} | ||
{{document.content}} | ||
{% endfor %} | ||
Answer the given question: {{query}} | ||
Answer: | ||
" | ||
type: haystack.components.builders.prompt_builder.PromptBuilder | ||
|
||
connections: | ||
- receiver: converter.sources | ||
sender: fetcher.streams | ||
- receiver: prompt.documents | ||
sender: converter.documents | ||
- receiver: llm.prompt | ||
sender: prompt.prompt | ||
|
||
metadata: {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from pathlib import Path | ||
from typing import Generator, List, Union | ||
from haystack import Pipeline | ||
from hayhooks.server.pipelines.utils import get_last_user_message | ||
from hayhooks.server.utils.base_pipeline_wrapper import BasePipelineWrapper | ||
from hayhooks.server.logger import log | ||
|
||
|
||
URLS = ["https://haystack.deepset.ai", "https://www.redis.io", "https://ssi.inc"] | ||
|
||
|
||
class PipelineWrapper(BasePipelineWrapper): | ||
def setup(self) -> None: | ||
pipeline_yaml = (Path(__file__).parent / "chat_with_website.yml").read_text() | ||
self.pipeline = Pipeline.loads(pipeline_yaml) | ||
|
||
def run_api(self, urls: List[str], question: str) -> str: | ||
log.trace(f"Running pipeline with urls: {urls} and question: {question}") | ||
result = self.pipeline.run({"fetcher": {"urls": urls}, "prompt": {"query": question}}) | ||
return result["llm"]["replies"][0] | ||
|
||
def run_chat_completion(self, model: str, messages: List[dict], body: dict) -> Union[str, Generator]: | ||
log.trace(f"Running pipeline with model: {model}, messages: {messages}, body: {body}") | ||
|
||
question = get_last_user_message(messages) | ||
log.trace(f"Question: {question}") | ||
|
||
# Plain pipeline run, will return a string | ||
result = self.pipeline.run({"fetcher": {"urls": URLS}, "prompt": {"query": question}}) | ||
return result["llm"]["replies"][0] |
50 changes: 50 additions & 0 deletions
50
examples/chat_with_website_streaming/chat_with_website.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
components: | ||
converter: | ||
type: haystack.components.converters.html.HTMLToDocument | ||
init_parameters: | ||
extraction_kwargs: null | ||
|
||
fetcher: | ||
init_parameters: | ||
raise_on_failure: true | ||
retry_attempts: 2 | ||
timeout: 3 | ||
user_agents: | ||
- haystack/LinkContentFetcher/2.0.0b8 | ||
type: haystack.components.fetchers.link_content.LinkContentFetcher | ||
|
||
llm: | ||
init_parameters: | ||
api_base_url: null | ||
api_key: | ||
env_vars: | ||
- OPENAI_API_KEY | ||
strict: true | ||
type: env_var | ||
generation_kwargs: {} | ||
model: gpt-4o-mini | ||
streaming_callback: null | ||
system_prompt: null | ||
type: haystack.components.generators.openai.OpenAIGenerator | ||
|
||
prompt: | ||
init_parameters: | ||
template: | | ||
"According to the contents of this website: | ||
{% for document in documents %} | ||
{{document.content}} | ||
{% endfor %} | ||
Answer the given question: {{query}} | ||
Answer: | ||
" | ||
type: haystack.components.builders.prompt_builder.PromptBuilder | ||
|
||
connections: | ||
- receiver: converter.sources | ||
sender: fetcher.streams | ||
- receiver: prompt.documents | ||
sender: converter.documents | ||
- receiver: llm.prompt | ||
sender: prompt.prompt | ||
|
||
metadata: {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from pathlib import Path | ||
from typing import Generator, List, Union | ||
from haystack import Pipeline | ||
from hayhooks.server.pipelines.utils import get_last_user_message, streaming_generator | ||
from hayhooks.server.utils.base_pipeline_wrapper import BasePipelineWrapper | ||
from hayhooks.server.logger import log | ||
|
||
|
||
URLS = ["https://haystack.deepset.ai", "https://www.redis.io", "https://ssi.inc"] | ||
|
||
|
||
class PipelineWrapper(BasePipelineWrapper): | ||
def setup(self) -> None: | ||
pipeline_yaml = (Path(__file__).parent / "chat_with_website.yml").read_text() | ||
self.pipeline = Pipeline.loads(pipeline_yaml) | ||
|
||
def run_api(self, urls: List[str], question: str) -> str: | ||
log.trace(f"Running pipeline with urls: {urls} and question: {question}") | ||
result = self.pipeline.run({"fetcher": {"urls": urls}, "prompt": {"query": question}}) | ||
return result["llm"]["replies"][0] | ||
|
||
def run_chat_completion(self, model: str, messages: List[dict], body: dict) -> Union[str, Generator]: | ||
log.trace(f"Running pipeline with model: {model}, messages: {messages}, body: {body}") | ||
|
||
question = get_last_user_message(messages) | ||
log.trace(f"Question: {question}") | ||
|
||
# Streaming pipeline run, will return a generator | ||
return streaming_generator( | ||
pipeline=self.pipeline, | ||
pipeline_run_args={"fetcher": {"urls": URLS}, "prompt": {"query": question}}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import threading | ||
from queue import Queue | ||
from typing import Generator, List, Union, Dict, Tuple | ||
from haystack import Pipeline | ||
from haystack.core.component import Component | ||
from hayhooks.server.logger import log | ||
from hayhooks.server.routers.openai import Message | ||
|
||
|
||
def is_user_message(msg: Union[Message, Dict]) -> bool: | ||
if isinstance(msg, Message): | ||
return msg.role == "user" | ||
return msg.get("role") == "user" | ||
|
||
|
||
def get_content(msg: Union[Message, Dict]) -> str: | ||
if isinstance(msg, Message): | ||
return msg.content | ||
return msg.get("content") | ||
|
||
|
||
def get_last_user_message(messages: List[Union[Message, Dict]]) -> Union[str, None]: | ||
user_messages = (msg for msg in reversed(messages) if is_user_message(msg)) | ||
|
||
for message in user_messages: | ||
return get_content(message) | ||
|
||
return None | ||
|
||
|
||
def find_streaming_component(pipeline) -> Tuple[Component, str]: | ||
""" | ||
Finds the component in the pipeline that supports streaming_callback | ||
Returns: | ||
The first component that supports streaming | ||
""" | ||
streaming_component = None | ||
streaming_component_name = None | ||
|
||
for name, component in pipeline.walk(): | ||
if hasattr(component, "streaming_callback"): | ||
log.trace(f"Streaming component found in '{name}' with type {type(component)}") | ||
streaming_component = component | ||
streaming_component_name = name | ||
if not streaming_component: | ||
raise ValueError("No streaming-capable component found in the pipeline") | ||
|
||
return streaming_component, streaming_component_name | ||
|
||
|
||
def streaming_generator(pipeline: Pipeline, pipeline_run_args: Dict) -> Generator: | ||
""" | ||
Creates a generator that yields streaming chunks from a pipeline execution. | ||
Automatically finds the streaming-capable component in the pipeline. | ||
""" | ||
queue = Queue() | ||
|
||
def streaming_callback(chunk): | ||
queue.put(chunk.content) | ||
|
||
_, streaming_component_name = find_streaming_component(pipeline) | ||
pipeline_run_args = pipeline_run_args.copy() | ||
|
||
if streaming_component_name not in pipeline_run_args: | ||
pipeline_run_args[streaming_component_name] = {} | ||
|
||
pipeline_run_args[streaming_component_name]["streaming_callback"] = streaming_callback | ||
log.trace(f"Streaming pipeline run args: {pipeline_run_args}") | ||
|
||
def run_pipeline(): | ||
try: | ||
pipeline.run(pipeline_run_args) | ||
finally: | ||
queue.put(None) | ||
|
||
thread = threading.Thread(target=run_pipeline) | ||
thread.start() | ||
|
||
while True: | ||
chunk = queue.get() | ||
if chunk is None: | ||
break | ||
yield chunk | ||
|
||
thread.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.