Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jupyter server api runtime #99

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions rplugin/python3/magma/jupyter_server_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import json
import uuid
import re
import time
from queue import Empty as EmptyQueueException
from typing import Any, Dict
from threading import Thread
from queue import Queue
from urllib.parse import urlparse

import requests
import websocket

from magma.runtime_state import RuntimeState


class JupyterAPIClient:
def __init__(self,
url: str,
kernel_info: Dict[str, Any],
headers: Dict[str, str]):
self._base_url = url
self._kernel_info = kernel_info
self._headers = headers

self._recv_queue: Queue[Dict[str, Any]] = Queue()

def wait_for_ready(self, **kwargs):
while True:
response = requests.get(self._kernel_api_base,
headers=self._headers)
response = json.loads(response.text)
if response["execution_state"] in ("idle", "starting"):
return
time.sleep(0.1)


def start_channels(self) -> None:
parsed_url = urlparse(self._base_url)
self._socket = websocket.create_connection(f"ws://{parsed_url.hostname}:{parsed_url.port}"
f"/api/kernels/{self._kernel_info['id']}/channels",
header=self._headers,
)
self._kernel_api_base = f"{self._base_url}/api/kernels/{self._kernel_info['id']}"

self._iopub_recv_thread = Thread(target=self._recv_message)
self._iopub_recv_thread.start()

def _recv_message(self) -> None:
while True:
response = json.loads(self._socket.recv())
self._recv_queue.put(response)

def get_iopub_msg(self, **kwargs):
if self._recv_queue.empty():
raise EmptyQueueException

response = self._recv_queue.get()

return response

def execute(self, code: str):
header = {
'msg_type': 'execute_request',
'msg_id': uuid.uuid1().hex,
'session': uuid.uuid1().hex
}

message = json.dumps({
'header': header,
'parent_header': header,
'metadata': {},
'content': {
'code': code,
'silent': False
}
})
self._socket.send(message)

def shutdown(self):
requests.delete(self._kernel_api_base,
headers=self._headers)
self._socket.close()


class JupyterAPIManager:
def __init__(self,
url: str,
):
parsed_url = urlparse(url)
self._base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"

token_part = re.search(r"token=(.*)", parsed_url.query)

if token_part:
token = token_part.groups()[0]
self._headers = {'Authorization': 'token ' + token}
else:
# Run notebook with --NotebookApp.disable_check_xsrf="True".
self._headers = {}

def start_kernel(self) -> None:
url = f"{self._base_url}/api/kernels"
response = requests.post(url,
headers=self._headers)
self._kernel_info = json.loads(response.text)
assert "id" in self._kernel_info, "Could not connect to Jupyter Server API. The URL specified may be incorrect."
self._kernel_api_base = f"{url}/{self._kernel_info['id']}"

def client(self) -> JupyterAPIClient:
return JupyterAPIClient(url=self._base_url,
kernel_info=self._kernel_info,
headers=self._headers)

def interrupt_kernel(self) -> None:
requests.post(f"{self._kernel_api_base}/interrupt",
headers=self._headers)

def restart_kernel(self) -> None:
self.state = RuntimeState.STARTING
requests.post(f"{self._kernel_api_base}/restart",
headers=self._headers)
32 changes: 19 additions & 13 deletions rplugin/python3/magma/runtime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional, Tuple, List, Dict, Generator, IO, Any
from enum import Enum
from typing import Optional, Tuple, List, Dict, Generator, IO, Any, Union
from contextlib import contextmanager
from queue import Empty as EmptyQueueException
import os
Expand All @@ -8,6 +7,7 @@

import jupyter_client

from magma.runtime_state import RuntimeState
from magma.options import MagmaOptions
from magma.outputchunks import (
Output,
Expand All @@ -18,20 +18,15 @@
to_outputchunk,
clean_up_text
)


class RuntimeState(Enum):
STARTING = 0
IDLE = 1
RUNNING = 2
from magma.jupyter_server_api import JupyterAPIClient, JupyterAPIManager


class JupyterRuntime:
state: RuntimeState
kernel_name: str

kernel_manager: jupyter_client.KernelManager
kernel_client: jupyter_client.KernelClient
kernel_manager: Union[jupyter_client.KernelManager, JupyterAPIManager]
kernel_client: Union[jupyter_client.KernelClient, JupyterAPIClient]

allocated_files: List[str]

Expand All @@ -41,7 +36,18 @@ def __init__(self, kernel_name: str, options: MagmaOptions):
self.state = RuntimeState.STARTING
self.kernel_name = kernel_name

if ".json" not in self.kernel_name:
if kernel_name.startswith("http://") or kernel_name.startswith("https://"):
self.external_kernel = True
self.kernel_manager = JupyterAPIManager(kernel_name)
self.kernel_manager.start_kernel()
self.kernel_client = self.kernel_manager.client()
self.kernel_client.start_channels()

self.allocated_files = []

self.options = options

elif ".json" not in self.kernel_name:

self.external_kernel = True
self.kernel_manager = jupyter_client.manager.KernelManager(
Expand Down Expand Up @@ -202,8 +208,8 @@ def tick(self, output: Optional[Output]) -> bool:
assert isinstance(
self.kernel_client,
jupyter_client.blocking.client.BlockingKernelClient,
)

) or isinstance(
self.kernel_client, JupyterAPIClient)
if not self.is_ready():
try:
self.kernel_client.wait_for_ready(timeout=0)
Expand Down
7 changes: 7 additions & 0 deletions rplugin/python3/magma/runtime_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum


class RuntimeState(Enum):
STARTING = 0
IDLE = 1
RUNNING = 2