Skip to content

Commit

Permalink
Rebased with master and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
devintang3 committed Dec 29, 2021
1 parent d28ce02 commit 4f46196
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 20 deletions.
41 changes: 22 additions & 19 deletions nbclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
DeadKernelError,
)
from .output_widget import OutputWidget
from .util import ensure_async, run_sync, run_hook
from .util import ensure_async, run_hook, run_sync


def timestamp() -> str:
Expand Down Expand Up @@ -248,40 +248,47 @@ class NotebookClient(LoggingConfigurable):
on_execution_start: t.Optional[t.Callable] = Any(
default_value=None,
allow_none=True,
help=dedent("""
help=dedent(
"""
Called after the kernel manager and kernel client are setup, and cells
are about to execute.
Called with kwargs `kernel_id`.
"""),
"""
),
).tag(config=True)

on_cell_start: t.Optional[t.Callable] = Any(
default_value=None,
allow_none=True,
help=dedent("""
help=dedent(
"""
A callable which executes before a cell is executed.
Called with kwargs `cell`, and `cell_index`.
"""),
"""
),
).tag(config=True)

on_cell_complete: t.Optional[t.Callable] = Any(
default_value=None,
allow_none=True,
help=dedent("""
help=dedent(
"""
A callable which executes after a cell execution is complete. It is
called even when a cell results in a failure.
Called with kwargs `cell`, and `cell_index`.
"""),
"""
),
).tag(config=True)

on_cell_error: t.Optional[t.Callable] = Any(
default_value=None,
allow_none=True,
help=dedent("""
help=dedent(
"""
A callable which executes when a cell execution results in an error.
This is executed even if errors are suppressed with `cell_allows_errors`.
Called with kwargs `cell`, and `cell_index`.
"""),
"""
),
).tag(config=True)

@default('kernel_manager_class')
Expand Down Expand Up @@ -465,7 +472,7 @@ async def async_start_new_kernel_client(self) -> KernelClient:
await self._async_cleanup_kernel()
raise
self.kc.allow_stdin = False
run_hook(sself.on_execution_start)
run_hook(self.on_execution_start)
return self.kc

start_new_kernel_client = run_sync(async_start_new_kernel_client)
Expand Down Expand Up @@ -770,10 +777,8 @@ def _passed_deadline(self, deadline: int) -> bool:
return False

def _check_raise_for_error(
self,
cell: NotebookNode,
cell_index: int,
exec_reply: t.Optional[t.Dict]) -> None:
self, cell: NotebookNode, cell_index: int, exec_reply: t.Optional[t.Dict]
) -> None:

if exec_reply is None:
return None
Expand All @@ -787,11 +792,9 @@ def _check_raise_for_error(
or exec_reply_content.get('ename') in self.allow_error_names
or "raises-exception" in cell.metadata.get("tags", [])
)

if (exec_reply is not None) and exec_reply['content']['status'] == 'error':
if not cell_allows_errors:
run_hook(self.on_cell_error, cell=cell, cell_index=cell_index)
if self.force_raise_errors or not cell_allows_errors:
raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content'])
raise CellExecutionError.from_cell_and_msg(cell, exec_reply_content)

async def async_execute_cell(
self,
Expand Down
87 changes: 87 additions & 0 deletions nbclient/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,45 @@ def test_widgets(self):
assert 'version_major' in wdata
assert 'version_minor' in wdata

def test_execution_hook(self):
filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb')
with open(filename) as f:
input_nb = nbformat.read(f, 4)
cell_hook = MagicMock()
execution_hook = MagicMock()

executor = NotebookClient(
input_nb,
resources=NBClientTestsBase().build_resources(),
on_cell_start=cell_hook,
on_cell_complete=cell_hook,
on_cell_error=cell_hook,
on_execution_start=execution_hook,
)
executor.execute()
execution_hook.assert_called_once()
assert cell_hook.call_count == 2

def test_error_execution_hook_error(self):
filename = os.path.join(current_dir, 'files', 'Error.ipynb')
with open(filename) as f:
input_nb = nbformat.read(f, 4)
cell_hook = MagicMock()
execution_hook = MagicMock()

executor = NotebookClient(
input_nb,
resources=NBClientTestsBase().build_resources(),
on_cell_start=cell_hook,
on_cell_complete=cell_hook,
on_cell_error=cell_hook,
on_execution_start=execution_hook,
)
with pytest.raises(CellExecutionError):
executor.execute()
execution_hook.assert_called_once()
assert cell_hook.call_count == 3


class TestRunCell(NBClientTestsBase):
"""Contains test functions for NotebookClient.execute_cell"""
Expand Down Expand Up @@ -1520,3 +1559,51 @@ def test_no_source(self, executor, cell_mock, message_mock):
assert message_mock.call_count == 0
# Should also consume the message stream
assert cell_mock.outputs == []

@prepare_cell_mocks()
async def test_cell_hooks(self, executor, cell_mock, message_mock):
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
tasks = [hook1, hook2, hook3, hook4]
executor.on_cell_start = hook1
executor.on_cell_complete = hook2
executor.on_cell_error = hook3
executor.on_execution_start = hook4
executor.execute_cell(cell_mock, 0)
await asyncio.gather(*tasks)
assert hook1.call_count == 1
assert hook2.call_count == 1
assert hook3.call_count == 0
assert hook4.call_count == 0
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)

@prepare_cell_mocks(
{
'msg_type': 'error',
'header': {'msg_type': 'error'},
'content': {'ename': 'foo', 'evalue': 'bar', 'traceback': ['Boom']},
},
reply_msg={
'msg_type': 'execute_reply',
'header': {'msg_type': 'execute_reply'},
# ERROR
'content': {'status': 'error'},
},
)
async def test_error_cell_hooks(self, executor, cell_mock, message_mock):
hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock()
tasks = [hook1, hook2, hook3, hook4]
executor.on_cell_start = hook1
executor.on_cell_complete = hook2
executor.on_cell_error = hook3
executor.on_execution_start = hook4
with self.assertRaises(CellExecutionError):
executor.execute_cell(cell_mock, 0)
await asyncio.gather(*tasks)
assert hook1.call_count == 1
assert hook2.call_count == 1
assert hook3.call_count == 1
assert hook4.call_count == 0
hook1.assert_called_once_with(cell=cell_mock, cell_index=0)
hook2.assert_called_once_with(cell=cell_mock, cell_index=0)
hook3.assert_called_once_with(cell=cell_mock, cell_index=0)
2 changes: 1 addition & 1 deletion nbclient/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import asyncio
import inspect
import sys
from typing import Any, Awaitable, Callable, Optional, Union
from functools import partial
from typing import Any, Awaitable, Callable, Optional, Union


def check_ipython() -> None:
Expand Down

0 comments on commit 4f46196

Please sign in to comment.