Skip to content

Commit

Permalink
Merge pull request #169 from dispatchrun/flask
Browse files Browse the repository at this point in the history
Flask integration
  • Loading branch information
chriso authored May 20, 2024
2 parents c699fc6 + 293c97a commit fd42f89
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 200 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ dev = [
"requests >= 2.31.0",
"types-requests >= 2.31.0.20240125",
"uvicorn >= 0.28.0",
"types-Flask >= 1.1.6",
"flask >= 3",
"awslambdaric-stubs"
]

Expand Down
139 changes: 18 additions & 121 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,14 @@ def read_root():

import asyncio
import logging
import os
from datetime import timedelta
from typing import Optional, Union
from urllib.parse import urlparse

import fastapi
import fastapi.responses
from http_message_signatures import InvalidSignature

from dispatch.function import Batch, Registry
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import (
CaseInsensitiveDict,
Ed25519PublicKey,
Request,
parse_verification_key,
verify_request,
)
from dispatch.status import Status

from dispatch.function import Registry
from dispatch.http import FunctionServiceError, function_service_run
from dispatch.signature import Ed25519PublicKey, parse_verification_key

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -92,25 +80,11 @@ def __init__(
app.mount("/dispatch.sdk.v1.FunctionService", function_service)


class _ConnectResponse(fastapi.Response):
media_type = "application/grpc+proto"


class _ConnectError(fastapi.HTTPException):
__slots__ = ("status", "code", "message")

def __init__(self, status, code, message):
super().__init__(status)
self.status = status
self.code = code
self.message = message


def _new_app(function_registry: Dispatch, verification_key: Optional[Ed25519PublicKey]):
def _new_app(function_registry: Registry, verification_key: Optional[Ed25519PublicKey]):
app = fastapi.FastAPI()

@app.exception_handler(_ConnectError)
async def on_error(request: fastapi.Request, exc: _ConnectError):
@app.exception_handler(FunctionServiceError)
async def on_error(request: fastapi.Request, exc: FunctionServiceError):
# https://connectrpc.com/docs/protocol/#error-end-stream
return fastapi.responses.JSONResponse(
status_code=exc.status, content={"code": exc.code, "message": exc.message}
Expand All @@ -121,103 +95,26 @@ async def on_error(request: fastapi.Request, exc: _ConnectError):
# gains more endpoints, this should be turned into a dynamic dispatch
# like the official gRPC server does.
"/Run",
response_class=_ConnectResponse,
)
async def execute(request: fastapi.Request):
# Raw request body bytes are only available through the underlying
# starlette Request object's body method, which returns an awaitable,
# forcing execute() to be async.
data: bytes = await request.body()
logger.debug("handling run request with %d byte body", len(data))

if verification_key is None:
logger.debug("skipping request signature verification")
else:
signed_request = Request(
method=request.method,
url=str(request.url),
headers=CaseInsensitiveDict(request.headers),
body=data,
)
max_age = timedelta(minutes=5)
try:
verify_request(signed_request, verification_key, max_age)
except ValueError as e:
raise _ConnectError(401, "unauthenticated", str(e))
except InvalidSignature as e:
# The http_message_signatures package sometimes wraps does not
# attach a message to the exception, so we set a default to
# have some context about the reason for the error.
message = str(e) or "invalid signature"
raise _ConnectError(403, "permission_denied", message)

req = function_pb.RunRequest.FromString(data)
if not req.function:
raise _ConnectError(400, "invalid_argument", "function is required")

try:
func = function_registry.functions[req.function]
except KeyError:
logger.debug("function '%s' not found", req.function)
raise _ConnectError(
404, "not_found", f"function '{req.function}' does not exist"
)

input = Input(req)
logger.info("running function '%s'", req.function)

loop = asyncio.get_running_loop()

try:
output = await loop.run_in_executor(None, func._primitive_call, input)
except Exception:
# This indicates that an exception was raised in a primitive
# function. Primitive functions must catch exceptions, categorize
# them in order to derive a Status, and then return a RunResponse
# that carries the Status and the error details. A failure to do
# so indicates a problem, and we return a 500 rather than attempt
# to catch and categorize the error here.
logger.error("function '%s' fatal error", req.function, exc_info=True)
raise _ConnectError(
500, "internal", f"function '{req.function}' fatal error"
)

response = output._message
status = Status(response.status)

if response.HasField("poll"):
logger.debug(
"function '%s' polling with %d call(s)",
req.function,
len(response.poll.calls),
)
elif response.HasField("exit"):
exit = response.exit
if not exit.HasField("result"):
logger.debug("function '%s' exiting with no result", req.function)
else:
result = exit.result
if result.HasField("output"):
logger.debug(
"function '%s' exiting with output value", req.function
)
elif result.HasField("error"):
err = result.error
logger.debug(
"function '%s' exiting with error: %s (%s)",
req.function,
err.message,
err.type,
)
if exit.HasField("tail_call"):
logger.debug(
"function '%s' tail calling function '%s'",
exit.tail_call.function,
)

logger.debug("finished handling run request with status %s", status.name)
return fastapi.Response(
content=response.SerializeToString(), media_type="application/proto"
content = await loop.run_in_executor(
None,
function_service_run,
str(request.url),
request.method,
request.headers,
data,
function_registry,
verification_key,
)

return fastapi.Response(content=content, media_type="application/proto")

return app
103 changes: 103 additions & 0 deletions src/dispatch/flask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Integration of Dispatch functions with Flask.
Example:
from flask import Flask
from dispatch.flask import Dispatch
app = Flask(__name__)
dispatch = Dispatch(app, api_key="test-key")
@dispatch.function
def my_function():
return "Hello World!"
@app.get("/")
def read_root():
my_function.dispatch()
"""

import logging
from typing import Optional, Union

from flask import Flask, make_response, request

from dispatch.function import Registry
from dispatch.http import FunctionServiceError, function_service_run
from dispatch.signature import Ed25519PublicKey, parse_verification_key

logger = logging.getLogger(__name__)


class Dispatch(Registry):
"""A Dispatch instance, powered by Flask."""

def __init__(
self,
app: Flask,
endpoint: Optional[str] = None,
verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
):
"""Initialize a Dispatch endpoint, and integrate it into a Flask app.
It mounts a sub-app that implements the Dispatch gRPC interface.
Args:
app: The Flask app to configure.
endpoint: Full URL of the application the Dispatch instance will
be running on. Uses the value of the DISPATCH_ENDPOINT_URL
environment variable by default.
verification_key: Key to use when verifying signed requests. Uses
the value of the DISPATCH_VERIFICATION_KEY environment variable
if omitted. The environment variable is expected to carry an
Ed25519 public key in base64 or PEM format.
If not set, request signature verification is disabled (a warning
will be logged by the constructor).
api_key: Dispatch API key to use for authentication. Uses the value of
the DISPATCH_API_KEY environment variable by default.
api_url: The URL of the Dispatch API to use. Uses the value of the
DISPATCH_API_URL environment variable if set, otherwise
defaults to the public Dispatch API (DEFAULT_API_URL).
Raises:
ValueError: If any of the required arguments are missing.
"""
if not app:
raise ValueError(
"missing Flask app as first argument of the Dispatch constructor"
)

super().__init__(endpoint, api_key=api_key, api_url=api_url)

self._verification_key = parse_verification_key(
verification_key, endpoint=endpoint
)

app.errorhandler(FunctionServiceError)(self._handle_error)

app.post("/dispatch.sdk.v1.FunctionService/Run")(self._execute)

def _handle_error(self, exc: FunctionServiceError):
return {"code": exc.code, "message": exc.message}, exc.status

def _execute(self):
data: bytes = request.get_data(cache=False)

content = function_service_run(
request.url,
request.method,
dict(request.headers),
data,
self,
self._verification_key,
)

res = make_response(content)
res.content_type = "application/proto"
return res
Loading

0 comments on commit fd42f89

Please sign in to comment.