Skip to content

Commit

Permalink
Add remote kernel support for papermill operator (#34840)
Browse files Browse the repository at this point in the history
Co-authored-by: Akshay Chitneni <[email protected]>
  • Loading branch information
akshaychitneni and Akshay Chitneni authored Nov 13, 2023
1 parent 4f5e482 commit 18dac61
Show file tree
Hide file tree
Showing 14 changed files with 600 additions and 3 deletions.
17 changes: 17 additions & 0 deletions airflow/providers/papermill/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
171 changes: 171 additions & 0 deletions airflow/providers/papermill/hooks/kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from jupyter_client import AsyncKernelManager
from papermill.clientwrap import PapermillNotebookClient
from papermill.engines import NBClientEngine
from papermill.utils import merge_kwargs, remove_args
from traitlets import Unicode

if TYPE_CHECKING:
from pydantic import typing

from airflow.hooks.base import BaseHook

JUPYTER_KERNEL_SHELL_PORT = 60316
JUPYTER_KERNEL_IOPUB_PORT = 60317
JUPYTER_KERNEL_STDIN_PORT = 60318
JUPYTER_KERNEL_CONTROL_PORT = 60319
JUPYTER_KERNEL_HB_PORT = 60320
REMOTE_KERNEL_ENGINE = "remote_kernel_engine"


class KernelConnection:
"""Class to represent kernel connection object."""

ip: str
shell_port: int
iopub_port: int
stdin_port: int
control_port: int
hb_port: int
session_key: str


class KernelHook(BaseHook):
"""
The KernelHook can be used to interact with remote jupyter kernel.
Takes kernel host/ip from connection and refers to jupyter kernel ports and session_key
from ``extra`` field.
:param kernel_conn_id: connection that has kernel host/ip
"""

conn_name_attr = "kernel_conn_id"
default_conn_name = "jupyter_kernel_default"
conn_type = "jupyter_kernel"
hook_name = "Jupyter Kernel"

def __init__(self, kernel_conn_id: str = default_conn_name, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.kernel_conn = self.get_connection(kernel_conn_id)
register_remote_kernel_engine()

def get_conn(self) -> KernelConnection:
kernel_connection = KernelConnection()
kernel_connection.ip = self.kernel_conn.host
kernel_connection.shell_port = self.kernel_conn.extra_dejson.get(
"shell_port", JUPYTER_KERNEL_SHELL_PORT
)
kernel_connection.iopub_port = self.kernel_conn.extra_dejson.get(
"iopub_port", JUPYTER_KERNEL_IOPUB_PORT
)
kernel_connection.stdin_port = self.kernel_conn.extra_dejson.get(
"stdin_port", JUPYTER_KERNEL_STDIN_PORT
)
kernel_connection.control_port = self.kernel_conn.extra_dejson.get(
"control_port", JUPYTER_KERNEL_CONTROL_PORT
)
kernel_connection.hb_port = self.kernel_conn.extra_dejson.get("hb_port", JUPYTER_KERNEL_HB_PORT)
kernel_connection.session_key = self.kernel_conn.extra_dejson.get("session_key", "")
return kernel_connection


def register_remote_kernel_engine():
"""Registers ``RemoteKernelEngine`` papermill engine."""
from papermill.engines import papermill_engines

papermill_engines.register(REMOTE_KERNEL_ENGINE, RemoteKernelEngine)


class RemoteKernelManager(AsyncKernelManager):
"""Jupyter kernel manager that connects to a remote kernel."""

session_key = Unicode("", config=True, help="Session key to connect to remote kernel")

@property
def has_kernel(self) -> bool:
return True

async def _async_is_alive(self) -> bool:
return True

def shutdown_kernel(self, now: bool = False, restart: bool = False):
pass

def client(self, **kwargs: typing.Any):
"""Create a client configured to connect to our kernel."""
kernel_client = super().client(**kwargs)
# load connection info to set session_key
config: dict[str, int | str | bytes] = dict(
ip=self.ip,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
control_port=self.control_port,
hb_port=self.hb_port,
key=self.session_key,
transport="tcp",
signature_scheme="hmac-sha256",
)
kernel_client.load_connection_info(config)
return kernel_client


class RemoteKernelEngine(NBClientEngine):
"""Papermill engine to use ``RemoteKernelManager`` to connect to remote kernel and execute notebook."""

@classmethod
def execute_managed_notebook(
cls,
nb_man,
kernel_name,
log_output=False,
stdout_file=None,
stderr_file=None,
start_timeout=60,
execution_timeout=None,
**kwargs,
):
"""Performs the actual execution of the parameterized notebook locally."""
km = RemoteKernelManager()
km.ip = kwargs["kernel_ip"]
km.shell_port = kwargs["kernel_shell_port"]
km.iopub_port = kwargs["kernel_iopub_port"]
km.stdin_port = kwargs["kernel_stdin_port"]
km.control_port = kwargs["kernel_control_port"]
km.hb_port = kwargs["kernel_hb_port"]
km.ip = kwargs["kernel_ip"]
km.session_key = kwargs["kernel_session_key"]

# Exclude parameters that named differently downstream
safe_kwargs = remove_args(["timeout", "startup_timeout"], **kwargs)

final_kwargs = merge_kwargs(
safe_kwargs,
timeout=execution_timeout if execution_timeout else kwargs.get("timeout"),
startup_timeout=start_timeout,
log_output=False,
stdout_file=stdout_file,
stderr_file=stderr_file,
)

return PapermillNotebookClient(nb_man, km=km, **final_kwargs).execute()
40 changes: 39 additions & 1 deletion airflow/providers/papermill/operators/papermill.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, ClassVar, Collection, Sequence

import attr
import papermill as pm

from airflow.lineage.entities import File
from airflow.models import BaseOperator
from airflow.providers.papermill.hooks.kernel import REMOTE_KERNEL_ENGINE, KernelHook

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -54,7 +56,14 @@ class PapermillOperator(BaseOperator):

supports_lineage = True

template_fields: Sequence[str] = ("input_nb", "output_nb", "parameters", "kernel_name", "language_name")
template_fields: Sequence[str] = (
"input_nb",
"output_nb",
"parameters",
"kernel_name",
"language_name",
"kernel_conn_id",
)

def __init__(
self,
Expand All @@ -64,6 +73,7 @@ def __init__(
parameters: dict | None = None,
kernel_name: str | None = None,
language_name: str | None = None,
kernel_conn_id: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -85,11 +95,29 @@ def __init__(

self.kernel_name = kernel_name
self.language_name = language_name
self.kernel_conn_id = kernel_conn_id

self.inlets.append(self.input_nb)
self.outlets.append(self.output_nb)

def execute(self, context: Context):
remote_kernel_kwargs = {}
kernel_hook = self.hook
if kernel_hook:
engine_name = REMOTE_KERNEL_ENGINE
kernel_connection = kernel_hook.get_conn()
remote_kernel_kwargs = {
"kernel_ip": kernel_connection.ip,
"kernel_shell_port": kernel_connection.shell_port,
"kernel_iopub_port": kernel_connection.iopub_port,
"kernel_stdin_port": kernel_connection.stdin_port,
"kernel_control_port": kernel_connection.control_port,
"kernel_hb_port": kernel_connection.hb_port,
"kernel_session_key": kernel_connection.session_key,
}
else:
engine_name = None

pm.execute_notebook(
self.input_nb.url,
self.output_nb.url,
Expand All @@ -98,4 +126,14 @@ def execute(self, context: Context):
report_mode=True,
kernel_name=self.kernel_name,
language=self.language_name,
engine_name=engine_name,
**remote_kernel_kwargs,
)

@cached_property
def hook(self) -> KernelHook | None:
"""Get valid hook."""
if self.kernel_conn_id:
return KernelHook(kernel_conn_id=self.kernel_conn_id)
else:
return None
10 changes: 10 additions & 0 deletions airflow/providers/papermill/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies:
- apache-airflow>=2.5.0
- papermill[all]>=1.2.1
- scrapbook[all]
- ipykernel

integrations:
- integration-name: Papermill
Expand All @@ -57,3 +58,12 @@ operators:
- integration-name: Papermill
python-modules:
- airflow.providers.papermill.operators.papermill

hooks:
- integration-name: Papermill
python-modules:
- airflow.providers.papermill.hooks.kernel

connection-types:
- hook-class-name: airflow.providers.papermill.hooks.kernel.KernelHook
connection-type: jupyter_kernel
28 changes: 28 additions & 0 deletions docs/apache-airflow-providers-papermill/connections/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
Jupyter Kernel Connections
==========================


.. toctree::
:maxdepth: 1
:glob:

*
Loading

0 comments on commit 18dac61

Please sign in to comment.