Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Saml2AuthenticationRequestRepository #10060

Merged
merged 2 commits into from
Jul 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,13 +33,16 @@
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
marcusdacoregio marked this conversation as resolved.
Show resolved Hide resolved
import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
Expand Down Expand Up @@ -206,6 +209,7 @@ public void init(B http) throws Exception {
}
this.saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter(getAuthenticationConverter(http),
this.loginProcessingUrl);
setAuthenticationRequestRepository(http, this.saml2WebSsoAuthenticationFilter);
setAuthenticationFilter(this.saml2WebSsoAuthenticationFilter);
super.loginProcessingUrl(this.loginProcessingUrl);
if (StringUtils.hasText(this.loginPage)) {
Expand Down Expand Up @@ -252,6 +256,11 @@ public void configure(B http) throws Exception {
}
}

private void setAuthenticationRequestRepository(B http,
Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter) {
saml2WebSsoAuthenticationFilter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
}

private AuthenticationConverter getAuthenticationConverter(B http) {
if (this.authenticationConverter == null) {
return new Saml2AuthenticationTokenConverter(
Expand Down Expand Up @@ -302,6 +311,16 @@ private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUr
return idps;
}

private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> getAuthenticationRequestRepository(
B http) {
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getBeanOrNull(http,
Saml2AuthenticationRequestRepository.class);
if (repository == null) {
return new HttpSessionSaml2AuthenticationRequestRepository();
}
return repository;
}

private <C> C getSharedOrBean(B http, Class<C> clazz) {
C shared = http.getSharedObject(clazz);
if (shared != null) {
Expand Down Expand Up @@ -339,8 +358,12 @@ private AuthenticationRequestEndpointConfig() {
private Filter build(B http) {
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
return postProcess(
new Saml2WebSsoAuthenticationRequestFilter(contextResolver, authenticationRequestResolver));
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getAuthenticationRequestRepository(
http);
Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(contextResolver,
authenticationRequestResolver);
filter.setAuthenticationRequestRepository(repository);
return postProcess(filter);
}

private Saml2AuthenticationRequestFactory getResolver(B http) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2Utils;
import org.springframework.security.saml2.core.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
Expand All @@ -76,9 +77,11 @@
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.context.HttpRequestResponseHolder;
Expand Down Expand Up @@ -183,7 +186,8 @@ public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() t
this.spring.register(CustomAuthenticationRequestContextResolver.class).autowire();
Saml2AuthenticationRequestContext context = TestSaml2AuthenticationRequestContexts
.authenticationRequestContext().build();
Saml2AuthenticationRequestContextResolver resolver = CustomAuthenticationRequestContextResolver.resolver;
Saml2AuthenticationRequestContextResolver resolver = this.spring.getContext()
.getBean(Saml2AuthenticationRequestContextResolver.class);
given(resolver.resolve(any(HttpServletRequest.class))).willReturn(context);
this.mvc.perform(get("/saml2/authenticate/registration-id")).andExpect(status().isFound());
verify(resolver).resolve(any(HttpServletRequest.class));
Expand Down Expand Up @@ -237,6 +241,29 @@ public void authenticateWithInvalidDeflatedSAMLResponseThenFailureHandlerUses()
assertThat(exception.getCause()).isInstanceOf(IOException.class);
}

@Test
public void authenticationRequestWhenCustomAuthnRequestRepositoryThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestRepository.class).autowire();
MockHttpServletRequestBuilder request = get("/saml2/authenticate/registration-id");
this.mvc.perform(request).andExpect(status().isFound());
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = this.spring.getContext()
.getBean(Saml2AuthenticationRequestRepository.class);
verify(repository).saveAuthenticationRequest(any(AbstractSaml2AuthenticationRequest.class),
any(HttpServletRequest.class), any(HttpServletResponse.class));
}

@Test
public void authenticateWhenCustomAuthnRequestRepositoryThenUses() throws Exception {
this.spring.register(CustomAuthenticationRequestRepository.class).autowire();
MockHttpServletRequestBuilder request = post("/login/saml2/sso/registration-id").param("SAMLResponse",
SIGNED_RESPONSE);
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = this.spring.getContext()
.getBean(Saml2AuthenticationRequestRepository.class);
this.mvc.perform(request);
verify(repository).loadAuthenticationRequest(any(HttpServletRequest.class));
verify(repository).removeAuthenticationRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
}

private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
// get the OpenSamlAuthenticationProvider
Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
Expand Down Expand Up @@ -355,7 +382,7 @@ protected void configure(HttpSecurity http) throws Exception {
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestContextResolver extends WebSecurityConfigurerAdapter {

private static final Saml2AuthenticationRequestContextResolver resolver = mock(
private final Saml2AuthenticationRequestContextResolver resolver = mock(
Saml2AuthenticationRequestContextResolver.class);

@Override
Expand All @@ -371,7 +398,7 @@ protected void configure(HttpSecurity http) throws Exception {

@Bean
Saml2AuthenticationRequestContextResolver resolver() {
return resolver;
return this.resolver;
}

}
Expand Down Expand Up @@ -420,6 +447,27 @@ protected void configure(HttpSecurity http) throws Exception {

}

@EnableWebSecurity
@Import(Saml2LoginConfigBeans.class)
static class CustomAuthenticationRequestRepository {

private final Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = mock(
Saml2AuthenticationRequestRepository.class);

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
http.authorizeRequests((authz) -> authz.anyRequest().authenticated());
http.saml2Login(withDefaults());
return http.build();
}

@Bean
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository() {
return this.repository;
}

}

static class Saml2LoginConfigBeans {

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1610,3 +1610,33 @@ http {
The success handler will send logout requests to the asserting party.

The request matcher will detect logout requests from the asserting party.

[[servlet-saml2login-store-authn-request]]
=== Storing the `AuthnRequest`

The `Saml2AuthenticationRequestRepository` is responsible for the persistence of the `AuthnRequest` from the time the `AuthnRequest` <<servlet-saml2login-sp-initiated-factory,is initiated>> to the time the `SAMLResponse` <<servlet-saml2login-authenticate-responses,is received>>.
The `Saml2AuthenticationTokenConverter` is responsible for loading the `AuthnRequest` from the `Saml2AuthenticationRequestRepository` and saving it into the `Saml2AuthenticationToken`.

The default implementation of `Saml2AuthenticationRequestRepository` is `HttpSessionSaml2AuthenticationRequestRepository`, which stores the `AuthnRequest` in the `HttpSession`.

If you have a custom implementation of `Saml2AuthenticationRequestRepository`, you may configure it by exposing it as a `@Bean` as shown in the following example:

====
.Java
[source,java,role="primary"]
----
@Bean
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository() {
return new CustomSaml2AuthenticationRequestRepository();
}
----

.Kotlin
[source,kotlin,role="secondary"]
----
@Bean
open fun authenticationRequestRepository(): Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> {
return CustomSaml2AuthenticationRequestRepository()
}
----
====
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dependencies {
testImplementation "org.junit.jupiter:junit-jupiter-params"
testImplementation "org.junit.jupiter:junit-jupiter-engine"
testImplementation "org.mockito:mockito-core"
testImplementation "org.mockito:mockito-inline"
testImplementation "org.mockito:mockito-junit-jupiter"
testImplementation "org.springframework:spring-test"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,8 +38,10 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {

private final String saml2Response;

private final AbstractSaml2AuthenticationRequest authenticationRequest;

/**
* Creates a {@link Saml2AuthenticationToken} with the provided parameters
* Creates a {@link Saml2AuthenticationToken} with the provided parameters.
*
* Note that the given {@link RelyingPartyRegistration} should have all its templates
* resolved at this point. See
Expand All @@ -48,15 +50,35 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
* @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to
* use
* @param saml2Response the SAML 2.0 response to authenticate
* @param authenticationRequest the {@code AuthNRequest} sent to the asserting party
*
* @since 5.4
* @since 5.6
*/
public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) {
public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response,
AbstractSaml2AuthenticationRequest authenticationRequest) {
super(Collections.emptyList());
marcusdacoregio marked this conversation as resolved.
Show resolved Hide resolved
Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null");
Assert.notNull(saml2Response, "saml2Response cannot be null");
this.relyingPartyRegistration = relyingPartyRegistration;
this.saml2Response = saml2Response;
this.authenticationRequest = authenticationRequest;
}

/**
* Creates a {@link Saml2AuthenticationToken} with the provided parameters
*
* Note that the given {@link RelyingPartyRegistration} should have all its templates
* resolved at this point. See
* {@link org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter}
* for an example of performing that resolution.
* @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to
* use
* @param saml2Response the SAML 2.0 response to authenticate
*
* @since 5.4
*/
public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) {
this(relyingPartyRegistration, saml2Response, null);
}

/**
Expand All @@ -81,6 +103,7 @@ public Saml2AuthenticationToken(String saml2Response, String recipientUri, Strin
.entityId(idpEntityId).singleSignOnServiceLocation(idpEntityId))
.build();
this.saml2Response = saml2Response;
this.authenticationRequest = null;
}

/**
Expand Down Expand Up @@ -179,4 +202,14 @@ public String getIdpEntityId() {
return this.relyingPartyRegistration.getAssertingPartyDetails().getEntityId();
}

/**
* Returns the authentication request sent to the assertion party or {@code null} if
* no authentication request is present
* @return the authentication request sent to the assertion party
* @since 5.6
*/
public AbstractSaml2AuthenticationRequest getAuthenticationRequest() {
return this.authenticationRequest;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.saml2.provider.service.servlet;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;

import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;

/**
* A {@link Saml2AuthenticationRequestRepository} implementation that uses
* {@link HttpSession} to store and retrieve the
* {@link AbstractSaml2AuthenticationRequest}
*
* @author Marcus Da Coregio
* @since 5.6
*/
public class HttpSessionSaml2AuthenticationRequestRepository
implements Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> {

private static final String DEFAULT_SAML2_AUTHN_REQUEST_ATTR_NAME = HttpSessionSaml2AuthenticationRequestRepository.class
.getName().concat(".SAML2_AUTHN_REQUEST");

private String saml2AuthnRequestAttributeName = DEFAULT_SAML2_AUTHN_REQUEST_ATTR_NAME;

@Override
public AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
HttpSession httpSession = request.getSession(false);
if (httpSession == null) {
return null;
}
return (AbstractSaml2AuthenticationRequest) httpSession.getAttribute(this.saml2AuthnRequestAttributeName);
}

@Override
public void saveAuthenticationRequest(AbstractSaml2AuthenticationRequest authenticationRequest,
HttpServletRequest request, HttpServletResponse response) {
if (authenticationRequest == null) {
removeAuthenticationRequest(request, response);
return;
}
HttpSession httpSession = request.getSession();
httpSession.setAttribute(this.saml2AuthnRequestAttributeName, authenticationRequest);
}

@Override
public AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request,
HttpServletResponse response) {
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
if (authenticationRequest == null) {
return null;
}
HttpSession httpSession = request.getSession();
httpSession.removeAttribute(this.saml2AuthnRequestAttributeName);
jzheaux marked this conversation as resolved.
Show resolved Hide resolved
return authenticationRequest;
}

}
Loading