diff --git a/pyproject.toml b/pyproject.toml index ea31036..f55def8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,8 @@ dev = [ "requests >= 2.31.0", "types-requests >= 2.31.0.20240125", "uvicorn >= 0.28.0", + "types-Flask >= 1.1.6", + "flask >= 3", "awslambdaric-stubs" ] diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 1e4a709..206e7bc 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -19,26 +19,14 @@ def read_root(): import asyncio import logging -import os -from datetime import timedelta from typing import Optional, Union -from urllib.parse import urlparse import fastapi import fastapi.responses -from http_message_signatures import InvalidSignature - -from dispatch.function import Batch, Registry -from dispatch.proto import Input -from dispatch.sdk.v1 import function_pb2 as function_pb -from dispatch.signature import ( - CaseInsensitiveDict, - Ed25519PublicKey, - Request, - parse_verification_key, - verify_request, -) -from dispatch.status import Status + +from dispatch.function import Registry +from dispatch.http import FunctionServiceError, function_service_run +from dispatch.signature import Ed25519PublicKey, parse_verification_key logger = logging.getLogger(__name__) @@ -92,25 +80,11 @@ def __init__( app.mount("/dispatch.sdk.v1.FunctionService", function_service) -class _ConnectResponse(fastapi.Response): - media_type = "application/grpc+proto" - - -class _ConnectError(fastapi.HTTPException): - __slots__ = ("status", "code", "message") - - def __init__(self, status, code, message): - super().__init__(status) - self.status = status - self.code = code - self.message = message - - -def _new_app(function_registry: Dispatch, verification_key: Optional[Ed25519PublicKey]): +def _new_app(function_registry: Registry, verification_key: Optional[Ed25519PublicKey]): app = fastapi.FastAPI() - @app.exception_handler(_ConnectError) - async def on_error(request: fastapi.Request, exc: _ConnectError): + @app.exception_handler(FunctionServiceError) + async def on_error(request: fastapi.Request, exc: FunctionServiceError): # https://connectrpc.com/docs/protocol/#error-end-stream return fastapi.responses.JSONResponse( status_code=exc.status, content={"code": exc.code, "message": exc.message} @@ -121,103 +95,26 @@ async def on_error(request: fastapi.Request, exc: _ConnectError): # gains more endpoints, this should be turned into a dynamic dispatch # like the official gRPC server does. "/Run", - response_class=_ConnectResponse, ) async def execute(request: fastapi.Request): # Raw request body bytes are only available through the underlying # starlette Request object's body method, which returns an awaitable, # forcing execute() to be async. data: bytes = await request.body() - logger.debug("handling run request with %d byte body", len(data)) - - if verification_key is None: - logger.debug("skipping request signature verification") - else: - signed_request = Request( - method=request.method, - url=str(request.url), - headers=CaseInsensitiveDict(request.headers), - body=data, - ) - max_age = timedelta(minutes=5) - try: - verify_request(signed_request, verification_key, max_age) - except ValueError as e: - raise _ConnectError(401, "unauthenticated", str(e)) - except InvalidSignature as e: - # The http_message_signatures package sometimes wraps does not - # attach a message to the exception, so we set a default to - # have some context about the reason for the error. - message = str(e) or "invalid signature" - raise _ConnectError(403, "permission_denied", message) - - req = function_pb.RunRequest.FromString(data) - if not req.function: - raise _ConnectError(400, "invalid_argument", "function is required") - - try: - func = function_registry.functions[req.function] - except KeyError: - logger.debug("function '%s' not found", req.function) - raise _ConnectError( - 404, "not_found", f"function '{req.function}' does not exist" - ) - - input = Input(req) - logger.info("running function '%s'", req.function) loop = asyncio.get_running_loop() - try: - output = await loop.run_in_executor(None, func._primitive_call, input) - except Exception: - # This indicates that an exception was raised in a primitive - # function. Primitive functions must catch exceptions, categorize - # them in order to derive a Status, and then return a RunResponse - # that carries the Status and the error details. A failure to do - # so indicates a problem, and we return a 500 rather than attempt - # to catch and categorize the error here. - logger.error("function '%s' fatal error", req.function, exc_info=True) - raise _ConnectError( - 500, "internal", f"function '{req.function}' fatal error" - ) - - response = output._message - status = Status(response.status) - - if response.HasField("poll"): - logger.debug( - "function '%s' polling with %d call(s)", - req.function, - len(response.poll.calls), - ) - elif response.HasField("exit"): - exit = response.exit - if not exit.HasField("result"): - logger.debug("function '%s' exiting with no result", req.function) - else: - result = exit.result - if result.HasField("output"): - logger.debug( - "function '%s' exiting with output value", req.function - ) - elif result.HasField("error"): - err = result.error - logger.debug( - "function '%s' exiting with error: %s (%s)", - req.function, - err.message, - err.type, - ) - if exit.HasField("tail_call"): - logger.debug( - "function '%s' tail calling function '%s'", - exit.tail_call.function, - ) - - logger.debug("finished handling run request with status %s", status.name) - return fastapi.Response( - content=response.SerializeToString(), media_type="application/proto" + content = await loop.run_in_executor( + None, + function_service_run, + str(request.url), + request.method, + request.headers, + data, + function_registry, + verification_key, ) + return fastapi.Response(content=content, media_type="application/proto") + return app diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py new file mode 100644 index 0000000..cbaded6 --- /dev/null +++ b/src/dispatch/flask.py @@ -0,0 +1,103 @@ +"""Integration of Dispatch functions with Flask. + +Example: + + from flask import Flask + from dispatch.flask import Dispatch + + app = Flask(__name__) + dispatch = Dispatch(app, api_key="test-key") + + @dispatch.function + def my_function(): + return "Hello World!" + + @app.get("/") + def read_root(): + my_function.dispatch() + """ + +import logging +from typing import Optional, Union + +from flask import Flask, make_response, request + +from dispatch.function import Registry +from dispatch.http import FunctionServiceError, function_service_run +from dispatch.signature import Ed25519PublicKey, parse_verification_key + +logger = logging.getLogger(__name__) + + +class Dispatch(Registry): + """A Dispatch instance, powered by Flask.""" + + def __init__( + self, + app: Flask, + endpoint: Optional[str] = None, + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + ): + """Initialize a Dispatch endpoint, and integrate it into a Flask app. + + It mounts a sub-app that implements the Dispatch gRPC interface. + + Args: + app: The Flask app to configure. + + endpoint: Full URL of the application the Dispatch instance will + be running on. Uses the value of the DISPATCH_ENDPOINT_URL + environment variable by default. + + verification_key: Key to use when verifying signed requests. Uses + the value of the DISPATCH_VERIFICATION_KEY environment variable + if omitted. The environment variable is expected to carry an + Ed25519 public key in base64 or PEM format. + If not set, request signature verification is disabled (a warning + will be logged by the constructor). + + api_key: Dispatch API key to use for authentication. Uses the value of + the DISPATCH_API_KEY environment variable by default. + + api_url: The URL of the Dispatch API to use. Uses the value of the + DISPATCH_API_URL environment variable if set, otherwise + defaults to the public Dispatch API (DEFAULT_API_URL). + + Raises: + ValueError: If any of the required arguments are missing. + """ + if not app: + raise ValueError( + "missing Flask app as first argument of the Dispatch constructor" + ) + + super().__init__(endpoint, api_key=api_key, api_url=api_url) + + self._verification_key = parse_verification_key( + verification_key, endpoint=endpoint + ) + + app.errorhandler(FunctionServiceError)(self._handle_error) + + app.post("/dispatch.sdk.v1.FunctionService/Run")(self._execute) + + def _handle_error(self, exc: FunctionServiceError): + return {"code": exc.code, "message": exc.message}, exc.status + + def _execute(self): + data: bytes = request.get_data(cache=False) + + content = function_service_run( + request.url, + request.method, + dict(request.headers), + data, + self, + self._verification_key, + ) + + res = make_response(content) + res.content_type = "application/proto" + return res diff --git a/src/dispatch/http.py b/src/dispatch/http.py index 6b7e832..d772cb1 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -4,7 +4,7 @@ import os from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Optional, Union +from typing import Mapping, Optional, Union from http_message_signatures import InvalidSignature @@ -51,6 +51,15 @@ def __call__(self, request, client_address, server): ) +class FunctionServiceError(Exception): + __slots__ = ("status", "code", "message") + + def __init__(self, status, code, message): + self.status = status + self.code = code + self.message = message + + class FunctionService(BaseHTTPRequestHandler): def __init__( @@ -106,91 +115,118 @@ def do_POST(self): return data: bytes = self.rfile.read(content_length) - logger.debug("handling run request with %d byte body", len(data)) - - if self.verification_key is not None: - signed_request = Request( - method="POST", - url=self.requestline, # TODO: need full URL - headers=CaseInsensitiveDict(self.headers), - body=data, - ) - max_age = timedelta(minutes=5) - try: - verify_request(signed_request, self.verification_key, max_age) - except ValueError as e: - self.send_error_response_unauthenticated(str(e)) - return - except InvalidSignature as e: - # The http_message_signatures package sometimes wraps does not - # attach a message to the exception, so we set a default to - # have some context about the reason for the error. - message = str(e) or "invalid signature" - self.send_error_response_permission_denied(message) - return - - req = function_pb.RunRequest.FromString(data) - if not req.function: - self.send_error_response_invalid_argument("function is required") - return + + method = "POST" + url = self.requestline # TODO: need full URL try: - func = self.registry.functions[req.function] - except KeyError: - logger.debug("function '%s' not found", req.function) - self.send_error_response_not_found( - f"function '{req.function}' does not exist" + content = function_service_run( + url, + method, + dict(self.headers), + data, + self.registry, + self.verification_key, ) - return + except FunctionServiceError as e: + return self.send_error_response(e.status, e.code, e.message) + self.send_response(200) + self.send_header("Content-Type", "application/proto") + self.end_headers() + self.wfile.write(content) + + +def function_service_run( + url: str, + method: str, + headers: Mapping[str, str], + data: bytes, + function_registry: Registry, + verification_key: Optional[Ed25519PublicKey], +) -> bytes: + logger.debug("handling run request with %d byte body", len(data)) + + if verification_key is None: + logger.debug("skipping request signature verification") + else: + signed_request = Request( + method=method, + url=url, + headers=CaseInsensitiveDict(headers), + body=data, + ) + max_age = timedelta(minutes=5) try: - output = func._primitive_call(Input(req)) - except Exception: - # This indicates that an exception was raised in a primitive - # function. Primitive functions must catch exceptions, categorize - # them in order to derive a Status, and then return a RunResponse - # that carries the Status and the error details. A failure to do - # so indicates a problem, and we return a 500 rather than attempt - # to catch and categorize the error here. - logger.error("function '%s' fatal error", req.function, exc_info=True) - self.send_error_response_internal(f"function '{req.function}' fatal error") - return + verify_request(signed_request, verification_key, max_age) + except ValueError as e: + raise FunctionServiceError(401, "unauthenticated", str(e)) + except InvalidSignature as e: + # The http_message_signatures package sometimes wraps does not + # attach a message to the exception, so we set a default to + # have some context about the reason for the error. + message = str(e) or "invalid signature" + raise FunctionServiceError(403, "permission_denied", message) + + req = function_pb.RunRequest.FromString(data) + if not req.function: + raise FunctionServiceError(400, "invalid_argument", "function is required") + + try: + func = function_registry.functions[req.function] + except KeyError: + logger.debug("function '%s' not found", req.function) + raise FunctionServiceError( + 404, "not_found", f"function '{req.function}' does not exist" + ) - response = output._message - status = Status(response.status) + input = Input(req) + logger.info("running function '%s'", req.function) + + try: + output = func._primitive_call(input) + except Exception: + # This indicates that an exception was raised in a primitive + # function. Primitive functions must catch exceptions, categorize + # them in order to derive a Status, and then return a RunResponse + # that carries the Status and the error details. A failure to do + # so indicates a problem, and we return a 500 rather than attempt + # to catch and categorize the error here. + logger.error("function '%s' fatal error", req.function, exc_info=True) + raise FunctionServiceError( + 500, "internal", f"function '{req.function}' fatal error" + ) - if response.HasField("poll"): - logger.debug( - "function '%s' polling with %d call(s)", - req.function, - len(response.poll.calls), - ) - elif response.HasField("exit"): - exit = response.exit - if not exit.HasField("result"): - logger.debug("function '%s' exiting with no result", req.function) - else: - result = exit.result - if result.HasField("output"): - logger.debug( - "function '%s' exiting with output value", req.function - ) - elif result.HasField("error"): - err = result.error - logger.debug( - "function '%s' exiting with error: %s (%s)", - req.function, - err.message, - err.type, - ) - if exit.HasField("tail_call"): + response = output._message + status = Status(response.status) + + if response.HasField("poll"): + logger.debug( + "function '%s' polling with %d call(s)", + req.function, + len(response.poll.calls), + ) + elif response.HasField("exit"): + exit = response.exit + if not exit.HasField("result"): + logger.debug("function '%s' exiting with no result", req.function) + else: + result = exit.result + if result.HasField("output"): + logger.debug("function '%s' exiting with output value", req.function) + elif result.HasField("error"): + err = result.error logger.debug( - "function '%s' tail calling function '%s'", - exit.tail_call.function, + "function '%s' exiting with error: %s (%s)", + req.function, + err.message, + err.type, ) + if exit.HasField("tail_call"): + logger.debug( + "function '%s' tail calling function '%s'", + exit.tail_call.function, + ) - logger.debug("finished handling run request with status %s", status.name) - self.send_response(200) - self.send_header("Content-Type", "application/proto") - self.end_headers() - self.wfile.write(response.SerializeToString()) + logger.debug("finished handling run request with status %s", status.name) + return response.SerializeToString()