diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 17eab9fb..2e29260b 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -53,7 +53,7 @@ def __init__( self, app: fastapi.FastAPI, endpoint: str | None = None, - verification_key: Ed25519PublicKey | None = None, + verification_key: Ed25519PublicKey | str | bytes | None = None, api_key: str | None = None, api_url: str | None = None, ): @@ -70,7 +70,7 @@ def __init__( verification_key: Key to use when verifying signed requests. Uses the value of the DISPATCH_VERIFICATION_KEY environment variable - by default. The environment variable is expected to carry an + if omitted. The environment variable is expected to carry an Ed25519 public key in base64 or PEM format. If not set, request signature verification is disabled (a warning will be logged by the constructor). @@ -99,21 +99,6 @@ def __init__( "missing application endpoint: set it with the DISPATCH_ENDPOINT_URL environment variable" ) - if not verification_key: - try: - verification_key_raw = os.environ["DISPATCH_VERIFICATION_KEY"] - except KeyError: - pass - else: - # Be forgiving when accepting keys in PEM format. - verification_key_raw = verification_key_raw.replace("\\n", "\n") - try: - verification_key = public_key_from_pem(verification_key_raw) - except ValueError: - verification_key = public_key_from_bytes( - base64.b64decode(verification_key_raw) - ) - logger.info("configuring Dispatch endpoint %s", endpoint) parsed_url = urlparse(endpoint) @@ -122,6 +107,7 @@ def __init__( f"{endpoint_from} must be a full URL with protocol and domain (e.g., https://example.com)" ) + verification_key = parse_verification_key(verification_key) if verification_key: base64_key = base64.b64encode(verification_key.public_bytes_raw()).decode() logger.info("verifying request signatures using key %s", base64_key) @@ -137,6 +123,40 @@ def __init__( app.mount("/dispatch.sdk.v1.FunctionService", function_service) +def parse_verification_key( + verification_key: Ed25519PublicKey | str | bytes | None, +) -> Ed25519PublicKey | None: + if isinstance(verification_key, Ed25519PublicKey): + return verification_key + + from_env = False + if not verification_key: + try: + verification_key = os.environ["DISPATCH_VERIFICATION_KEY"] + except KeyError: + return None + from_env = True + + if isinstance(verification_key, bytes): + verification_key = verification_key.decode() + + # Be forgiving when accepting keys in PEM format, which may span + # multiple lines. Users attempting to pass a PEM key via an environment + # variable may accidentally include literal "\n" bytes rather than a + # newline char (0xA). + try: + return public_key_from_pem(verification_key.replace("\\n", "\n")) + except ValueError: + pass + + try: + return public_key_from_bytes(base64.b64decode(verification_key.encode())) + except ValueError: + if from_env: + raise ValueError(f"invalid DISPATCH_VERIFICATION_KEY '{verification_key}'") + raise ValueError(f"invalid verification key '{verification_key}'") + + class _ConnectResponse(fastapi.Response): media_type = "application/grpc+proto" diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 21467033..33ee7160 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -1,3 +1,4 @@ +import base64 import os import pickle import unittest @@ -8,17 +9,25 @@ import google.protobuf.any_pb2 import google.protobuf.wrappers_pb2 import httpx +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey from fastapi.testclient import TestClient from dispatch.experimental.durable.registry import clear_functions -from dispatch.fastapi import Dispatch +from dispatch.fastapi import Dispatch, parse_verification_key from dispatch.function import Arguments, Error, Function, Input, Output from dispatch.proto import _any_unpickle as any_unpickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb +from dispatch.signature import public_key_from_pem from dispatch.status import Status from dispatch.test import EndpointClient +public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----" +public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----" +public_key = public_key_from_pem(public_key_pem) +public_key_bytes = public_key.public_bytes_raw() +public_key_b64 = base64.b64encode(public_key_bytes) + def create_dispatch_instance(app, endpoint): return Dispatch( @@ -98,6 +107,71 @@ def my_function(input: Input) -> Output: self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") + @mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem}) + def test_parse_verification_key_env_pem_str(self): + verification_key = parse_verification_key(None) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) + + @mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_pem2}) + def test_parse_verification_key_env_pem_escaped_newline_str(self): + verification_key = parse_verification_key(None) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) + + @mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": public_key_b64.decode()}) + def test_parse_verification_key_env_b64_str(self): + verification_key = parse_verification_key(None) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) + + def test_parse_verification_key_none(self): + # The verification key is optional. Both Dispatch(verification_key=...) and + # DISPATCH_VERIFICATION_KEY may be omitted/None. + verification_key = parse_verification_key(None) + self.assertIsNone(verification_key) + + def test_parse_verification_key_ed25519publickey(self): + verification_key = parse_verification_key(public_key) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) + + def test_parse_verification_key_pem_str(self): + verification_key = parse_verification_key(public_key_pem) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) + + def test_parse_verification_key_pem_escaped_newline_str(self): + verification_key = parse_verification_key(public_key_pem2) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) + + def test_parse_verification_key_pem_bytes(self): + verification_key = parse_verification_key(public_key_pem.encode()) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) + + def test_parse_verification_key_b64_str(self): + verification_key = parse_verification_key(public_key_b64.decode()) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) + + def test_parse_verification_key_b64_bytes(self): + verification_key = parse_verification_key(public_key_b64) + self.assertIsInstance(verification_key, Ed25519PublicKey) + self.assertEqual(verification_key.public_bytes_raw(), public_key_bytes) + + def test_parse_verification_key_invalid(self): + with self.assertRaisesRegex(ValueError, "invalid verification key 'foo'"): + parse_verification_key("foo") + + @mock.patch.dict(os.environ, {"DISPATCH_VERIFICATION_KEY": "foo"}) + def test_parse_verification_key_invalid_env(self): + with self.assertRaisesRegex( + ValueError, "invalid DISPATCH_VERIFICATION_KEY 'foo'" + ): + parse_verification_key(None) + def response_output(resp: function_pb.RunResponse) -> Any: return any_unpickle(resp.exit.result.output)