Skip to content

Commit

Permalink
GH-4: Extract UserRepository
Browse files Browse the repository at this point in the history
  • Loading branch information
markhobson authored and Sparrow0hawk committed Oct 13, 2023
1 parent f132ef9 commit c7ba9f7
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 16 deletions.
11 changes: 6 additions & 5 deletions schemes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from schemes import api, auth, home, start
from schemes.config import DevConfig
from schemes.users import User
from schemes.users import User, UserRepository


def create_app(test_config: Mapping[str, Any] | None = None) -> Flask:
Expand Down Expand Up @@ -75,9 +75,10 @@ def _configure_oidc(app: Flask) -> None:


def _configure_users(app: Flask) -> None:
app.extensions["users"] = []
users = UserRepository()

if not app.testing:
app.extensions["users"].extend(
[User("[email protected]"), User("[email protected]")]
)
users.add(User("[email protected]"))
users.add(User("[email protected]"))

app.extensions["users"] = users
8 changes: 5 additions & 3 deletions schemes/api.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from flask import Blueprint, Response, current_app, request

from schemes.users import User
from schemes.users import User, UserRepository

bp = Blueprint("api", __name__)


@bp.route("/users", methods=["POST"])
def add_user() -> Response:
user = User(request.get_json()["email"])
current_app.extensions["users"].append(user)
users: UserRepository = current_app.extensions["users"]
users.add(user)
return Response(status=201)


@bp.route("/users", methods=["DELETE"])
def clear_users() -> Response:
current_app.extensions["users"].clear()
users: UserRepository = current_app.extensions["users"]
users.clear()
return Response(status=204)
6 changes: 4 additions & 2 deletions schemes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
)
from werkzeug.wrappers import Response as BaseResponse

from schemes.users import UserRepository

bp = Blueprint("auth", __name__)


Expand Down Expand Up @@ -68,8 +70,8 @@ def decorated_function(*args: P.args, **kwargs: P.kwargs) -> T | Response:


def _is_authorized(user: UserInfo) -> bool:
users = current_app.extensions["users"]
return user["email"] in [user.email for user in users]
users: UserRepository = current_app.extensions["users"]
return users.get(user["email"]) is not None


def _get_oauth() -> OAuth:
Expand Down
21 changes: 21 additions & 0 deletions schemes/users.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
from dataclasses import dataclass
from typing import List, TypeGuard


@dataclass
class User:
email: str


class UserRepository:
def __init__(self) -> None:
self._users: List[User] = []

def add(self, user: User) -> None:
self._users.append(user)

def clear(self) -> None:
self._users.clear()

def get(self, email: str) -> User | None:
def by_email(user: User) -> TypeGuard[User]:
return user.email == email

return next(filter(by_email, self._users), None)

def get_all(self) -> List[User]:
return self._users
6 changes: 3 additions & 3 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ def test_add_user(client: FlaskClient) -> None:
response = client.post("/api/users", json={"email": "[email protected]"})

assert response.status_code == 201
assert User("[email protected]") in current_app.extensions["users"]
assert current_app.extensions["users"].get("[email protected]")


def test_clear_users(client: FlaskClient) -> None:
current_app.extensions["users"].append(User("[email protected]"))
current_app.extensions["users"].add(User("[email protected]"))

response = client.delete("/api/users")

assert response.status_code == 204
assert not current_app.extensions["users"]
assert not current_app.extensions["users"].get_all()


class TestProduction:
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def config_fixture(config: Mapping[str, Any]) -> Mapping[str, Any]:


def test_callback_logs_in(client: FlaskClient) -> None:
current_app.extensions["users"].append(User("[email protected]"))
current_app.extensions["users"].add(User("[email protected]"))
_given_oidc_returns_token_response({"id_token": "jwt"})
_given_oidc_returns_user_info(UserInfo({"email": "[email protected]"}))

Expand All @@ -28,7 +28,7 @@ def test_callback_logs_in(client: FlaskClient) -> None:


def test_callback_redirects_to_home(client: FlaskClient) -> None:
current_app.extensions["users"].append(User("[email protected]"))
current_app.extensions["users"].add(User("[email protected]"))
_given_oidc_returns_token_response({"id_token": "jwt"})
_given_oidc_returns_user_info(UserInfo({"email": "[email protected]"}))

Expand All @@ -38,7 +38,7 @@ def test_callback_redirects_to_home(client: FlaskClient) -> None:


def test_callback_when_unauthorized_redirects_to_unauthorized(client: FlaskClient) -> None:
current_app.extensions["users"].append(User("[email protected]"))
current_app.extensions["users"].add(User("[email protected]"))
_given_oidc_returns_token_response({"id_token": "jwt"})
_given_oidc_returns_user_info(UserInfo({"email": "[email protected]"}))

Expand Down
44 changes: 44 additions & 0 deletions tests/unit/test_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

from schemes.users import User, UserRepository


@pytest.fixture(name="user_repository")
def user_repository_fixture() -> UserRepository:
return UserRepository()


def test_add_user(user_repository: UserRepository) -> None:
user_repository.add(User("[email protected]"))

assert user_repository.get("[email protected]") == User("[email protected]")


def test_get_user(user_repository: UserRepository) -> None:
user_repository.add(User("[email protected]"))

assert user_repository.get("[email protected]") == User("[email protected]")


def test_get_user_who_does_not_exist(user_repository: UserRepository) -> None:
user_repository.add(User("[email protected]"))

assert user_repository.get("[email protected]") is None


def test_get_all_users(user_repository: UserRepository) -> None:
user_repository.add(User("[email protected]"))
user_repository.add(User("[email protected]"))

user_list = user_repository.get_all()

assert user_list == [User("[email protected]"), User("[email protected]")]


def test_clear_all_users(user_repository: UserRepository) -> None:
user_repository.add(User("[email protected]"))
user_repository.add(User("[email protected]"))

user_repository.clear()

assert user_repository.get_all() == []

0 comments on commit c7ba9f7

Please sign in to comment.