Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable %run_viz magic line to use the arguments that Kedro-Viz supports on the command line #1733

76 changes: 61 additions & 15 deletions package/kedro_viz/launchers/jupyter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import logging
import multiprocessing
import os
import shlex
import socket
from contextlib import closing
from typing import Any, Dict

import IPython
from IPython.display import HTML, display
from watchgod import RegExpWatcher, run_process

from kedro_viz.launchers.utils import _check_viz_up, _wait_for
from kedro_viz.server import DEFAULT_HOST, DEFAULT_PORT, run_server
Expand Down Expand Up @@ -78,26 +80,48 @@ def _display_databricks_html(port: int): # pragma: no cover
print(f"Kedro-Viz is available at {url}")


def run_viz(port: int = None, local_ns: Dict[str, Any] = None) -> None:
def parse_args(args): # pragma: no cover
"""Parses the args string and returns a dictionary of arguments."""
parsed_args = shlex.split(args)
arg_dict = {
arg.lstrip("-").split("=")[0]: arg.split("=")[1] if "=" in arg else True
for arg in parsed_args
}
return arg_dict


def run_viz( # pylint: disable=too-many-locals
args: str = "", local_ns: Dict[str, Any] = None
) -> None:
"""
Line magic function to start kedro viz. It calls a kedro viz in a process and displays it in
the Jupyter notebook environment.
Line magic function to start Kedro Viz with optional arguments.

Args:
port: TCP port that viz will listen to. Defaults to 4141.
args: String of arguments to pass to Kedro Viz. If empty, defaults will be used.
local_ns: Local namespace with local variables of the scope where the line magic
is invoked. This argument must be in the signature, even though it is not
used. This is because the Kedro IPython extension registers line magics with
needs_local_scope.
https://ipython.readthedocs.io/en/stable/config/custommagics.html

"""
port = port or DEFAULT_PORT # Default argument doesn't work in Jupyter line magic.
host = _DATABRICKS_HOST if _is_databricks() else DEFAULT_HOST
port = _allocate_port(
host, start_at=int(port)
) # Line magic provides string arguments by default, so we need to convert to int.

# Parse arguments
arg_dict = parse_args(args)

host = arg_dict.get("host", _DATABRICKS_HOST if _is_databricks() else DEFAULT_HOST)
port = int(arg_dict.get("port", DEFAULT_PORT))
load_file = arg_dict.get("load-file", None)
save_file = arg_dict.get("save-file", None)
pipeline = arg_dict.get("pipeline", None)
env = arg_dict.get("env", None)
SajidAlamQB marked this conversation as resolved.
Show resolved Hide resolved
autoreload = arg_dict.get("autoreload", False)
ignore_plugins = arg_dict.get("ignore-plugins", False)
params = arg_dict.get("params", "")

# Allocate port
port = _allocate_port(host, start_at=port)

# Terminate existing process if needed
if port in _VIZ_PROCESSES and _VIZ_PROCESSES[port].is_alive():
_VIZ_PROCESSES[port].terminate()

Expand All @@ -107,12 +131,34 @@ def run_viz(port: int = None, local_ns: Dict[str, Any] = None) -> None:
else None
)

run_server_kwargs = {
"host": host,
"port": port,
"load_file": load_file,
"save_file": save_file,
"pipeline_name": pipeline,
"env": env,
"autoreload": autoreload,
"ignore_plugins": ignore_plugins,
"extra_params": params,
"project_path": project_path,
}
process_context = multiprocessing.get_context("spawn")
viz_process = process_context.Process(
target=run_server,
daemon=True,
kwargs={"project_path": project_path, "host": host, "port": port},
)
if autoreload:
run_process_kwargs = {
"path": project_path,
"target": run_server,
"kwargs": run_server_kwargs,
"watcher_cls": RegExpWatcher,
"watcher_kwargs": {"re_files": r"^.*(\.yml|\.yaml|\.py|\.json)$"},
}
viz_process = process_context.Process(
target=run_process, daemon=False, kwargs={**run_process_kwargs}
)
else:
viz_process = process_context.Process(
target=run_server, daemon=True, kwargs={**run_server_kwargs}
)

viz_process.start()
_VIZ_PROCESSES[port] = viz_process
Expand Down
49 changes: 48 additions & 1 deletion package/tests/test_launchers/test_jupyter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def test_run_viz(self, mocker, patched_check_viz_up):
"project_path": None,
"host": "127.0.0.1",
"port": 4141,
"load_file": None,
"save_file": None,
"pipeline_name": None,
"env": None,
"autoreload": False,
"ignore_plugins": False,
"extra_params": "",
},
)
mock_jupyter_display.assert_called_once()
Expand All @@ -46,13 +53,20 @@ def test_run_viz(self, mocker, patched_check_viz_up):
"project_path": None,
"host": "127.0.0.1",
"port": 4141,
"load_file": None,
"save_file": None,
"pipeline_name": None,
"env": None,
"autoreload": False,
"ignore_plugins": False,
"extra_params": "",
},
)
assert set(_VIZ_PROCESSES.keys()) == {4141}

def test_run_viz_invalid_port(self, mocker, patched_check_viz_up):
with pytest.raises(ValueError):
run_viz(port=999999)
run_viz("--port=999999")

def test_exception_when_viz_cannot_be_launched(self, mocker):
mocker.patch(
Expand Down Expand Up @@ -85,6 +99,13 @@ def test_run_viz_on_databricks(self, mocker, patched_check_viz_up, monkeypatch):
"project_path": None,
"host": "0.0.0.0",
"port": 4141,
"load_file": None,
"save_file": None,
"pipeline_name": None,
"env": None,
"autoreload": False,
"ignore_plugins": False,
"extra_params": "",
},
)
databricks_display.assert_called_once()
Expand All @@ -104,3 +125,29 @@ def test_run_viz_creates_correct_link(self, mocker, patched_check_viz_up):
displayed_html = mock_display_html.call_args[0][0].data
assert 'target="_blank"' in displayed_html
assert "Open Kedro-Viz" in displayed_html

def test_run_viz_with_autoreload(self, mocker, patched_check_viz_up):
mock_process_context = mocker.patch("multiprocessing.get_context")
mock_context_instance = mocker.Mock()
mock_process_context.return_value = mock_context_instance
mock_process = mocker.patch.object(mock_context_instance, "Process")

run_viz("--autoreload", None)

mock_process.assert_called_once_with(
target=mocker.ANY,
daemon=False, # No daemon for autoreload
kwargs=mocker.ANY,
)

def test_run_viz_without_autoreload(self, mocker, patched_check_viz_up):
mock_process_context = mocker.patch("multiprocessing.get_context")
mock_context_instance = mocker.Mock()
mock_process_context.return_value = mock_context_instance
mock_process = mocker.patch.object(mock_context_instance, "Process")

run_viz("", None)

mock_process.assert_called_once_with(
target=run_server, daemon=True, kwargs=mocker.ANY
)
Loading