Skip to content

Commit

Permalink
fetching id_token from refresh_token for culling (jupyter-server#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akshay Chitneni authored and GitHub Enterprise committed Jun 7, 2022
1 parent a462918 commit 41129dc
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 3 deletions.
9 changes: 9 additions & 0 deletions data_studio_jupyter_extensions/auth/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
from jupyter_server.utils import url_path_join

from data_studio_jupyter_extensions.base_handler import DSExtensionHandlerMixin
from data_studio_jupyter_extensions.configurables.refresh_token import (
RefreshTokenConfigurable,
)


class JWTLoginHandler(DSExtensionHandlerMixin, LoginHandler):
"""The basic tornado login handler
authenticates with JWT.
"""

refresh_token_provider = RefreshTokenConfigurable.instance()

def _render(self, message=None):
if self.datastudio_url:
login_url = url_path_join(
Expand All @@ -38,7 +43,11 @@ def set_additional_login_cookie(cls, handler, name, value):

def post(self):
data_studio_jwt = self.get_argument("jwt", default="")
data_studio_refresh_token = self.get_argument("refresh_token", default=None)

if self.authenticator.is_authenticated(data_studio_jwt):
if data_studio_refresh_token:
self.refresh_token_provider.refresh_token = data_studio_refresh_token
self.set_login_cookie(self, uuid.uuid4().hex)
self.set_additional_login_cookie(
self, self.datastudio_secure_cookie_name.lower(), data_studio_jwt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
from traitlets import default
from traitlets import Instance
from traitlets import Unicode
from traitlets.config import SingletonConfigurable

from data_studio_jupyter_extensions import constants
from data_studio_jupyter_extensions.configurables.hubble import hubble
from data_studio_jupyter_extensions.configurables.refresh_token import (
RefreshTokenConfigurable,
)
from data_studio_jupyter_extensions.traits import IntFromEnv
from data_studio_jupyter_extensions.traits import UnicodeFromEnv
from data_studio_jupyter_extensions.utils import get_ssl_cert


class NotebookServiceClient(SingletonConfigurable):
class NotebookServiceClient(RefreshTokenConfigurable):
"""A client to interact with the Notebook Service API defined here:
https://github.pie.apple.com/pie-data-studio/notebook-service/blob/master/public/openapi.yaml
Expand Down Expand Up @@ -136,6 +138,8 @@ async def fetch(self, *parts, method="GET", data=None):
self.request_token = ""
elif not self.request_token:
self.request_token = await self.fetch_token()
elif not self.is_token_valid(self.request_token):
self.request_token = await self.fetch_id_token_from_refresh_token()

url = ujoin(self.base_url, *parts)
self.log.debug(f"Making {method.upper()} request against {url}")
Expand Down
129 changes: 129 additions & 0 deletions data_studio_jupyter_extensions/configurables/refresh_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import base64
import os.path as osp

from Crypto.Cipher import AES
from ias_jwt_tools import payload_reader
from ias_jwt_tools import token_issuer
from jose import JWTError
from jupyter_core.paths import jupyter_runtime_dir
from traitlets import default
from traitlets import Instance

from data_studio_jupyter_extensions.configurables.session_key import (
SessionKeyConfigurable,
)
from data_studio_jupyter_extensions.traits import UnicodeFromEnv

DEFAULT_REFRESH_TOKEN_PATH = osp.join(jupyter_runtime_dir(), "refresh_token")


class RefreshTokenConfigurable(SessionKeyConfigurable):
"""A configurable that extends SessionKeyConfigurable and
uses AES-CFB to encrypt and decrypt refresh tokens.
"""

ref_token_ias_client_id = UnicodeFromEnv(
name="REFRESH_TOKEN_IAS_CLIENT_ID", allow_none=True
).tag(config=True)
ref_token_ias_client_secret = UnicodeFromEnv(
name="REFRESH_TOKEN_IAS_CLIENT_SECRET", allow_none=True
).tag(config=True)
ref_token_aud = UnicodeFromEnv(name="REFRESH_TOKEN_AUD", allow_none=True).tag(
config=True
)
ref_token_may_act_sub = UnicodeFromEnv(
name="REFRESH_TOKEN_MAY_ACT_SUB", allow_none=True
).tag(config=True)
ref_token_path = UnicodeFromEnv(
name="REFRESH_TOKEN_PATH",
allow_none=True,
default_value=DEFAULT_REFRESH_TOKEN_PATH,
).tag(config=True)

token_issuer = Instance(token_issuer.TokenIssuer).tag(config=True)

@default("token_issuer")
def _default_token_issuer(self): # pragma: no cover
return token_issuer.TokenIssuer(
ias_client_id=self.ref_token_ias_client_id,
ias_client_secret=self.ref_token_ias_client_secret,
grant_type="urn:ietf:params:oauth:grant-type:token-exchange",
aud=self.ref_token_aud,
may_act_sub=self.ref_token_may_act_sub,
)

@property
def refresh_token(self):
try:
with open(self.ref_token_path, "r") as f:
encrypt_ref_token = f.read()
return self.decrypt_token(encrypt_ref_token.rstrip())
except Exception as e:
self.log(
f"Failed to retrieve refresh_token from path: {self.ref_token_path}, error: {e}"
)
return None

@refresh_token.setter
def refresh_token(self, ref_token):
try:
encrypt_ref_token = self.encrypt_token(ref_token)
with open(self.ref_token_path, "w") as f:
f.write(encrypt_ref_token)
except Exception as e:
self.log(f"Failed to save refresh_token, error: {e}")

def is_token_valid(self, token) -> bool:
try:
payload_reader(token) # verifies jwt signature and expiration
return True
except JWTError:
# Add expiration specific error in ias-jwt-tools
return False

def encrypt_token(self, token: str) -> str:
if self.shared_key_enabled() and self.valid_keys():
aes = AES.new(
self.shared_encrypt_key.encode("utf-8"),
AES.MODE_CFB,
self.shared_seed.encode("utf-8"),
)
return base64.b64encode(aes.encrypt(token.encode("utf-8"))).decode("utf-8")
else:
raise RuntimeError("refresh_token not supported")

def decrypt_token(self, encrypted_token: str) -> str:
if self.shared_key_enabled() and self.valid_keys():
aes = AES.new(
self.shared_encrypt_key.encode("utf-8"),
AES.MODE_CFB,
self.shared_seed.encode("utf-8"),
)
return aes.decrypt(
base64.b64decode(encrypted_token.encode("utf-8"))
).decode("utf-8")
else:
raise RuntimeError("refresh_token not supported")

async def fetch_id_token_from_refresh_token(self) -> str:
refresh_token = self.refresh_token
if refresh_token is None:
raise RuntimeError("Failed to fetch refresh_token")
self.token_issuer.subject_token_type = (
"urn:ietf:params:oauth:token-type:refresh_token"
)
self.token_issuer.requested_token_type = (
"urn:ietf:params:oauth:token-type:id_token"
)
self.token_issuer.subject_token = refresh_token
return self.token_issuer.token


if __name__ == "__main__": # pragma: no cover
ref_token_provider = RefreshTokenConfigurable()
ref_token_provider.shared_encrypt_key = "UOAdBJVt3lN0Eabc"
ref_token_provider.shared_seed = "2vsJua9A5kRNsabc"
token = "test_refresh_token"
encrypt_token = ref_token_provider.encrypt_token(token)
decrypt_token = ref_token_provider.decrypt_token(encrypt_token)
assert decrypt_token == token
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,19 @@ def notebook_service_client(
async def fake_fetch_token(self):
return "faketoken"

def is_token_valid(self, token):
return False

async def fetch_id_token_from_refresh_token(self):
return "newfaketoken"

monkeypatch.setattr(NotebookServiceClient, "fetch_token", fake_fetch_token)
monkeypatch.setattr(NotebookServiceClient, "is_token_valid", is_token_valid)
monkeypatch.setattr(
NotebookServiceClient,
"fetch_id_token_from_refresh_token",
fetch_id_token_from_refresh_token,
)
return NotebookServiceClient(ssl_cert_file=None)


Expand Down Expand Up @@ -87,3 +99,11 @@ async def test_externel_links_for_kernel(notebook_service_client):

for item in response:
assert "label" in item


async def test_ref_token(notebook_service_client):
await notebook_service_client.stop_kernel("process1")
assert notebook_service_client.request_token == "faketoken"

await notebook_service_client.stop_kernel("process2")
assert notebook_service_client.request_token == "newfaketoken"
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

from data_studio_jupyter_extensions.configurables.refresh_token import (
RefreshTokenConfigurable,
)


@pytest.fixture
def refresh_token_encryption_provider():
ref_token_provider = RefreshTokenConfigurable()
ref_token_provider.shared_encrypt_key = "UOAdBJVt3lN0Eabc"
ref_token_provider.shared_seed = "2vsJua9A5kRNsabc"
return ref_token_provider


def test_refresh_token_encrypt_decrypt(refresh_token_encryption_provider):
token = "test_refresh_token"
encrypt_token = refresh_token_encryption_provider.encrypt_token(token)
decrypt_token = refresh_token_encryption_provider.decrypt_token(encrypt_token)
assert decrypt_token == token


def test_refresh_token_encrypt_decrypt_multiple_invocations(
refresh_token_encryption_provider,
):
token = "test_refresh_token"
encrypt_token1 = refresh_token_encryption_provider.encrypt_token(token)
decrypt_token1 = refresh_token_encryption_provider.decrypt_token(encrypt_token1)
encrypt_token2 = refresh_token_encryption_provider.encrypt_token(token)
decrypt_token2 = refresh_token_encryption_provider.decrypt_token(encrypt_token1)
assert encrypt_token1 == encrypt_token2
assert decrypt_token1 == token
assert decrypt_token2 == token


def test_encrypt_error(refresh_token_encryption_provider):
token = "test_refresh_token"
refresh_token_encryption_provider.shared_encrypt_key = None

with pytest.raises(RuntimeError) as encryptError:
refresh_token_encryption_provider.encrypt_token(token)

assert "refresh_token not supported" in str(encryptError)


def test_decrypt_error(refresh_token_encryption_provider):
token = "test_refresh_token"
refresh_token_encryption_provider.shared_encrypt_key = None

with pytest.raises(RuntimeError) as decryptError:
refresh_token_encryption_provider.decrypt_token(token)

assert "refresh_token not supported" in str(decryptError)
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ install_requires =
jupyter_server==1.16.0
jupyter_telemetry
dnspython~=1.15.0
tenacity~=5.0.4
tenacity~=4.12.0
pyoneer
jupyterlab~=3.4
nest_asyncio
Expand All @@ -32,6 +32,7 @@ install_requires =
validators~=0.18.2
pycryptodome~=3.14.1
jupyter_server_synchronizer==0.0.8
ias-jwt-tools==1.5.0

[options.entry_points]
jupyter_client.kernel_provisioners =
Expand Down

0 comments on commit 41129dc

Please sign in to comment.