Skip to content

Commit

Permalink
Enhancements for SAML2 bearer flow (#3132)
Browse files Browse the repository at this point in the history
* Test saml bearer

* Fixes for SAML2 bearer flow

* reverted test
  • Loading branch information
strehle authored Nov 15, 2024
1 parent d02c5dd commit d281b28
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.web.AuthenticationEntryPoint;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
Expand Down Expand Up @@ -237,9 +238,15 @@ protected Authentication attemptTokenAuthentication(HttpServletRequest request,
log.debug("Attempting SAML authentication for token endpoint.");
try {
authResult = saml2BearerGrantAuthenticationConverter.convert(request);
} catch (AuthenticationException e) {
String errorMessage = (e instanceof Saml2AuthenticationException saml2AuthenticationException) ?
saml2AuthenticationException.getSaml2Error().getDescription() : e.getMessage();
log.debug(errorMessage, e);
throw new InsufficientAuthenticationException(errorMessage);
} catch (Exception e) {
log.error("Error setting assertion in SAML filter", e);
throw new InsufficientAuthenticationException("Error setting assertion in SAML filter");
String errorMessage = "Error setting assertion in SAML filter";
log.error(errorMessage, e);
throw new InsufficientAuthenticationException(errorMessage);
}
} else {
log.debug("No assertion or filter, not attempting SAML authentication for token endpoint.");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
package org.cloudfoundry.identity.uaa.provider.saml;

import lombok.extern.slf4j.Slf4j;
import org.cloudfoundry.identity.uaa.provider.AbstractIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.util.KeyWithCert;
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.ZoneAware;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.util.Assert;

import java.util.List;
import java.util.Optional;

@Slf4j
public class ConfiguratorRelyingPartyRegistrationRepository extends BaseUaaRelyingPartyRegistrationRepository
implements RelyingPartyRegistrationRepository, ZoneAware {
public class ConfiguratorRelyingPartyRegistrationRepository extends BaseUaaRelyingPartyRegistrationRepository {

private final SamlIdentityProviderConfigurator configurator;

Expand All @@ -38,19 +36,26 @@ public ConfiguratorRelyingPartyRegistrationRepository(String uaaWideSamlEntityID
@Override
public RelyingPartyRegistration findByRegistrationId(String registrationId) {
IdentityZone currentZone = retrieveZone();
List<SamlIdentityProviderDefinition> identityProviderDefinitions = configurator.getIdentityProviderDefinitionsForZone(currentZone);
AbstractIdentityProviderDefinition idpDefinition = configurator.getIdentityProviderDefinitionsForOrigin(currentZone, registrationId);
if (idpDefinition == null) {
idpDefinition = configurator.getIdentityProviderDefinitionsForIssuer(currentZone, registrationId);
}
if (idpDefinition instanceof SamlIdentityProviderDefinition foundSamlIdentityProviderDefinition) {
return createRelyingPartyRegistration(foundSamlIdentityProviderDefinition.getIdpEntityAlias(), foundSamlIdentityProviderDefinition, currentZone);
}

List<SamlIdentityProviderDefinition> identityProviderDefinitions = configurator.getIdentityProviderDefinitionsForZone(currentZone);
for (SamlIdentityProviderDefinition identityProviderDefinition : identityProviderDefinitions) {
if (identityProviderDefinition.getIdpEntityAlias().equals(registrationId)) {
return createRelyingPartyRegistration(registrationId, identityProviderDefinition, currentZone);
if (registrationId.equals(identityProviderDefinition.getIdpEntityAlias()) || registrationId.equals(identityProviderDefinition.getIdpEntityId())) {
return createRelyingPartyRegistration(identityProviderDefinition.getIdpEntityAlias(), identityProviderDefinition, currentZone);
}
}

if (!identityProviderDefinitions.isEmpty()) {
// TODO remove hack
if (!identityProviderDefinitions.isEmpty() && identityProviderDefinitions.size() == 1) {
SamlIdentityProviderDefinition identityProviderDefinition = identityProviderDefinitions.get(0);
return createRelyingPartyRegistration(identityProviderDefinition.getIdpEntityAlias(), identityProviderDefinition, currentZone);
}

return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest requ
return null;
}
if (request.getBinding() == Saml2MessageBinding.REDIRECT) {
samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest));
samlRequest = Saml2Utils.samlDecodeAndInflate(samlRequest);
} else {
samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.cloudfoundry.identity.uaa.authentication.UaaPrincipal;
import org.cloudfoundry.identity.uaa.authentication.UaaSamlPrincipal;
import org.cloudfoundry.identity.uaa.authentication.event.IdentityProviderAuthenticationSuccessEvent;
import org.cloudfoundry.identity.uaa.constants.OriginKeys;
import org.cloudfoundry.identity.uaa.provider.IdentityProvider;
import org.cloudfoundry.identity.uaa.provider.IdentityProviderProvisioning;
import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition;
Expand All @@ -37,6 +36,7 @@
import org.opensaml.saml.common.assertion.ValidationContext;
import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.impl.AssertionUnmarshaller;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
Expand Down Expand Up @@ -77,10 +77,12 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;

import static org.cloudfoundry.identity.uaa.constants.OriginKeys.NotANumber;
import static org.cloudfoundry.identity.uaa.constants.OriginKeys.SAML;
import static org.cloudfoundry.identity.uaa.provider.saml.OpenSaml4AuthenticationProvider.createDefaultAssertionValidatorWithParameters;

/**
Expand Down Expand Up @@ -184,36 +186,26 @@ static Converter<OpenSaml4AuthenticationProvider.AssertionToken, AbstractAuthent

@Override
public Authentication convert(HttpServletRequest request) throws AuthenticationException {
RelyingPartyRegistration relyingPartyRegistration = relyingPartyRegistrationResolver.resolve(request, null);

String serializedAssertion = request.getParameter("assertion");
byte[] decodedAssertion = Saml2Utils.samlDecode(serializedAssertion);
byte[] decodedAssertion = Saml2Utils.samlBearerDecode(serializedAssertion);
String assertionXml = new String(decodedAssertion, StandardCharsets.UTF_8);

Assertion assertion = parseAssertion(assertionXml);
RelyingPartyRegistration relyingPartyRegistration = relyingPartyRegistrationResolver.resolve(request, getIssuer(assertion));
IdentityProvider<SamlIdentityProviderDefinition> idp = retrieveSamlIdpSamlIdentityProvider(relyingPartyRegistration.getRegistrationId());
Saml2AuthenticationToken authenticationToken = new Saml2AuthenticationToken(relyingPartyRegistration, assertionXml);
process(authenticationToken, assertion);

String subjectName = assertion.getSubject().getNameID().getValue();
String alias = relyingPartyRegistration.getRegistrationId();
String alias = idp.getOriginKey();
IdentityZone zone = identityZoneManager.getCurrentIdentityZone();

UaaPrincipal initialPrincipal = new UaaPrincipal(NotANumber, subjectName, subjectName,
alias, subjectName, zone.getId());

boolean addNew;
IdentityProvider<SamlIdentityProviderDefinition> idp;
SamlIdentityProviderDefinition samlConfig;
try {
idp = identityProviderProvisioning.retrieveByOrigin(alias, identityZoneManager.getCurrentIdentityZoneId());
samlConfig = idp.getConfig();
addNew = samlConfig.isAddShadowUserOnLogin();
if (!idp.isActive()) {
throw new ProviderNotFoundException("Identity Provider has been disabled by administrator for alias:" + alias);
}
} catch (EmptyResultDataAccessException x) {
throw new ProviderNotFoundException("No SAML identity provider found in zone for alias:" + alias);
}
SamlIdentityProviderDefinition samlConfig = idp.getConfig();
boolean addNew = samlConfig.isAddShadowUserOnLogin();

MultiValueMap<String, String> userAttributes = new LinkedMultiValueMap<>();

Expand All @@ -233,7 +225,7 @@ public Authentication convert(HttpServletRequest request) throws AuthenticationE
authentication.setAuthenticationMethods(Set.of("ext"));
setAuthContextClassRefs(assertion, authentication);

publish(new IdentityProviderAuthenticationSuccessEvent(user, authentication, OriginKeys.SAML, identityZoneManager.getCurrentIdentityZoneId()));
publish(new IdentityProviderAuthenticationSuccessEvent(user, authentication, SAML, identityZoneManager.getCurrentIdentityZoneId()));

AbstractSaml2AuthenticationRequest authenticationRequest = authenticationToken.getAuthenticationRequest();
if (authenticationRequest != null) {
Expand Down Expand Up @@ -309,12 +301,17 @@ private static Assertion parseAssertion(String assertion) throws Saml2Exception,
Element element = document.getDocumentElement();
return (Assertion) assertionUnmarshaller.unmarshall(element);
} catch (Exception ex) {
throw OpenSaml4AuthenticationProvider.createAuthenticationException(Saml2ErrorCodes.INVALID_ASSERTION, ex.getMessage(), ex);
throw OpenSaml4AuthenticationProvider.createAuthenticationException(Saml2ErrorCodes.INVALID_ASSERTION, "Unable to parse bearer assertion", ex);
}
}

private static String getIssuer(Assertion assertion) {
return Optional.ofNullable(assertion.getIssuer()).map(Issuer::getValue)
.orElseThrow(() -> new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_ASSERTION, "Missing issuer in bearer assertion")));
}

private void process(Saml2AuthenticationToken token, Assertion assertion) {
String issuer = assertion.getIssuer().getValue();
String issuer = getIssuer(assertion);
log.debug("Processing SAML response from {}", issuer);

OpenSaml4AuthenticationProvider.AssertionToken assertionToken = new OpenSaml4AuthenticationProvider.AssertionToken(assertion, token);
Expand Down Expand Up @@ -344,4 +341,16 @@ private void process(Saml2AuthenticationToken token, Assertion assertion) {
}
}

private IdentityProvider<SamlIdentityProviderDefinition> retrieveSamlIdpSamlIdentityProvider(String origin) {
try {
IdentityProvider<SamlIdentityProviderDefinition> idp = identityProviderProvisioning.retrieveByOrigin(origin,
identityZoneManager.getCurrentIdentityZoneId());
if (idp == null || !SAML.equals(idp.getType()) || !idp.isActive()) {
throw new ProviderNotFoundException("Identity Provider has been disabled by administrator for alias: " + origin);
}
return idp;
} catch (EmptyResultDataAccessException x) {
throw new ProviderNotFoundException("No SAML identity provider found in zone for alias: " + origin);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.cloudfoundry.identity.uaa.provider.saml;

import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2ErrorCodes;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand Down Expand Up @@ -46,10 +47,22 @@ public static String samlEncode(byte[] b) {
return Base64.getEncoder().encodeToString(b);
}

public static String samlBearerEncode(byte[] b) {
return Base64.getUrlEncoder().encodeToString(b);
}

public static byte[] samlDecode(String s) {
return Base64.getMimeDecoder().decode(s);
}

public static byte[] samlBearerDecode(String s) {
try {
return Base64.getUrlDecoder().decode(s);
} catch (IllegalArgumentException ex) {
throw OpenSaml4AuthenticationProvider.createAuthenticationException(Saml2ErrorCodes.INVALID_ASSERTION, "Unable to urlBase64Decode bearer assertion", ex);
}
}

public static byte[] samlDeflate(String s) {
try {
ByteArrayOutputStream out = new ByteArrayOutputStream();
Expand Down Expand Up @@ -79,8 +92,8 @@ public static String samlInflate(byte[] b) {
* Below are convenience methods not originally in the Spring-Security class
*****************************************************************************/

public static String samlEncode(String s) {
return samlEncode(s.getBytes(StandardCharsets.UTF_8));
public static String samlBearerEncode(String s) {
return samlBearerEncode(s.getBytes(StandardCharsets.UTF_8));
}

public static String samlDecodeAndInflate(String s) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import org.apache.commons.io.IOUtils;
import org.apache.http.client.utils.URIBuilder;
import org.cloudfoundry.identity.uaa.constants.OriginKeys;
import org.cloudfoundry.identity.uaa.provider.AbstractIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.provider.IdentityProvider;
import org.cloudfoundry.identity.uaa.provider.IdentityProviderProvisioning;
import org.cloudfoundry.identity.uaa.provider.IdpAlreadyExistsException;
Expand Down Expand Up @@ -42,6 +43,22 @@ public List<SamlIdentityProviderDefinition> getIdentityProviderDefinitions() {
return getIdentityProviderDefinitionsForZone(identityZoneManager.getCurrentIdentityZone());
}

public AbstractIdentityProviderDefinition getIdentityProviderDefinitionsForOrigin(IdentityZone zone, String origin) {
try {
return providerProvisioning.retrieveByOrigin(origin, zone.getId()).getConfig();
} catch (EmptyResultDataAccessException e) {
return null;
}
}

public AbstractIdentityProviderDefinition getIdentityProviderDefinitionsForIssuer(IdentityZone zone, String issuer) {
try {
return providerProvisioning.retrieveByExternId(issuer, OriginKeys.SAML, zone.getId()).getConfig();
} catch (EmptyResultDataAccessException e) {
return null;
}
}

public List<SamlIdentityProviderDefinition> getIdentityProviderDefinitionsForZone(IdentityZone zone) {
List<SamlIdentityProviderDefinition> result = new LinkedList<>();
for (IdentityProvider<SamlIdentityProviderDefinition> provider : providerProvisioning.retrieveActive(zone.getId())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,8 @@ private AuthnRequest request() {

private String serializedRequest(AuthnRequest request, Saml2MessageBinding binding) {
String xml = serialize(request);
return (binding == Saml2MessageBinding.POST) ? Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8))
: Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
return (binding == Saml2MessageBinding.POST) ? Saml2Utils.samlBearerEncode(xml.getBytes(StandardCharsets.UTF_8))
: Saml2Utils.samlBearerEncode(Saml2Utils.samlDeflate(xml));
}

private Assertion assertion(String inResponseTo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,6 @@ public static String getEncodedAssertion(String issuerEntityId,
assertion = signed(assertion, signingCredential, issuerEntityId);
}
String serialized = Saml2TestUtils.serialize(assertion);
return Saml2Utils.samlEncode(serialized);
return Saml2Utils.samlBearerEncode(serialized);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ void getTokenUsingSaml2BearerGrant() throws Exception {
String idpMetadata = getIdpMetadata(host, origin);
SamlIdentityProviderDefinition idpDef = createLocalSamlIdpDefinition(
origin, testZone.getIdentityZone().getId(), idpMetadata);
idpDef.setIdpEntityId("68uexx.cloudfoundry-saml-login");
IdentityProvider<SamlIdentityProviderDefinition> provider = new IdentityProvider<>();
provider.setConfig(idpDef);
provider.setActive(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ void sendAuthnRequestFromNonDefaultZoneToIdpPostBindingMode() throws Exception {
void receiveAuthnResponseFromIdpToLegacyAliasUrl() throws Exception {

String encodedSamlResponse = serializedResponse(responseWithAssertions());
mockMvc.perform(post("/uaa/saml/SSO/alias/%s".formatted("integration-saml-entity-id"))
mockMvc.perform(post("/uaa/saml/SSO/alias/%s".formatted("testsaml-post-binding"))
.contextPath("/uaa")
.header(HOST, "localhost:8080")
.param(SAML_RESPONSE, encodedSamlResponse)
Expand Down Expand Up @@ -477,7 +477,7 @@ void AuthnResponseSucceedsWithWithInvalidInResponseTo() throws Exception {
Response response = responseWithAssertions();
response.setInResponseTo("incorrect");
String encodedSamlResponse = serializedResponse(response);
mockMvc.perform(post("/uaa/saml/SSO/alias/%s".formatted("integration-saml-entity-id"))
mockMvc.perform(post("/uaa/saml/SSO/alias/%s".formatted("testsaml-post-binding"))
.contextPath("/uaa")
.header(HOST, "localhost:8080")
.param(SAML_RESPONSE, encodedSamlResponse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ void getTokenUsingSaml2BearerGrant() throws Exception {
final String host = "%s.localhost".formatted(subdomain);
final String fullPath = "/uaa/oauth/token/alias/%s.integration-saml-entity-id".formatted(subdomain);
final String origin = "%s.integration-saml-entity-id".formatted(subdomain);
final String entityId = "%s.cloudfoundry-saml-login".formatted(subdomain);
MockMvcUtils.IdentityZoneCreationResult testZone =
MockMvcUtils.createOtherIdentityZoneAndReturnResult(
subdomain, mockMvc, this.webApplicationContext, null,
Expand All @@ -59,6 +60,7 @@ void getTokenUsingSaml2BearerGrant() throws Exception {
String idpMetadata = getIdpMetadata(host, origin);
SamlIdentityProviderDefinition idpDef = createLocalSamlIdpDefinition(
origin, testZone.getIdentityZone().getId(), idpMetadata);
idpDef.setIdpEntityId(entityId);
IdentityProvider<SamlIdentityProviderDefinition> provider = new IdentityProvider<>();
provider.setConfig(idpDef);
provider.setActive(true);
Expand All @@ -71,7 +73,7 @@ void getTokenUsingSaml2BearerGrant() throws Exception {
IdentityZoneHolder.clear();

String spEndpoint = "http://%s:8080/uaa/oauth/token/alias/%s".formatted(host, origin);
String assertionStr = TestOpenSamlObjects.getEncodedAssertion("68uexx.cloudfoundry-saml-login", NameID.UNSPECIFIED,
String assertionStr = TestOpenSamlObjects.getEncodedAssertion(entityId, NameID.UNSPECIFIED,
"Saml2BearerIntegrationUser", spEndpoint, origin, true);

// create a client in the test zone
Expand Down

0 comments on commit d281b28

Please sign in to comment.