From 897632feba2c65b386745d9e72d9f3b882eaa50d Mon Sep 17 00:00:00 2001 From: Michele Pangrazzi Date: Thu, 30 Jan 2025 15:15:30 +0100 Subject: [PATCH] Add OpenAI compatibilty (#60) * Add openai router ; Chat completion route can be static ('model'-based) ; Update tests * Refactoring: move all pydantic models in one file * Avoid running the real pipeline in tests * Add streaming support * Pass streming_callback as additional input to pipeline component which supports it * Add tests for streaming (single request and concurrent requests) * Add examples ; update test files * Rename run_chat into run_chat_completion * Add return type to run_chat_completion also on base wrapper class --- .../chat_with_website/chat_with_website.yml | 50 +++++ .../chat_with_website/pipeline_wrapper.py | 30 +++ .../chat_with_website.yml | 50 +++++ .../pipeline_wrapper.py | 32 +++ src/hayhooks/server/app.py | 3 +- src/hayhooks/server/pipelines/utils.py | 86 ++++++++ src/hayhooks/server/routers/__init__.py | 3 +- src/hayhooks/server/routers/openai.py | 156 ++++++++++++++ .../server/utils/base_pipeline_wrapper.py | 8 +- src/hayhooks/server/utils/deploy_utils.py | 60 ++---- tests/conftest.py | 8 + tests/test_deploy_at_startup.py | 2 +- tests/test_deploy_utils.py | 8 +- .../chat_with_website/pipeline_wrapper.py | 21 +- .../chat_with_website.yml | 50 +++++ .../pipeline_wrapper.py | 40 ++++ .../files/no_chat/pipeline_wrapper.py | 9 + .../files/setup_error/pipeline_wrapper.py | 2 +- .../chat_with_website/pipeline_wrapper.py | 11 +- tests/test_it_deploy_files.py | 2 +- tests/test_it_draw.py | 2 +- tests/test_it_openai.py | 204 ++++++++++++++++++ tests/test_registry.py | 2 +- 23 files changed, 765 insertions(+), 74 deletions(-) create mode 100644 examples/chat_with_website/chat_with_website.yml create mode 100644 examples/chat_with_website/pipeline_wrapper.py create mode 100644 examples/chat_with_website_streaming/chat_with_website.yml create mode 100644 examples/chat_with_website_streaming/pipeline_wrapper.py create mode 100644 src/hayhooks/server/pipelines/utils.py create mode 100644 src/hayhooks/server/routers/openai.py create mode 100644 tests/test_files/files/chat_with_website_streaming/chat_with_website.yml create mode 100644 tests/test_files/files/chat_with_website_streaming/pipeline_wrapper.py create mode 100644 tests/test_files/files/no_chat/pipeline_wrapper.py create mode 100644 tests/test_it_openai.py diff --git a/examples/chat_with_website/chat_with_website.yml b/examples/chat_with_website/chat_with_website.yml new file mode 100644 index 0000000..db4063f --- /dev/null +++ b/examples/chat_with_website/chat_with_website.yml @@ -0,0 +1,50 @@ +components: + converter: + type: haystack.components.converters.html.HTMLToDocument + init_parameters: + extraction_kwargs: null + + fetcher: + init_parameters: + raise_on_failure: true + retry_attempts: 2 + timeout: 3 + user_agents: + - haystack/LinkContentFetcher/2.0.0b8 + type: haystack.components.fetchers.link_content.LinkContentFetcher + + llm: + init_parameters: + api_base_url: null + api_key: + env_vars: + - OPENAI_API_KEY + strict: true + type: env_var + generation_kwargs: {} + model: gpt-4o-mini + streaming_callback: null + system_prompt: null + type: haystack.components.generators.openai.OpenAIGenerator + + prompt: + init_parameters: + template: | + "According to the contents of this website: + {% for document in documents %} + {{document.content}} + {% endfor %} + Answer the given question: {{query}} + Answer: + " + type: haystack.components.builders.prompt_builder.PromptBuilder + +connections: + - receiver: converter.sources + sender: fetcher.streams + - receiver: prompt.documents + sender: converter.documents + - receiver: llm.prompt + sender: prompt.prompt + +metadata: {} diff --git a/examples/chat_with_website/pipeline_wrapper.py b/examples/chat_with_website/pipeline_wrapper.py new file mode 100644 index 0000000..ff56530 --- /dev/null +++ b/examples/chat_with_website/pipeline_wrapper.py @@ -0,0 +1,30 @@ +from pathlib import Path +from typing import Generator, List, Union +from haystack import Pipeline +from hayhooks.server.pipelines.utils import get_last_user_message +from hayhooks.server.utils.base_pipeline_wrapper import BasePipelineWrapper +from hayhooks.server.logger import log + + +URLS = ["https://haystack.deepset.ai", "https://www.redis.io", "https://ssi.inc"] + + +class PipelineWrapper(BasePipelineWrapper): + def setup(self) -> None: + pipeline_yaml = (Path(__file__).parent / "chat_with_website.yml").read_text() + self.pipeline = Pipeline.loads(pipeline_yaml) + + def run_api(self, urls: List[str], question: str) -> str: + log.trace(f"Running pipeline with urls: {urls} and question: {question}") + result = self.pipeline.run({"fetcher": {"urls": urls}, "prompt": {"query": question}}) + return result["llm"]["replies"][0] + + def run_chat_completion(self, model: str, messages: List[dict], body: dict) -> Union[str, Generator]: + log.trace(f"Running pipeline with model: {model}, messages: {messages}, body: {body}") + + question = get_last_user_message(messages) + log.trace(f"Question: {question}") + + # Plain pipeline run, will return a string + result = self.pipeline.run({"fetcher": {"urls": URLS}, "prompt": {"query": question}}) + return result["llm"]["replies"][0] diff --git a/examples/chat_with_website_streaming/chat_with_website.yml b/examples/chat_with_website_streaming/chat_with_website.yml new file mode 100644 index 0000000..db4063f --- /dev/null +++ b/examples/chat_with_website_streaming/chat_with_website.yml @@ -0,0 +1,50 @@ +components: + converter: + type: haystack.components.converters.html.HTMLToDocument + init_parameters: + extraction_kwargs: null + + fetcher: + init_parameters: + raise_on_failure: true + retry_attempts: 2 + timeout: 3 + user_agents: + - haystack/LinkContentFetcher/2.0.0b8 + type: haystack.components.fetchers.link_content.LinkContentFetcher + + llm: + init_parameters: + api_base_url: null + api_key: + env_vars: + - OPENAI_API_KEY + strict: true + type: env_var + generation_kwargs: {} + model: gpt-4o-mini + streaming_callback: null + system_prompt: null + type: haystack.components.generators.openai.OpenAIGenerator + + prompt: + init_parameters: + template: | + "According to the contents of this website: + {% for document in documents %} + {{document.content}} + {% endfor %} + Answer the given question: {{query}} + Answer: + " + type: haystack.components.builders.prompt_builder.PromptBuilder + +connections: + - receiver: converter.sources + sender: fetcher.streams + - receiver: prompt.documents + sender: converter.documents + - receiver: llm.prompt + sender: prompt.prompt + +metadata: {} diff --git a/examples/chat_with_website_streaming/pipeline_wrapper.py b/examples/chat_with_website_streaming/pipeline_wrapper.py new file mode 100644 index 0000000..0cc44b6 --- /dev/null +++ b/examples/chat_with_website_streaming/pipeline_wrapper.py @@ -0,0 +1,32 @@ +from pathlib import Path +from typing import Generator, List, Union +from haystack import Pipeline +from hayhooks.server.pipelines.utils import get_last_user_message, streaming_generator +from hayhooks.server.utils.base_pipeline_wrapper import BasePipelineWrapper +from hayhooks.server.logger import log + + +URLS = ["https://haystack.deepset.ai", "https://www.redis.io", "https://ssi.inc"] + + +class PipelineWrapper(BasePipelineWrapper): + def setup(self) -> None: + pipeline_yaml = (Path(__file__).parent / "chat_with_website.yml").read_text() + self.pipeline = Pipeline.loads(pipeline_yaml) + + def run_api(self, urls: List[str], question: str) -> str: + log.trace(f"Running pipeline with urls: {urls} and question: {question}") + result = self.pipeline.run({"fetcher": {"urls": urls}, "prompt": {"query": question}}) + return result["llm"]["replies"][0] + + def run_chat_completion(self, model: str, messages: List[dict], body: dict) -> Union[str, Generator]: + log.trace(f"Running pipeline with model: {model}, messages: {messages}, body: {body}") + + question = get_last_user_message(messages) + log.trace(f"Question: {question}") + + # Streaming pipeline run, will return a generator + return streaming_generator( + pipeline=self.pipeline, + pipeline_run_args={"fetcher": {"urls": URLS}, "prompt": {"query": question}}, + ) diff --git a/src/hayhooks/server/app.py b/src/hayhooks/server/app.py index 44901b9..621eed5 100644 --- a/src/hayhooks/server/app.py +++ b/src/hayhooks/server/app.py @@ -6,7 +6,7 @@ deploy_pipeline_files, read_pipeline_files_from_folder, ) -from hayhooks.server.routers import status_router, draw_router, deploy_router, undeploy_router +from hayhooks.server.routers import status_router, draw_router, deploy_router, undeploy_router, openai_router from hayhooks.settings import settings from hayhooks.server.logger import log @@ -77,6 +77,7 @@ def create_app() -> FastAPI: app.include_router(draw_router) app.include_router(deploy_router) app.include_router(undeploy_router) + app.include_router(openai_router) # Deploy all pipelines in the pipelines directory pipelines_dir = settings.pipelines_dir diff --git a/src/hayhooks/server/pipelines/utils.py b/src/hayhooks/server/pipelines/utils.py new file mode 100644 index 0000000..8924d96 --- /dev/null +++ b/src/hayhooks/server/pipelines/utils.py @@ -0,0 +1,86 @@ +import threading +from queue import Queue +from typing import Generator, List, Union, Dict, Tuple +from haystack import Pipeline +from haystack.core.component import Component +from hayhooks.server.logger import log +from hayhooks.server.routers.openai import Message + + +def is_user_message(msg: Union[Message, Dict]) -> bool: + if isinstance(msg, Message): + return msg.role == "user" + return msg.get("role") == "user" + + +def get_content(msg: Union[Message, Dict]) -> str: + if isinstance(msg, Message): + return msg.content + return msg.get("content") + + +def get_last_user_message(messages: List[Union[Message, Dict]]) -> Union[str, None]: + user_messages = (msg for msg in reversed(messages) if is_user_message(msg)) + + for message in user_messages: + return get_content(message) + + return None + + +def find_streaming_component(pipeline) -> Tuple[Component, str]: + """ + Finds the component in the pipeline that supports streaming_callback + + Returns: + The first component that supports streaming + """ + streaming_component = None + streaming_component_name = None + + for name, component in pipeline.walk(): + if hasattr(component, "streaming_callback"): + log.trace(f"Streaming component found in '{name}' with type {type(component)}") + streaming_component = component + streaming_component_name = name + if not streaming_component: + raise ValueError("No streaming-capable component found in the pipeline") + + return streaming_component, streaming_component_name + + +def streaming_generator(pipeline: Pipeline, pipeline_run_args: Dict) -> Generator: + """ + Creates a generator that yields streaming chunks from a pipeline execution. + Automatically finds the streaming-capable component in the pipeline. + """ + queue = Queue() + + def streaming_callback(chunk): + queue.put(chunk.content) + + _, streaming_component_name = find_streaming_component(pipeline) + pipeline_run_args = pipeline_run_args.copy() + + if streaming_component_name not in pipeline_run_args: + pipeline_run_args[streaming_component_name] = {} + + pipeline_run_args[streaming_component_name]["streaming_callback"] = streaming_callback + log.trace(f"Streaming pipeline run args: {pipeline_run_args}") + + def run_pipeline(): + try: + pipeline.run(pipeline_run_args) + finally: + queue.put(None) + + thread = threading.Thread(target=run_pipeline) + thread.start() + + while True: + chunk = queue.get() + if chunk is None: + break + yield chunk + + thread.join() diff --git a/src/hayhooks/server/routers/__init__.py b/src/hayhooks/server/routers/__init__.py index 880dc49..49931a9 100644 --- a/src/hayhooks/server/routers/__init__.py +++ b/src/hayhooks/server/routers/__init__.py @@ -2,5 +2,6 @@ from hayhooks.server.routers.draw import router as draw_router from hayhooks.server.routers.deploy import router as deploy_router from hayhooks.server.routers.undeploy import router as undeploy_router +from hayhooks.server.routers.openai import router as openai_router -__all__ = ['status_router', 'draw_router', 'deploy_router', 'undeploy_router'] +__all__ = ['status_router', 'draw_router', 'deploy_router', 'undeploy_router', 'openai_router'] diff --git a/src/hayhooks/server/routers/openai.py b/src/hayhooks/server/routers/openai.py new file mode 100644 index 0000000..e981844 --- /dev/null +++ b/src/hayhooks/server/routers/openai.py @@ -0,0 +1,156 @@ +import time +import uuid +from typing import Generator, List, Literal, Union +from fastapi import APIRouter, HTTPException +from fastapi.concurrency import run_in_threadpool +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, ConfigDict +from hayhooks.server.pipelines import registry +from hayhooks.server.utils.deploy_utils import handle_pipeline_exceptions +from hayhooks.server.logger import log + +router = APIRouter() + + +class ModelObject(BaseModel): + id: str + name: str + object: Literal["model"] + created: int + owned_by: str + + +class ModelsResponse(BaseModel): + data: List[ModelObject] + object: Literal["list"] + + +class OpenAIBaseModel(BaseModel): + model_config = ConfigDict(extra="allow") + + +class ChatRequest(OpenAIBaseModel): + model: str + messages: List[dict] + stream: bool = False + + +class Message(OpenAIBaseModel): + role: Literal["user", "assistant"] + content: str + + +class Choice(OpenAIBaseModel): + index: int + delta: Union[Message, None] = None + finish_reason: Union[Literal["stop"], None] = None + logprobs: Union[None, dict] = None + message: Union[Message, None] = None + + +class ChatCompletion(OpenAIBaseModel): + id: str + object: Union[Literal["chat.completion"], Literal["chat.completion.chunk"]] + created: int + model: str + choices: List[Choice] + + +@router.get("/v1/models", response_model=ModelsResponse) +@router.get("/models", response_model=ModelsResponse) +async def get_models(): + """ + Implementation of OpenAI /models endpoint. + + Here we list all hayhooks pipelines (using `name` field). + They will appear as selectable models in `open-webui` frontend. + + References: + - https://github.com/ollama/ollama/blob/main/docs/openai.md + - https://platform.openai.com/docs/api-reference/models/list + """ + pipelines = registry.get_names() + + return ModelsResponse( + data=[ + ModelObject( + id=pipeline_name, + name=pipeline_name, + object="model", + created=int(time.time()), + owned_by="hayhooks", + ) + for pipeline_name in pipelines + ], + object="list", + ) + + +@router.post("/v1/chat/completions", response_model=ChatCompletion) +@router.post("/chat/completions", response_model=ChatCompletion) +@router.post("/{pipeline_name}/chat", response_model=ChatCompletion) +@handle_pipeline_exceptions() +async def chat_endpoint(chat_req: ChatRequest) -> ChatCompletion: + pipeline_wrapper = registry.get(chat_req.model) + + if not pipeline_wrapper: + raise HTTPException(status_code=404, detail=f"Pipeline '{chat_req.model}' not found") + + if not pipeline_wrapper._is_run_chat_completion_implemented: + raise HTTPException(status_code=501, detail="Chat endpoint not implemented for this model") + + result = await run_in_threadpool( + pipeline_wrapper.run_chat_completion, + model=chat_req.model, + messages=chat_req.messages, + body=chat_req.model_dump(), + ) + + resp_id = f"{chat_req.model}-{uuid.uuid4()}" + + if isinstance(result, str): + # If the pipeline returns a string, we can directly return a ChatCompletion object + + resp = ChatCompletion( + id=resp_id, + object="chat.completion", + created=int(time.time()), + model=chat_req.model, + choices=[Choice(index=0, message=Message(role="assistant", content=result), finish_reason="stop")], + ) + + log.debug(f"resp: {resp.model_dump_json()}") + return resp + + elif isinstance(result, Generator): + # If the pipeline returns a generator, we need to stream the chunks as SSE events + + def stream_chunks() -> Generator: + # Consume the input generator sending chunks as SSE events + for chunk in result: + resp = ChatCompletion( + id=resp_id, + object="chat.completion.chunk", + created=int(time.time()), + model=chat_req.model, + choices=[Choice(index=0, delta=Message(role="assistant", content=chunk), finish_reason=None)], + ) + + # This is the format for SSE + # Ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events + yield f"data: {resp.model_dump_json()}\n\n" + + # After consuming the generator, send a final event with finish_reason "stop" + final_resp = ChatCompletion( + id=resp_id, + object="chat.completion.chunk", + created=int(time.time()), + model=chat_req.model, + choices=[Choice(index=0, finish_reason="stop")], + ) + yield f"data: {final_resp.model_dump_json()}\n\n" + + return StreamingResponse(stream_chunks(), media_type="text/event-stream") + + else: + raise HTTPException(status_code=500, detail="Unsupported response type from pipeline") diff --git a/src/hayhooks/server/utils/base_pipeline_wrapper.py b/src/hayhooks/server/utils/base_pipeline_wrapper.py index 1608fad..967f89e 100644 --- a/src/hayhooks/server/utils/base_pipeline_wrapper.py +++ b/src/hayhooks/server/utils/base_pipeline_wrapper.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +from typing import Generator, List, Union class BasePipelineWrapper(ABC): @@ -32,7 +32,7 @@ def run_api(self): """ raise NotImplementedError("run_api not implemented") - def run_chat(self, model_id: str, messages: List[dict], body: dict): + def run_chat_completion(self, model: str, messages: List[dict], body: dict) -> Union[str, Generator]: """ This method is called when a user sends an OpenAI-compatible chat completion request. @@ -42,8 +42,8 @@ def run_chat(self, model_id: str, messages: List[dict], body: dict): This method will be used as the handler for the `/chat` API endpoint. Args: - model_id: The `name` of the deployed Haystack pipeline to run + model: The `name` of the deployed Haystack pipeline to run messages: The history of messages as OpenAI-compatible list of dicts body: Additional parameters and configuration options """ - raise NotImplementedError("run_chat not implemented") + raise NotImplementedError("run_chat_completion not implemented") diff --git a/src/hayhooks/server/utils/deploy_utils.py b/src/hayhooks/server/utils/deploy_utils.py index ae4ce09..250be44 100644 --- a/src/hayhooks/server/utils/deploy_utils.py +++ b/src/hayhooks/server/utils/deploy_utils.py @@ -1,8 +1,9 @@ import inspect import importlib.util from functools import wraps +import traceback from types import ModuleType -from typing import Callable, Union, List +from typing import Callable, Union from fastapi import FastAPI, HTTPException from fastapi.concurrency import run_in_threadpool from fastapi.responses import JSONResponse @@ -23,18 +24,7 @@ from hayhooks.server.logger import log from hayhooks.server.utils.base_pipeline_wrapper import BasePipelineWrapper from hayhooks.settings import settings -from pydantic import BaseModel, create_model - - -class ChatRequest(BaseModel): - user_message: str - model_id: str - messages: List[dict] - body: dict - - -class ChatResponse(BaseModel): - result: dict +from pydantic import create_model def deploy_pipeline_def(app, pipeline_def: PipelineDefinition): @@ -192,9 +182,11 @@ def decorator(func): async def wrapper(*args, **kwargs): try: return await func(*args, **kwargs) + except HTTPException as e: + raise e from e except Exception as e: - log.error(f"Pipeline execution error: {str(e)}") - raise HTTPException(status_code=500, detail=f"Pipeline execution failed: {str(e)}") + log.error(f"Pipeline execution error: {str(e)} - {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Pipeline execution failed: {str(e)}") from e return wrapper @@ -245,10 +237,9 @@ def deploy_pipeline_files(app: FastAPI, pipeline_name: str, files: dict[str, str RunResponse = create_response_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run') @handle_pipeline_exceptions() - # See comment on pipeline_run() for explanation of the "type: ignore" below - async def run_endpoint(run_req: RunRequest) -> JSONResponse: # type: ignore + async def run_endpoint(run_req: RunRequest) -> RunResponse: # type: ignore result = await run_in_threadpool(pipeline_wrapper.run_api, urls=run_req.urls, question=run_req.question) - return JSONResponse({"result": result}, status_code=200) + return RunResponse(result=result) app.add_api_route( path=f"/{pipeline_name}/run", @@ -259,35 +250,11 @@ async def run_endpoint(run_req: RunRequest) -> JSONResponse: # type: ignore tags=["pipelines"], ) - if pipeline_wrapper._is_run_chat_implemented: - clog.debug("Creating dynamic Pydantic models for run_chat") - - @handle_pipeline_exceptions() - async def chat_endpoint(chat_req: ChatRequest) -> JSONResponse: - result = await run_in_threadpool( - pipeline_wrapper.run_chat, - user_message=chat_req.user_message, - model_id=chat_req.model_id, - messages=chat_req.messages, - body=chat_req.body, - ) - return JSONResponse({"result": result}, status_code=200) - - app.add_api_route( - path=f"/{pipeline_name}/chat", - endpoint=chat_endpoint, - methods=["POST"], - name=f"{pipeline_name}_chat", - response_model=ChatResponse, - tags=["pipelines"], - ) - clog.debug("Setting up FastAPI app") app.openapi_schema = None app.setup() clog.success("Pipeline deployment complete") - return {"name": pipeline_name} @@ -303,10 +270,13 @@ def create_pipeline_wrapper_instance(pipeline_module: ModuleType) -> BasePipelin raise PipelineWrapperError(f"Failed to call setup() on pipeline wrapper instance: {str(e)}") from e pipeline_wrapper._is_run_api_implemented = pipeline_wrapper.run_api.__func__ is not BasePipelineWrapper.run_api - pipeline_wrapper._is_run_chat_implemented = pipeline_wrapper.run_chat.__func__ is not BasePipelineWrapper.run_chat + pipeline_wrapper._is_run_chat_completion_implemented = pipeline_wrapper.run_chat_completion.__func__ is not BasePipelineWrapper.run_chat_completion + + log.debug(f"pipeline_wrapper._is_run_api_implemented: {pipeline_wrapper._is_run_api_implemented}") + log.debug(f"pipeline_wrapper._is_run_chat_completion_implemented: {pipeline_wrapper._is_run_chat_completion_implemented}") - if not (pipeline_wrapper._is_run_api_implemented or pipeline_wrapper._is_run_chat_implemented): - raise PipelineWrapperError("At least one of run_api or run_chat must be implemented") + if not (pipeline_wrapper._is_run_api_implemented or pipeline_wrapper._is_run_chat_completion_implemented): + raise PipelineWrapperError("At least one of run_api or run_chat_completion must be implemented") return pipeline_wrapper diff --git a/tests/conftest.py b/tests/conftest.py index e71d4ab..36666ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,3 +32,11 @@ def _status_pipeline(client: TestClient, pipeline_name: str): status_response = client.get(f"/status/{pipeline_name}") return status_response return _status_pipeline + + +@pytest.fixture +def deploy_files(): + def _deploy_files(client: TestClient, pipeline_name: str, pipeline_files: dict): + deploy_response = client.post("/deploy_files", json={"name": pipeline_name, "files": pipeline_files}) + return deploy_response + return _deploy_files diff --git a/tests/test_deploy_at_startup.py b/tests/test_deploy_at_startup.py index aaadbea..5b5b961 100644 --- a/tests/test_deploy_at_startup.py +++ b/tests/test_deploy_at_startup.py @@ -68,7 +68,7 @@ def test_app_loads_pipeline_from_files_directory(test_client_files, test_files_p assert response.status_code == 200 pipelines = response.json()["pipelines"] - assert len(pipelines) == 1 # only one pipeline should be loaded + assert len(pipelines) >= 1 # at least one pipeline should be loaded assert "chat_with_website" in pipelines diff --git a/tests/test_deploy_utils.py b/tests/test_deploy_utils.py index ca30512..a39019e 100644 --- a/tests/test_deploy_utils.py +++ b/tests/test_deploy_utils.py @@ -37,7 +37,7 @@ def test_load_pipeline_module(): assert module is not None assert hasattr(module, "PipelineWrapper") assert isinstance(getattr(module.PipelineWrapper, "run_api"), Callable) - assert isinstance(getattr(module.PipelineWrapper, "run_chat"), Callable) + assert isinstance(getattr(module.PipelineWrapper, "run_chat_completion"), Callable) assert isinstance(getattr(module.PipelineWrapper, "setup"), Callable) @@ -136,7 +136,7 @@ def setup(self): def run_api(self): pass - def run_chat(self, model_id, messages, body): + def run_chat_completion(self, model, messages, body): pass module = type('Module', (), {'PipelineWrapper': ValidPipelineWrapper}) @@ -144,7 +144,7 @@ def run_chat(self, model_id, messages, body): wrapper = create_pipeline_wrapper_instance(module) assert isinstance(wrapper, BasePipelineWrapper) assert hasattr(wrapper, 'run_api') - assert hasattr(wrapper, 'run_chat') + assert hasattr(wrapper, 'run_chat_completion') assert isinstance(wrapper.pipeline, Pipeline) @@ -182,5 +182,5 @@ def setup(self): module = type('Module', (), {'PipelineWrapper': IncompleteWrapper}) - with pytest.raises(PipelineWrapperError, match="At least one of run_api or run_chat must be implemented"): + with pytest.raises(PipelineWrapperError, match="At least one of run_api or run_chat_completion must be implemented"): create_pipeline_wrapper_instance(module) diff --git a/tests/test_files/files/chat_with_website/pipeline_wrapper.py b/tests/test_files/files/chat_with_website/pipeline_wrapper.py index 43ed503..60ea622 100644 --- a/tests/test_files/files/chat_with_website/pipeline_wrapper.py +++ b/tests/test_files/files/chat_with_website/pipeline_wrapper.py @@ -1,11 +1,12 @@ from pathlib import Path -from typing import List +from typing import Generator, List, Union from haystack import Pipeline +from hayhooks.server.pipelines.utils import get_last_user_message from hayhooks.server.utils.base_pipeline_wrapper import BasePipelineWrapper from hayhooks.server.logger import log -URLS = ["https://haystack.deepset.ai", "https://www.redis.io"] +URLS = ["https://haystack.deepset.ai", "https://www.redis.io", "https://ssi.inc"] class PipelineWrapper(BasePipelineWrapper): @@ -18,10 +19,12 @@ def run_api(self, urls: List[str], question: str) -> str: result = self.pipeline.run({"fetcher": {"urls": urls}, "prompt": {"query": question}}) return result["llm"]["replies"][0] - def run_chat(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> str: - log.trace( - f"Running pipeline with user_message: {user_message}, model_id: {model_id}, messages: {messages}, body: {body}" - ) - question = user_message - result = self.pipeline.run({"fetcher": {"urls": URLS}, "prompt": {"query": question}}) - return result["llm"]["replies"][0] + def run_chat_completion(self, model: str, messages: List[dict], body: dict) -> Union[str, Generator]: + log.trace(f"Running pipeline with model: {model}, messages: {messages}, body: {body}") + + question = get_last_user_message(messages) + log.trace(f"Question: {question}") + + # Mock streaming pipeline run, will return a fixed string + # NOTE: This is used in tests, please don't change it + return "This is a mock response from the pipeline" diff --git a/tests/test_files/files/chat_with_website_streaming/chat_with_website.yml b/tests/test_files/files/chat_with_website_streaming/chat_with_website.yml new file mode 100644 index 0000000..db4063f --- /dev/null +++ b/tests/test_files/files/chat_with_website_streaming/chat_with_website.yml @@ -0,0 +1,50 @@ +components: + converter: + type: haystack.components.converters.html.HTMLToDocument + init_parameters: + extraction_kwargs: null + + fetcher: + init_parameters: + raise_on_failure: true + retry_attempts: 2 + timeout: 3 + user_agents: + - haystack/LinkContentFetcher/2.0.0b8 + type: haystack.components.fetchers.link_content.LinkContentFetcher + + llm: + init_parameters: + api_base_url: null + api_key: + env_vars: + - OPENAI_API_KEY + strict: true + type: env_var + generation_kwargs: {} + model: gpt-4o-mini + streaming_callback: null + system_prompt: null + type: haystack.components.generators.openai.OpenAIGenerator + + prompt: + init_parameters: + template: | + "According to the contents of this website: + {% for document in documents %} + {{document.content}} + {% endfor %} + Answer the given question: {{query}} + Answer: + " + type: haystack.components.builders.prompt_builder.PromptBuilder + +connections: + - receiver: converter.sources + sender: fetcher.streams + - receiver: prompt.documents + sender: converter.documents + - receiver: llm.prompt + sender: prompt.prompt + +metadata: {} diff --git a/tests/test_files/files/chat_with_website_streaming/pipeline_wrapper.py b/tests/test_files/files/chat_with_website_streaming/pipeline_wrapper.py new file mode 100644 index 0000000..d178f96 --- /dev/null +++ b/tests/test_files/files/chat_with_website_streaming/pipeline_wrapper.py @@ -0,0 +1,40 @@ +from pathlib import Path +from pprint import pprint +from typing import Generator, List, Union +from haystack import Pipeline +from hayhooks.server.pipelines.utils import get_last_user_message, streaming_generator +from hayhooks.server.utils.base_pipeline_wrapper import BasePipelineWrapper +from hayhooks.server.logger import log + + +URLS = ["https://haystack.deepset.ai", "https://www.redis.io", "https://ssi.inc"] + + +class PipelineWrapper(BasePipelineWrapper): + def setup(self) -> None: + pipeline_yaml = (Path(__file__).parent / "chat_with_website.yml").read_text() + self.pipeline = Pipeline.loads(pipeline_yaml) + + def run_api(self, urls: List[str], question: str) -> str: + log.trace(f"Running pipeline with urls: {urls} and question: {question}") + result = self.pipeline.run({"fetcher": {"urls": urls}, "prompt": {"query": question}}) + return result["llm"]["replies"][0] + + def run_chat_completion(self, model: str, messages: List[dict], body: dict) -> Union[str, Generator]: + log.trace(f"Running pipeline with model: {model}, messages: {messages}, body: {body}") + + question = get_last_user_message(messages) + log.trace(f"Question: {question}") + + # Mock streaming pipeline run, will return a fixed string + # NOTE: This is used in tests, please don't change it + if "Redis" in question: + mock_response = "Redis is an in-memory data structure store, used as a database, cache and message broker." + else: + mock_response = "This is a mock response from the pipeline" + + def mock_generator(): + for word in mock_response.split(): + yield word + " " + + return mock_generator() diff --git a/tests/test_files/files/no_chat/pipeline_wrapper.py b/tests/test_files/files/no_chat/pipeline_wrapper.py new file mode 100644 index 0000000..4846814 --- /dev/null +++ b/tests/test_files/files/no_chat/pipeline_wrapper.py @@ -0,0 +1,9 @@ +from haystack import Pipeline +from hayhooks.server.utils.base_pipeline_wrapper import BasePipelineWrapper + +class PipelineWrapper(BasePipelineWrapper): + def setup(self): + self.pipeline = Pipeline() + + def run_api(self) -> dict: + return {"result": "Dummy result"} diff --git a/tests/test_files/files/setup_error/pipeline_wrapper.py b/tests/test_files/files/setup_error/pipeline_wrapper.py index f0cb2b0..cbf53e9 100644 --- a/tests/test_files/files/setup_error/pipeline_wrapper.py +++ b/tests/test_files/files/setup_error/pipeline_wrapper.py @@ -4,5 +4,5 @@ class PipelineWrapper(BasePipelineWrapper): def setup(self): raise ValueError("Setup failed!") - def run_api(self): + def run_api(self) -> dict: return {"result": "This should never be reached"} diff --git a/tests/test_files/mixed/chat_with_website/pipeline_wrapper.py b/tests/test_files/mixed/chat_with_website/pipeline_wrapper.py index 43ed503..f6014ed 100644 --- a/tests/test_files/mixed/chat_with_website/pipeline_wrapper.py +++ b/tests/test_files/mixed/chat_with_website/pipeline_wrapper.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import List from haystack import Pipeline +from hayhooks.server.pipelines.utils import get_last_user_message from hayhooks.server.utils.base_pipeline_wrapper import BasePipelineWrapper from hayhooks.server.logger import log @@ -18,10 +19,10 @@ def run_api(self, urls: List[str], question: str) -> str: result = self.pipeline.run({"fetcher": {"urls": urls}, "prompt": {"query": question}}) return result["llm"]["replies"][0] - def run_chat(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> str: - log.trace( - f"Running pipeline with user_message: {user_message}, model_id: {model_id}, messages: {messages}, body: {body}" - ) - question = user_message + def run_chat_completion(self, model: str, messages: List[dict], body: dict) -> str: + log.trace(f"Running pipeline with model: {model}, messages: {messages}, body: {body}") + + question = get_last_user_message(messages) result = self.pipeline.run({"fetcher": {"urls": URLS}, "prompt": {"query": question}}) + return result["llm"]["replies"][0] diff --git a/tests/test_it_deploy_files.py b/tests/test_it_deploy_files.py index c1d9386..87dc49a 100644 --- a/tests/test_it_deploy_files.py +++ b/tests/test_it_deploy_files.py @@ -107,7 +107,7 @@ def test_deploy_files_missing_required_methods(): response = client.post("/deploy_files", json={"name": "test_pipeline", "files": invalid_files}) print(response.json()) assert response.status_code == 422 - assert "At least one of run_api or run_chat must be implemented" in response.json()["detail"] + assert "At least one of run_api or run_chat_completion must be implemented" in response.json()["detail"] def test_deploy_files_setup_error(): diff --git a/tests/test_it_draw.py b/tests/test_it_draw.py index b297837..f51ab9f 100644 --- a/tests/test_it_draw.py +++ b/tests/test_it_draw.py @@ -36,7 +36,7 @@ def setup(self) -> None: def run_api(self, urls: List[str], question: str) -> dict: return {} - def run_chat(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> dict: + def run_chat_completion(self, model: str, messages: List[dict], body: dict) -> dict: return {} wrapper = TestPipelineWrapper() diff --git a/tests/test_it_openai.py b/tests/test_it_openai.py new file mode 100644 index 0000000..31e25b8 --- /dev/null +++ b/tests/test_it_openai.py @@ -0,0 +1,204 @@ +from concurrent.futures import ThreadPoolExecutor +import json +import shutil +import pytest +from hayhooks.settings import settings +from pathlib import Path +from fastapi.testclient import TestClient +from hayhooks.server import app +from hayhooks.server.pipelines import registry +from hayhooks.server.routers.openai import ChatRequest, ChatCompletion, ModelObject, ModelsResponse + +client = TestClient(app) + + +def cleanup(): + registry.clear() + if Path(settings.pipelines_dir).exists(): + shutil.rmtree(settings.pipelines_dir) + + +@pytest.fixture(autouse=True) +def clear_registry(): + cleanup() + yield + + +@pytest.fixture(scope="session", autouse=True) +def final_cleanup(): + yield + cleanup() + + +def collect_chunks(response): + chunks = [] + for event in response.iter_lines(): + if event: + chunks.append(event) + return chunks + + +TEST_FILES_DIR = Path(__file__).parent / "test_files/files/chat_with_website" +SAMPLE_PIPELINE_FILES = { + "pipeline_wrapper.py": (TEST_FILES_DIR / "pipeline_wrapper.py").read_text(), + "chat_with_website.yml": (TEST_FILES_DIR / "chat_with_website.yml").read_text(), +} + +TEST_FILES_DIR_STREAMING = Path(__file__).parent / "test_files/files/chat_with_website_streaming" +SAMPLE_PIPELINE_FILES_STREAMING = { + "pipeline_wrapper.py": (TEST_FILES_DIR_STREAMING / "pipeline_wrapper.py").read_text(), + "chat_with_website.yml": (TEST_FILES_DIR_STREAMING / "chat_with_website.yml").read_text(), +} + + +def test_get_models_empty(): + response = client.get("/models") + assert response.status_code == 200 + assert response.json() == {"data": [], "object": "list"} + + +def test_get_models(): + pipeline_data = {"name": "test_pipeline", "files": SAMPLE_PIPELINE_FILES} + + response = client.post("/deploy_files", json=pipeline_data) + assert response.status_code == 200 + assert response.json() == {"name": "test_pipeline"} + + response = client.get("/models") + response_data = response.json() + + expected_response = ModelsResponse( + object="list", + data=[ + ModelObject( + id="test_pipeline", + name="test_pipeline", + object="model", + created=response_data["data"][0]["created"], + owned_by="hayhooks", + ) + ], + ) + + assert response.status_code == 200 + assert response_data == expected_response.model_dump() + + +def test_chat_completion_success(deploy_files): + pipeline_data = {"name": "test_pipeline", "files": SAMPLE_PIPELINE_FILES} + + response = deploy_files(client, pipeline_data["name"], pipeline_data["files"]) + assert response.status_code == 200 + assert response.json() == {"name": "test_pipeline"} + + # This is a sample request coming from openai-webui + request = ChatRequest( + stream=False, + model="test_pipeline", + messages=[{"role": "user", "content": "what is Redis?"}], + features={"web_search": False}, + session_id="_Qtpw_fE4g9dMKVKAAAP", + chat_id="7d436049-d316-462a-b1c6-c61740f979c9", + id="b8050e7d-d6ec-4dbc-b69e-6b38d36d847e", + background_tasks={"title_generation": True, "tags_generation": True}, + ) + + response = client.post("/chat/completions", json=request.model_dump()) + assert response.status_code == 200 + + response_data = response.json() + chat_completion = ChatCompletion(**response_data) + assert chat_completion.object == "chat.completion" + assert chat_completion.model == "test_pipeline" + assert len(chat_completion.choices) == 1 + assert chat_completion.choices[0].message.content + assert chat_completion.choices[0].index == 0 + assert chat_completion.choices[0].logprobs is None + + +def test_chat_completion_invalid_model(): + request = ChatRequest(model="nonexistent_model", messages=[{"role": "user", "content": "Hello"}]) + + response = client.post("/chat/completions", json=request.model_dump()) + assert response.status_code == 404 + + +def test_chat_completion_not_implemented(deploy_files): + pipeline_file = Path(__file__).parent / "test_files/files/no_chat/pipeline_wrapper.py" + pipeline_data = {"name": "test_pipeline_no_chat", "files": {"pipeline_wrapper.py": pipeline_file.read_text()}} + + response = deploy_files(client, pipeline_data["name"], pipeline_data["files"]) + assert response.status_code == 200 + assert response.json() == {"name": "test_pipeline_no_chat"} + + request = ChatRequest(model="test_pipeline_no_chat", messages=[{"role": "user", "content": "Hello"}]) + + response = client.post("/chat/completions", json=request.model_dump()) + assert response.status_code == 501 + assert response.json()["detail"] == "Chat endpoint not implemented for this model" + + +def test_chat_completion_streaming(deploy_files): + pipeline_data = {"name": "test_pipeline_streaming", "files": SAMPLE_PIPELINE_FILES_STREAMING} + + response = deploy_files(client, pipeline_data["name"], pipeline_data["files"]) + assert response.status_code == 200 + assert response.json() == {"name": "test_pipeline_streaming"} + + request = ChatRequest( + model="test_pipeline_streaming", + messages=[{"role": "user", "content": "what is Redis?"}], + ) + + response = client.post("/chat/completions", json=request.model_dump()) + + # response is a stream of SSE events + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/event-stream; charset=utf-8" + + # collect the chunks + chunks = collect_chunks(response) + + # check if the chunks are valid + assert len(chunks) > 0 + assert chunks[0].startswith("data:") + assert chunks[-1].startswith("data:") + + # check if the chunks are valid ChatCompletion objects + sample_chunk = chunks[1] + chat_completion = ChatCompletion(**json.loads(sample_chunk.split("data:")[1])) + assert chat_completion.object == "chat.completion.chunk" + assert chat_completion.model == "test_pipeline_streaming" + assert chat_completion.choices[0].delta.content + assert chat_completion.choices[0].delta.role == "assistant" + assert chat_completion.choices[0].index == 0 + assert chat_completion.choices[0].logprobs is None + + +def test_chat_completion_concurrent_requests(deploy_files): + pipeline_data = {"name": "test_pipeline_streaming", "files": SAMPLE_PIPELINE_FILES_STREAMING} + + response = deploy_files(client, pipeline_data["name"], pipeline_data["files"]) + assert response.status_code == 200 + assert response.json() == {"name": "test_pipeline_streaming"} + + request_1 = ChatRequest(model="test_pipeline_streaming", messages=[{"role": "user", "content": "what is Redis?"}]) + request_2 = ChatRequest(model="test_pipeline_streaming", messages=[{"role": "user", "content": "what is MongoDB?"}]) + + # run the requests concurrently + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(client.post, "/chat/completions", json=request_1.model_dump()), + executor.submit(client.post, "/chat/completions", json=request_2.model_dump()), + ] + results = [future.result() for future in futures] + + assert results[0].status_code == 200 + assert results[1].status_code == 200 + + chunks_1 = collect_chunks(results[0]) + chunks_2 = collect_chunks(results[1]) + + # check if the responses are valid + assert "Redis" in chunks_1[0] # "Redis" is the first chunk (see pipeline_wrapper.py) + assert "This" in chunks_2[0] # "This" is the first chunk (see pipeline_wrapper.py) diff --git a/tests/test_registry.py b/tests/test_registry.py index fdfbc9b..94b7aee 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -27,7 +27,7 @@ def setup(self) -> None: def run_api(self, urls: List[str], question: str) -> dict: return {} - def run_chat(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> dict: + def run_chat_completion(self, model: str, messages: List[dict], body: dict) -> dict: return {} return TestPipelineWrapper