diff --git a/schemes/__init__.py b/schemes/__init__.py index c21bacec..663307ff 100644 --- a/schemes/__init__.py +++ b/schemes/__init__.py @@ -122,5 +122,4 @@ def _create_database() -> None: def _configure_users() -> None: users = inject.instance(UserRepository) if not users.get_all(): - users.add(User("alex.coleman@activetravelengland.gov.uk")) - users.add(User("mark.hobson@activetravelengland.gov.uk")) + users.add(User("alex.coleman@activetravelengland.gov.uk"), User("mark.hobson@activetravelengland.gov.uk")) diff --git a/schemes/api.py b/schemes/api.py index eae5d2e4..3c03a712 100644 --- a/schemes/api.py +++ b/schemes/api.py @@ -8,9 +8,9 @@ @bp.route("/users", methods=["POST"]) @inject.autoparams() -def add_user(users: UserRepository) -> Response: - user = User(request.get_json()["email"]) - users.add(user) +def add_users(users: UserRepository) -> Response: + json = request.get_json() + users.add(*[User(element["email"]) for element in json]) return Response(status=201) diff --git a/schemes/users.py b/schemes/users.py index 9e16e098..f7d2a872 100644 --- a/schemes/users.py +++ b/schemes/users.py @@ -11,7 +11,7 @@ class User: class UserRepository: - def add(self, user: User) -> None: + def add(self, *users: User) -> None: raise NotImplementedError() def clear(self) -> None: @@ -38,9 +38,10 @@ class DatabaseUserRepository(UserRepository): def __init__(self, engine: Engine): self._engine = engine - def add(self, user: User) -> None: + def add(self, *users: User) -> None: with self._engine.begin() as connection: - connection.execute(text("INSERT INTO users (email) VALUES (:email)"), {"email": user.email}) + for user in users: + connection.execute(text("INSERT INTO users (email) VALUES (:email)"), {"email": user.email}) def clear(self) -> None: with self._engine.begin() as connection: diff --git a/tests/e2e/app_client.py b/tests/e2e/app_client.py index 66b2efb1..a0a845b7 100644 --- a/tests/e2e/app_client.py +++ b/tests/e2e/app_client.py @@ -8,8 +8,8 @@ def __init__(self, url: str): self._url = url def add_user(self, email: str) -> None: - user = {"email": email} - response = requests.post(f"{self._url}/api/users", json=user, timeout=self.DEFAULT_TIMEOUT) + users = [{"email": email}] + response = requests.post(f"{self._url}/api/users", json=users, timeout=self.DEFAULT_TIMEOUT) assert response.status_code == 201 def clear_users(self) -> None: diff --git a/tests/integration/fakes.py b/tests/integration/fakes.py index 79707a02..88364668 100644 --- a/tests/integration/fakes.py +++ b/tests/integration/fakes.py @@ -7,8 +7,8 @@ class MemoryUserRepository(UserRepository): def __init__(self) -> None: self._users: List[User] = [] - def add(self, user: User) -> None: - self._users.append(user) + def add(self, *users: User) -> None: + self._users.extend(users) def clear(self) -> None: self._users.clear() diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index a45831ad..42dd6d8d 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -12,11 +12,11 @@ def users_fixture() -> UserRepository: return inject.instance(UserRepository) -def test_add_user(users: UserRepository, client: FlaskClient) -> None: - response = client.post("/api/users", json={"email": "boardman@example.com"}) +def test_add_users(users: UserRepository, client: FlaskClient) -> None: + response = client.post("/api/users", json=[{"email": "boardman@example.com"}, {"email": "obree@example.com"}]) assert response.status_code == 201 - assert users.get_by_email("boardman@example.com") + assert users.get_all() == [User("boardman@example.com"), User("obree@example.com")] def test_clear_users(users: UserRepository, client: FlaskClient) -> None: diff --git a/tests/unit/test_users.py b/tests/unit/test_users.py index 368c3045..33b2d107 100644 --- a/tests/unit/test_users.py +++ b/tests/unit/test_users.py @@ -16,10 +16,10 @@ def users_fixture() -> DatabaseUserRepository: return repository -def test_add_user(users: DatabaseUserRepository) -> None: - users.add(User("boardman@example.com")) +def test_add_users(users: DatabaseUserRepository) -> None: + users.add(User("boardman@example.com"), User("obree@example.com")) - assert users.get_by_email("boardman@example.com") == User("boardman@example.com") + assert users.get_all() == [User("boardman@example.com"), User("obree@example.com")] def test_get_user_by_email(users: DatabaseUserRepository) -> None: