diff --git a/src/palace/manager/api/controller/analytics.py b/src/palace/manager/api/controller/analytics.py index ba1022a7d..b07adb310 100644 --- a/src/palace/manager/api/controller/analytics.py +++ b/src/palace/manager/api/controller/analytics.py @@ -1,13 +1,12 @@ from __future__ import annotations -import flask from flask import Response from palace.manager.api.controller.circulation_manager import ( CirculationManagerController, ) from palace.manager.api.problem_details import INVALID_ANALYTICS_EVENT_TYPE -from palace.manager.api.util.flask import get_request_library +from palace.manager.api.util.flask import get_request_library, get_request_patron from palace.manager.sqlalchemy.model.circulationevent import CirculationEvent from palace.manager.util.datetime_helpers import utc_now from palace.manager.util.problem_detail import ProblemDetail @@ -21,8 +20,8 @@ def track_event(self, identifier_type, identifier, event_type): if event_type in CirculationEvent.CLIENT_EVENTS: library = get_request_library() # Authentication on the AnalyticsController is optional, - # so flask.request.patron may or may not be set. - patron = getattr(flask.request, "patron", None) + # so we may not have a patron. + patron = get_request_patron(default=None) neighborhood = None if patron: neighborhood = getattr(patron, "neighborhood", None) diff --git a/src/palace/manager/api/controller/annotation.py b/src/palace/manager/api/controller/annotation.py index ec9d14b05..90aacc066 100644 --- a/src/palace/manager/api/controller/annotation.py +++ b/src/palace/manager/api/controller/annotation.py @@ -12,6 +12,7 @@ CirculationManagerController, ) from palace.manager.api.problem_details import NO_ANNOTATION +from palace.manager.api.util.flask import get_request_patron from palace.manager.sqlalchemy.model.identifier import Identifier from palace.manager.sqlalchemy.model.patron import Annotation from palace.manager.sqlalchemy.util import get_one @@ -30,7 +31,7 @@ def container(self, identifier=None, accept_post=True): if flask.request.method == "HEAD": return Response(status=200, headers=headers) - patron = flask.request.patron + patron = get_request_patron() if flask.request.method == "GET": headers["Link"] = [ @@ -78,7 +79,7 @@ def detail(self, annotation_id): if flask.request.method == "HEAD": return Response(status=200, headers=headers) - patron = flask.request.patron + patron = get_request_patron() annotation = get_one( self._db, Annotation, patron=patron, id=annotation_id, active=True diff --git a/src/palace/manager/api/controller/device_tokens.py b/src/palace/manager/api/controller/device_tokens.py index 94bf34734..673ffed67 100644 --- a/src/palace/manager/api/controller/device_tokens.py +++ b/src/palace/manager/api/controller/device_tokens.py @@ -11,6 +11,7 @@ DEVICE_TOKEN_NOT_FOUND, DEVICE_TOKEN_TYPE_INVALID, ) +from palace.manager.api.util.flask import get_request_patron from palace.manager.sqlalchemy.model.devicetokens import ( DeviceToken, DuplicateDeviceTokenError, @@ -20,7 +21,7 @@ class DeviceTokensController(CirculationManagerController): def get_patron_device(self): - patron = flask.request.patron + patron = get_request_patron() device_token = flask.request.args["device_token"] token: DeviceToken = ( self._db.query(DeviceToken) @@ -35,7 +36,7 @@ def get_patron_device(self): return dict(token_type=token.token_type, device_token=token.device_token), 200 def create_patron_device(self): - patron = flask.request.patron + patron = get_request_patron() device_token = flask.request.json["device_token"] token_type = flask.request.json["token_type"] @@ -49,7 +50,7 @@ def create_patron_device(self): return "", 201 def delete_patron_device(self): - patron = flask.request.patron + patron = get_request_patron() device_token = flask.request.json["device_token"] token_type = flask.request.json["token_type"] diff --git a/src/palace/manager/api/controller/loan.py b/src/palace/manager/api/controller/loan.py index 75067962b..55a19dcac 100644 --- a/src/palace/manager/api/controller/loan.py +++ b/src/palace/manager/api/controller/loan.py @@ -23,7 +23,7 @@ NO_ACTIVE_LOAN_OR_HOLD, NO_LICENSES, ) -from palace.manager.api.util.flask import get_request_library +from palace.manager.api.util.flask import get_request_library, get_request_patron from palace.manager.celery.tasks.patron_activity import sync_patron_activity from palace.manager.core.problem_details import INTERNAL_SERVER_ERROR from palace.manager.feed.acquisition import OPDSAcquisitionFeed @@ -46,7 +46,7 @@ def sync(self) -> Response: :return: A Response containing an OPDS feed with up-to-date information. """ - patron: Patron = flask.request.patron # type: ignore[attr-defined] + patron: Patron = get_request_patron() try: # Parse the refresh query parameter as a boolean. @@ -89,7 +89,7 @@ def borrow( "http://opds-spec.org/acquisition", which can be used to fetch the book or the license file. """ - patron = flask.request.patron # type: ignore[attr-defined] + patron = get_request_patron() library = get_request_library() header = self.authorization_header() @@ -428,7 +428,7 @@ def can_fulfill_without_loan( return self.circulation.can_fulfill_without_loan(patron, pool, lpdm) def revoke(self, license_pool_id: int) -> OPDSEntryResponse | ProblemDetail: - patron = flask.request.patron # type: ignore[attr-defined] + patron = get_request_patron() pool = self.load_licensepool(license_pool_id) if isinstance(pool, ProblemDetail): return pool @@ -492,7 +492,7 @@ def revoke(self, license_pool_id: int) -> OPDSEntryResponse | ProblemDetail: def detail( self, identifier_type: str, identifier: str ) -> OPDSEntryResponse | ProblemDetail | None: - patron = flask.request.patron # type: ignore[attr-defined] + patron = get_request_patron() library = get_request_library() pools = self.load_licensepools(library, identifier_type, identifier) if isinstance(pools, ProblemDetail): diff --git a/src/palace/manager/api/controller/patron_activity_history.py b/src/palace/manager/api/controller/patron_activity_history.py index 89594add0..5f1471187 100644 --- a/src/palace/manager/api/controller/patron_activity_history.py +++ b/src/palace/manager/api/controller/patron_activity_history.py @@ -2,10 +2,9 @@ import uuid -import flask from flask import Response -from palace.manager.sqlalchemy.model.patron import Patron +from palace.manager.api.util.flask import get_request_patron class PatronActivityHistoryController: @@ -13,6 +12,6 @@ class PatronActivityHistoryController: def reset_statistics_uuid(self) -> Response: """Resets the patron's the statistics UUID that links the patron to past activity thus effectively erasing the link between activity history and patron.""" - patron: Patron = flask.request.patron # type: ignore [attr-defined] + patron = get_request_patron() patron.uuid = uuid.uuid4() return Response("UUID reset", 200) diff --git a/src/palace/manager/api/controller/patron_auth_token.py b/src/palace/manager/api/controller/patron_auth_token.py index 87070fdbb..765595dc3 100644 --- a/src/palace/manager/api/controller/patron_auth_token.py +++ b/src/palace/manager/api/controller/patron_auth_token.py @@ -8,6 +8,7 @@ ) from palace.manager.api.model.patron_auth import PatronAuthAccessToken from palace.manager.api.problem_details import PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE +from palace.manager.api.util.flask import get_request_patron from palace.manager.util.log import LoggerMixin from palace.manager.util.problem_detail import ProblemDetailException @@ -15,11 +16,11 @@ class PatronAuthTokenController(CirculationManagerController, LoggerMixin): def get_token(self): """Create a Patron Auth access token for an authenticated patron""" - patron = flask.request.patron + patron = get_request_patron(default=None) auth = flask.request.authorization token_expiry = 3600 - if not patron or auth.type.lower() != "basic": + if patron is None or auth is None or auth.type.lower() != "basic": return PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE try: diff --git a/src/palace/manager/api/controller/playtime_entries.py b/src/palace/manager/api/controller/playtime_entries.py index 2fd74fd71..58ff753c4 100644 --- a/src/palace/manager/api/controller/playtime_entries.py +++ b/src/palace/manager/api/controller/playtime_entries.py @@ -14,7 +14,7 @@ PlaytimeEntriesPostResponse, ) from palace.manager.api.problem_details import NOT_FOUND_ON_REMOTE -from palace.manager.api.util.flask import get_request_library +from palace.manager.api.util.flask import get_request_library, get_request_patron from palace.manager.core.problem_details import INVALID_INPUT from palace.manager.core.query.playtime_entries import PlaytimeEntries from palace.manager.sqlalchemy.model.collection import Collection @@ -77,7 +77,7 @@ def track_playtimes(self, collection_id, identifier_type, identifier_idn): .join(LicensePool) .where( LicensePool.identifier == identifier, - Loan.patron == flask.request.patron, + Loan.patron == get_request_patron(), Loan.start <= entry_max_start_time, or_(Loan.end > entry_min_end_time, Loan.end == None), ) diff --git a/src/palace/manager/api/controller/profile.py b/src/palace/manager/api/controller/profile.py index 99bc65583..60c5ce496 100644 --- a/src/palace/manager/api/controller/profile.py +++ b/src/palace/manager/api/controller/profile.py @@ -7,6 +7,7 @@ from palace.manager.api.controller.circulation_manager import ( CirculationManagerController, ) +from palace.manager.api.util.flask import get_request_patron from palace.manager.core.user_profile import ProfileController as CoreProfileController from palace.manager.util.problem_detail import ProblemDetail @@ -21,7 +22,7 @@ def _controller(self, patron): def protocol(self): """Handle a UPMP request.""" - patron = flask.request.patron + patron = get_request_patron() controller = self._controller(patron) if flask.request.method == "GET": result = controller.get() diff --git a/src/palace/manager/api/controller/work.py b/src/palace/manager/api/controller/work.py index bf2fb19b4..03d0d8c45 100644 --- a/src/palace/manager/api/controller/work.py +++ b/src/palace/manager/api/controller/work.py @@ -17,7 +17,7 @@ SeriesLane, ) from palace.manager.api.problem_details import NO_SUCH_LANE, NOT_FOUND_ON_REMOTE -from palace.manager.api.util.flask import get_request_library +from palace.manager.api.util.flask import get_request_library, get_request_patron from palace.manager.core.app_server import load_pagination_from_request from palace.manager.core.config import CannotLoadConfiguration from palace.manager.core.metadata_layer import ContributorData @@ -111,7 +111,7 @@ def permalink(self, identifier_type, identifier): if isinstance(work, ProblemDetail): return work - patron = flask.request.patron + patron = get_request_patron(default=None) if patron: pools = self.load_licensepools(library, identifier_type, identifier) diff --git a/src/palace/manager/api/util/flask.py b/src/palace/manager/api/util/flask.py index f7db64c59..fa3f82f05 100644 --- a/src/palace/manager/api/util/flask.py +++ b/src/palace/manager/api/util/flask.py @@ -7,6 +7,7 @@ from palace.manager.core.exceptions import PalaceValueError from palace.manager.sqlalchemy.model.library import Library +from palace.manager.sqlalchemy.model.patron import Patron from palace.manager.util.sentinel import SentinelType if TYPE_CHECKING: @@ -104,6 +105,33 @@ def get_request_library( return get_request_var("library", Library, default=default) +@overload +def get_request_patron() -> Patron: ... + + +@overload +def get_request_patron(*, default: TDefault) -> Patron | TDefault: ... + + +def get_request_patron( + *, default: TDefault | Literal[SentinelType.NotGiven] = SentinelType.NotGiven +) -> Patron | TDefault: + """ + Retrieve the 'patron' attribute from the current Flask request object. + + This attribute should be set by using the @requires_auth or @allows_auth decorator + on the route or by calling the BaseCirculationManagerController.authenticated_patron_from_request + method. + + :param default: The default value to return if the 'patron' attribute is not set + or if there is no request context. If not provided, a `PalaceValueError` will be + raised if the attribute is missing or has an incorrect type. + + :return: The `Patron` object from the request, or the default value if provided. + """ + return get_request_var("patron", Patron, default=default) + + class PalaceFlask(flask.Flask): """ A subclass of Flask sets properties used by Palace. diff --git a/tests/fixtures/api_routes.py b/tests/fixtures/api_routes.py index 3c5be4963..e36bd5e4f 100644 --- a/tests/fixtures/api_routes.py +++ b/tests/fixtures/api_routes.py @@ -97,8 +97,7 @@ def __init__(self, name): def authenticated_patron_from_request(self): if self.authenticated: - patron = object() - flask.request.patron = self.AUTHENTICATED_PATRON + setattr(flask.request, "patron", self.AUTHENTICATED_PATRON) return self.AUTHENTICATED_PATRON else: return flask.Response( diff --git a/tests/fixtures/flask.py b/tests/fixtures/flask.py index 43d8cffa4..ec594e700 100644 --- a/tests/fixtures/flask.py +++ b/tests/fixtures/flask.py @@ -12,6 +12,7 @@ from palace.manager.api.util.flask import PalaceFlask from palace.manager.sqlalchemy.model.admin import Admin, AdminRole from palace.manager.sqlalchemy.model.library import Library +from palace.manager.sqlalchemy.model.patron import Patron from palace.manager.sqlalchemy.util import get_one_or_create from tests.fixtures.database import DatabaseTransactionFixture @@ -38,12 +39,14 @@ def test_request_context( *args: Any, admin: Admin | None = None, library: Library | None = None, + patron: Patron | None = None, **kwargs: Any, ) -> Generator[RequestContext]: with self.app.test_request_context(*args, **kwargs) as c: self.db.session.begin_nested() setattr(c.request, "library", library) setattr(c.request, "admin", admin) + setattr(c.request, "patron", patron) setattr(c.request, "form", ImmutableMultiDict()) setattr(c.request, "files", ImmutableMultiDict()) try: diff --git a/tests/manager/api/controller/test_analytics.py b/tests/manager/api/controller/test_analytics.py index 1f391bde5..673c745bf 100644 --- a/tests/manager/api/controller/test_analytics.py +++ b/tests/manager/api/controller/test_analytics.py @@ -1,4 +1,3 @@ -import flask import pytest from palace.manager.api.problem_details import INVALID_ANALYTICS_EVENT_TYPE @@ -42,8 +41,8 @@ def test_track_event(self, analytics_fixture: AnalyticsFixture): patron = db.patron() patron.neighborhood = "Mars Grid 4810579" - with analytics_fixture.request_context_with_library("/"): - flask.request.patron = patron # type: ignore + with analytics_fixture.request_context_with_library("/") as ctx: + setattr(ctx.request, "patron", patron) response = analytics_fixture.manager.analytics_controller.track_event( analytics_fixture.identifier.type, analytics_fixture.identifier.identifier, diff --git a/tests/manager/api/controller/test_patron_access_token.py b/tests/manager/api/controller/test_patron_access_token.py index 802b360f4..28278ca45 100644 --- a/tests/manager/api/controller/test_patron_access_token.py +++ b/tests/manager/api/controller/test_patron_access_token.py @@ -1,66 +1,72 @@ -from typing import TYPE_CHECKING +from unittest.mock import MagicMock import pytest from werkzeug.datastructures import Authorization +from palace.manager.api.controller.patron_auth_token import PatronAuthTokenController from palace.manager.api.problem_details import PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE -from tests.fixtures.api_controller import CirculationControllerFixture from tests.fixtures.database import DatabaseTransactionFixture -from tests.fixtures.services import ServicesFixture +from tests.fixtures.flask import FlaskAppFixture -if TYPE_CHECKING: - from palace.manager.api.controller.patron_auth_token import ( - PatronAuthTokenController, - ) - -class PatronAuthTokenControllerFixture(CirculationControllerFixture): - def __init__( - self, db: DatabaseTransactionFixture, services_fixture: ServicesFixture - ) -> None: - super().__init__(db, services_fixture) - self.controller: PatronAuthTokenController = self.manager.patron_auth_token +class PatronAuthTokenControllerFixture: + def __init__(self, db: DatabaseTransactionFixture) -> None: + mock_manager = MagicMock() + mock_manager._db = db.session + self.controller = PatronAuthTokenController(mock_manager) @pytest.fixture(scope="function") -def patron_auth_token_fixture( - db: DatabaseTransactionFixture, services_fixture: ServicesFixture -): - return PatronAuthTokenControllerFixture(db, services_fixture) +def patron_auth_token_fixture(db: DatabaseTransactionFixture): + return PatronAuthTokenControllerFixture(db) class TestPatronAuthTokenController: def test_get_token( - self, patron_auth_token_fixture: PatronAuthTokenControllerFixture + self, + patron_auth_token_fixture: PatronAuthTokenControllerFixture, + db: DatabaseTransactionFixture, + flask_app_fixture: FlaskAppFixture, ): - fxtr = patron_auth_token_fixture - db = fxtr.db patron = db.patron() - with fxtr.request_context_with_library("/") as ctx: - ctx.request.authorization = Authorization( - auth_type="Basic", data=dict(username="user", password="pass") - ) - ctx.request.patron = patron - token = fxtr.controller.get_token() + with flask_app_fixture.test_request_context( + auth=Authorization( + auth_type="basic", data=dict(username="user", password="pass") + ), + patron=patron, + ): + token = patron_auth_token_fixture.controller.get_token() assert ("accessToken", "tokenType", "expiresIn") == tuple(token.keys()) assert token["expiresIn"] == 3600 assert token["tokenType"] == "Bearer" def test_get_token_errors( - self, patron_auth_token_fixture: PatronAuthTokenControllerFixture + self, + patron_auth_token_fixture: PatronAuthTokenControllerFixture, + db: DatabaseTransactionFixture, + flask_app_fixture: FlaskAppFixture, ): - fxtr = patron_auth_token_fixture - db = fxtr.db patron = db.patron() - with fxtr.request_context_with_library("/") as ctx: - ctx.request.authorization = Authorization( - auth_type="Bearer", token="Some-token" + with flask_app_fixture.test_request_context( + patron=patron, auth=Authorization(auth_type="bearer", token="Some-token") + ): + assert ( + patron_auth_token_fixture.controller.get_token() + == PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE ) - ctx.request.patron = patron - assert fxtr.controller.get_token() == PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE - ctx.request.authorization = Authorization( - auth_type="Basic", data=dict(username="user", password="pass") + with flask_app_fixture.test_request_context(patron=patron): + assert ( + patron_auth_token_fixture.controller.get_token() + == PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE + ) + + with flask_app_fixture.test_request_context( + auth=Authorization( + auth_type="basic", data=dict(username="user", password="pass") + ) + ): + assert ( + patron_auth_token_fixture.controller.get_token() + == PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE ) - ctx.request.patron = None - assert fxtr.controller.get_token() == PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE diff --git a/tests/manager/api/controller/test_patron_activity_history.py b/tests/manager/api/controller/test_patron_activity_history.py index 772ded9d9..ac92ce4d7 100644 --- a/tests/manager/api/controller/test_patron_activity_history.py +++ b/tests/manager/api/controller/test_patron_activity_history.py @@ -1,12 +1,10 @@ from __future__ import annotations -import flask import pytest from palace.manager.api.controller.patron_activity_history import ( PatronActivityHistoryController, ) -from palace.manager.sqlalchemy.model.patron import Patron from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.flask import FlaskAppFixture @@ -34,12 +32,10 @@ def test_reset_statistics_uuid( self, controller_fixture: PatronActivityHistoryControllerFixture, flask_app_fixture: FlaskAppFixture, + db: DatabaseTransactionFixture, ): - - with flask_app_fixture.test_request_context("/", method="PUT"): - patron = controller_fixture.db.patron() - flask.request.patron = patron # type: ignore[attr-defined] - assert isinstance(patron, Patron) + patron = db.patron() + with flask_app_fixture.test_request_context("/", method="PUT", patron=patron): uuid1 = patron.uuid assert uuid1 response = controller_fixture.controller.reset_statistics_uuid() diff --git a/tests/manager/api/controller/test_playtime_entries.py b/tests/manager/api/controller/test_playtime_entries.py index c8f9a45da..937e5bb0d 100644 --- a/tests/manager/api/controller/test_playtime_entries.py +++ b/tests/manager/api/controller/test_playtime_entries.py @@ -1,13 +1,13 @@ import datetime import hashlib -from unittest.mock import patch +from unittest.mock import MagicMock, patch -import flask import pytest from sqlalchemy.exc import IntegrityError from palace.manager.api.controller.playtime_entries import ( MISSING_LOAN_IDENTIFIER, + PlaytimeEntriesController, resolve_loan_identifier, ) from palace.manager.sqlalchemy.model.patron import Loan @@ -15,16 +15,31 @@ from palace.manager.sqlalchemy.util import get_one from palace.manager.util.datetime_helpers import utc_now from palace.manager.util.problem_detail import ProblemDetail -from tests.fixtures.api_controller import CirculationControllerFixture +from tests.fixtures.database import DatabaseTransactionFixture +from tests.fixtures.flask import FlaskAppFixture -def date_string(hour=None, minute=None): - now = utc_now() - if hour is not None: - now = now.replace(hour=hour) - if minute is not None: - now = now.replace(minute=minute) - return now.strftime("%Y-%m-%dT%H:%M:00+00:00") +class PlaytimeEntriesControllerFixture: + def __init__(self, db: DatabaseTransactionFixture) -> None: + mock_manager = MagicMock() + mock_manager._db = db.session + self.controller = PlaytimeEntriesController(mock_manager) + + @staticmethod + def date_string(hour=None, minute=None): + now = utc_now() + if hour is not None: + now = now.replace(hour=hour) + if minute is not None: + now = now.replace(minute=minute) + return now.strftime("%Y-%m-%dT%H:%M:00+00:00") + + +@pytest.fixture() +def playtime_entries_controller_fixture( + db: DatabaseTransactionFixture, +) -> PlaytimeEntriesControllerFixture: + return PlaytimeEntriesControllerFixture(db) class TestPlaytimeEntriesController: @@ -36,9 +51,12 @@ class TestPlaytimeEntriesController: ], ) def test_track_playtime( - self, circulation_fixture: CirculationControllerFixture, no_loan_end_date: bool - ): - db = circulation_fixture.db + self, + db: DatabaseTransactionFixture, + flask_app_fixture: FlaskAppFixture, + playtime_entries_controller_fixture: PlaytimeEntriesControllerFixture, + no_loan_end_date: bool, + ) -> None: identifier = db.identifier() collection = db.default_collection() # Attach the identifier to the collection @@ -50,7 +68,9 @@ def test_track_playtime( ) patron = db.patron() - loan_exists_date_str = date_string(hour=12, minute=0) + loan_exists_date_str = playtime_entries_controller_fixture.date_string( + hour=12, minute=0 + ) inscope_loan_start = datetime.datetime.fromisoformat(loan_exists_date_str) inscope_loan_end = ( None @@ -75,7 +95,9 @@ def test_track_playtime( }, { "id": "tracking-id-1", - "during_minute": date_string(hour=12, minute=1), + "during_minute": playtime_entries_controller_fixture.date_string( + hour=12, minute=1 + ), "seconds_played": 17, }, { @@ -85,11 +107,10 @@ def test_track_playtime( }, ] ) - with circulation_fixture.request_context_with_library( - "/", method="POST", json=data + with flask_app_fixture.test_request_context( + "/", method="POST", json=data, patron=patron, library=db.default_library() ): - flask.request.patron = patron # type: ignore - response = circulation_fixture.manager.playtime_entries.track_playtimes( + response = playtime_entries_controller_fixture.controller.track_playtimes( collection.id, identifier.type, identifier.identifier ) @@ -115,7 +136,10 @@ def test_track_playtime( assert entry.collection == collection assert entry.library == db.default_library() assert entry.total_seconds_played == 12 - assert entry.timestamp.isoformat() == date_string(hour=12, minute=0) + assert ( + entry.timestamp.isoformat() + == playtime_entries_controller_fixture.date_string(hour=12, minute=0) + ) assert entry.loan_identifier == expected_loan_identifier entry = get_one(db.session, PlaytimeEntry, tracking_id="tracking-id-1") @@ -124,7 +148,10 @@ def test_track_playtime( assert entry.collection == collection assert entry.library == db.default_library() assert entry.total_seconds_played == 17 - assert entry.timestamp.isoformat() == date_string(hour=12, minute=1) + assert ( + entry.timestamp.isoformat() + == playtime_entries_controller_fixture.date_string(hour=12, minute=1) + ) assert entry.loan_identifier == expected_loan_identifier # The very old entry does not get recorded @@ -132,7 +159,7 @@ def test_track_playtime( db.session, PlaytimeEntry, tracking_id="tracking-id-2" ) - def test_resolve_loan_identifier(self): + def test_resolve_loan_identifier(self) -> None: no_loan = resolve_loan_identifier(loan=None) test_id = 1 test_loan_identifier = resolve_loan_identifier(Loan(id=test_id)) @@ -143,9 +170,11 @@ def test_resolve_loan_identifier(self): ) def test_track_playtime_duplicate_id_ok( - self, circulation_fixture: CirculationControllerFixture + self, + db: DatabaseTransactionFixture, + flask_app_fixture: FlaskAppFixture, + playtime_entries_controller_fixture: PlaytimeEntriesControllerFixture, ): - db = circulation_fixture.db identifier = db.identifier() collection = db.default_collection() library = db.default_library() @@ -163,7 +192,9 @@ def test_track_playtime_duplicate_id_ok( db.session.add( PlaytimeEntry( tracking_id="tracking-id-0", - timestamp=date_string(hour=12, minute=0), + timestamp=playtime_entries_controller_fixture.date_string( + hour=12, minute=0 + ), total_seconds_played=12, identifier_id=identifier.id, collection_id=collection.id, @@ -179,21 +210,24 @@ def test_track_playtime_duplicate_id_ok( timeEntries=[ { "id": "tracking-id-0", - "during_minute": date_string(hour=12, minute=0), + "during_minute": playtime_entries_controller_fixture.date_string( + hour=12, minute=0 + ), "seconds_played": 12, }, { "id": "tracking-id-1", - "during_minute": date_string(hour=12, minute=1), + "during_minute": playtime_entries_controller_fixture.date_string( + hour=12, minute=1 + ), "seconds_played": 12, }, ] ) - with circulation_fixture.request_context_with_library( - "/", method="POST", json=data + with flask_app_fixture.test_request_context( + "/", method="POST", json=data, library=library, patron=patron ): - flask.request.patron = patron # type: ignore - response = circulation_fixture.manager.playtime_entries.track_playtimes( + response = playtime_entries_controller_fixture.controller.track_playtimes( collection.id, identifier.type, identifier.identifier ) @@ -210,9 +244,11 @@ def test_track_playtime_duplicate_id_ok( ) def test_track_playtime_failure( - self, circulation_fixture: CirculationControllerFixture + self, + db: DatabaseTransactionFixture, + flask_app_fixture: FlaskAppFixture, + playtime_entries_controller_fixture: PlaytimeEntriesControllerFixture, ): - db = circulation_fixture.db identifier = db.identifier() collection = db.default_collection() # Attach the identifier to the collection @@ -228,23 +264,26 @@ def test_track_playtime_failure( timeEntries=[ { "id": "tracking-id-1", - "during_minute": date_string(hour=12, minute=1), + "during_minute": playtime_entries_controller_fixture.date_string( + hour=12, minute=1 + ), "seconds_played": 12, } ] ) - with circulation_fixture.request_context_with_library( - "/", method="POST", json=data + with flask_app_fixture.test_request_context( + "/", method="POST", json=data, patron=patron, library=db.default_library() ): - flask.request.patron = patron # type: ignore with patch( "palace.manager.core.query.playtime_entries.create" ) as mock_create: mock_create.side_effect = IntegrityError( "STATEMENT", {}, Exception("Fake Exception") ) - response = circulation_fixture.manager.playtime_entries.track_playtimes( - collection.id, identifier.type, identifier.identifier + response = ( + playtime_entries_controller_fixture.controller.track_playtimes( + collection.id, identifier.type, identifier.identifier + ) ) assert response.status_code == 207 @@ -255,20 +294,22 @@ def test_track_playtime_failure( dict(status=400, message="Fake Exception", id="tracking-id-1") ] - def test_api_validation(self, circulation_fixture: CirculationControllerFixture): - db = circulation_fixture.db + def test_api_validation( + self, + db: DatabaseTransactionFixture, + flask_app_fixture: FlaskAppFixture, + playtime_entries_controller_fixture: PlaytimeEntriesControllerFixture, + ): identifier = db.identifier() collection = db.collection() library = db.default_library() patron = db.patron() - with circulation_fixture.request_context_with_library( - "/", method="POST", json={} + with flask_app_fixture.test_request_context( + "/", method="POST", json={}, library=library, patron=patron ): - flask.request.patron = patron # type: ignore - # Bad identifier - response = circulation_fixture.manager.playtime_entries.track_playtimes( + response = playtime_entries_controller_fixture.controller.track_playtimes( collection.id, identifier.type, "not-an-identifier" ) assert isinstance(response, ProblemDetail) @@ -279,7 +320,7 @@ def test_api_validation(self, circulation_fixture: CirculationControllerFixture) ) # Bad collection - response = circulation_fixture.manager.playtime_entries.track_playtimes( + response = playtime_entries_controller_fixture.controller.track_playtimes( 9088765, identifier.type, identifier.identifier ) assert isinstance(response, ProblemDetail) @@ -287,7 +328,7 @@ def test_api_validation(self, circulation_fixture: CirculationControllerFixture) assert response.detail == f"The collection 9088765 was not found." # Collection not in library - response = circulation_fixture.manager.playtime_entries.track_playtimes( + response = playtime_entries_controller_fixture.controller.track_playtimes( collection.id, identifier.type, identifier.identifier ) assert isinstance(response, ProblemDetail) @@ -296,7 +337,7 @@ def test_api_validation(self, circulation_fixture: CirculationControllerFixture) # Identifier not part of collection collection.libraries.append(library) - response = circulation_fixture.manager.playtime_entries.track_playtimes( + response = playtime_entries_controller_fixture.controller.track_playtimes( collection.id, identifier.type, identifier.identifier ) assert isinstance(response, ProblemDetail) @@ -312,7 +353,7 @@ def test_api_validation(self, circulation_fixture: CirculationControllerFixture) ) # Incorrect JSON format - response = circulation_fixture.manager.playtime_entries.track_playtimes( + response = playtime_entries_controller_fixture.controller.track_playtimes( collection.id, identifier.type, identifier.identifier ) assert isinstance(response, ProblemDetail) diff --git a/tests/manager/api/controller/test_work.py b/tests/manager/api/controller/test_work.py index d7ce1935c..2b77ed4d0 100644 --- a/tests/manager/api/controller/test_work.py +++ b/tests/manager/api/controller/test_work.py @@ -6,7 +6,6 @@ from unittest.mock import MagicMock, create_autospec import feedparser -import flask import pytest from flask import url_for @@ -302,13 +301,13 @@ def test_permalink(self, work_fixture: WorkFixture): def test_permalink_does_not_return_fulfillment_links_for_authenticated_patrons_without_loans( self, work_fixture: WorkFixture ): - with work_fixture.request_context_with_library("/"): + with work_fixture.request_context_with_library("/") as ctx: # We have two patrons. patron_1 = work_fixture.db.patron() patron_2 = work_fixture.db.patron() # But the request was initiated by the first patron. - flask.request.patron = patron_1 # type: ignore + setattr(ctx.request, "patron", patron_1) identifier_type = Identifier.GUTENBERG_ID identifier = "1234567890" @@ -349,13 +348,13 @@ def test_permalink_does_not_return_fulfillment_links_for_authenticated_patrons_w def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_loans( self, work_fixture: WorkFixture ): - with work_fixture.request_context_with_library("/"): + with work_fixture.request_context_with_library("/") as ctx: # We have two patrons. patron_1 = work_fixture.db.patron() patron_2 = work_fixture.db.patron() # But the request was initiated by the first patron. - flask.request.patron = patron_1 # type: ignore + setattr(ctx.request, "patron", patron_1) identifier_type = Identifier.GUTENBERG_ID identifier = "1234567890" @@ -399,7 +398,7 @@ def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_fulf ): auth = dict(Authorization=work_fixture.valid_auth) - with work_fixture.request_context_with_library("/", headers=auth): + with work_fixture.request_context_with_library("/", headers=auth) as ctx: content_link = "https://content" # We have two patrons. @@ -409,7 +408,7 @@ def test_permalink_returns_fulfillment_links_for_authenticated_patrons_with_fulf patron_2 = work_fixture.db.patron() # But the request was initiated by the first patron. - flask.request.patron = patron_1 # type: ignore + setattr(ctx.request, "patron", patron_1) identifier_type = Identifier.GUTENBERG_ID identifier = "1234567890" diff --git a/tests/manager/api/test_circulationapi.py b/tests/manager/api/test_circulationapi.py index 0568eccd9..7c367bf1b 100644 --- a/tests/manager/api/test_circulationapi.py +++ b/tests/manager/api/test_circulationapi.py @@ -5,7 +5,6 @@ from typing import cast from unittest.mock import MagicMock, create_autospec -import flask import pytest from flask import Flask from freezegun import freeze_time @@ -994,10 +993,10 @@ def assert_event(inp, outp): # Now let's check neighborhood gathering. p2.neighborhood = "Compton" - with app.test_request_context(): + with app.test_request_context() as ctx: # Neighborhood is only gathered if we explicitly ask for # it. - flask.request.patron = p2 # type: ignore + setattr(ctx.request, "patron", p2) assert_event((p2, None, "event"), (l2, None, "event", None, p2)) assert_event((p2, None, "event", False), (l2, None, "event", None, p2)) assert_event((p2, None, "event", True), (l2, None, "event", "Compton", p2)) @@ -1006,10 +1005,10 @@ def assert_event(inp, outp): # patron is not the patron who triggered the event. assert_event((p1, None, "event", True), (l1, None, "event", None, p1)) - with app.test_request_context(): + with app.test_request_context() as ctx: # Even if we ask for it, neighborhood is not gathered if # the data isn't available. - flask.request.patron = p1 # type: ignore + setattr(ctx.request, "patron", p1) assert_event((p1, None, "event", True), (l1, None, "event", None, p1)) # Finally, remove the mock Analytics object entirely and diff --git a/tests/manager/api/test_device_tokens.py b/tests/manager/api/test_device_tokens.py index f2f89b914..e227df8a5 100644 --- a/tests/manager/api/test_device_tokens.py +++ b/tests/manager/api/test_device_tokens.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -9,6 +9,7 @@ ) from palace.manager.sqlalchemy.model.devicetokens import DeviceToken, DeviceTokenTypes from tests.fixtures.database import DatabaseTransactionFixture +from tests.fixtures.flask import FlaskAppFixture @pytest.fixture @@ -18,37 +19,43 @@ def controller(db: DatabaseTransactionFixture) -> DeviceTokensController: return DeviceTokensController(mock_manager) -@patch("palace.manager.api.controller.device_tokens.flask") class TestDeviceTokens: def test_create_invalid_type( - self, flask, controller: DeviceTokensController, db: DatabaseTransactionFixture + self, + flask_app_fixture: FlaskAppFixture, + controller: DeviceTokensController, + db: DatabaseTransactionFixture, ): - request = MagicMock() - request.patron = db.patron() - request.json = {"device_token": "xx", "token_type": "aninvalidtoken"} - flask.request = request - detail = controller.create_patron_device() + with flask_app_fixture.test_request_context( + json={"device_token": "xx", "token_type": "aninvalidtoken"}, + patron=db.patron(), + ): + detail = controller.create_patron_device() assert detail is DEVICE_TOKEN_TYPE_INVALID assert detail.status_code == 400 def test_create_token( - self, flask, controller: DeviceTokensController, db: DatabaseTransactionFixture + self, + flask_app_fixture: FlaskAppFixture, + controller: DeviceTokensController, + db: DatabaseTransactionFixture, ): - request = MagicMock() - request.patron = db.patron() - request.json = { - "device_token": "xxx", - "token_type": DeviceTokenTypes.FCM_ANDROID, - } - flask.request = request - response = controller.create_patron_device() + patron = db.patron() + with flask_app_fixture.test_request_context( + json={ + "device_token": "xxx", + "token_type": DeviceTokenTypes.FCM_ANDROID, + }, + patron=patron, + ): + response = controller.create_patron_device() assert response[1] == 201 devices = ( db.session.query(DeviceToken) - .filter(DeviceToken.patron_id == request.patron.id) + .filter(DeviceToken.patron_id == patron.id) .all() ) @@ -58,119 +65,119 @@ def test_create_token( assert device.token_type == DeviceTokenTypes.FCM_ANDROID def test_get_token( - self, flask, controller: DeviceTokensController, db: DatabaseTransactionFixture + self, + flask_app_fixture: FlaskAppFixture, + controller: DeviceTokensController, + db: DatabaseTransactionFixture, ): patron = db.patron() device = DeviceToken.create( db.session, DeviceTokenTypes.FCM_ANDROID, "xx", patron ) - request = MagicMock() - request.patron = patron - request.args = {"device_token": "xx"} - flask.request = request - response = controller.get_patron_device() + with flask_app_fixture.test_request_context( + path="/?device_token=xx", patron=patron + ): + response = controller.get_patron_device() assert response[1] == 200 assert response[0]["token_type"] == DeviceTokenTypes.FCM_ANDROID assert response[0]["device_token"] == "xx" def test_get_token_not_found( - self, flask, controller: DeviceTokensController, db: DatabaseTransactionFixture + self, + flask_app_fixture: FlaskAppFixture, + controller: DeviceTokensController, + db: DatabaseTransactionFixture, ): patron = db.patron() device = DeviceToken.create( db.session, DeviceTokenTypes.FCM_ANDROID, "xx", patron ) - request = MagicMock() - request.patron = patron - request.args = {"device_token": "xxs"} - flask.request = request - detail = controller.get_patron_device() + with flask_app_fixture.test_request_context( + path="/?device_token=xxs", patron=patron + ): + detail = controller.get_patron_device() assert detail == DEVICE_TOKEN_NOT_FOUND def test_get_token_different_patron( - self, flask, controller: DeviceTokensController, db: DatabaseTransactionFixture + self, + flask_app_fixture: FlaskAppFixture, + controller: DeviceTokensController, + db: DatabaseTransactionFixture, ): patron = db.patron() device = DeviceToken.create( db.session, DeviceTokenTypes.FCM_ANDROID, "xx", patron ) - request = MagicMock() - request.patron = db.patron() - request.args = {"device_token": "xx"} - flask.request = request - detail = controller.get_patron_device() + with flask_app_fixture.test_request_context( + path="/?device_token=xx", patron=db.patron() + ): + detail = controller.get_patron_device() assert detail == DEVICE_TOKEN_NOT_FOUND def test_create_duplicate_token( - self, flask, controller: DeviceTokensController, db: DatabaseTransactionFixture + self, + flask_app_fixture: FlaskAppFixture, + controller: DeviceTokensController, + db: DatabaseTransactionFixture, ): patron = db.patron() device = DeviceToken.create(db.session, DeviceTokenTypes.FCM_IOS, "xxx", patron) - - # Same patron same token - request = MagicMock() - request.patron = patron - request.json = { + json_data = { "device_token": "xxx", "token_type": DeviceTokenTypes.FCM_ANDROID, } - flask.request = request - nested = db.session.begin_nested() # rollback only affects device create - response = controller.create_patron_device() + + # Same patron same token + with flask_app_fixture.test_request_context(json=json_data, patron=patron): + response = controller.create_patron_device() assert response == (dict(exists=True), 200) # different patron same token - patron1 = db.patron() - request = MagicMock() - request.patron = patron1 - request.json = { - "device_token": "xxx", - "token_type": DeviceTokenTypes.FCM_ANDROID, - } - flask.request = request - response = controller.create_patron_device() + with flask_app_fixture.test_request_context(json=json_data, patron=db.patron()): + response = controller.create_patron_device() assert response[1] == 201 def test_delete_token( - self, flask, controller: DeviceTokensController, db: DatabaseTransactionFixture + self, + flask_app_fixture: FlaskAppFixture, + controller: DeviceTokensController, + db: DatabaseTransactionFixture, ): patron = db.patron() device = DeviceToken.create(db.session, DeviceTokenTypes.FCM_IOS, "xxx", patron) - - request = MagicMock() - request.patron = patron - request.json = { + json_data = { "device_token": "xxx", "token_type": DeviceTokenTypes.FCM_IOS, } - flask.request = request - response = controller.delete_patron_device() - db.session.commit() + assert db.session.query(DeviceToken).get(device.id) is not None + + with flask_app_fixture.test_request_context(json=json_data, patron=patron): + response = controller.delete_patron_device() assert response.status_code == 204 - assert db.session.query(DeviceToken).get(device.id) == None + assert db.session.query(DeviceToken).get(device.id) is None def test_delete_no_token( - self, flask, controller: DeviceTokensController, db: DatabaseTransactionFixture + self, + flask_app_fixture: FlaskAppFixture, + controller: DeviceTokensController, + db: DatabaseTransactionFixture, ): patron = db.patron() device = DeviceToken.create(db.session, DeviceTokenTypes.FCM_IOS, "xxx", patron) - - request = MagicMock() - request.patron = patron - request.json = { + json_data = { "device_token": "xxxy", "token_type": DeviceTokenTypes.FCM_IOS, } - flask.request = request - response = controller.delete_patron_device() + with flask_app_fixture.test_request_context(json=json_data, patron=patron): + response = controller.delete_patron_device() assert response == DEVICE_TOKEN_NOT_FOUND