Skip to content

Commit

Permalink
Add RelyingPartyRegistrationResolver
Browse files Browse the repository at this point in the history
Closes gh-9486
  • Loading branch information
jzheaux committed Sep 13, 2021
1 parent f5a525e commit 6488295
Show file tree
Hide file tree
Showing 14 changed files with 239 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.web.authentication.AuthenticationConverter;
Expand Down Expand Up @@ -264,7 +265,8 @@ private void setAuthenticationRequestRepository(B http,
private AuthenticationConverter getAuthenticationConverter(B http) {
if (this.authenticationConverter == null) {
return new Saml2AuthenticationTokenConverter(
new DefaultRelyingPartyRegistrationResolver(this.relyingPartyRegistrationRepository));
(RelyingPartyRegistrationResolver) new DefaultRelyingPartyRegistrationResolver(
this.relyingPartyRegistrationRepository));
}
return this.authenticationConverter;
}
Expand Down Expand Up @@ -390,8 +392,9 @@ private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http,
Saml2AuthenticationRequestContextResolver.class);
if (resolver == null) {
return new DefaultSaml2AuthenticationRequestContextResolver(new DefaultRelyingPartyRegistrationResolver(
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository));
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository);
return new DefaultSaml2AuthenticationRequestContextResolver(relyingPartyRegistrationResolver);
}
return resolver;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ There are a number of reasons you may want to customize. Among them:
* You may know that you will never be a multi-tenant application and so want to have a simpler URL scheme
* You may identify tenants in a way other than by the URI path

To customize the way that a `RelyingPartyRegistration` is resolved, you can configure a custom `Converter<HttpServletRequest, RelyingPartyRegistration>`.
To customize the way that a `RelyingPartyRegistration` is resolved, you can configure a custom `RelyingPartyRegistrationResolver`.
The default looks up the registration id from the URI's last path element and looks it up in your `RelyingPartyRegistrationRepository`.

You can provide a simpler resolver that, for example, always returns the same relying party:
Expand All @@ -736,22 +736,27 @@ You can provide a simpler resolver that, for example, always returns the same re
.Java
[source,java,role="primary"]
----
public class SingleRelyingPartyRegistrationResolver
implements Converter<HttpServletRequest, RelyingPartyRegistration> {
public class SingleRelyingPartyRegistrationResolver implements RelyingPartyRegistrationResolver {
private final RelyingPartyRegistrationResolver delegate;
public SingleRelyingPartyRegistrationResolver(RelyingPartyRegistrationRepository registrations) {
this.delegate = new DefaultRelyingPartyRegistrationResolver(registrations);
}
@Override
public RelyingPartyRegistration convert(HttpServletRequest request) {
return this.relyingParty;
public RelyingPartyRegistration resolve(HttpServletRequest request, String registrationId) {
return this.delegate.resolve(request, "single");
}
}
----
.Kotlin
[source,kotlin,role="secondary"]
----
class SingleRelyingPartyRegistrationResolver : Converter<HttpServletRequest?, RelyingPartyRegistration?> {
override fun convert(request: HttpServletRequest?): RelyingPartyRegistration? {
return this.relyingParty
class SingleRelyingPartyRegistrationResolver(delegate: RelyingPartyRegistrationResolver) : RelyingPartyRegistrationResolver {
override fun resolve(request: HttpServletRequest?, registrationId: String?): RelyingPartyRegistration? {
return this.delegate.resolve(request, "single")
}
}
----
Expand Down Expand Up @@ -1544,7 +1549,7 @@ You can publish a metadata endpoint by adding the `Saml2MetadataFilter` to the f
.Java
[source,java,role="primary"]
----
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver =
DefaultRelyingPartyRegistrationResolver relyingPartyRegistrationResolver =
new DefaultRelyingPartyRegistrationResolver(this.relyingPartyRegistrationRepository);
Saml2MetadataFilter filter = new Saml2MetadataFilter(
relyingPartyRegistrationResolver,
Expand Down Expand Up @@ -1594,8 +1599,6 @@ filter.setRequestMatcher(AntPathRequestMatcher("/saml2/metadata/{registrationId}
----
====

ensuring that the `registrationId` hint is at the end of the path.

Or, if you have registered a custom relying party registration resolver in the constructor, then you can specify a path without a `registrationId` hint, like so:

====
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.security.web.authentication.AuthenticationConverter;
Expand Down Expand Up @@ -67,7 +68,9 @@ public Saml2WebSsoAuthenticationFilter(RelyingPartyRegistrationRepository relyin
public Saml2WebSsoAuthenticationFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
String filterProcessesUrl) {
this(new Saml2AuthenticationTokenConverter(
new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), filterProcessesUrl);
(RelyingPartyRegistrationResolver) new DefaultRelyingPartyRegistrationResolver(
relyingPartyRegistrationRepository)),
filterProcessesUrl);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
Expand Down Expand Up @@ -96,7 +97,9 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
public Saml2WebSsoAuthenticationRequestFilter(
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
this(new DefaultSaml2AuthenticationRequestContextResolver(
new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), requestFactory());
(RelyingPartyRegistrationResolver) new DefaultRelyingPartyRegistrationResolver(
relyingPartyRegistrationRepository)),
requestFactory());
}

private static Saml2AuthenticationRequestFactory requestFactory() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

import javax.servlet.http.HttpServletRequest;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
Expand All @@ -42,28 +45,51 @@
* @since 5.4
*/
public final class DefaultRelyingPartyRegistrationResolver
implements Converter<HttpServletRequest, RelyingPartyRegistration> {
implements RelyingPartyRegistrationResolver, Converter<HttpServletRequest, RelyingPartyRegistration> {

private Log logger = LogFactory.getLog(getClass());

private static final char PATH_DELIMITER = '/';

private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;

private final Converter<HttpServletRequest, String> registrationIdResolver = new RegistrationIdResolver();
private final RequestMatcher registrationRequestMatcher = new AntPathRequestMatcher("/**/{registrationId}");

public DefaultRelyingPartyRegistrationResolver(
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
}

/**
* {@inheritDoc}
*/
@Override
public RelyingPartyRegistration convert(HttpServletRequest request) {
String registrationId = this.registrationIdResolver.convert(request);
if (registrationId == null) {
return resolve(request, null);
}

/**
* {@inheritDoc}
*/
@Override
public RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId) {
if (relyingPartyRegistrationId == null) {
if (this.logger.isTraceEnabled()) {
this.logger.trace("Attempting to resolve from " + this.registrationRequestMatcher
+ " since registrationId is null");
}
relyingPartyRegistrationId = this.registrationRequestMatcher.matcher(request).getVariables()
.get("registrationId");
}
if (relyingPartyRegistrationId == null) {
if (this.logger.isTraceEnabled()) {
this.logger.trace("Returning null registration since registrationId is null");
}
return null;
}
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository
.findByRegistrationId(registrationId);
.findByRegistrationId(relyingPartyRegistrationId);
if (relyingPartyRegistration == null) {
return null;
}
Expand Down Expand Up @@ -111,16 +137,4 @@ private static String getApplicationUri(HttpServletRequest request) {
return uriComponents.toUriString();
}

private static class RegistrationIdResolver implements Converter<HttpServletRequest, String> {

private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}");

@Override
public String convert(HttpServletRequest request) {
RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
return result.getVariables().get("registrationId");
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,24 @@ public final class DefaultSaml2AuthenticationRequestContextResolver

private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;

/**
* Construct a {@link DefaultSaml2AuthenticationRequestContextResolver}
* @param relyingPartyRegistrationResolver
* @deprecated Use
* {@link DefaultSaml2AuthenticationRequestContextResolver#DefaultSaml2AuthenticationRequestContextResolver(RelyingPartyRegistrationResolver)}
* instead
*/
@Deprecated
public DefaultSaml2AuthenticationRequestContextResolver(
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
}

public DefaultSaml2AuthenticationRequestContextResolver(
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
this.relyingPartyRegistrationResolver = (request) -> relyingPartyRegistrationResolver.resolve(request, null);
}

@Override
public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.saml2.provider.service.web;

import javax.servlet.http.HttpServletRequest;

import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;

/**
* A contract for resolving a {@link RelyingPartyRegistration} from the HTTP request
*
* @author Josh Cummings
* @since 5.6
*/
public interface RelyingPartyRegistrationResolver {

/**
* Resolve a {@link RelyingPartyRegistration} from the HTTP request, using the
* {@code relyingPartyRegistrationId}, if it is provided
* @param request the HTTP request
* @param relyingPartyRegistrationId the {@link RelyingPartyRegistration} identifier
* @return the resolved {@link RelyingPartyRegistration}
*/
RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId);

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,28 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
* resolving {@link RelyingPartyRegistration}s
* @param relyingPartyRegistrationResolver the strategy for resolving
* {@link RelyingPartyRegistration}s
* @deprecated Use
* {@link Saml2AuthenticationTokenConverter#Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver)}
* instead
*/
@Deprecated
public Saml2AuthenticationTokenConverter(
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
}

public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
this(adaptToConverter(relyingPartyRegistrationResolver));
}

private static Converter<HttpServletRequest, RelyingPartyRegistration> adaptToConverter(
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
return (request) -> relyingPartyRegistrationResolver.resolve(request, null);
}

@Override
public Saml2AuthenticationToken convert(HttpServletRequest request) {
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {

public static final String DEFAULT_METADATA_FILE_NAME = "saml-{registrationId}-metadata.xml";

private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter;
private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;

private final Saml2MetadataResolver saml2MetadataResolver;

Expand All @@ -55,11 +55,26 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
private RequestMatcher requestMatcher = new AntPathRequestMatcher(
"/saml2/service-provider-metadata/{registrationId}");

public Saml2MetadataFilter(
Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter,
/**
* Construct a {@link Saml2MetadataFilter}
* @param relyingPartyRegistrationResolver
* @param saml2MetadataResolver
* @deprecated Use
* {@link Saml2MetadataFilter#Saml2MetadataFilter(RelyingPartyRegistrationResolver)}
* instead
*/
@Deprecated
public Saml2MetadataFilter(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver,
Saml2MetadataResolver saml2MetadataResolver) {
this.relyingPartyRegistrationResolver = (request, id) -> relyingPartyRegistrationResolver.convert(request);
this.saml2MetadataResolver = saml2MetadataResolver;
}

this.relyingPartyRegistrationConverter = relyingPartyRegistrationConverter;
public Saml2MetadataFilter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver,
Saml2MetadataResolver saml2MetadataResolver) {
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
Assert.notNull(saml2MetadataResolver, "saml2MetadataResolver cannot be null");
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
this.saml2MetadataResolver = saml2MetadataResolver;
}

Expand All @@ -71,14 +86,15 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
chain.doFilter(request, response);
return;
}
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationConverter.convert(request);
String registrationId = matcher.getVariables().get("registrationId");
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
registrationId);
if (relyingPartyRegistration == null) {
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
return;
}
String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration);
String registrationId = relyingPartyRegistration.getRegistrationId();
writeMetadataToResponse(response, registrationId, metadata);
writeMetadataToResponse(response, relyingPartyRegistration.getRegistrationId(), metadata);
}

private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata)
Expand Down
Loading

0 comments on commit 6488295

Please sign in to comment.