Skip to content

Commit

Permalink
Add Saml2AuthenticationRequestResolver
Browse files Browse the repository at this point in the history
Closes gh-10355
  • Loading branch information
jzheaux committed Jan 24, 2022
1 parent cca35bd commit d538423
Show file tree
Hide file tree
Showing 15 changed files with 1,404 additions and 187 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,8 +19,6 @@
import java.util.LinkedHashMap;
import java.util.Map;

import javax.servlet.Filter;

import org.opensaml.core.Version;

import org.springframework.beans.factory.NoSuchBeanDefinitionException;
Expand Down Expand Up @@ -50,6 +48,7 @@
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
Expand Down Expand Up @@ -115,9 +114,11 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>

private String loginPage;

private String loginProcessingUrl = Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI;
private String authenticationRequestUri = "/saml2/authenticate/{registrationId}";

private Saml2AuthenticationRequestResolver authenticationRequestResolver;

private AuthenticationRequestEndpointConfig authenticationRequestEndpoint = new AuthenticationRequestEndpointConfig();
private String loginProcessingUrl = Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI;

private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;

Expand Down Expand Up @@ -176,6 +177,20 @@ public Saml2LoginConfigurer<B> loginPage(String loginPage) {
return this;
}

/**
* Use this {@link Saml2AuthenticationRequestResolver} for generating SAML 2.0
* Authentication Requests.
* @param authenticationRequestResolver
* @return the {@link Saml2LoginConfigurer} for further configuration
* @since 5.7
*/
public Saml2LoginConfigurer<B> authenticationRequestResolver(
Saml2AuthenticationRequestResolver authenticationRequestResolver) {
Assert.notNull(authenticationRequestResolver, "authenticationRequestResolver cannot be null");
this.authenticationRequestResolver = authenticationRequestResolver;
return this;
}

/**
* Specifies the URL to validate the credentials. If specified a custom URL, consider
* specifying a custom {@link AuthenticationConverter} via
Expand All @@ -200,7 +215,7 @@ protected RequestMatcher createLoginProcessingUrlMatcher(String loginProcessingU

/**
* {@inheritDoc}
*
* <p>
* Initializes this filter chain for SAML 2 Login. The following actions are taken:
* <ul>
* <li>The WebSSO endpoint has CSRF disabled, typically {@code /login/saml2/sso}</li>
Expand All @@ -226,8 +241,8 @@ public void init(B http) throws Exception {
super.init(http);
}
else {
Map<String, String> providerUrlMap = getIdentityProviderUrlMap(
this.authenticationRequestEndpoint.filterProcessingUrl, this.relyingPartyRegistrationRepository);
Map<String, String> providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri,
this.relyingPartyRegistrationRepository);
boolean singleProvider = providerUrlMap.size() == 1;
if (singleProvider) {
// Setup auto-redirect to provider login page
Expand All @@ -247,14 +262,16 @@ public void init(B http) throws Exception {

/**
* {@inheritDoc}
*
* <p>
* During the {@code configure} phase, a
* {@link Saml2WebSsoAuthenticationRequestFilter} is added to handle SAML 2.0
* AuthNRequest redirects
*/
@Override
public void configure(B http) throws Exception {
http.addFilter(this.authenticationRequestEndpoint.build(http));
Saml2WebSsoAuthenticationRequestFilter filter = getAuthenticationRequestFilter(http);
filter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
http.addFilter(postProcess(filter));
super.configure(http);
if (this.authenticationManager == null) {
registerDefaultAuthenticationProvider(http);
Expand All @@ -264,6 +281,11 @@ public void configure(B http) throws Exception {
}
}

private RelyingPartyRegistrationResolver relyingPartyRegistrationResolver(B http) {
RelyingPartyRegistrationRepository registrations = relyingPartyRegistrationRepository(http);
return new DefaultRelyingPartyRegistrationResolver(registrations);
}

RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(B http) {
if (this.relyingPartyRegistrationRepository == null) {
this.relyingPartyRegistrationRepository = getSharedOrBean(http, RelyingPartyRegistrationRepository.class);
Expand All @@ -276,6 +298,46 @@ private void setAuthenticationRequestRepository(B http,
saml2WebSsoAuthenticationFilter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
}

private Saml2WebSsoAuthenticationRequestFilter getAuthenticationRequestFilter(B http) {
Saml2AuthenticationRequestResolver authenticationRequestResolver = getAuthenticationRequestResolver(http);
if (authenticationRequestResolver != null) {
return new Saml2WebSsoAuthenticationRequestFilter(authenticationRequestResolver);
}
return new Saml2WebSsoAuthenticationRequestFilter(getAuthenticationRequestContextResolver(http),
getAuthenticationRequestFactory(http));
}

private Saml2AuthenticationRequestResolver getAuthenticationRequestResolver(B http) {
if (this.authenticationRequestResolver != null) {
return this.authenticationRequestResolver;
}
return getBeanOrNull(http, Saml2AuthenticationRequestResolver.class);
}

private Saml2AuthenticationRequestFactory getAuthenticationRequestFactory(B http) {
Saml2AuthenticationRequestFactory resolver = getSharedOrBean(http, Saml2AuthenticationRequestFactory.class);
if (resolver != null) {
return resolver;
}
if (version().startsWith("4")) {
return new OpenSaml4AuthenticationRequestFactory();
}
else {
return new OpenSamlAuthenticationRequestFactory();
}
}

private Saml2AuthenticationRequestContextResolver getAuthenticationRequestContextResolver(B http) {
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http,
Saml2AuthenticationRequestContextResolver.class);
if (resolver != null) {
return resolver;
}
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
this.relyingPartyRegistrationRepository);
return new DefaultSaml2AuthenticationRequestContextResolver(registrationResolver);
}

private AuthenticationConverter getAuthenticationConverter(B http) {
if (this.authenticationConverter != null) {
return this.authenticationConverter;
Expand Down Expand Up @@ -324,8 +386,8 @@ private void initDefaultLoginFilter(B http) {
return;
}
loginPageGeneratingFilter.setSaml2LoginEnabled(true);
loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap(
this.authenticationRequestEndpoint.filterProcessingUrl, this.relyingPartyRegistrationRepository));
loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(
this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository));
loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage());
loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl());
}
Expand Down Expand Up @@ -379,46 +441,4 @@ private <C> void setSharedObject(B http, Class<C> clazz, C object) {
}
}

private final class AuthenticationRequestEndpointConfig {

private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";

private AuthenticationRequestEndpointConfig() {
}

private Filter build(B http) {
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getAuthenticationRequestRepository(
http);
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(contextResolver,
authenticationRequestResolver);
filter.setAuthenticationRequestRepository(repository);
return postProcess(filter);
}

private Saml2AuthenticationRequestFactory getResolver(B http) {
Saml2AuthenticationRequestFactory resolver = getSharedOrBean(http, Saml2AuthenticationRequestFactory.class);
if (resolver == null) {
if (version().startsWith("4")) {
return new OpenSaml4AuthenticationRequestFactory();
}
return new OpenSamlAuthenticationRequestFactory();
}
return resolver;
}

private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http,
Saml2AuthenticationRequestContextResolver.class);
if (resolver == null) {
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository);
return new DefaultSaml2AuthenticationRequestContextResolver(relyingPartyRegistrationResolver);
}
return resolver;
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -80,9 +80,13 @@
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver;
import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.AuthenticationConverter;
Expand All @@ -104,6 +108,7 @@
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
Expand Down Expand Up @@ -211,6 +216,41 @@ public void authenticationRequestWhenAuthnRequestContextConverterThenUses() thro
assertThat(inflated).contains("ForceAuthn=\"true\"");
}

@Test
public void authenticationRequestWhenAuthenticationRequestResolverBeanThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestResolverBean.class).autowire();
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
assertThat(inflated).contains("ForceAuthn=\"true\"");
}

@Test
public void authenticationRequestWhenAuthenticationRequestResolverDslThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestResolverDsl.class).autowire();
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
assertThat(inflated).contains("ForceAuthn=\"true\"");
}

@Test
public void authenticationRequestWhenAuthenticationRequestResolverAndFactoryThenResolverTakesPrecedence()
throws Exception {
this.spring.register(CustomAuthenticationRequestResolverPrecedence.class).autowire();
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id")).andReturn();
UriComponents components = UriComponentsBuilder.fromHttpUrl(result.getResponse().getRedirectedUrl()).build();
String samlRequest = components.getQueryParams().getFirst("SAMLRequest");
String decoded = URLDecoder.decode(samlRequest, "UTF-8");
String inflated = Saml2Utils.samlInflate(Saml2Utils.samlDecode(decoded));
assertThat(inflated).contains("ForceAuthn=\"true\"");
verifyNoInteractions(this.spring.getContext().getBean(Saml2AuthenticationRequestFactory.class));
}

@Test
public void authenticateWhenCustomAuthenticationConverterThenUses() throws Exception {
this.spring.register(CustomAuthenticationConverter.class).autowire();
Expand Down Expand Up @@ -506,6 +546,103 @@ Saml2AuthenticationRequestFactory authenticationRequestFactory() {

}

@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestResolverBean {

@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeRequests((authz) -> authz
.anyRequest().authenticated()
)
.saml2Login(Customizer.withDefaults());
// @formatter:on

return http.build();
}

@Bean
Saml2AuthenticationRequestResolver authenticationRequestResolver(
RelyingPartyRegistrationRepository registrations) {
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
registrations);
OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
registrationResolver);
delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
return delegate;
}

}

@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestResolverDsl {

@Bean
SecurityFilterChain filterChain(HttpSecurity http, RelyingPartyRegistrationRepository registrations)
throws Exception {
// @formatter:off
http
.authorizeRequests((authz) -> authz
.anyRequest().authenticated()
)
.saml2Login((saml2) -> saml2
.authenticationRequestResolver(authenticationRequestResolver(registrations))
);
// @formatter:on

return http.build();
}

Saml2AuthenticationRequestResolver authenticationRequestResolver(
RelyingPartyRegistrationRepository registrations) {
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
registrations);
OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
registrationResolver);
delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
return delegate;
}

}

@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestResolverPrecedence {

@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeRequests((authz) -> authz
.anyRequest().authenticated()
)
.saml2Login(Customizer.withDefaults());
// @formatter:on

return http.build();
}

@Bean
Saml2AuthenticationRequestFactory authenticationRequestFactory() {
return mock(Saml2AuthenticationRequestFactory.class);
}

@Bean
Saml2AuthenticationRequestResolver authenticationRequestResolver(
RelyingPartyRegistrationRepository registrations) {
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(
registrations);
OpenSaml4AuthenticationRequestResolver delegate = new OpenSaml4AuthenticationRequestResolver(
registrationResolver);
delegate.setAuthnRequestCustomizer((parameters) -> parameters.getAuthnRequest().setForceAuthn(true));
return delegate;
}

}

@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationConverter extends WebSecurityConfigurerAdapter {
Expand Down
Loading

0 comments on commit d538423

Please sign in to comment.