diff --git a/api/admin/controller/patron.py b/api/admin/controller/patron.py index 2865696be7..31204c98c3 100644 --- a/api/admin/controller/patron.py +++ b/api/admin/controller/patron.py @@ -35,13 +35,14 @@ def _load_patrondata(self, authenticator=None): patron_data = PatronData(authorization_identifier=identifier) complete_patron_data = None + patron_lookup_providers = list(authenticator.unique_patron_lookup_providers) - if not authenticator.providers: + if not patron_lookup_providers: return NO_SUCH_PATRON.detailed( _("This library has no authentication providers, so it has no patrons.") ) - for provider in authenticator.providers: + for provider in patron_lookup_providers: complete_patron_data = provider.remote_patron_lookup(patron_data) if complete_patron_data: return complete_patron_data diff --git a/api/authenticator.py b/api/authenticator.py index 11d655b0e1..b1ce28b56c 100644 --- a/api/authenticator.py +++ b/api/authenticator.py @@ -33,8 +33,6 @@ from core.util.log import elapsed_time_logging from core.util.problem_detail import ProblemDetail, ProblemError -from .authentication.base import AuthenticationProvider -from .authentication.basic import BasicAuthenticationProvider from .config import CannotLoadConfiguration, Configuration from .integration.registry.patron_auth import PatronAuthRegistry from .problem_details import * @@ -430,6 +428,34 @@ def providers(self) -> Iterable[AuthenticationProvider]: yield self.basic_auth_provider yield from self.saml_providers_by_name.values() + @property + def unique_patron_lookup_providers(self) -> Iterable[AuthenticationProvider]: + """Iterator over unique patron data providers for registered AuthenticationProviders. + + We want to these providers to be unique to avoid performing the same + lookup multiple times. Otherwise, this would be likely to happen when + the lookup failed for a given provider. + """ + + # For `BasicTokenAuthenticationProvider`s, the lookup provider is + # its `basic_provider`. `BasicAuthenticationProvider`s are their + # own lookup providers. + basic_providers = filter( + None, + [ + ( + self.access_token_authentication_provider.basic_provider + if self.access_token_authentication_provider + else None + ), + self.basic_auth_provider, + ], + ) + # De-dupe, but preserve provider order. + unique_basic_provider = dict.fromkeys(basic_providers) + yield from unique_basic_provider + yield from self.saml_providers_by_name.values() + def authenticated_patron( self, _db: Session, auth: Authorization ) -> Patron | ProblemDetail | None: diff --git a/tests/api/test_authenticator.py b/tests/api/test_authenticator.py index eed0516412..8f2dc76ae2 100644 --- a/tests/api/test_authenticator.py +++ b/tests/api/test_authenticator.py @@ -937,6 +937,15 @@ def test_authenticated_patron_bearer_access_token( authenticator = LibraryAuthenticator( _db=db.session, library=db.default_library(), basic_auth_provider=basic ) + + token_auth_provider, basic_auth_provider = authenticator.providers + [patron_lookup_provider] = authenticator.unique_patron_lookup_providers + assert ( + cast(BasicTokenAuthenticationProvider, token_auth_provider).basic_provider + == basic_auth_provider + ) + assert patron_lookup_provider == basic_auth_provider + patron = db.patron() token = AccessTokenProvider.generate_token(db.session, patron, "pass") auth = Authorization(auth_type="bearer", token=token)