From 7d89524f9903d58c75371c1c1f62ed3a5e034ae4 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Tue, 12 Sep 2023 13:58:18 -0300 Subject: [PATCH] Update OPDS2 token auth to work with any authentication mechanism. --- api/opds2.py | 27 ++++++++++++++------------- tests/api/test_opds2.py | 15 +++++++++------ 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/api/opds2.py b/api/opds2.py index fa38f3d82b..b49e2319bb 100644 --- a/api/opds2.py +++ b/api/opds2.py @@ -9,7 +9,7 @@ from api.circulation import CirculationFulfillmentPostProcessor, FulfillmentInfo from api.circulation_exceptions import CannotFulfill from core.lane import Facets -from core.model import ConfigurationSetting, ExternalIntegration +from core.model import ConfigurationSetting, DataSource, ExternalIntegration from core.model.edition import Edition from core.model.identifier import Identifier from core.model.licensing import LicensePoolDeliveryMechanism @@ -94,6 +94,10 @@ class TokenAuthenticationFulfillmentProcessor(CirculationFulfillmentPostProcesso def __init__(self, collection) -> None: pass + @classmethod + def logger(cls) -> logging.Logger: + return logging.getLogger(f"{cls.__module__}.{cls.__name__}") + def fulfill( self, patron: Patron, @@ -120,7 +124,9 @@ def fulfill( if not token_auth or token_auth.value is None: return fulfillment - token = self.get_authentication_token(patron, token_auth.value) + token = self.get_authentication_token( + patron, licensepool.data_source, token_auth.value + ) if isinstance(token, ProblemDetail): raise CannotFulfill() @@ -130,24 +136,19 @@ def fulfill( @classmethod def get_authentication_token( - cls, patron: Patron, token_auth_url: str + cls, patron: Patron, datasource: DataSource, token_auth_url: str ) -> ProblemDetail | str: """Get the authentication token for a patron""" - log = logging.getLogger("OPDS2API") - - patron_id = patron.username if patron.username else patron.external_identifier - if patron_id is None: - log.error( - f"Could not authenticate the patron({patron.authorization_identifier}): " - f"both username and external_identifier are None." - ) - return INVALID_CREDENTIALS + log = cls.logger() + log.debug(f"Getting authentication token for patron({patron.id})") + patron_id = patron.identifier_to_remote_service(datasource) url = URITemplate(token_auth_url).expand(patron_id=patron_id) response = HTTP.get_with_timeout(url) if response.status_code != 200: log.error( - f"Could not authenticate the patron({patron_id}): {str(response.content)}" + f"Could not authenticate the patron (authorization identifier: '{patron.authorization_identifier}' " + f"external identifier: '{patron_id}'): {str(response.content)}" ) return INVALID_CREDENTIALS diff --git a/tests/api/test_opds2.py b/tests/api/test_opds2.py index d4a896adb6..969b290261 100644 --- a/tests/api/test_opds2.py +++ b/tests/api/test_opds2.py @@ -175,7 +175,6 @@ class TestTokenAuthenticationFulfillmentProcessor: @patch("api.opds2.HTTP") def test_fulfill(self, mock_http, db: DatabaseTransactionFixture): patron = db.patron() - patron.username = "username" collection: Collection = db.collection( protocol=ExternalIntegration.OPDS2_IMPORT ) @@ -207,10 +206,14 @@ def test_fulfill(self, mock_http, db: DatabaseTransactionFixture): processor = TokenAuthenticationFulfillmentProcessor(collection) ff_info = processor.fulfill(patron, "", work.license_pools[0], None, ff_info) + patron_id = patron.identifier_to_remote_service( + work.license_pools[0].data_source + ) + assert mock_http.get_with_timeout.call_count == 1 assert ( mock_http.get_with_timeout.call_args[0][0] - == "http://example.org/token?userName=username" + == f"http://example.org/token?userName={patron_id}" ) assert ( @@ -265,9 +268,9 @@ def test_get_authentication_token(self, mock_http, db: DatabaseTransactionFixtur resp.raw = io.BytesIO(b"plaintext-auth-token") mock_http.get_with_timeout.return_value = resp patron = db.patron() - patron.username = "test" + datasource = DataSource.lookup(db.session, "test", autocreate=True) token = TokenAuthenticationFulfillmentProcessor.get_authentication_token( - patron, "http://example.org/token" + patron, datasource, "http://example.org/token" ) assert token == "plaintext-auth-token" @@ -280,9 +283,9 @@ def test_get_authentication_token_errors( resp = Response() resp.status_code = 400 mock_http.get_with_timeout.return_value = resp - + datasource = DataSource.lookup(db.session, "test", autocreate=True) token = TokenAuthenticationFulfillmentProcessor.get_authentication_token( - db.patron(), "http://example.org/token" + db.patron(), datasource, "http://example.org/token" ) assert token == INVALID_CREDENTIALS