diff --git a/schemes/__init__.py b/schemes/__init__.py
index 101b48ae..de4c2b6b 100644
--- a/schemes/__init__.py
+++ b/schemes/__init__.py
@@ -6,7 +6,7 @@
from flask import Flask, Response, request, url_for
from jinja2 import ChoiceLoader, FileSystemLoader, PackageLoader, PrefixLoader
-from schemes import auth, home, start
+from schemes import api, auth, home, start
from schemes.config import DevConfig
@@ -21,10 +21,13 @@ def create_app(test_config: Mapping[str, Any] | None = None) -> Flask:
_configure_basic_auth(app)
_configure_govuk_frontend(app)
_configure_oidc(app)
+ _configure_users(app)
app.register_blueprint(start.bp)
app.register_blueprint(auth.bp, url_prefix="/auth")
app.register_blueprint(home.bp, url_prefix="/home")
+ if app.testing:
+ app.register_blueprint(api.bp, url_prefix="/api")
return app
@@ -68,3 +71,12 @@ def _configure_oidc(app: Flask) -> None:
"token_endpoint_auth_method": PrivateKeyJWT(app.config["GOVUK_TOKEN_ENDPOINT"]),
},
)
+
+
+def _configure_users(app: Flask) -> None:
+ app.extensions["users"] = []
+
+ if not app.testing:
+ app.extensions["users"].extend(
+ ["alex.coleman@activetravelengland.gov.uk", "mark.hobson@activetravelengland.gov.uk"]
+ )
diff --git a/schemes/api.py b/schemes/api.py
new file mode 100644
index 00000000..368366c4
--- /dev/null
+++ b/schemes/api.py
@@ -0,0 +1,16 @@
+from flask import Blueprint, Response, current_app, request
+
+bp = Blueprint("api", __name__)
+
+
+@bp.route("/users", methods=["POST"])
+def add_user() -> Response:
+ email = request.get_json()["email"]
+ current_app.extensions["users"].append(email)
+ return Response(status=201)
+
+
+@bp.route("/users", methods=["DELETE"])
+def clear_users() -> Response:
+ current_app.extensions["users"].clear()
+ return Response(status=204)
diff --git a/schemes/auth.py b/schemes/auth.py
index 5ee048cc..43fa19c7 100644
--- a/schemes/auth.py
+++ b/schemes/auth.py
@@ -13,7 +13,12 @@
def callback() -> BaseResponse:
oauth = _get_oauth()
token = oauth.govuk.authorize_access_token()
- session["user"] = oauth.govuk.userinfo(token=token)
+ user = oauth.govuk.userinfo(token=token)
+
+ if user["email"] not in current_app.extensions["users"]:
+ return Response("
Unauthorized
", status=401)
+
+ session["user"] = user
session["id_token"] = token["id_token"]
return redirect(url_for("home.index"))
diff --git a/tests/e2e/app_client.py b/tests/e2e/app_client.py
new file mode 100644
index 00000000..66b2efb1
--- /dev/null
+++ b/tests/e2e/app_client.py
@@ -0,0 +1,17 @@
+import requests
+
+
+class AppClient:
+ DEFAULT_TIMEOUT = 10
+
+ def __init__(self, url: str):
+ self._url = url
+
+ def add_user(self, email: str) -> None:
+ user = {"email": email}
+ response = requests.post(f"{self._url}/api/users", json=user, timeout=self.DEFAULT_TIMEOUT)
+ assert response.status_code == 201
+
+ def clear_users(self) -> None:
+ response = requests.delete(f"{self._url}/api/users", timeout=self.DEFAULT_TIMEOUT)
+ assert response.status_code == 204
diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py
index 34422703..074e3f72 100644
--- a/tests/e2e/conftest.py
+++ b/tests/e2e/conftest.py
@@ -18,6 +18,7 @@
from pytest_flask.live_server import LiveServer
from schemes import create_app
+from tests.e2e.app_client import AppClient
from tests.e2e.oidc_server.app import OidcServerApp
from tests.e2e.oidc_server.app import create_app as oidc_server_create_app
from tests.e2e.oidc_server.clients import StubClient
@@ -70,6 +71,14 @@ def configure_live_server_fixture() -> None:
multiprocessing.set_start_method("fork")
+@pytest.fixture(name="app_client")
+def app_client_fixture(live_server: LiveServer) -> Generator[AppClient, Any, Any]:
+ url = f"http://{live_server.host}:{live_server.port}"
+ client = AppClient(url)
+ yield client
+ client.clear_users()
+
+
@pytest.fixture(name="oidc_server_app", scope="class")
def oidc_server_app_fixture() -> OidcServerApp:
os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "true"
diff --git a/tests/e2e/pages.py b/tests/e2e/pages.py
index 57746854..45b4d278 100644
--- a/tests/e2e/pages.py
+++ b/tests/e2e/pages.py
@@ -44,6 +44,10 @@ def open_when_unauthenticated(self) -> LoginPage:
self.open()
return LoginPage(self._app, self._page)
+ def open_when_unauthorized(self) -> UnauthorizedPage:
+ self.open()
+ return UnauthorizedPage(self._app, self._page)
+
def visible(self) -> bool:
return self._page.get_by_role("heading", name="Home").is_visible()
@@ -67,6 +71,15 @@ def visible(self) -> bool:
return self._page.get_by_role("heading", name="Login").is_visible()
+class UnauthorizedPage:
+ def __init__(self, app: Flask, page: Page):
+ self._app = app
+ self._page = page
+
+ def visible(self) -> bool:
+ return self._page.get_by_role("heading", name="Unauthorized").is_visible()
+
+
def _get_base_url(app: Flask) -> str:
scheme = app.config["PREFERRED_URL_SCHEME"]
server_name = app.config["SERVER_NAME"]
diff --git a/tests/e2e/test_home.py b/tests/e2e/test_home.py
index baf015cc..cf1a04d2 100644
--- a/tests/e2e/test_home.py
+++ b/tests/e2e/test_home.py
@@ -2,18 +2,27 @@
from flask import Flask
from playwright.sync_api import Page
+from tests.e2e.app_client import AppClient
from tests.e2e.pages import HomePage
@pytest.mark.usefixtures("live_server", "oidc_server", "oidc_user")
@pytest.mark.oidc_user(id="stub_user", email="user@domain.com")
class TestAuthenticated:
- def test_home_when_authenticated(self, app: Flask, page: Page) -> None:
+ def test_home_when_authorized(self, app_client: AppClient, app: Flask, page: Page) -> None:
+ app_client.add_user("user@domain.com")
+
home_page = HomePage(app, page).open()
assert home_page.visible()
- def test_header_sign_out(self, app: Flask, page: Page) -> None:
+ def test_home_when_unauthorized(self, app: Flask, page: Page) -> None:
+ unauthorized_page = HomePage(app, page).open_when_unauthorized()
+
+ assert unauthorized_page.visible()
+
+ def test_header_sign_out(self, app_client: AppClient, app: Flask, page: Page) -> None:
+ app_client.add_user("user@domain.com")
home_page = HomePage(app, page).open()
start_page = home_page.header.sign_out()
diff --git a/tests/e2e/test_start.py b/tests/e2e/test_start.py
index 79c4c601..bdb7c68e 100644
--- a/tests/e2e/test_start.py
+++ b/tests/e2e/test_start.py
@@ -2,6 +2,7 @@
from flask import Flask
from playwright.sync_api import Page
+from tests.e2e.app_client import AppClient
from tests.e2e.pages import StartPage
@@ -24,7 +25,8 @@ def test_start_shows_login(self, app: Flask, page: Page) -> None:
@pytest.mark.usefixtures("live_server", "oidc_server", "oidc_user")
@pytest.mark.oidc_user(id="stub_user", email="user@domain.com")
class TestAuthenticated:
- def test_start_shows_home(self, app: Flask, page: Page) -> None:
+ def test_start_shows_home(self, app_client: AppClient, app: Flask, page: Page) -> None:
+ app_client.add_user("user@domain.com")
start_page = StartPage(app, page).open()
start_page.start()
diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py
new file mode 100644
index 00000000..176bae85
--- /dev/null
+++ b/tests/integration/test_api.py
@@ -0,0 +1,32 @@
+from typing import Any, Mapping
+
+import pytest
+from flask import current_app
+from flask.testing import FlaskClient
+
+
+def test_add_user(client: FlaskClient) -> None:
+ response = client.post("/api/users", json={"email": "hello@test.com"})
+
+ assert response.status_code == 201
+ assert "hello@test.com" in current_app.extensions["users"]
+
+
+def test_clear_users(client: FlaskClient) -> None:
+ current_app.extensions["users"].append("hello@test.com")
+
+ response = client.delete("/api/users")
+
+ assert response.status_code == 204
+ assert not current_app.extensions["users"]
+
+
+class TestProduction:
+ @pytest.fixture(name="config")
+ def config_fixture(self, config: Mapping[str, Any]) -> Mapping[str, Any]:
+ return config | {"TESTING": False}
+
+ def test_cannot_add_user(self, client: FlaskClient) -> None:
+ response = client.post("/api/users", json={"email": "hello@test.com"})
+
+ assert response.status_code == 404
diff --git a/tests/integration/test_auth.py b/tests/integration/test_auth.py
index 70f1cbd4..2efbda57 100644
--- a/tests/integration/test_auth.py
+++ b/tests/integration/test_auth.py
@@ -14,24 +14,35 @@ def config_fixture(config: Mapping[str, Any]) -> Mapping[str, Any]:
def test_callback_logs_in(client: FlaskClient) -> None:
+ current_app.extensions["users"].append("user@domain.com")
_given_token_response({"id_token": "jwt"})
- _given_user_info(UserInfo({"sub": "123"}))
+ _given_user_info(UserInfo({"email": "user@domain.com"}))
with client:
client.get("/auth")
- assert session["user"] == UserInfo({"sub": "123"}) and session["id_token"] == "jwt"
+ assert session["user"] == UserInfo({"email": "user@domain.com"}) and session["id_token"] == "jwt"
def test_callback_redirects_to_home(client: FlaskClient) -> None:
+ current_app.extensions["users"].append("user@domain.com")
_given_token_response({"id_token": "jwt"})
- _given_user_info(UserInfo({"sub": "123"}))
+ _given_user_info(UserInfo({"email": "user@domain.com"}))
response = client.get("/auth")
assert response.status_code == 302 and response.location == "/home"
+def test_callback_when_unauthorized_shows_unauthorized(client: FlaskClient) -> None:
+ _given_token_response({"id_token": "jwt"})
+ _given_user_info(UserInfo({"email": "user@domain.com"}))
+
+ response = client.get("/auth")
+
+ assert response.status_code == 401 and response.text == "Unauthorized
"
+
+
def test_logout_logs_out_from_oidc(client: FlaskClient) -> None:
with client.session_transaction() as setup_session:
setup_session["user"] = "test"