diff --git a/pyproject.toml b/pyproject.toml index 152fa593..16109230 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,8 @@ dependencies = [ "gunicorn~=21.2.0", "inject~=5.0.0", "python-dotenv~=1.0.0", - "requests~=2.31.0" + "requests~=2.31.0", + "sqlalchemy~=2.0.22" ] [project.optional-dependencies] diff --git a/schemes/__init__.py b/schemes/__init__.py index e568b521..9fc497e6 100644 --- a/schemes/__init__.py +++ b/schemes/__init__.py @@ -7,10 +7,11 @@ from flask import Flask, Response, render_template, request, url_for from inject import Binder from jinja2 import ChoiceLoader, FileSystemLoader, PackageLoader, PrefixLoader +from sqlalchemy import Engine, MetaData, create_engine from schemes import api, auth, home, start from schemes.config import DevConfig -from schemes.users import DatabaseUserRepository, User, UserRepository +from schemes.users import DatabaseUserRepository, User, UserRepository, add_tables def create_app(test_config: Mapping[str, Any] | None = None) -> Flask: @@ -27,8 +28,6 @@ def create_app(test_config: Mapping[str, Any] | None = None) -> Flask: _configure_basic_auth(app) _configure_govuk_frontend(app) _configure_oidc(app) - if not app.testing: - _configure_users() app.register_blueprint(start.bp) app.register_blueprint(auth.bp, url_prefix="/auth") @@ -36,11 +35,16 @@ def create_app(test_config: Mapping[str, Any] | None = None) -> Flask: if app.testing: app.register_blueprint(api.bp, url_prefix="/api") + _create_database() + if not app.testing: + _configure_users() + return app def _bindings(binder: Binder) -> None: - binder.bind(UserRepository, DatabaseUserRepository()) + binder.bind(Engine, create_engine("sqlite+pysqlite:///file::memory:?uri=true")) + binder.bind_to_constructor(UserRepository, DatabaseUserRepository) def _configure_error_pages(app: Flask) -> None: @@ -94,6 +98,14 @@ def _configure_oidc(app: Flask) -> None: ) +def _create_database() -> None: + metadata = MetaData() + add_tables(metadata) + + engine = inject.instance(Engine) + metadata.create_all(engine) + + def _configure_users() -> None: users = inject.instance(UserRepository) users.add(User("alex.coleman@activetravelengland.gov.uk")) diff --git a/schemes/users.py b/schemes/users.py index 4cf3a304..1cb390d1 100644 --- a/schemes/users.py +++ b/schemes/users.py @@ -1,5 +1,8 @@ from dataclasses import dataclass -from typing import List, TypeGuard +from typing import List + +import inject +from sqlalchemy import Column, Engine, MetaData, Table, Text, text @dataclass @@ -21,22 +24,31 @@ def get_all(self) -> List[User]: raise NotImplementedError() +def add_tables(metadata: MetaData) -> None: + Table("users", metadata, Column("email", Text, nullable=False)) + + # pylint: disable=duplicate-code class DatabaseUserRepository(UserRepository): - def __init__(self) -> None: - self._users: List[User] = [] + @inject.autoparams() + def __init__(self, engine: Engine): + self._engine = engine def add(self, user: User) -> None: - self._users.append(user) + with self._engine.begin() as connection: + connection.execute(text("INSERT INTO users (email) VALUES (:email)"), {"email": user.email}) def clear(self) -> None: - self._users.clear() + with self._engine.begin() as connection: + connection.execute(text("DELETE FROM users")) def get(self, email: str) -> User | None: - def by_email(user: User) -> TypeGuard[User]: - return user.email == email - - return next(filter(by_email, self._users), None) + with self._engine.connect() as connection: + result = connection.execute(text("SELECT email FROM users WHERE email = :email"), {"email": email}) + row = result.one_or_none() + return User(row.email) if row else None def get_all(self) -> List[User]: - return self._users + with self._engine.connect() as connection: + result = connection.execute(text("SELECT email FROM users")) + return [User(row.email) for row in result] diff --git a/tests/unit/test_users.py b/tests/unit/test_users.py index 97fcd050..bfc624f8 100644 --- a/tests/unit/test_users.py +++ b/tests/unit/test_users.py @@ -1,11 +1,19 @@ import pytest +from sqlalchemy import MetaData, create_engine -from schemes.users import DatabaseUserRepository, User +from schemes.users import DatabaseUserRepository, User, add_tables @pytest.fixture(name="users") def users_fixture() -> DatabaseUserRepository: - return DatabaseUserRepository() + metadata = MetaData() + add_tables(metadata) + + engine = create_engine("sqlite+pysqlite:///:memory:", echo=True) + metadata.create_all(engine) + + repository: DatabaseUserRepository = DatabaseUserRepository(engine) + return repository def test_add_user(users: DatabaseUserRepository) -> None: