diff --git a/package/kedro_viz/launchers/jupyter.py b/package/kedro_viz/launchers/jupyter.py index f46ca30274..5c9da6dcd6 100644 --- a/package/kedro_viz/launchers/jupyter.py +++ b/package/kedro_viz/launchers/jupyter.py @@ -5,12 +5,14 @@ import logging import multiprocessing import os +import shlex import socket from contextlib import closing from typing import Any, Dict import IPython from IPython.display import HTML, display +from watchgod import RegExpWatcher, run_process from kedro_viz.launchers.utils import _check_viz_up, _wait_for from kedro_viz.server import DEFAULT_HOST, DEFAULT_PORT, run_server @@ -78,13 +80,24 @@ def _display_databricks_html(port: int): # pragma: no cover print(f"Kedro-Viz is available at {url}") -def run_viz(port: int = None, local_ns: Dict[str, Any] = None) -> None: +def parse_args(args): # pragma: no cover + """Parses the args string and returns a dictionary of arguments.""" + parsed_args = shlex.split(args) + arg_dict = { + arg.lstrip("-").split("=")[0]: arg.split("=")[1] if "=" in arg else True + for arg in parsed_args + } + return arg_dict + + +def run_viz( # pylint: disable=too-many-locals + args: str = "", local_ns: Dict[str, Any] = None +) -> None: """ - Line magic function to start kedro viz. It calls a kedro viz in a process and displays it in - the Jupyter notebook environment. + Line magic function to start Kedro Viz with optional arguments. Args: - port: TCP port that viz will listen to. Defaults to 4141. + args: String of arguments to pass to Kedro Viz. If empty, defaults will be used. local_ns: Local namespace with local variables of the scope where the line magic is invoked. This argument must be in the signature, even though it is not used. This is because the Kedro IPython extension registers line magics with @@ -92,12 +105,23 @@ def run_viz(port: int = None, local_ns: Dict[str, Any] = None) -> None: https://ipython.readthedocs.io/en/stable/config/custommagics.html """ - port = port or DEFAULT_PORT # Default argument doesn't work in Jupyter line magic. - host = _DATABRICKS_HOST if _is_databricks() else DEFAULT_HOST - port = _allocate_port( - host, start_at=int(port) - ) # Line magic provides string arguments by default, so we need to convert to int. - + # Parse arguments + arg_dict = parse_args(args) + + host = arg_dict.get("host", _DATABRICKS_HOST if _is_databricks() else DEFAULT_HOST) + port = int(arg_dict.get("port", DEFAULT_PORT)) + load_file = arg_dict.get("load-file", None) + save_file = arg_dict.get("save-file", None) + pipeline = arg_dict.get("pipeline", None) + env = arg_dict.get("env", None) + autoreload = arg_dict.get("autoreload", False) + ignore_plugins = arg_dict.get("ignore-plugins", False) + params = arg_dict.get("params", "") + + # Allocate port + port = _allocate_port(host, start_at=port) + + # Terminate existing process if needed if port in _VIZ_PROCESSES and _VIZ_PROCESSES[port].is_alive(): _VIZ_PROCESSES[port].terminate() @@ -107,12 +131,34 @@ def run_viz(port: int = None, local_ns: Dict[str, Any] = None) -> None: else None ) + run_server_kwargs = { + "host": host, + "port": port, + "load_file": load_file, + "save_file": save_file, + "pipeline_name": pipeline, + "env": env, + "autoreload": autoreload, + "ignore_plugins": ignore_plugins, + "extra_params": params, + "project_path": project_path, + } process_context = multiprocessing.get_context("spawn") - viz_process = process_context.Process( - target=run_server, - daemon=True, - kwargs={"project_path": project_path, "host": host, "port": port}, - ) + if autoreload: + run_process_kwargs = { + "path": project_path, + "target": run_server, + "kwargs": run_server_kwargs, + "watcher_cls": RegExpWatcher, + "watcher_kwargs": {"re_files": r"^.*(\.yml|\.yaml|\.py|\.json)$"}, + } + viz_process = process_context.Process( + target=run_process, daemon=False, kwargs={**run_process_kwargs} + ) + else: + viz_process = process_context.Process( + target=run_server, daemon=True, kwargs={**run_server_kwargs} + ) viz_process.start() _VIZ_PROCESSES[port] = viz_process diff --git a/package/tests/test_launchers/test_jupyter.py b/package/tests/test_launchers/test_jupyter.py index 60447a4550..6a13309713 100644 --- a/package/tests/test_launchers/test_jupyter.py +++ b/package/tests/test_launchers/test_jupyter.py @@ -29,6 +29,13 @@ def test_run_viz(self, mocker, patched_check_viz_up): "project_path": None, "host": "127.0.0.1", "port": 4141, + "load_file": None, + "save_file": None, + "pipeline_name": None, + "env": None, + "autoreload": False, + "ignore_plugins": False, + "extra_params": "", }, ) mock_jupyter_display.assert_called_once() @@ -46,13 +53,20 @@ def test_run_viz(self, mocker, patched_check_viz_up): "project_path": None, "host": "127.0.0.1", "port": 4141, + "load_file": None, + "save_file": None, + "pipeline_name": None, + "env": None, + "autoreload": False, + "ignore_plugins": False, + "extra_params": "", }, ) assert set(_VIZ_PROCESSES.keys()) == {4141} def test_run_viz_invalid_port(self, mocker, patched_check_viz_up): with pytest.raises(ValueError): - run_viz(port=999999) + run_viz("--port=999999") def test_exception_when_viz_cannot_be_launched(self, mocker): mocker.patch( @@ -85,6 +99,13 @@ def test_run_viz_on_databricks(self, mocker, patched_check_viz_up, monkeypatch): "project_path": None, "host": "0.0.0.0", "port": 4141, + "load_file": None, + "save_file": None, + "pipeline_name": None, + "env": None, + "autoreload": False, + "ignore_plugins": False, + "extra_params": "", }, ) databricks_display.assert_called_once() @@ -104,3 +125,29 @@ def test_run_viz_creates_correct_link(self, mocker, patched_check_viz_up): displayed_html = mock_display_html.call_args[0][0].data assert 'target="_blank"' in displayed_html assert "Open Kedro-Viz" in displayed_html + + def test_run_viz_with_autoreload(self, mocker, patched_check_viz_up): + mock_process_context = mocker.patch("multiprocessing.get_context") + mock_context_instance = mocker.Mock() + mock_process_context.return_value = mock_context_instance + mock_process = mocker.patch.object(mock_context_instance, "Process") + + run_viz("--autoreload", None) + + mock_process.assert_called_once_with( + target=mocker.ANY, + daemon=False, # No daemon for autoreload + kwargs=mocker.ANY, + ) + + def test_run_viz_without_autoreload(self, mocker, patched_check_viz_up): + mock_process_context = mocker.patch("multiprocessing.get_context") + mock_context_instance = mocker.Mock() + mock_process_context.return_value = mock_context_instance + mock_process = mocker.patch.object(mock_context_instance, "Process") + + run_viz("", None) + + mock_process.assert_called_once_with( + target=run_server, daemon=True, kwargs=mocker.ANY + )