Skip to content

Commit

Permalink
Polish InResponseTo support
Browse files Browse the repository at this point in the history
- Moved methods so methods are listed before the methods they call
- Adjusted exception handling so no exceptions are eaten
- Adjusted so that malformed_request_data is returned with request data is malformed
- Refactored methods to have only immutable method parameters
- Removed usage of Stream API
- Moved AuthnRequestUnmarshaller into static block so that only looked
up once

Issue gh-9174
  • Loading branch information
jzheaux committed Mar 15, 2022
1 parent 3c87854 commit cf29bf9
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
import org.opensaml.saml.saml2.core.OneTimeUse;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.StatusCode;
import org.opensaml.saml.saml2.core.Subject;
import org.opensaml.saml.saml2.core.SubjectConfirmation;
import org.opensaml.saml.saml2.core.SubjectConfirmationData;
import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
Expand Down Expand Up @@ -146,6 +145,13 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv

private final ResponseUnmarshaller responseUnmarshaller;

private static final AuthnRequestUnmarshaller authnRequestUnmarshaller;
static {
XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
authnRequestUnmarshaller = (AuthnRequestUnmarshaller) registry.getUnmarshallerFactory()
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
}

private final ParserPool parserPool;

private final Converter<ResponseToken, Saml2ResponseValidatorResult> responseSignatureValidator = createDefaultResponseSignatureValidator();
Expand Down Expand Up @@ -355,37 +361,6 @@ public void setResponseAuthenticationConverter(
this.responseAuthenticationConverter = responseAuthenticationConverter;
}

private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
String inResponseTo) {
if (!StringUtils.hasText(inResponseTo)) {
return Saml2ResponseValidatorResult.success();
}
AuthnRequest request;
try {
request = parseRequest(storedRequest);
}
catch (Exception ex) {
String message = "The stored AuthNRequest could not be properly deserialized [" + ex.getMessage() + "]";
return Saml2ResponseValidatorResult
.failure(new Saml2Error(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message));
}
if (request == null) {
String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]"
+ " but no saved AuthNRequest request was found";
return Saml2ResponseValidatorResult
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
}
else if (!request.getID().equals(inResponseTo)) {
String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the "
+ "AuthNRequest [" + request.getID() + "]";
return Saml2ResponseValidatorResult
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
}
else {
return Saml2ResponseValidatorResult.success();
}
}

/**
* Construct a default strategy for validating the SAML 2.0 Response
* @return the default response validator strategy
Expand Down Expand Up @@ -428,6 +403,27 @@ public static Converter<ResponseToken, Saml2ResponseValidatorResult> createDefau
};
}

private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
String inResponseTo) {
if (!StringUtils.hasText(inResponseTo)) {
return Saml2ResponseValidatorResult.success();
}
AuthnRequest request = parseRequest(storedRequest);
if (request == null) {
String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]"
+ " but no saved authentication request was found";
return Saml2ResponseValidatorResult
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
}
if (!inResponseTo.equals(request.getID())) {
String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the "
+ "authentication request [" + request.getID() + "]";
return Saml2ResponseValidatorResult
.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
}
return Saml2ResponseValidatorResult.success();
}

/**
* Construct a default strategy for validating each SAML 2.0 Assertion and associated
* {@link Authentication} token
Expand Down Expand Up @@ -522,28 +518,6 @@ private Response parseResponse(String response) throws Saml2Exception, Saml2Auth
}
}

private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) throws Exception {
if (request == null) {
return null;
}
String samlRequest = request.getSamlRequest();
if (!StringUtils.hasText(samlRequest)) {
return null;
}
if (request.getBinding() == Saml2MessageBinding.REDIRECT) {
samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest));
}
else {
samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8);
}
Document document = XMLObjectProviderRegistrySupport.getParserPool()
.parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8)));
Element element = document.getDocumentElement();
AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) XMLObjectProviderRegistrySupport
.getUnmarshallerFactory().getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
return (AuthnRequest) unmarshaller.unmarshall(element);
}

private void process(Saml2AuthenticationToken token, Response response) {
String issuer = response.getIssuer().getValue();
this.logger.debug(LogMessage.format("Processing SAML response from %s", issuer));
Expand Down Expand Up @@ -748,40 +722,18 @@ private static Converter<AssertionToken, Saml2ResponseValidatorResult> createAss
};
}

private static boolean assertionContainsInResponseTo(Assertion assertion) {
Subject subject = (assertion != null) ? assertion.getSubject() : null;
List<SubjectConfirmation> confirmations = (subject != null) ? subject.getSubjectConfirmations()
: new ArrayList<>();
return confirmations.stream().filter((confirmation) -> {
SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData();
return confirmationData != null && StringUtils.hasText(confirmationData.getInResponseTo());
}).findFirst().orElse(null) != null;
}

private static void addRequestIdToValidationContext(AbstractSaml2AuthenticationRequest storedRequest,
Map<String, Object> context) {
String requestId = null;
try {
AuthnRequest request = parseRequest(storedRequest);
requestId = (request != null) ? request.getID() : null;
}
catch (Exception ex) {
}
if (StringUtils.hasText(requestId)) {
context.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId);
}
}

private static ValidationContext createValidationContext(AssertionToken assertionToken,
Consumer<Map<String, Object>> paramsConsumer) {
RelyingPartyRegistration relyingPartyRegistration = assertionToken.token.getRelyingPartyRegistration();
Saml2AuthenticationToken token = assertionToken.token;
RelyingPartyRegistration relyingPartyRegistration = token.getRelyingPartyRegistration();
String audience = relyingPartyRegistration.getEntityId();
String recipient = relyingPartyRegistration.getAssertionConsumerServiceLocation();
String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyDetails().getEntityId();
Map<String, Object> params = new HashMap<>();
Assertion assertion = assertionToken.getAssertion();
if (assertionContainsInResponseTo(assertion)) {
addRequestIdToValidationContext(assertionToken.token.getAuthenticationRequest(), params);
String requestId = getAuthnRequestId(token.getAuthenticationRequest());
params.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId);
}
params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience));
params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient));
Expand All @@ -790,6 +742,56 @@ private static ValidationContext createValidationContext(AssertionToken assertio
return new ValidationContext(params);
}

private static boolean assertionContainsInResponseTo(Assertion assertion) {
if (assertion.getSubject() == null) {
return false;
}
for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) {
SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData();
if (confirmationData == null) {
continue;
}
if (StringUtils.hasText(confirmationData.getInResponseTo())) {
return true;
}
}
return false;
}

private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) {
AuthnRequest request = parseRequest(serialized);
if (request == null) {
return null;
}
return request.getID();
}

private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) {
if (request == null) {
return null;
}
String samlRequest = request.getSamlRequest();
if (!StringUtils.hasText(samlRequest)) {
return null;
}
if (request.getBinding() == Saml2MessageBinding.REDIRECT) {
samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest));
}
else {
samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8);
}
try {
Document document = XMLObjectProviderRegistrySupport.getParserPool()
.parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8)));
Element element = document.getDocumentElement();
return (AuthnRequest) authnRequestUnmarshaller.unmarshall(element);
}
catch (Exception ex) {
String message = "Failed to deserialize associated authentication request [" + ex.getMessage() + "]";
throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message, ex);
}
}

private static class SAML20AssertionValidators {

private static final Collection<ConditionValidator> conditions = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndCorrupted
Saml2MessageBinding.POST, true);
Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
assertThatExceptionOfType(Saml2AuthenticationException.class)
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion");
.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data");
}

@Test
Expand Down

0 comments on commit cf29bf9

Please sign in to comment.