diff --git a/schemes/__init__.py b/schemes/__init__.py index 101b48ae..de4c2b6b 100644 --- a/schemes/__init__.py +++ b/schemes/__init__.py @@ -6,7 +6,7 @@ from flask import Flask, Response, request, url_for from jinja2 import ChoiceLoader, FileSystemLoader, PackageLoader, PrefixLoader -from schemes import auth, home, start +from schemes import api, auth, home, start from schemes.config import DevConfig @@ -21,10 +21,13 @@ def create_app(test_config: Mapping[str, Any] | None = None) -> Flask: _configure_basic_auth(app) _configure_govuk_frontend(app) _configure_oidc(app) + _configure_users(app) app.register_blueprint(start.bp) app.register_blueprint(auth.bp, url_prefix="/auth") app.register_blueprint(home.bp, url_prefix="/home") + if app.testing: + app.register_blueprint(api.bp, url_prefix="/api") return app @@ -68,3 +71,12 @@ def _configure_oidc(app: Flask) -> None: "token_endpoint_auth_method": PrivateKeyJWT(app.config["GOVUK_TOKEN_ENDPOINT"]), }, ) + + +def _configure_users(app: Flask) -> None: + app.extensions["users"] = [] + + if not app.testing: + app.extensions["users"].extend( + ["alex.coleman@activetravelengland.gov.uk", "mark.hobson@activetravelengland.gov.uk"] + ) diff --git a/schemes/api.py b/schemes/api.py new file mode 100644 index 00000000..368366c4 --- /dev/null +++ b/schemes/api.py @@ -0,0 +1,16 @@ +from flask import Blueprint, Response, current_app, request + +bp = Blueprint("api", __name__) + + +@bp.route("/users", methods=["POST"]) +def add_user() -> Response: + email = request.get_json()["email"] + current_app.extensions["users"].append(email) + return Response(status=201) + + +@bp.route("/users", methods=["DELETE"]) +def clear_users() -> Response: + current_app.extensions["users"].clear() + return Response(status=204) diff --git a/schemes/auth.py b/schemes/auth.py index 5ee048cc..43fa19c7 100644 --- a/schemes/auth.py +++ b/schemes/auth.py @@ -13,7 +13,12 @@ def callback() -> BaseResponse: oauth = _get_oauth() token = oauth.govuk.authorize_access_token() - session["user"] = oauth.govuk.userinfo(token=token) + user = oauth.govuk.userinfo(token=token) + + if user["email"] not in current_app.extensions["users"]: + return Response("

Unauthorized

", status=401) + + session["user"] = user session["id_token"] = token["id_token"] return redirect(url_for("home.index")) diff --git a/tests/e2e/app_client.py b/tests/e2e/app_client.py new file mode 100644 index 00000000..66b2efb1 --- /dev/null +++ b/tests/e2e/app_client.py @@ -0,0 +1,17 @@ +import requests + + +class AppClient: + DEFAULT_TIMEOUT = 10 + + def __init__(self, url: str): + self._url = url + + def add_user(self, email: str) -> None: + user = {"email": email} + response = requests.post(f"{self._url}/api/users", json=user, timeout=self.DEFAULT_TIMEOUT) + assert response.status_code == 201 + + def clear_users(self) -> None: + response = requests.delete(f"{self._url}/api/users", timeout=self.DEFAULT_TIMEOUT) + assert response.status_code == 204 diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 34422703..074e3f72 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -18,6 +18,7 @@ from pytest_flask.live_server import LiveServer from schemes import create_app +from tests.e2e.app_client import AppClient from tests.e2e.oidc_server.app import OidcServerApp from tests.e2e.oidc_server.app import create_app as oidc_server_create_app from tests.e2e.oidc_server.clients import StubClient @@ -70,6 +71,14 @@ def configure_live_server_fixture() -> None: multiprocessing.set_start_method("fork") +@pytest.fixture(name="app_client") +def app_client_fixture(live_server: LiveServer) -> Generator[AppClient, Any, Any]: + url = f"http://{live_server.host}:{live_server.port}" + client = AppClient(url) + yield client + client.clear_users() + + @pytest.fixture(name="oidc_server_app", scope="class") def oidc_server_app_fixture() -> OidcServerApp: os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "true" diff --git a/tests/e2e/pages.py b/tests/e2e/pages.py index 57746854..45b4d278 100644 --- a/tests/e2e/pages.py +++ b/tests/e2e/pages.py @@ -44,6 +44,10 @@ def open_when_unauthenticated(self) -> LoginPage: self.open() return LoginPage(self._app, self._page) + def open_when_unauthorized(self) -> UnauthorizedPage: + self.open() + return UnauthorizedPage(self._app, self._page) + def visible(self) -> bool: return self._page.get_by_role("heading", name="Home").is_visible() @@ -67,6 +71,15 @@ def visible(self) -> bool: return self._page.get_by_role("heading", name="Login").is_visible() +class UnauthorizedPage: + def __init__(self, app: Flask, page: Page): + self._app = app + self._page = page + + def visible(self) -> bool: + return self._page.get_by_role("heading", name="Unauthorized").is_visible() + + def _get_base_url(app: Flask) -> str: scheme = app.config["PREFERRED_URL_SCHEME"] server_name = app.config["SERVER_NAME"] diff --git a/tests/e2e/test_home.py b/tests/e2e/test_home.py index baf015cc..cf1a04d2 100644 --- a/tests/e2e/test_home.py +++ b/tests/e2e/test_home.py @@ -2,18 +2,27 @@ from flask import Flask from playwright.sync_api import Page +from tests.e2e.app_client import AppClient from tests.e2e.pages import HomePage @pytest.mark.usefixtures("live_server", "oidc_server", "oidc_user") @pytest.mark.oidc_user(id="stub_user", email="user@domain.com") class TestAuthenticated: - def test_home_when_authenticated(self, app: Flask, page: Page) -> None: + def test_home_when_authorized(self, app_client: AppClient, app: Flask, page: Page) -> None: + app_client.add_user("user@domain.com") + home_page = HomePage(app, page).open() assert home_page.visible() - def test_header_sign_out(self, app: Flask, page: Page) -> None: + def test_home_when_unauthorized(self, app: Flask, page: Page) -> None: + unauthorized_page = HomePage(app, page).open_when_unauthorized() + + assert unauthorized_page.visible() + + def test_header_sign_out(self, app_client: AppClient, app: Flask, page: Page) -> None: + app_client.add_user("user@domain.com") home_page = HomePage(app, page).open() start_page = home_page.header.sign_out() diff --git a/tests/e2e/test_start.py b/tests/e2e/test_start.py index 79c4c601..bdb7c68e 100644 --- a/tests/e2e/test_start.py +++ b/tests/e2e/test_start.py @@ -2,6 +2,7 @@ from flask import Flask from playwright.sync_api import Page +from tests.e2e.app_client import AppClient from tests.e2e.pages import StartPage @@ -24,7 +25,8 @@ def test_start_shows_login(self, app: Flask, page: Page) -> None: @pytest.mark.usefixtures("live_server", "oidc_server", "oidc_user") @pytest.mark.oidc_user(id="stub_user", email="user@domain.com") class TestAuthenticated: - def test_start_shows_home(self, app: Flask, page: Page) -> None: + def test_start_shows_home(self, app_client: AppClient, app: Flask, page: Page) -> None: + app_client.add_user("user@domain.com") start_page = StartPage(app, page).open() start_page.start() diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py new file mode 100644 index 00000000..176bae85 --- /dev/null +++ b/tests/integration/test_api.py @@ -0,0 +1,32 @@ +from typing import Any, Mapping + +import pytest +from flask import current_app +from flask.testing import FlaskClient + + +def test_add_user(client: FlaskClient) -> None: + response = client.post("/api/users", json={"email": "hello@test.com"}) + + assert response.status_code == 201 + assert "hello@test.com" in current_app.extensions["users"] + + +def test_clear_users(client: FlaskClient) -> None: + current_app.extensions["users"].append("hello@test.com") + + response = client.delete("/api/users") + + assert response.status_code == 204 + assert not current_app.extensions["users"] + + +class TestProduction: + @pytest.fixture(name="config") + def config_fixture(self, config: Mapping[str, Any]) -> Mapping[str, Any]: + return config | {"TESTING": False} + + def test_cannot_add_user(self, client: FlaskClient) -> None: + response = client.post("/api/users", json={"email": "hello@test.com"}) + + assert response.status_code == 404 diff --git a/tests/integration/test_auth.py b/tests/integration/test_auth.py index 70f1cbd4..2efbda57 100644 --- a/tests/integration/test_auth.py +++ b/tests/integration/test_auth.py @@ -14,24 +14,35 @@ def config_fixture(config: Mapping[str, Any]) -> Mapping[str, Any]: def test_callback_logs_in(client: FlaskClient) -> None: + current_app.extensions["users"].append("user@domain.com") _given_token_response({"id_token": "jwt"}) - _given_user_info(UserInfo({"sub": "123"})) + _given_user_info(UserInfo({"email": "user@domain.com"})) with client: client.get("/auth") - assert session["user"] == UserInfo({"sub": "123"}) and session["id_token"] == "jwt" + assert session["user"] == UserInfo({"email": "user@domain.com"}) and session["id_token"] == "jwt" def test_callback_redirects_to_home(client: FlaskClient) -> None: + current_app.extensions["users"].append("user@domain.com") _given_token_response({"id_token": "jwt"}) - _given_user_info(UserInfo({"sub": "123"})) + _given_user_info(UserInfo({"email": "user@domain.com"})) response = client.get("/auth") assert response.status_code == 302 and response.location == "/home" +def test_callback_when_unauthorized_shows_unauthorized(client: FlaskClient) -> None: + _given_token_response({"id_token": "jwt"}) + _given_user_info(UserInfo({"email": "user@domain.com"})) + + response = client.get("/auth") + + assert response.status_code == 401 and response.text == "

Unauthorized

" + + def test_logout_logs_out_from_oidc(client: FlaskClient) -> None: with client.session_transaction() as setup_session: setup_session["user"] = "test"