diff --git a/haystack/core/pipeline/draw/mermaid.py b/haystack/core/pipeline/draw.py similarity index 71% rename from haystack/core/pipeline/draw/mermaid.py rename to haystack/core/pipeline/draw.py index eb059c7c5c..60aa6ac806 100644 --- a/haystack/core/pipeline/draw/mermaid.py +++ b/haystack/core/pipeline/draw.py @@ -1,17 +1,50 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -import logging import base64 +import logging -import requests import networkx # type:ignore +import requests from haystack.core.errors import PipelineDrawingError +from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs from haystack.core.type_utils import _type_name logger = logging.getLogger(__name__) + +def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph: + """ + Add some extra nodes to show the inputs and outputs of the pipeline. + Also adds labels to edges. + """ + # Label the edges + for inp, outp, key, data in graph.edges(keys=True, data=True): + data[ + "label" + ] = f"{data['from_socket'].name} -> {data['to_socket'].name}{' (opt.)' if not data['mandatory'] else ''}" + graph.add_edge(inp, outp, key=key, **data) + + # Add inputs fake node + graph.add_node("input") + for node, in_sockets in find_pipeline_inputs(graph).items(): + for in_socket in in_sockets: + if not in_socket.senders and in_socket.is_mandatory: + # If this socket has no sender it could be a socket that receives input + # directly when running the Pipeline. We can't know that for sure, in doubt + # we draw it as receiving input directly. + graph.add_edge("input", node, label=in_socket.name, conn_type=_type_name(in_socket.type)) + + # Add outputs fake node + graph.add_node("output") + for node, out_sockets in find_pipeline_outputs(graph).items(): + for out_socket in out_sockets: + graph.add_edge(node, "output", label=out_socket.name, conn_type=_type_name(out_socket.type)) + + return graph + + ARROWTAIL_MANDATORY = "--" ARROWTAIL_OPTIONAL = "-." ARROWHEAD_MANDATORY = "-->" @@ -31,6 +64,8 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph): """ Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access. """ + # Copy the graph to avoid modifying the original + graph = _prepare_for_drawing(graph.copy()) graph_styled = _to_mermaid_text(graph=graph) graphbytes = graph_styled.encode("ascii") @@ -63,6 +98,8 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str: Converts a Networkx graph into Mermaid syntax. The output of this function can be used in the documentation with `mermaid` codeblocks and it will be automatically rendered. """ + # Copy the graph to avoid modifying the original + graph = _prepare_for_drawing(graph.copy()) sockets = { comp: "".join( [ diff --git a/haystack/core/pipeline/draw/__init__.py b/haystack/core/pipeline/draw/__init__.py deleted file mode 100644 index c1764a6e03..0000000000 --- a/haystack/core/pipeline/draw/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 diff --git a/haystack/core/pipeline/draw/draw.py b/haystack/core/pipeline/draw/draw.py deleted file mode 100644 index c68ac3277f..0000000000 --- a/haystack/core/pipeline/draw/draw.py +++ /dev/null @@ -1,100 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -from typing import Literal, Optional, Dict, get_args, Any - -import logging -from pathlib import Path - -import networkx # type:ignore - -from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs -from haystack.core.pipeline.draw.graphviz import _to_agraph -from haystack.core.pipeline.draw.mermaid import _to_mermaid_image, _to_mermaid_text -from haystack.core.type_utils import _type_name - -logger = logging.getLogger(__name__) -RenderingEngines = Literal["graphviz", "mermaid-image", "mermaid-text"] - - -def _draw( - graph: networkx.MultiDiGraph, - path: Path, - engine: RenderingEngines = "mermaid-image", - style_map: Optional[Dict[str, str]] = None, -) -> None: - """ - Renders the pipeline graph and saves it to file. - """ - converted_graph = _convert(graph=graph, engine=engine, style_map=style_map) - - if engine == "graphviz": - converted_graph.draw(path) - - elif engine == "mermaid-image": - with open(path, "wb") as imagefile: - imagefile.write(converted_graph) - - elif engine == "mermaid-text": - with open((path), "w", encoding="utf-8") as textfile: - textfile.write(converted_graph) - - else: - raise ValueError(f"Unknown rendering engine '{engine}'. Choose one from: {get_args(RenderingEngines)}.") - - logger.debug("Pipeline diagram saved at %s", path) - - -def _convert( - graph: networkx.MultiDiGraph, engine: RenderingEngines = "mermaid-image", style_map: Optional[Dict[str, str]] = None -) -> Any: - """ - Renders the pipeline graph with the correct render and returns it. - """ - graph = _prepare_for_drawing(graph=graph, style_map=style_map or {}) - - if engine == "graphviz": - return _to_agraph(graph=graph) - - if engine == "mermaid-image": - return _to_mermaid_image(graph=graph) - - if engine == "mermaid-text": - return _to_mermaid_text(graph=graph) - - raise ValueError(f"Unknown rendering engine '{engine}'. Choose one from: {get_args(RenderingEngines)}.") - - -def _prepare_for_drawing(graph: networkx.MultiDiGraph, style_map: Dict[str, str]) -> networkx.MultiDiGraph: - """ - Prepares the graph to be drawn: adds explitic input and output nodes, labels the edges, applies the styles, etc. - """ - # Apply the styles - if style_map: - for node, style in style_map.items(): - graph.nodes[node]["style"] = style - - # Label the edges - for inp, outp, key, data in graph.edges(keys=True, data=True): - data[ - "label" - ] = f"{data['from_socket'].name} -> {data['to_socket'].name}{' (opt.)' if not data['mandatory'] else ''}" - graph.add_edge(inp, outp, key=key, **data) - - # Draw the inputs - graph.add_node("input") - for node, in_sockets in find_pipeline_inputs(graph).items(): - for in_socket in in_sockets: - if not in_socket.senders and in_socket.is_mandatory: - # If this socket has no sender it could be a socket that receives input - # directly when running the Pipeline. We can't know that for sure, in doubt - # we draw it as receiving input directly. - graph.add_edge("input", node, label=in_socket.name, conn_type=_type_name(in_socket.type)) - - # Draw the outputs - graph.add_node("output") - for node, out_sockets in find_pipeline_outputs(graph).items(): - for out_socket in out_sockets: - graph.add_edge(node, "output", label=out_socket.name, conn_type=_type_name(out_socket.type)) - - return graph diff --git a/haystack/core/pipeline/draw/graphviz.py b/haystack/core/pipeline/draw/graphviz.py deleted file mode 100644 index fb7f311ba3..0000000000 --- a/haystack/core/pipeline/draw/graphviz.py +++ /dev/null @@ -1,41 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -import logging - -import networkx # type:ignore - -from networkx.drawing.nx_agraph import to_agraph as nx_to_agraph # type:ignore - - -logger = logging.getLogger(__name__) - -# pyright: reportMissingImports=false -# pylint: disable=unused-import,import-outside-toplevel - - -def _to_agraph(graph: networkx.MultiDiGraph): - """ - Renders a pipeline graph using PyGraphViz. You need to install it and all its system dependencies for it to work. - """ - try: - import pygraphviz # type: ignore - except (ModuleNotFoundError, ImportError) as exc: - raise ImportError( - "Can't use 'pygraphviz' to draw this pipeline: pygraphviz could not be imported. " - "Make sure pygraphviz is installed and all its system dependencies are setup correctly." - ) from exc - - for inp, outp, key, data in graph.out_edges("input", keys=True, data=True): - data["style"] = "dashed" - graph.add_edge(inp, outp, key=key, **data) - - for inp, outp, key, data in graph.in_edges("output", keys=True, data=True): - data["style"] = "dashed" - graph.add_edge(inp, outp, key=key, **data) - - graph.nodes["input"]["shape"] = "plain" - graph.nodes["output"]["shape"] = "plain" - agraph = nx_to_agraph(graph) - agraph.layout("dot") - return agraph diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index eb38e4dc18..98ba8df871 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -11,11 +11,19 @@ import networkx # type:ignore from haystack.core.component import Component, InputSocket, OutputSocket, component -from haystack.core.errors import PipelineConnectError, PipelineError, PipelineRuntimeError, PipelineValidationError -from haystack.core.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs -from haystack.core.pipeline.draw.draw import RenderingEngines, _draw +from haystack.core.errors import ( + PipelineConnectError, + PipelineDrawingError, + PipelineError, + PipelineRuntimeError, + PipelineValidationError, +) from haystack.core.serialization import component_from_dict, component_to_dict from haystack.core.type_utils import _type_name, _types_are_compatible +from haystack.utils import is_in_jupyter + +from .descriptions import find_pipeline_inputs, find_pipeline_outputs +from .draw import _to_mermaid_image logger = logging.getLogger(__name__) @@ -63,6 +71,34 @@ def __eq__(self, other) -> bool: return False return self.to_dict() == other.to_dict() + def __repr__(self) -> str: + """ + Returns a text representation of the Pipeline. + If this runs in a Jupyter notebook, it will instead display the Pipeline image. + """ + if is_in_jupyter(): + # If we're in a Jupyter notebook we want to display the image instead of the text repr. + self.show() + return "" + + res = f"{object.__repr__(self)}\n" + if self.metadata: + res += "🧱 Metadata\n" + for k, v in self.metadata.items(): + res += f" - {k}: {v}\n" + + res += "🚅 Components\n" + for name, instance in self.graph.nodes(data="instance"): + res += f" - {name}: {instance.__class__.__name__}\n" + + res += "🛤️ Connections\n" + for sender, receiver, edge_data in self.graph.edges(data=True): + sender_socket = edge_data["from_socket"].name + receiver_socket = edge_data["to_socket"].name + res += f" - {sender}.{sender_socket} -> {receiver}.{receiver_socket} ({edge_data['conn_type']})\n" + + return res + def to_dict(self) -> Dict[str, Any]: """ Returns this Pipeline instance as a dictionary. @@ -441,24 +477,29 @@ def outputs(self) -> Dict[str, Dict[str, Any]]: } return outputs - def draw(self, path: Path, engine: RenderingEngines = "mermaid-image") -> None: + def show(self) -> None: """ - Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid. - Run `pip install graphviz` or `pip install mermaid` to install missing dependencies. + If running in a Jupyter notebook, display an image representing this `Pipeline`. - Args: - path: where to save the diagram. - engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-image'. - Default is 'mermaid-image'. + """ + if is_in_jupyter(): + from IPython.display import Image, display - Returns: - None + image_data = _to_mermaid_image(self.graph) - Raises: - ImportError: if `engine='graphviz'` and `pygraphviz` is not installed. - HTTPConnectionError: (and similar) if the internet connection is down or other connection issues. + display(Image(image_data)) + else: + msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally." + raise PipelineDrawingError(msg) + + def draw(self, path: Path) -> None: + """ + Save an image representing this `Pipeline` to `path`. """ - _draw(graph=networkx.MultiDiGraph(self.graph), path=path, engine=engine) + # Before drawing we edit a bit the graph, to avoid modifying the original that is + # used for running the pipeline we copy it. + image_data = _to_mermaid_image(self.graph) + Path(path).write_bytes(image_data) def warm_up(self): """ diff --git a/haystack/utils/__init__.py b/haystack/utils/__init__.py index 9e0ae39000..9282a612fb 100644 --- a/haystack/utils/__init__.py +++ b/haystack/utils/__init__.py @@ -1,5 +1,19 @@ -from haystack.utils.expit import expit -from haystack.utils.requests_utils import request_with_retry -from haystack.utils.filters import document_matches_filter -from haystack.utils.device import ComponentDevice, DeviceType, Device, DeviceMap -from haystack.utils.auth import Secret, deserialize_secrets_inplace +from .auth import Secret, deserialize_secrets_inplace +from .device import ComponentDevice, Device, DeviceMap, DeviceType +from .expit import expit +from .filters import document_matches_filter +from .jupyter import is_in_jupyter +from .requests_utils import request_with_retry + +__all__ = [ + "Secret", + "deserialize_secrets_inplace", + "ComponentDevice", + "Device", + "DeviceMap", + "DeviceType", + "expit", + "document_matches_filter", + "is_in_jupyter", + "request_with_retry", +] diff --git a/haystack/utils/jupyter.py b/haystack/utils/jupyter.py new file mode 100644 index 0000000000..3c4ffdc6c3 --- /dev/null +++ b/haystack/utils/jupyter.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +def is_in_jupyter() -> bool: + """ + Utility function to easily check if we are in a Jupyter or Google Colab environment. + + Inspired by: + https://github.com/explosion/spaCy/blob/e1249d3722765aaca56f538e830add7014d20e2a/spacy/util.py#L1079 + + Returns True if in Jupyter or Google Colab, False otherwise + """ + # + # + try: + # We don't need to import `get_ipython` as it's always present in Jupyter notebooks + if get_ipython().__class__.__name__ == "ZMQInteractiveShell": # type: ignore[name-defined] + return True # Jupyter notebook or qtconsole + if get_ipython().__class__.__module__ == "google.colab._shell": # type: ignore[name-defined] + return True # Colab notebook + except NameError: + pass # Probably standard Python interpreter + return False diff --git a/releasenotes/notes/enhance-pipeline-draw-5fe3131db71f6f54.yaml b/releasenotes/notes/enhance-pipeline-draw-5fe3131db71f6f54.yaml new file mode 100644 index 0000000000..8f586b90fb --- /dev/null +++ b/releasenotes/notes/enhance-pipeline-draw-5fe3131db71f6f54.yaml @@ -0,0 +1,7 @@ +--- +enhancements: + - | + Add new `Pipeline.show()` method to generated image inline if run in a Jupyter notebook. + If called outside a notebook it will raise a `PipelineDrawingError`. + `Pipeline.draw()` has also been simplified and the `engine` argument has been removed. + Now all images will be generated using Mermaid. diff --git a/releasenotes/notes/enhance-repr-0c5efa1e2ca6bafa.yaml b/releasenotes/notes/enhance-repr-0c5efa1e2ca6bafa.yaml new file mode 100644 index 0000000000..a9f1914efc --- /dev/null +++ b/releasenotes/notes/enhance-repr-0c5efa1e2ca6bafa.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Customize `Pipeline.__repr__()` to return a nice text representation of it. + If run on a Jupyter notebook it will instead have the same behaviour as `Pipeline.show()`. diff --git a/test/core/conftest.py b/test/core/conftest.py index 1c2fde1b86..4be63b2f2a 100644 --- a/test/core/conftest.py +++ b/test/core/conftest.py @@ -1,23 +1,9 @@ from pathlib import Path +from unittest.mock import MagicMock, patch import pytest -from unittest.mock import patch, MagicMock - @pytest.fixture def test_files(): return Path(__file__).parent / "test_files" - - -@pytest.fixture(autouse=True) -def mock_mermaid_request(test_files): - """ - Prevents real requests to https://mermaid.ink/ - """ - with patch("haystack.core.pipeline.draw.mermaid.requests.get") as mock_get: - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.content = open(test_files / "mermaid_mock" / "test_response.png", "rb").read() - mock_get.return_value = mock_response - yield diff --git a/test/core/pipeline/test_default_value.py b/test/core/pipeline/test_default_value.py index 8b37017f2a..18f2f93542 100644 --- a/test/core/pipeline/test_default_value.py +++ b/test/core/pipeline/test_default_value.py @@ -1,13 +1,12 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import logging from pathlib import Path from haystack.core.component import component from haystack.core.pipeline import Pipeline -import logging - logging.basicConfig(level=logging.DEBUG) @@ -18,10 +17,9 @@ def run(self, a: int, b: int = 2): return {"c": a + b} -def test_pipeline(tmp_path): +def test_pipeline(): pipeline = Pipeline() pipeline.add_component("with_defaults", WithDefault()) - pipeline.draw(tmp_path / "default_value.png") # Pass all the inputs results = pipeline.run({"with_defaults": {"a": 40, "b": 30}}) diff --git a/test/core/pipeline/test_double_loop_pipeline.py b/test/core/pipeline/test_double_loop_pipeline.py index 7791e62c8a..886e91ec12 100644 --- a/test/core/pipeline/test_double_loop_pipeline.py +++ b/test/core/pipeline/test_double_loop_pipeline.py @@ -10,7 +10,7 @@ logging.basicConfig(level=logging.DEBUG) -def test_pipeline(tmp_path): +def test_pipeline(): accumulator = Accumulate() pipeline = Pipeline(max_loops_allowed=10) @@ -31,8 +31,6 @@ def test_pipeline(tmp_path): pipeline.connect("add_three.result", "multiplexer") pipeline.connect("below_10.above", "add_two.value") - pipeline.draw(tmp_path / "double_loop_pipeline.png") - results = pipeline.run({"add_one": {"value": 3}}) assert results == {"add_two": {"result": 13}} diff --git a/test/core/pipeline/test_draw.py b/test/core/pipeline/test_draw.py index a142f01c57..6554f241cf 100644 --- a/test/core/pipeline/test_draw.py +++ b/test/core/pipeline/test_draw.py @@ -1,41 +1,52 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -import os -import filecmp +from unittest.mock import MagicMock, patch -from unittest.mock import patch, MagicMock import pytest import requests -from haystack.core.pipeline import Pipeline -from haystack.core.pipeline.draw.draw import _draw, _convert from haystack.core.errors import PipelineDrawingError -from haystack.testing.sample_components import Double, AddFixedValue +from haystack.core.pipeline import Pipeline +from haystack.core.pipeline.draw import _to_mermaid_image, _to_mermaid_text +from haystack.testing.sample_components import AddFixedValue, Double @pytest.mark.integration -def test_draw_mermaid_image(tmp_path, test_files): +def test_to_mermaid_image(test_files): pipe = Pipeline() pipe.add_component("comp1", Double()) pipe.add_component("comp2", Double()) pipe.connect("comp1", "comp2") pipe.connect("comp2", "comp1") - _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image") - assert os.path.exists(tmp_path / "test_pipe.jpg") - assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "mermaid_mock" / "test_response.png") + image_data = _to_mermaid_image(pipe.graph) + test_image = test_files / "test_mermaid_graph.png" + assert test_image.read_bytes() == image_data -@pytest.mark.integration -def test_draw_mermaid_img_failing_request(tmp_path): +@patch("haystack.core.pipeline.draw.requests") +def test_to_mermaid_image_does_not_edit_graph(mock_requests): + pipe = Pipeline() + pipe.add_component("comp1", AddFixedValue(add=3)) + pipe.add_component("comp2", Double()) + pipe.connect("comp1.result", "comp2.value") + pipe.connect("comp2.value", "comp1.value") + + mock_requests.get.return_value = MagicMock(status_code=200) + expected_pipe = pipe.to_dict() + _to_mermaid_image(pipe.graph) + assert expected_pipe == pipe.to_dict() + + +def test_to_mermaid_image_failing_request(tmp_path): pipe = Pipeline() pipe.add_component("comp1", Double()) pipe.add_component("comp2", Double()) pipe.connect("comp1", "comp2") pipe.connect("comp2", "comp1") - with patch("haystack.core.pipeline.draw.mermaid.requests.get") as mock_get: + with patch("haystack.core.pipeline.draw.requests.get") as mock_get: def raise_for_status(self): raise requests.HTTPError() @@ -47,21 +58,19 @@ def raise_for_status(self): mock_get.return_value = mock_response with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink/"): - _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image") + _to_mermaid_image(pipe.graph) -@pytest.mark.integration -def test_draw_mermaid_text(tmp_path): +def test_to_mermaid_text(): pipe = Pipeline() pipe.add_component("comp1", AddFixedValue(add=3)) pipe.add_component("comp2", Double()) pipe.connect("comp1.result", "comp2.value") pipe.connect("comp2.value", "comp1.value") - _draw(pipe.graph, tmp_path / "test_pipe.md", engine="mermaid-text") - assert os.path.exists(tmp_path / "test_pipe.md") + text = _to_mermaid_text(pipe.graph) assert ( - open(tmp_path / "test_pipe.md", "r").read() + text == """ %%{ init: {'theme': 'neutral' } }%% @@ -75,23 +84,13 @@ def test_draw_mermaid_text(tmp_path): ) -def test_draw_unknown_engine(tmp_path): - pipe = Pipeline() - pipe.add_component("comp1", Double()) - pipe.add_component("comp2", Double()) - pipe.connect("comp1", "comp2") - pipe.connect("comp2", "comp1") - - with pytest.raises(ValueError, match="Unknown rendering engine 'unknown'"): - _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="unknown") - - -def test_convert_unknown_engine(tmp_path): +def test_to_mermaid_text_does_not_edit_graph(): pipe = Pipeline() - pipe.add_component("comp1", Double()) + pipe.add_component("comp1", AddFixedValue(add=3)) pipe.add_component("comp2", Double()) - pipe.connect("comp1", "comp2") - pipe.connect("comp2", "comp1") + pipe.connect("comp1.result", "comp2.value") + pipe.connect("comp2.value", "comp1.value") - with pytest.raises(ValueError, match="Unknown rendering engine 'unknown'"): - _convert(pipe.graph, engine="unknown") + expected_pipe = pipe.to_dict() + _to_mermaid_text(pipe.graph) + assert expected_pipe == pipe.to_dict() diff --git a/test/core/pipeline/test_draw_graphviz.py b/test/core/pipeline/test_draw_graphviz.py deleted file mode 100644 index 03c4197540..0000000000 --- a/test/core/pipeline/test_draw_graphviz.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -import os -import filecmp - -import pytest - -from haystack.core.pipeline import Pipeline -from haystack.core.pipeline.draw.draw import _draw -from haystack.testing.sample_components import Double - - -pygraphviz = pytest.importorskip("pygraphviz") - - -@pytest.mark.integration -def test_draw_pygraphviz(tmp_path, test_files): - pipe = Pipeline() - pipe.add_component("comp1", Double()) - pipe.add_component("comp2", Double()) - pipe.connect("comp1", "comp2") - - _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="graphviz") - assert os.path.exists(tmp_path / "test_pipe.jpg") - assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "pipeline_draw" / "pygraphviz.jpg") diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 45f1c883da..4e66f38f17 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -3,11 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 import logging from typing import Optional +from unittest.mock import patch import pytest from haystack.core.component.types import InputSocket, OutputSocket -from haystack.core.errors import PipelineError, PipelineRuntimeError +from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineRuntimeError from haystack.core.pipeline import Pipeline from haystack.testing.factory import component_class from haystack.testing.sample_components import AddFixedValue, Double @@ -15,6 +16,41 @@ logging.basicConfig(level=logging.DEBUG) +@patch("haystack.core.pipeline.pipeline._to_mermaid_image") +@patch("haystack.core.pipeline.pipeline.is_in_jupyter") +@patch("IPython.display.Image") +@patch("IPython.display.display") +def test_show_in_notebook(mock_ipython_display, mock_ipython_image, mock_is_in_jupyter, mock_to_mermaid_image): + pipe = Pipeline() + + mock_to_mermaid_image.return_value = b"some_image_data" + mock_is_in_jupyter.return_value = True + + pipe.show() + mock_ipython_image.assert_called_once_with(b"some_image_data") + mock_ipython_display.assert_called_once() + + +@patch("haystack.core.pipeline.pipeline.is_in_jupyter") +def test_show_not_in_notebook(mock_is_in_jupyter): + pipe = Pipeline() + + mock_is_in_jupyter.return_value = False + + with pytest.raises(PipelineDrawingError): + pipe.show() + + +@patch("haystack.core.pipeline.pipeline._to_mermaid_image") +def test_draw(mock_to_mermaid_image, tmp_path): + pipe = Pipeline() + mock_to_mermaid_image.return_value = b"some_image_data" + + image_path = tmp_path / "test.png" + pipe.draw(path=image_path) + assert image_path.read_bytes() == mock_to_mermaid_image.return_value + + def test_add_component_to_different_pipelines(): first_pipe = Pipeline() second_pipe = Pipeline() @@ -43,6 +79,49 @@ def test_get_component_name_not_added_to_pipeline(): assert pipe.get_component_name(some_component) == "" +@patch("haystack.core.pipeline.pipeline.is_in_jupyter") +def test_repr(mock_is_in_jupyter): + pipe = Pipeline(metadata={"test": "test"}, max_loops_allowed=42) + pipe.add_component("add_two", AddFixedValue(add=2)) + pipe.add_component("add_default", AddFixedValue()) + pipe.add_component("double", Double()) + pipe.connect("add_two", "double") + pipe.connect("double", "add_default") + + expected_repr = ( + f"{object.__repr__(pipe)}\n" + "🧱 Metadata\n" + " - test: test\n" + "🚅 Components\n" + " - add_two: AddFixedValue\n" + " - add_default: AddFixedValue\n" + " - double: Double\n" + "🛤️ Connections\n" + " - add_two.result -> double.value (int)\n" + " - double.value -> add_default.value (int)\n" + ) + # Simulate not being in a notebook + mock_is_in_jupyter.return_value = False + assert repr(pipe) == expected_repr + + +@patch("haystack.core.pipeline.pipeline.is_in_jupyter") +def test_repr_in_notebook(mock_is_in_jupyter): + pipe = Pipeline(metadata={"test": "test"}, max_loops_allowed=42) + pipe.add_component("add_two", AddFixedValue(add=2)) + pipe.add_component("add_default", AddFixedValue()) + pipe.add_component("double", Double()) + pipe.connect("add_two", "double") + pipe.connect("double", "add_default") + + # Simulate being in a notebook + mock_is_in_jupyter.return_value = True + + with patch.object(Pipeline, "show") as mock_show: + assert repr(pipe) == "" + mock_show.assert_called_once_with() + + def test_run_with_component_that_does_not_return_dict(): BrokenComponent = component_class( "BrokenComponent", input_types={"a": int}, output_types={"b": int}, output=1 # type:ignore diff --git a/test/core/test_files/mermaid_mock/test_response.png b/test/core/test_files/mermaid_mock/test_response.png deleted file mode 100644 index 3bdd0db2f0..0000000000 Binary files a/test/core/test_files/mermaid_mock/test_response.png and /dev/null differ diff --git a/test/core/test_files/test_mermaid_graph.png b/test/core/test_files/test_mermaid_graph.png new file mode 100644 index 0000000000..359e59f625 Binary files /dev/null and b/test/core/test_files/test_mermaid_graph.png differ