Skip to content

Commit

Permalink
Rebased with master and added tests
Browse files Browse the repository at this point in the history
Run_hook is now async and renamed util to test_util so it gets picked up by pytest.
  • Loading branch information
devintang3 committed Jan 13, 2022
1 parent 0016be4 commit 207cac2
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 57 deletions.
9 changes: 9 additions & 0 deletions docs/client.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ on both versions. Here the traitlet ``kernel_name`` helps simplify and
maintain consistency: we can just run a notebook twice, specifying first
"python2" and then "python3" as the kernel name.

In addition to the two above, we also support traitlets for hooks. They are as
follows: ``on_execution_start``, ``on_cell_start``, ``on_cell_complete``,
``on_cell_error``. These traitlets allow specifying a ``Callable`` function,
which will run at certain points during the notebook execution and is executed asynchronously.
``on_execution_start`` will run when the notebook client is kicked off.
``on_cell_start`` will run right before each cell is executed.
``on_cell_complete`` will run right after the cell is executed.
``on_cell_error`` will run if there is an error in the cell.

Handling errors and exceptions
------------------------------

Expand Down
92 changes: 53 additions & 39 deletions nbclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,18 @@
from jupyter_client.client import KernelClient
from nbformat import NotebookNode
from nbformat.v4 import output_from_msg
from traitlets import Any, Bool, Dict, Enum, Integer, List, Type, Unicode, default
from traitlets import (
Any,
Bool,
Callable,
Dict,
Enum,
Integer,
List,
Type,
Unicode,
default,
)
from traitlets.config.configurable import LoggingConfigurable

from .exceptions import (
Expand All @@ -25,7 +36,7 @@
DeadKernelError,
)
from .output_widget import OutputWidget
from .util import ensure_async, run_sync, run_hook
from .util import ensure_async, run_hook, run_sync


def timestamp() -> str:
Expand Down Expand Up @@ -245,43 +256,50 @@ class NotebookClient(LoggingConfigurable):

kernel_manager_class: KernelManager = Type(config=True, help='The kernel manager class to use.')

on_execution_start: t.Optional[t.Callable] = Any(
on_execution_start: t.Optional[t.Callable] = Callable(
default_value=None,
allow_none=True,
help=dedent("""
Called after the kernel manager and kernel client are setup, and cells
are about to execute.
Called with kwargs `kernel_id`.
"""),
help=dedent(
"""
Called after the kernel manager and kernel client are setup, and cells
are about to execute.
"""
),
).tag(config=True)

on_cell_start: t.Optional[t.Callable] = Any(
on_cell_start: t.Optional[t.Callable] = Callable(
default_value=None,
allow_none=True,
help=dedent("""
A callable which executes before a cell is executed.
Called with kwargs `cell`, and `cell_index`.
"""),
help=dedent(
"""
A callable which executes before a cell is executed.
Called with kwargs `cell` and `cell_index`.
"""
),
).tag(config=True)

on_cell_complete: t.Optional[t.Callable] = Any(
on_cell_complete: t.Optional[t.Callable] = Callable(
default_value=None,
allow_none=True,
help=dedent("""
A callable which executes after a cell execution is complete. It is
called even when a cell results in a failure.
Called with kwargs `cell`, and `cell_index`.
"""),
help=dedent(
"""
A callable which executes after a cell execution is complete. It is
called even when a cell results in a failure.
Called with kwargs `cell` and `cell_index`.
"""
),
).tag(config=True)

on_cell_error: t.Optional[t.Callable] = Any(
on_cell_error: t.Optional[t.Callable] = Callable(
default_value=None,
allow_none=True,
help=dedent("""
A callable which executes when a cell execution results in an error.
This is executed even if errors are suppressed with `cell_allows_errors`.
Called with kwargs `cell`, and `cell_index`.
"""),
help=dedent(
"""
A callable which executes when a cell execution results in an error.
This is executed even if errors are suppressed with `cell_allows_errors`.
Called with kwargs `cell` and `cell_index`.
"""
),
).tag(config=True)

@default('kernel_manager_class')
Expand Down Expand Up @@ -465,7 +483,7 @@ async def async_start_new_kernel_client(self) -> KernelClient:
await self._async_cleanup_kernel()
raise
self.kc.allow_stdin = False
run_hook(sself.on_execution_start)
await run_hook(self.on_execution_start)
return self.kc

start_new_kernel_client = run_sync(async_start_new_kernel_client)
Expand Down Expand Up @@ -769,11 +787,9 @@ def _passed_deadline(self, deadline: int) -> bool:
return True
return False

def _check_raise_for_error(
self,
cell: NotebookNode,
cell_index: int,
exec_reply: t.Optional[t.Dict]) -> None:
async def _check_raise_for_error(
self, cell: NotebookNode, cell_index: int, exec_reply: t.Optional[t.Dict]
) -> None:

if exec_reply is None:
return None
Expand All @@ -787,11 +803,9 @@ def _check_raise_for_error(
or exec_reply_content.get('ename') in self.allow_error_names
or "raises-exception" in cell.metadata.get("tags", [])
)

if (exec_reply is not None) and exec_reply['content']['status'] == 'error':
run_hook(self.on_cell_error, cell=cell, cell_index=cell_index)
if self.force_raise_errors or not cell_allows_errors:
raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
await run_hook(self.on_cell_error, cell=cell, cell_index=cell_index)
if not cell_allows_errors:
raise CellExecutionError.from_cell_and_msg(cell, exec_reply_content)

async def async_execute_cell(
self,
Expand Down Expand Up @@ -851,13 +865,13 @@ async def async_execute_cell(
self.allow_errors or "raises-exception" in cell.metadata.get("tags", [])
)

run_hook(self.on_cell_start, cell=cell, cell_index=cell_index)
await run_hook(self.on_cell_start, cell=cell, cell_index=cell_index)
parent_msg_id = await ensure_async(
self.kc.execute(
cell.source, store_history=store_history, stop_on_error=not cell_allows_errors
)
)
run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index)
await run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index)
# We launched a code cell to execute
self.code_cells_executed += 1
exec_timeout = self._get_timeout(cell)
Expand Down Expand Up @@ -891,7 +905,7 @@ async def async_execute_cell(

if execution_count:
cell['execution_count'] = execution_count
self._check_raise_for_error(cell, cell_index, exec_reply)
await self._check_raise_for_error(cell, cell_index, exec_reply)
self.nb['cells'][cell_index] = cell
return cell

Expand Down
163 changes: 155 additions & 8 deletions nbclient/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from base64 import b64decode, b64encode
from queue import Empty
from unittest.mock import AsyncMock as AMock
from unittest.mock import MagicMock, Mock

import nbformat
Expand Down Expand Up @@ -345,11 +346,7 @@ def test_async_parallel_notebooks(capfd, tmpdir):
res = notebook_resources()

with modified_env({"NBEXECUTE_TEST_PARALLEL_TMPDIR": str(tmpdir)}):
tasks = [
async_run_notebook(input_file.format(label=label), opts, res) for label in ("A", "B")
]
loop = asyncio.get_event_loop()
loop.run_until_complete(asyncio.gather(*tasks))
[async_run_notebook(input_file.format(label=label), opts, res) for label in ("A", "B")]

captured = capfd.readouterr()
assert filter_messages_on_error_output(captured.err) == ""
Expand All @@ -370,9 +367,7 @@ def test_many_async_parallel_notebooks(capfd):
# run once, to trigger creating the original context
run_notebook(input_file, opts, res)

tasks = [async_run_notebook(input_file, opts, res) for i in range(4)]
loop = asyncio.get_event_loop()
loop.run_until_complete(asyncio.gather(*tasks))
[async_run_notebook(input_file, opts, res) for i in range(4)]

captured = capfd.readouterr()
assert filter_messages_on_error_output(captured.err) == ""
Expand Down Expand Up @@ -741,6 +736,80 @@ def test_widgets(self):
assert 'version_major' in wdata
assert 'version_minor' in wdata

def test_execution_hook(self):
filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb')
with open(filename) as f:
input_nb = nbformat.read(f, 4)
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
executor = NotebookClient(
input_nb,
on_cell_start=hook1,
on_cell_complete=hook2,
on_cell_error=hook3,
on_execution_start=hook4,
)
executor.execute()
hook1.assert_called_once()
hook2.assert_called_once()
hook3.assert_not_called()
hook4.assert_called_once()

def test_error_execution_hook_error(self):
filename = os.path.join(current_dir, 'files', 'Error.ipynb')
with open(filename) as f:
input_nb = nbformat.read(f, 4)
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
executor = NotebookClient(
input_nb,
on_cell_start=hook1,
on_cell_complete=hook2,
on_cell_error=hook3,
on_execution_start=hook4,
)
with pytest.raises(CellExecutionError):
executor.execute()
hook1.assert_called_once()
hook2.assert_called_once()
hook3.assert_called_once()
hook4.assert_called_once()

def test_async_execution_hook(self):
filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb')
with open(filename) as f:
input_nb = nbformat.read(f, 4)
hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock()
executor = NotebookClient(
input_nb,
on_cell_start=hook1,
on_cell_complete=hook2,
on_cell_error=hook3,
on_execution_start=hook4,
)
executor.execute()
hook1.assert_called_once()
hook2.assert_called_once()
hook3.assert_not_called()
hook4.assert_called_once()

def test_error_async_execution_hook(self):
filename = os.path.join(current_dir, 'files', 'Error.ipynb')
with open(filename) as f:
input_nb = nbformat.read(f, 4)
hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock()
executor = NotebookClient(
input_nb,
on_cell_start=hook1,
on_cell_complete=hook2,
on_cell_error=hook3,
on_execution_start=hook4,
)
with pytest.raises(CellExecutionError):
executor.execute().execute()
hook1.assert_called_once()
hook2.assert_called_once()
hook3.assert_called_once()
hook4.assert_called_once()


class TestRunCell(NBClientTestsBase):
"""Contains test functions for NotebookClient.execute_cell"""
Expand Down Expand Up @@ -1524,3 +1593,81 @@ def test_no_source(self, executor, cell_mock, message_mock):
assert message_mock.call_count == 0
# Should also consume the message stream
assert cell_mock.outputs == []

@prepare_cell_mocks()
def test_cell_hooks(self, executor, cell_mock, message_mock):
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
executor.on_cell_start = hook1
executor.on_cell_complete = hook2
executor.on_cell_error = hook3
executor.on_execution_start = hook4
executor.execute_cell(cell_mock, 0)
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
hook3.assert_not_called()
hook4.assert_not_called()

@prepare_cell_mocks(
{
'msg_type': 'error',
'header': {'msg_type': 'error'},
'content': {'ename': 'foo', 'evalue': 'bar', 'traceback': ['Boom']},
},
reply_msg={
'msg_type': 'execute_reply',
'header': {'msg_type': 'execute_reply'},
# ERROR
'content': {'status': 'error'},
},
)
def test_error_cell_hooks(self, executor, cell_mock, message_mock):
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
executor.on_cell_start = hook1
executor.on_cell_complete = hook2
executor.on_cell_error = hook3
executor.on_execution_start = hook4
with self.assertRaises(CellExecutionError):
executor.execute_cell(cell_mock, 0)
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
hook3.assert_called_once_with(cell=cell_mock, cell_index=0)
hook4.assert_not_called()

@prepare_cell_mocks()
def test_async_cell_hooks(self, executor, cell_mock, message_mock):
hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock()
executor.on_cell_start = hook1
executor.on_cell_complete = hook2
executor.on_cell_error = hook3
executor.on_execution_start = hook4
executor.execute_cell(cell_mock, 0)
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
hook3.assert_not_called()
hook4.assert_not_called()

@prepare_cell_mocks(
{
'msg_type': 'error',
'header': {'msg_type': 'error'},
'content': {'ename': 'foo', 'evalue': 'bar', 'traceback': ['Boom']},
},
reply_msg={
'msg_type': 'execute_reply',
'header': {'msg_type': 'execute_reply'},
# ERROR
'content': {'status': 'error'},
},
)
def test_error_async_cell_hooks(self, executor, cell_mock, message_mock):
hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock()
executor.on_cell_start = hook1
executor.on_cell_complete = hook2
executor.on_cell_error = hook3
executor.on_execution_start = hook4
with self.assertRaises(CellExecutionError):
executor.execute_cell(cell_mock, 0)
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
hook3.assert_called_once_with(cell=cell_mock, cell_index=0)
hook4.assert_not_called()
Loading

0 comments on commit 207cac2

Please sign in to comment.