Skip to content

Commit

Permalink
Merge pull request #120 from stealthrocket/verification-key-str
Browse files Browse the repository at this point in the history
Improve UX around verification keys
  • Loading branch information
chriso authored Mar 10, 2024
2 parents 5139a05 + 76ce71e commit 49a1c2f
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 18 deletions.
54 changes: 37 additions & 17 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self,
app: fastapi.FastAPI,
endpoint: str | None = None,
verification_key: Ed25519PublicKey | None = None,
verification_key: Ed25519PublicKey | str | bytes | None = None,
api_key: str | None = None,
api_url: str | None = None,
):
Expand All @@ -70,7 +70,7 @@ def __init__(
verification_key: Key to use when verifying signed requests. Uses
the value of the DISPATCH_VERIFICATION_KEY environment variable
by default. The environment variable is expected to carry an
if omitted. The environment variable is expected to carry an
Ed25519 public key in base64 or PEM format.
If not set, request signature verification is disabled (a warning
will be logged by the constructor).
Expand Down Expand Up @@ -99,21 +99,6 @@ def __init__(
"missing application endpoint: set it with the DISPATCH_ENDPOINT_URL environment variable"
)

if not verification_key:
try:
verification_key_raw = os.environ["DISPATCH_VERIFICATION_KEY"]
except KeyError:
pass
else:
# Be forgiving when accepting keys in PEM format.
verification_key_raw = verification_key_raw.replace("\\n", "\n")
try:
verification_key = public_key_from_pem(verification_key_raw)
except ValueError:
verification_key = public_key_from_bytes(
base64.b64decode(verification_key_raw)
)

logger.info("configuring Dispatch endpoint %s", endpoint)

parsed_url = urlparse(endpoint)
Expand All @@ -122,6 +107,7 @@ def __init__(
f"{endpoint_from} must be a full URL with protocol and domain (e.g., https://example.com)"
)

verification_key = parse_verification_key(verification_key)
if verification_key:
base64_key = base64.b64encode(verification_key.public_bytes_raw()).decode()
logger.info("verifying request signatures using key %s", base64_key)
Expand All @@ -137,6 +123,40 @@ def __init__(
app.mount("/dispatch.sdk.v1.FunctionService", function_service)


def parse_verification_key(
verification_key: Ed25519PublicKey | str | bytes | None,
) -> Ed25519PublicKey | None:
if isinstance(verification_key, Ed25519PublicKey):
return verification_key

from_env = False
if not verification_key:
try:
verification_key = os.environ["DISPATCH_VERIFICATION_KEY"]
except KeyError:
return None
from_env = True

if isinstance(verification_key, bytes):
verification_key = verification_key.decode()

# Be forgiving when accepting keys in PEM format, which may span
# multiple lines. Users attempting to pass a PEM key via an environment
# variable may accidentally include literal "\n" bytes rather than a
# newline char (0xA).
try:
return public_key_from_pem(verification_key.replace("\\n", "\n"))
except ValueError:
pass

try:
return public_key_from_bytes(base64.b64decode(verification_key.encode()))
except ValueError:
if from_env:
raise ValueError(f"invalid DISPATCH_VERIFICATION_KEY '{verification_key}'")
raise ValueError(f"invalid verification key '{verification_key}'")


class _ConnectResponse(fastapi.Response):
media_type = "application/grpc+proto"

Expand Down
76 changes: 75 additions & 1 deletion tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import os
import pickle
import unittest
Expand All @@ -8,17 +9,25 @@
import google.protobuf.any_pb2
import google.protobuf.wrappers_pb2
import httpx
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
from fastapi.testclient import TestClient

from dispatch.experimental.durable.registry import clear_functions
from dispatch.fastapi import Dispatch
from dispatch.fastapi import Dispatch, parse_verification_key
from dispatch.function import Arguments, Error, Function, Input, Output
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import public_key_from_pem
from dispatch.status import Status
from dispatch.test import EndpointClient

public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----"
public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----"
public_key = public_key_from_pem(public_key_pem)
public_key_bytes = public_key.public_bytes_raw()
public_key_b64 = base64.b64encode(public_key_bytes)


def create_dispatch_instance(app, endpoint):
return Dispatch(
Expand Down Expand Up @@ -98,6 +107,71 @@ def my_function(input: Input) -> Output:

self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem})
def test_parse_verification_key_env_pem_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem2})
def test_parse_verification_key_env_pem_escaped_newline_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_b64.decode()})
def test_parse_verification_key_env_b64_str(self):
verification_key = parse_verification_key(None)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_none(self):
# The verification key is optional. Both Dispatch(verification_key=...) and
# DISPATCH_VERIFICATION_KEY may be omitted/None.
verification_key = parse_verification_key(None)
self.assertIsNone(verification_key)

def test_parse_verification_key_ed25519publickey(self):
verification_key = parse_verification_key(public_key)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_pem_str(self):
verification_key = parse_verification_key(public_key_pem)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_pem_escaped_newline_str(self):
verification_key = parse_verification_key(public_key_pem2)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_pem_bytes(self):
verification_key = parse_verification_key(public_key_pem.encode())
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_b64_str(self):
verification_key = parse_verification_key(public_key_b64.decode())
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_b64_bytes(self):
verification_key = parse_verification_key(public_key_b64)
self.assertIsInstance(verification_key, Ed25519PublicKey)
self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes)

def test_parse_verification_key_invalid(self):
with self.assertRaisesRegex(ValueError, "invalid verification key 'foo'"):
parse_verification_key("foo")

@mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": "foo"})
def test_parse_verification_key_invalid_env(self):
with self.assertRaisesRegex(
ValueError, "invalid DISPATCH_VERIFICATION_KEY 'foo'"
):
parse_verification_key(None)


def response_output(resp: function_pb.RunResponse) -> Any:
return any_unpickle(resp.exit.result.output)
Expand Down

0 comments on commit 49a1c2f

Please sign in to comment.