Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flask integration tests #171

Merged
merged 4 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Python package to develop applications with the Dispatch platform.
- [Running Dispatch Applications](#running-dispatch-applications)
- [Writing Transactional Applications with Dispatch](#writing-transactional-applications-with-dispatch)
- [Integration with FastAPI](#integration-with-fastapi)
- [Integration with Flask](#integration-with-flask)
- [Configuration](#configuration)
- [Serialization](#serialization)
- [Examples](#examples)
Expand Down Expand Up @@ -198,6 +199,22 @@ In this example, GET requests on the HTTP server dispatch calls to the
`publish` function. The function runs concurrently to the rest of the
program, driven by the Dispatch SDK.

### Integration with Flask

Dispatch can also be integrated with web applications built on [Flask][flask].

The API is nearly identical to FastAPI above, instead use:

```python
from flask import Flask
from dispatch.flask import Dispatch

app = Flask(__name__)
dispatch = Dispatch(app)
```

[flask]: https://flask.palletsprojects.com/en/3.0.x/

### Configuration

The Dispatch CLI automatically configures the SDK, so manual configuration is
Expand Down
46 changes: 46 additions & 0 deletions src/dispatch/test/flask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Mapping

import werkzeug.test
from flask import Flask

from dispatch.test.http import HttpClient, HttpResponse


def http_client(app: Flask) -> HttpClient:
"""Build a client for a Flask app."""
return Client(app.test_client())


class Client(HttpClient):
def __init__(self, client: werkzeug.test.Client):
self.client = client

def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse:
response = self.client.get(url, headers=headers.items())
return Response(response)

def post(
self, url: str, body: bytes, headers: Mapping[str, str] = {}
) -> HttpResponse:
response = self.client.post(url, data=body, headers=headers.items())
return Response(response)

def url_for(self, path: str) -> str:
return "http://localhost" + path


class Response(HttpResponse):
def __init__(self, response):
self.response = response

@property
def status_code(self):
return self.response.status_code

@property
def body(self):
return self.response.data

def raise_for_status(self):
if self.response.status_code // 100 != 2:
raise RuntimeError(f"HTTP status code {self.response.status_code}")
143 changes: 143 additions & 0 deletions tests/test_flask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import base64
import os
import pickle
import struct
import unittest
from typing import Any, Optional
from unittest import mock

import google.protobuf.any_pb2
import google.protobuf.wrappers_pb2
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
Ed25519PrivateKey,
Ed25519PublicKey,
)
from flask import Flask

import dispatch
from dispatch.experimental.durable.registry import clear_functions
from dispatch.flask import Dispatch
from dispatch.function import Arguments, Error, Function, Input, Output
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import (
parse_verification_key,
private_key_from_pem,
public_key_from_pem,
)
from dispatch.status import Status
from dispatch.test import DispatchServer, DispatchService, EndpointClient
from dispatch.test.flask import http_client


def create_dispatch_instance(app: Flask, endpoint: str):
return Dispatch(
app,
endpoint=endpoint,
api_key="0000000000000000",
api_url="http://127.0.0.1:10000",
)


def create_endpoint_client(app: Flask, signing_key: Optional[Ed25519PrivateKey] = None):
return EndpointClient(http_client(app), signing_key)


class TestFlask(unittest.TestCase):
def test_flask(self):
app = Flask(__name__)
dispatch = create_dispatch_instance(app, endpoint="http://127.0.0.1:9999/")

@dispatch.primitive_function
def my_function(input: Input) -> Output:
return Output.value(
f"You told me: '{input.input}' ({len(input.input)} characters)"
)

client = create_endpoint_client(app)
pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))

req = function_pb.RunRequest(
function=my_function.name,
input=input_any,
)

resp = client.run(req)

self.assertIsInstance(resp, function_pb.RunResponse)

resp.exit.result.output.Unpack(
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
)
output = pickle.loads(output_bytes.value)

self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")


signing_key = private_key_from_pem(
"""
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIJ+DYvh6SEqVTm50DFtMDoQikTmiCqirVv9mWG9qfSnF
-----END PRIVATE KEY-----
"""
)

verification_key = public_key_from_pem(
"""
-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=
-----END PUBLIC KEY-----
"""
)


class TestFlaskE2E(unittest.TestCase):
def setUp(self):
self.endpoint_app = Flask(__name__)
endpoint_client = create_endpoint_client(self.endpoint_app, signing_key)

api_key = "0000000000000000"
self.dispatch_service = DispatchService(
endpoint_client, api_key, collect_roundtrips=True
)
self.dispatch_server = DispatchServer(self.dispatch_service)
self.dispatch_client = dispatch.Client(
api_key, api_url=self.dispatch_server.url
)

self.dispatch = Dispatch(
self.endpoint_app,
endpoint="http://function-service", # unused
verification_key=verification_key,
api_key=api_key,
api_url=self.dispatch_server.url,
)

self.dispatch_server.start()

def tearDown(self):
self.dispatch_server.stop()

def test_simple_end_to_end(self):
# The Flask server.
@self.dispatch.function
def my_function(name: str) -> str:
return f"Hello world: {name}"

call = my_function.build_call(52)
self.assertEqual(call.function.split(".")[-1], "my_function")

# The client.
[dispatch_id] = self.dispatch_client.dispatch([my_function.build_call(52)])

# Simulate execution for testing purposes.
self.dispatch_service.dispatch_calls()

# Validate results.
roundtrips = self.dispatch_service.roundtrips[dispatch_id]
self.assertEqual(len(roundtrips), 1)
_, response = roundtrips[0]
self.assertEqual(any_unpickle(response.exit.result.output), "Hello world: 52")