Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

%load_node experiments #3568

9 changes: 7 additions & 2 deletions kedro/ipython/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ def _find_node(node_name: str, pipelines: _ProjectPipelines) -> Node:
except ValueError:
continue
# If reached the node was not found in the project
raise ValueError(f"Node with name='{node_name}' not found in any pipelines.")
raise ValueError(
f"Node with name='{node_name}' not found in any pipelines. Remember to specify the node name, not the node function."
)


def _prepare_imports(node_func: Callable) -> str:
Expand Down Expand Up @@ -280,7 +282,10 @@ def _prepare_node_inputs(node: Node) -> str:
node_inputs = node.inputs
func_params = list(signature.parameters)

statements = ["# Prepare necessary inputs for debugging"]
statements = [
"# Prepare necessary inputs for debugging",
"# All debugging inputs must be defined in your project catalog",
]

for node_input, func_param in zip(node_inputs, func_params):
statements.append(f'{func_param} = catalog.load("{node_input}")')
Expand Down
69 changes: 24 additions & 45 deletions tests/ipython/test_ipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def my_func:,


@pytest.fixture
def dummy_pipeline(dummy_node):
# return a list of pipelines
return {"dummy": modular_pipeline([dummy_node])}
def dummy_pipelines(dummy_node):
# return a dict of pipelines
return {"dummy_pipeline": modular_pipeline([dummy_node])}


class TestLoadKedroObjects:
Expand Down Expand Up @@ -382,7 +382,7 @@ class MockKedroContext:


class TestLoadNodeMagic:
def test_load_node_magic(self, mocker, dummy_function_file_lines, dummy_pipeline):
def test_load_node_magic(self, mocker, dummy_function_file_lines, dummy_pipelines):
# Reimport `pipelines` from `kedro.framework.project` to ensure that
# it was not removed by prior tests.
from kedro.framework.project import pipelines
Expand All @@ -393,40 +393,22 @@ def test_load_node_magic(self, mocker, dummy_function_file_lines, dummy_pipeline
mocker.patch(
"builtins.open", mocker.mock_open(read_data=dummy_function_file_lines)
)
pipelines.configure("dummy_pipeline") # Setup the pipelines
my_pipelines = dummy_pipeline

def my_register_pipeline():
return my_pipelines

mocker.patch.object(
pipelines,
"_get_pipelines_registry_callable",
return_value=my_register_pipeline,
)
mock_pipeline_values = dummy_pipelines.values()
mocker.patch.object(pipelines, "values", return_value=mock_pipeline_values)

node_to_load = "dummy_node"
magic_load_node(node_to_load)

def test_load_node(self, mocker, dummy_function_file_lines, dummy_pipeline):
def test_load_node(self, mocker, dummy_function_file_lines, dummy_pipelines):
# wraps all the other functions
mocker.patch(
"builtins.open", mocker.mock_open(read_data=dummy_function_file_lines)
)
pipelines.configure("dummy_pipeline") # Setup the pipelines

my_pipelines = dummy_pipeline

def my_register_pipeline():
return my_pipelines

mocker.patch.object(
pipelines,
"_get_pipelines_registry_callable",
return_value=my_register_pipeline,
)
mock_pipeline_values = dummy_pipelines.values()
mocker.patch.object(pipelines, "values", return_value=mock_pipeline_values)

node_inputs = """# Prepare necessary inputs for debugging
# All debugging inputs must be defined in your project catalog
dummy_input = catalog.load("dummy_input")
my_input = catalog.load("extra_input")"""

Expand Down Expand Up @@ -455,20 +437,20 @@ def my_register_pipeline():
for cell, expected_cell in zip(cells_list, expected_cells):
assert cell == expected_cell

def test_find_node(self, mocker, dummy_pipeline, dummy_node):
mocker.patch.object(pipelines, "values", return_value=dummy_pipeline)
def test_find_node(self, dummy_pipelines, dummy_node):
node_to_find = "dummy_node"
result = _find_node(node_to_find, dummy_pipeline)
result = _find_node(node_to_find, dummy_pipelines)
assert result == dummy_node

def test_node_not_found(self, mocker, dummy_pipeline):
mocker.patch.object(pipelines, "values", return_value=dummy_pipeline)
def test_node_not_found(self, dummy_pipelines):
node_to_find = "not_a_node"
dummy_registered_pipelines = dummy_pipelines
with pytest.raises(ValueError) as excinfo:
_find_node(node_to_find, dummy_pipeline)
_find_node(node_to_find, dummy_registered_pipelines)

assert f"Node with name='{node_to_find}' not found in any pipelines." in str(
excinfo.value
assert (
f"Node with name='{node_to_find}' not found in any pipelines. Remember to specify the node name, not the node function."
in str(excinfo.value)
)

def test_prepare_imports(self, mocker, dummy_function_file_lines):
Expand All @@ -493,8 +475,8 @@ def test_prepare_imports_func_not_found(self, mocker):
assert f"Could not find {dummy_function.__name__}" in str(excinfo.value)

def test_prepare_node_inputs(self, dummy_node):
# TODO Ahdra check - does this address parameters properly?
func_inputs = """# Prepare necessary inputs for debugging
# All debugging inputs must be defined in your project catalog
dummy_input = catalog.load("dummy_input")
my_input = catalog.load("extra_input")"""

Expand All @@ -516,16 +498,13 @@ def test_get_lambda_function_body(self, lambda_node):
result = _prepare_function_body(lambda_node.func)
assert result == "func=lambda x: x"

@pytest.mark.skip(reason="Not supported yet")
def test_get_nested_function_body(self):
func_strings = [
"def nested_function(input):",
"\nreturn not input",
"\nreturn nested_function(dummy_input)\n",
]
func_strings = """def nested_function(input):
return not input
return nested_function(dummy_input)"""

result = _prepare_function_body(dummy_nested_function)
assert result == "".join(func_strings)
# TODO fix - fails because skips nested function definition
assert result == func_strings

def test_get_function_with_loop_body(self):
func_strings = """for x in dummy_list:
Expand Down