From cf29bf996c9c8ef88162f6e0612ac44cbf81e013 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Tue, 15 Mar 2022 13:05:39 -0600 Subject: [PATCH] Polish InResponseTo support - Moved methods so methods are listed before the methods they call - Adjusted exception handling so no exceptions are eaten - Adjusted so that malformed_request_data is returned with request data is malformed - Refactored methods to have only immutable method parameters - Removed usage of Stream API - Moved AuthnRequestUnmarshaller into static block so that only looked up once Issue gh-9174 --- .../OpenSaml4AuthenticationProvider.java | 162 +++++++++--------- .../OpenSaml4AuthenticationProviderTests.java | 2 +- 2 files changed, 83 insertions(+), 81 deletions(-) diff --git a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java index 835fb616778..31acccfa743 100644 --- a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java @@ -65,7 +65,6 @@ import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.Response; import org.opensaml.saml.saml2.core.StatusCode; -import org.opensaml.saml.saml2.core.Subject; import org.opensaml.saml.saml2.core.SubjectConfirmation; import org.opensaml.saml.saml2.core.SubjectConfirmationData; import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller; @@ -146,6 +145,13 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv private final ResponseUnmarshaller responseUnmarshaller; + private static final AuthnRequestUnmarshaller authnRequestUnmarshaller; + static { + XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class); + authnRequestUnmarshaller = (AuthnRequestUnmarshaller) registry.getUnmarshallerFactory() + .getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); + } + private final ParserPool parserPool; private final Converter responseSignatureValidator = createDefaultResponseSignatureValidator(); @@ -355,37 +361,6 @@ public void setResponseAuthenticationConverter( this.responseAuthenticationConverter = responseAuthenticationConverter; } - private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest, - String inResponseTo) { - if (!StringUtils.hasText(inResponseTo)) { - return Saml2ResponseValidatorResult.success(); - } - AuthnRequest request; - try { - request = parseRequest(storedRequest); - } - catch (Exception ex) { - String message = "The stored AuthNRequest could not be properly deserialized [" + ex.getMessage() + "]"; - return Saml2ResponseValidatorResult - .failure(new Saml2Error(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message)); - } - if (request == null) { - String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]" - + " but no saved AuthNRequest request was found"; - return Saml2ResponseValidatorResult - .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); - } - else if (!request.getID().equals(inResponseTo)) { - String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the " - + "AuthNRequest [" + request.getID() + "]"; - return Saml2ResponseValidatorResult - .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); - } - else { - return Saml2ResponseValidatorResult.success(); - } - } - /** * Construct a default strategy for validating the SAML 2.0 Response * @return the default response validator strategy @@ -428,6 +403,27 @@ public static Converter createDefau }; } + private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest, + String inResponseTo) { + if (!StringUtils.hasText(inResponseTo)) { + return Saml2ResponseValidatorResult.success(); + } + AuthnRequest request = parseRequest(storedRequest); + if (request == null) { + String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]" + + " but no saved authentication request was found"; + return Saml2ResponseValidatorResult + .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); + } + if (!inResponseTo.equals(request.getID())) { + String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the " + + "authentication request [" + request.getID() + "]"; + return Saml2ResponseValidatorResult + .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); + } + return Saml2ResponseValidatorResult.success(); + } + /** * Construct a default strategy for validating each SAML 2.0 Assertion and associated * {@link Authentication} token @@ -522,28 +518,6 @@ private Response parseResponse(String response) throws Saml2Exception, Saml2Auth } } - private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) throws Exception { - if (request == null) { - return null; - } - String samlRequest = request.getSamlRequest(); - if (!StringUtils.hasText(samlRequest)) { - return null; - } - if (request.getBinding() == Saml2MessageBinding.REDIRECT) { - samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest)); - } - else { - samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8); - } - Document document = XMLObjectProviderRegistrySupport.getParserPool() - .parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8))); - Element element = document.getDocumentElement(); - AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) XMLObjectProviderRegistrySupport - .getUnmarshallerFactory().getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); - return (AuthnRequest) unmarshaller.unmarshall(element); - } - private void process(Saml2AuthenticationToken token, Response response) { String issuer = response.getIssuer().getValue(); this.logger.debug(LogMessage.format("Processing SAML response from %s", issuer)); @@ -748,40 +722,18 @@ private static Converter createAss }; } - private static boolean assertionContainsInResponseTo(Assertion assertion) { - Subject subject = (assertion != null) ? assertion.getSubject() : null; - List confirmations = (subject != null) ? subject.getSubjectConfirmations() - : new ArrayList<>(); - return confirmations.stream().filter((confirmation) -> { - SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData(); - return confirmationData != null && StringUtils.hasText(confirmationData.getInResponseTo()); - }).findFirst().orElse(null) != null; - } - - private static void addRequestIdToValidationContext(AbstractSaml2AuthenticationRequest storedRequest, - Map context) { - String requestId = null; - try { - AuthnRequest request = parseRequest(storedRequest); - requestId = (request != null) ? request.getID() : null; - } - catch (Exception ex) { - } - if (StringUtils.hasText(requestId)) { - context.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId); - } - } - private static ValidationContext createValidationContext(AssertionToken assertionToken, Consumer> paramsConsumer) { - RelyingPartyRegistration relyingPartyRegistration = assertionToken.token.getRelyingPartyRegistration(); + Saml2AuthenticationToken token = assertionToken.token; + RelyingPartyRegistration relyingPartyRegistration = token.getRelyingPartyRegistration(); String audience = relyingPartyRegistration.getEntityId(); String recipient = relyingPartyRegistration.getAssertionConsumerServiceLocation(); String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyDetails().getEntityId(); Map params = new HashMap<>(); Assertion assertion = assertionToken.getAssertion(); if (assertionContainsInResponseTo(assertion)) { - addRequestIdToValidationContext(assertionToken.token.getAuthenticationRequest(), params); + String requestId = getAuthnRequestId(token.getAuthenticationRequest()); + params.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId); } params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience)); params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient)); @@ -790,6 +742,56 @@ private static ValidationContext createValidationContext(AssertionToken assertio return new ValidationContext(params); } + private static boolean assertionContainsInResponseTo(Assertion assertion) { + if (assertion.getSubject() == null) { + return false; + } + for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) { + SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData(); + if (confirmationData == null) { + continue; + } + if (StringUtils.hasText(confirmationData.getInResponseTo())) { + return true; + } + } + return false; + } + + private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) { + AuthnRequest request = parseRequest(serialized); + if (request == null) { + return null; + } + return request.getID(); + } + + private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) { + if (request == null) { + return null; + } + String samlRequest = request.getSamlRequest(); + if (!StringUtils.hasText(samlRequest)) { + return null; + } + if (request.getBinding() == Saml2MessageBinding.REDIRECT) { + samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest)); + } + else { + samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8); + } + try { + Document document = XMLObjectProviderRegistrySupport.getParserPool() + .parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8))); + Element element = document.getDocumentElement(); + return (AuthnRequest) authnRequestUnmarshaller.unmarshall(element); + } + catch (Exception ex) { + String message = "Failed to deserialize associated authentication request [" + ex.getMessage() + "]"; + throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message, ex); + } + } + private static class SAML20AssertionValidators { private static final Collection conditions = new ArrayList<>(); diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index 2cf7c62ba75..6a7eb1ff0c5 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -252,7 +252,7 @@ public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndCorrupted Saml2MessageBinding.POST, true); Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); assertThatExceptionOfType(Saml2AuthenticationException.class) - .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion"); + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data"); } @Test