Skip to content

Commit

Permalink
Helper function to get flask.request.patron (#2178)
Browse files Browse the repository at this point in the history
* Do same work for flask.request.patron

* Add some doc comments

* Update playtime tests
  • Loading branch information
jonathangreen authored Nov 20, 2024
1 parent b1d8cb2 commit 9a6e727
Show file tree
Hide file tree
Showing 19 changed files with 286 additions and 207 deletions.
7 changes: 3 additions & 4 deletions src/palace/manager/api/controller/analytics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

import flask
from flask import Response

from palace.manager.api.controller.circulation_manager import (
CirculationManagerController,
)
from palace.manager.api.problem_details import INVALID_ANALYTICS_EVENT_TYPE
from palace.manager.api.util.flask import get_request_library
from palace.manager.api.util.flask import get_request_library, get_request_patron
from palace.manager.sqlalchemy.model.circulationevent import CirculationEvent
from palace.manager.util.datetime_helpers import utc_now
from palace.manager.util.problem_detail import ProblemDetail
Expand All @@ -21,8 +20,8 @@ def track_event(self, identifier_type, identifier, event_type):
if event_type in CirculationEvent.CLIENT_EVENTS:
library = get_request_library()
# Authentication on the AnalyticsController is optional,
# so flask.request.patron may or may not be set.
patron = getattr(flask.request, "patron", None)
# so we may not have a patron.
patron = get_request_patron(default=None)
neighborhood = None
if patron:
neighborhood = getattr(patron, "neighborhood", None)
Expand Down
5 changes: 3 additions & 2 deletions src/palace/manager/api/controller/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
CirculationManagerController,
)
from palace.manager.api.problem_details import NO_ANNOTATION
from palace.manager.api.util.flask import get_request_patron
from palace.manager.sqlalchemy.model.identifier import Identifier
from palace.manager.sqlalchemy.model.patron import Annotation
from palace.manager.sqlalchemy.util import get_one
Expand All @@ -30,7 +31,7 @@ def container(self, identifier=None, accept_post=True):
if flask.request.method == "HEAD":
return Response(status=200, headers=headers)

patron = flask.request.patron
patron = get_request_patron()

if flask.request.method == "GET":
headers["Link"] = [
Expand Down Expand Up @@ -78,7 +79,7 @@ def detail(self, annotation_id):
if flask.request.method == "HEAD":
return Response(status=200, headers=headers)

patron = flask.request.patron
patron = get_request_patron()

annotation = get_one(
self._db, Annotation, patron=patron, id=annotation_id, active=True
Expand Down
7 changes: 4 additions & 3 deletions src/palace/manager/api/controller/device_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DEVICE_TOKEN_NOT_FOUND,
DEVICE_TOKEN_TYPE_INVALID,
)
from palace.manager.api.util.flask import get_request_patron
from palace.manager.sqlalchemy.model.devicetokens import (
DeviceToken,
DuplicateDeviceTokenError,
Expand All @@ -20,7 +21,7 @@

class DeviceTokensController(CirculationManagerController):
def get_patron_device(self):
patron = flask.request.patron
patron = get_request_patron()
device_token = flask.request.args["device_token"]
token: DeviceToken = (
self._db.query(DeviceToken)
Expand All @@ -35,7 +36,7 @@ def get_patron_device(self):
return dict(token_type=token.token_type, device_token=token.device_token), 200

def create_patron_device(self):
patron = flask.request.patron
patron = get_request_patron()
device_token = flask.request.json["device_token"]
token_type = flask.request.json["token_type"]

Expand All @@ -49,7 +50,7 @@ def create_patron_device(self):
return "", 201

def delete_patron_device(self):
patron = flask.request.patron
patron = get_request_patron()
device_token = flask.request.json["device_token"]
token_type = flask.request.json["token_type"]

Expand Down
10 changes: 5 additions & 5 deletions src/palace/manager/api/controller/loan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
NO_ACTIVE_LOAN_OR_HOLD,
NO_LICENSES,
)
from palace.manager.api.util.flask import get_request_library
from palace.manager.api.util.flask import get_request_library, get_request_patron
from palace.manager.celery.tasks.patron_activity import sync_patron_activity
from palace.manager.core.problem_details import INTERNAL_SERVER_ERROR
from palace.manager.feed.acquisition import OPDSAcquisitionFeed
Expand All @@ -46,7 +46,7 @@ def sync(self) -> Response:
:return: A Response containing an OPDS feed with up-to-date information.
"""
patron: Patron = flask.request.patron # type: ignore[attr-defined]
patron: Patron = get_request_patron()

try:
# Parse the refresh query parameter as a boolean.
Expand Down Expand Up @@ -89,7 +89,7 @@ def borrow(
"http://opds-spec.org/acquisition", which can be used to fetch the
book or the license file.
"""
patron = flask.request.patron # type: ignore[attr-defined]
patron = get_request_patron()
library = get_request_library()

header = self.authorization_header()
Expand Down Expand Up @@ -428,7 +428,7 @@ def can_fulfill_without_loan(
return self.circulation.can_fulfill_without_loan(patron, pool, lpdm)

def revoke(self, license_pool_id: int) -> OPDSEntryResponse | ProblemDetail:
patron = flask.request.patron # type: ignore[attr-defined]
patron = get_request_patron()
pool = self.load_licensepool(license_pool_id)
if isinstance(pool, ProblemDetail):
return pool
Expand Down Expand Up @@ -492,7 +492,7 @@ def revoke(self, license_pool_id: int) -> OPDSEntryResponse | ProblemDetail:
def detail(
self, identifier_type: str, identifier: str
) -> OPDSEntryResponse | ProblemDetail | None:
patron = flask.request.patron # type: ignore[attr-defined]
patron = get_request_patron()
library = get_request_library()
pools = self.load_licensepools(library, identifier_type, identifier)
if isinstance(pools, ProblemDetail):
Expand Down
5 changes: 2 additions & 3 deletions src/palace/manager/api/controller/patron_activity_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

import uuid

import flask
from flask import Response

from palace.manager.sqlalchemy.model.patron import Patron
from palace.manager.api.util.flask import get_request_patron


class PatronActivityHistoryController:

def reset_statistics_uuid(self) -> Response:
"""Resets the patron's the statistics UUID that links the patron to past activity thus effectively erasing the
link between activity history and patron."""
patron: Patron = flask.request.patron # type: ignore [attr-defined]
patron = get_request_patron()
patron.uuid = uuid.uuid4()
return Response("UUID reset", 200)
5 changes: 3 additions & 2 deletions src/palace/manager/api/controller/patron_auth_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@
)
from palace.manager.api.model.patron_auth import PatronAuthAccessToken
from palace.manager.api.problem_details import PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE
from palace.manager.api.util.flask import get_request_patron
from palace.manager.util.log import LoggerMixin
from palace.manager.util.problem_detail import ProblemDetailException


class PatronAuthTokenController(CirculationManagerController, LoggerMixin):
def get_token(self):
"""Create a Patron Auth access token for an authenticated patron"""
patron = flask.request.patron
patron = get_request_patron(default=None)
auth = flask.request.authorization
token_expiry = 3600

if not patron or auth.type.lower() != "basic":
if patron is None or auth is None or auth.type.lower() != "basic":
return PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE

try:
Expand Down
4 changes: 2 additions & 2 deletions src/palace/manager/api/controller/playtime_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
PlaytimeEntriesPostResponse,
)
from palace.manager.api.problem_details import NOT_FOUND_ON_REMOTE
from palace.manager.api.util.flask import get_request_library
from palace.manager.api.util.flask import get_request_library, get_request_patron
from palace.manager.core.problem_details import INVALID_INPUT
from palace.manager.core.query.playtime_entries import PlaytimeEntries
from palace.manager.sqlalchemy.model.collection import Collection
Expand Down Expand Up @@ -77,7 +77,7 @@ def track_playtimes(self, collection_id, identifier_type, identifier_idn):
.join(LicensePool)
.where(
LicensePool.identifier == identifier,
Loan.patron == flask.request.patron,
Loan.patron == get_request_patron(),
Loan.start <= entry_max_start_time,
or_(Loan.end > entry_min_end_time, Loan.end == None),
)
Expand Down
3 changes: 2 additions & 1 deletion src/palace/manager/api/controller/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from palace.manager.api.controller.circulation_manager import (
CirculationManagerController,
)
from palace.manager.api.util.flask import get_request_patron
from palace.manager.core.user_profile import ProfileController as CoreProfileController
from palace.manager.util.problem_detail import ProblemDetail

Expand All @@ -21,7 +22,7 @@ def _controller(self, patron):

def protocol(self):
"""Handle a UPMP request."""
patron = flask.request.patron
patron = get_request_patron()
controller = self._controller(patron)
if flask.request.method == "GET":
result = controller.get()
Expand Down
4 changes: 2 additions & 2 deletions src/palace/manager/api/controller/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
SeriesLane,
)
from palace.manager.api.problem_details import NO_SUCH_LANE, NOT_FOUND_ON_REMOTE
from palace.manager.api.util.flask import get_request_library
from palace.manager.api.util.flask import get_request_library, get_request_patron
from palace.manager.core.app_server import load_pagination_from_request
from palace.manager.core.config import CannotLoadConfiguration
from palace.manager.core.metadata_layer import ContributorData
Expand Down Expand Up @@ -111,7 +111,7 @@ def permalink(self, identifier_type, identifier):
if isinstance(work, ProblemDetail):
return work

patron = flask.request.patron
patron = get_request_patron(default=None)

if patron:
pools = self.load_licensepools(library, identifier_type, identifier)
Expand Down
28 changes: 28 additions & 0 deletions src/palace/manager/api/util/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from palace.manager.core.exceptions import PalaceValueError
from palace.manager.sqlalchemy.model.library import Library
from palace.manager.sqlalchemy.model.patron import Patron
from palace.manager.util.sentinel import SentinelType

if TYPE_CHECKING:
Expand Down Expand Up @@ -104,6 +105,33 @@ def get_request_library(
return get_request_var("library", Library, default=default)


@overload
def get_request_patron() -> Patron: ...


@overload
def get_request_patron(*, default: TDefault) -> Patron | TDefault: ...


def get_request_patron(
*, default: TDefault | Literal[SentinelType.NotGiven] = SentinelType.NotGiven
) -> Patron | TDefault:
"""
Retrieve the 'patron' attribute from the current Flask request object.
This attribute should be set by using the @requires_auth or @allows_auth decorator
on the route or by calling the BaseCirculationManagerController.authenticated_patron_from_request
method.
:param default: The default value to return if the 'patron' attribute is not set
or if there is no request context. If not provided, a `PalaceValueError` will be
raised if the attribute is missing or has an incorrect type.
:return: The `Patron` object from the request, or the default value if provided.
"""
return get_request_var("patron", Patron, default=default)


class PalaceFlask(flask.Flask):
"""
A subclass of Flask sets properties used by Palace.
Expand Down
3 changes: 1 addition & 2 deletions tests/fixtures/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def __init__(self, name):

def authenticated_patron_from_request(self):
if self.authenticated:
patron = object()
flask.request.patron = self.AUTHENTICATED_PATRON
setattr(flask.request, "patron", self.AUTHENTICATED_PATRON)
return self.AUTHENTICATED_PATRON
else:
return flask.Response(
Expand Down
3 changes: 3 additions & 0 deletions tests/fixtures/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from palace.manager.api.util.flask import PalaceFlask
from palace.manager.sqlalchemy.model.admin import Admin, AdminRole
from palace.manager.sqlalchemy.model.library import Library
from palace.manager.sqlalchemy.model.patron import Patron
from palace.manager.sqlalchemy.util import get_one_or_create
from tests.fixtures.database import DatabaseTransactionFixture

Expand All @@ -38,12 +39,14 @@ def test_request_context(
*args: Any,
admin: Admin | None = None,
library: Library | None = None,
patron: Patron | None = None,
**kwargs: Any,
) -> Generator[RequestContext]:
with self.app.test_request_context(*args, **kwargs) as c:
self.db.session.begin_nested()
setattr(c.request, "library", library)
setattr(c.request, "admin", admin)
setattr(c.request, "patron", patron)
setattr(c.request, "form", ImmutableMultiDict())
setattr(c.request, "files", ImmutableMultiDict())
try:
Expand Down
5 changes: 2 additions & 3 deletions tests/manager/api/controller/test_analytics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import flask
import pytest

from palace.manager.api.problem_details import INVALID_ANALYTICS_EVENT_TYPE
Expand Down Expand Up @@ -42,8 +41,8 @@ def test_track_event(self, analytics_fixture: AnalyticsFixture):

patron = db.patron()
patron.neighborhood = "Mars Grid 4810579"
with analytics_fixture.request_context_with_library("/"):
flask.request.patron = patron # type: ignore
with analytics_fixture.request_context_with_library("/") as ctx:
setattr(ctx.request, "patron", patron)
response = analytics_fixture.manager.analytics_controller.track_event(
analytics_fixture.identifier.type,
analytics_fixture.identifier.identifier,
Expand Down
Loading

0 comments on commit 9a6e727

Please sign in to comment.