diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java index a8447e7d141..c81d7b07f8d 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java @@ -34,6 +34,7 @@ import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.util.Assert; @@ -171,6 +172,8 @@ public final class AuthorizationCodeGrantConfigurer { private AuthorizationRequestRepository authorizationRequestRepository; + private RedirectStrategy authorizationRedirectStrategy; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; private AuthorizationCodeGrantConfigurer() { @@ -202,6 +205,17 @@ public AuthorizationCodeGrantConfigurer authorizationRequestRepository( return this; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + * @return the {@link AuthorizationCodeGrantConfigurer} for further configuration + */ + public AuthorizationCodeGrantConfigurer authorizationRedirectStrategy( + RedirectStrategy authorizationRedirectStrategy) { + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + return this; + } + /** * Sets the client used for requesting the access token credential from the Token * Endpoint. @@ -247,6 +261,9 @@ private OAuth2AuthorizationRequestRedirectFilter createAuthorizationRequestRedir authorizationRequestRedirectFilter .setAuthorizationRequestRepository(this.authorizationRequestRepository); } + if (this.authorizationRedirectStrategy != null) { + authorizationRequestRedirectFilter.setAuthorizationRedirectStrategy(this.authorizationRedirectStrategy); + } RequestCache requestCache = builder.getSharedObject(RequestCache.class); if (requestCache != null) { authorizationRequestRedirectFilter.setRequestCache(requestCache); diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index be1707e3946..ef3d2dc9c21 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -68,6 +68,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; @@ -367,6 +368,10 @@ public void configure(B http) throws Exception { authorizationRequestFilter .setAuthorizationRequestRepository(this.authorizationEndpointConfig.authorizationRequestRepository); } + if (this.authorizationEndpointConfig.authorizationRedirectStrategy != null) { + authorizationRequestFilter + .setAuthorizationRedirectStrategy(this.authorizationEndpointConfig.authorizationRedirectStrategy); + } RequestCache requestCache = http.getSharedObject(RequestCache.class); if (requestCache != null) { authorizationRequestFilter.setRequestCache(requestCache); @@ -539,6 +544,8 @@ public final class AuthorizationEndpointConfig { private AuthorizationRequestRepository authorizationRequestRepository; + private RedirectStrategy authorizationRedirectStrategy; + private AuthorizationEndpointConfig() { } @@ -581,6 +588,17 @@ public AuthorizationEndpointConfig authorizationRequestRepository( return this; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + * @return the {@link AuthorizationEndpointConfig} for further configuration + */ + public AuthorizationEndpointConfig authorizationRedirectStrategy( + RedirectStrategy authorizationRedirectStrategy) { + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + return this; + } + /** * Returns the {@link OAuth2LoginConfigurer} for further configuration. * @return the {@link OAuth2LoginConfigurer} diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java index f2c1ebd0f09..5f039548fe4 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java @@ -44,6 +44,8 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { private static final String ATT_AUTHORIZATION_REQUEST_RESOLVER_REF = "authorization-request-resolver-ref"; + private static final String ATT_AUTHORIZATION_REDIRECT_STRATEGY_REF = "authorization-redirect-strategy-ref"; + private static final String ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF = "access-token-response-client-ref"; private final BeanReference requestCache; @@ -83,6 +85,7 @@ public BeanDefinition parse(Element element, ParserContext parserContext) { } BeanMetadataElement authorizationRequestRepository = getAuthorizationRequestRepository( authorizationCodeGrantElt); + BeanMetadataElement authorizationRedirectStrategy = getAuthorizationRedirectStrategy(authorizationCodeGrantElt); BeanDefinitionBuilder authorizationRequestRedirectFilterBuilder = BeanDefinitionBuilder .rootBeanDefinition(OAuth2AuthorizationRequestRedirectFilter.class); String authorizationRequestResolverRef = (authorizationCodeGrantElt != null) @@ -95,6 +98,7 @@ public BeanDefinition parse(Element element, ParserContext parserContext) { } this.authorizationRequestRedirectFilter = authorizationRequestRedirectFilterBuilder .addPropertyValue("authorizationRequestRepository", authorizationRequestRepository) + .addPropertyValue("authorizationRedirectStrategy", authorizationRedirectStrategy) .addPropertyValue("requestCache", this.requestCache).getBeanDefinition(); BeanDefinitionBuilder authorizationCodeGrantFilterBldr = BeanDefinitionBuilder .rootBeanDefinition(OAuth2AuthorizationCodeGrantFilter.class) @@ -126,6 +130,16 @@ private BeanMetadataElement getAuthorizationRequestRepository(Element element) { .getBeanDefinition(); } + private BeanMetadataElement getAuthorizationRedirectStrategy(Element element) { + String authorizationRedirectStrategyRef = (element != null) + ? element.getAttribute(ATT_AUTHORIZATION_REDIRECT_STRATEGY_REF) : null; + if (StringUtils.hasText(authorizationRedirectStrategyRef)) { + return new RuntimeBeanReference(authorizationRedirectStrategyRef); + } + return BeanDefinitionBuilder.rootBeanDefinition("org.springframework.security.web.DefaultRedirectStrategy") + .getBeanDefinition(); + } + private BeanMetadataElement getAccessTokenResponseClient(Element element) { String accessTokenResponseClientRef = (element != null) ? element.getAttribute(ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF) : null; diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java index 288b09072e8..1b8efc6695e 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java @@ -87,6 +87,8 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { private static final String ATT_AUTHORIZATION_REQUEST_RESOLVER_REF = "authorization-request-resolver-ref"; + private static final String ATT_AUTHORIZATION_REDIRECT_STRATEGY_REF = "authorization-redirect-strategy-ref"; + private static final String ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF = "access-token-response-client-ref"; private static final String ATT_USER_AUTHORITIES_MAPPER_REF = "user-authorities-mapper-ref"; @@ -199,6 +201,7 @@ public BeanDefinition parse(Element element, ParserContext parserContext) { } oauth2AuthorizationRequestRedirectFilterBuilder .addPropertyValue("authorizationRequestRepository", authorizationRequestRepository) + .addPropertyValue("authorizationRedirectStrategy", getAuthorizationRedirectStrategy(element)) .addPropertyValue("requestCache", this.requestCache); this.oauth2AuthorizationRequestRedirectFilter = oauth2AuthorizationRequestRedirectFilterBuilder .getBeanDefinition(); @@ -261,6 +264,15 @@ private BeanMetadataElement getAuthorizationRequestRepository(Element element) { .getBeanDefinition(); } + private BeanMetadataElement getAuthorizationRedirectStrategy(Element element) { + String authorizationRedirectStrategyRef = element.getAttribute(ATT_AUTHORIZATION_REDIRECT_STRATEGY_REF); + if (StringUtils.hasText(authorizationRedirectStrategyRef)) { + return new RuntimeBeanReference(authorizationRedirectStrategyRef); + } + return BeanDefinitionBuilder.rootBeanDefinition("org.springframework.security.web.DefaultRedirectStrategy") + .getBeanDefinition(); + } + private BeanDefinition getOidcAuthProvider(Element element, BeanMetadataElement accessTokenResponseClient, String userAuthoritiesMapperRef) { boolean oidcAuthenticationProviderEnabled = ClassUtils diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 880c22630bd..d87bd03d86e 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -102,12 +102,14 @@ import org.springframework.security.web.PortMapper; import org.springframework.security.web.authentication.preauth.x509.SubjectDnX509PrincipalExtractor; import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; +import org.springframework.security.web.server.DefaultServerRedirectStrategy; import org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint; import org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry; import org.springframework.security.web.server.ExchangeMatcherRedirectWebFilter; import org.springframework.security.web.server.MatcherSecurityWebFilterChain; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.ServerAuthenticationEntryPoint; +import org.springframework.security.web.server.ServerRedirectStrategy; import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilter; import org.springframework.security.web.server.authentication.AuthenticationConverterServerWebExchangeMatcher; import org.springframework.security.web.server.authentication.AuthenticationWebFilter; @@ -3375,6 +3377,8 @@ public final class OAuth2LoginSpec { private ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver; + private ServerRedirectStrategy authorizationRedirectStrategy; + private ServerWebExchangeMatcher authenticationMatcher; private ServerAuthenticationSuccessHandler authenticationSuccessHandler; @@ -3547,6 +3551,16 @@ public OAuth2LoginSpec authorizationRequestResolver( return this; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + * @return the {@link OAuth2LoginSpec} for further configuration + */ + public OAuth2LoginSpec authorizationRedirectStrategy(ServerRedirectStrategy authorizationRedirectStrategy) { + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + return this; + } + /** * Sets the {@link ServerWebExchangeMatcher matcher} used for determining if the * request is an authentication request. @@ -3581,7 +3595,9 @@ protected void configure(ServerHttpSecurity http) { OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter(); ServerAuthorizationRequestRepository authorizationRequestRepository = getAuthorizationRequestRepository(); oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository); + oauthRedirectFilter.setAuthorizationRedirectStrategy(getAuthorizationRedirectStrategy()); oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); + ReactiveAuthenticationManager manager = getAuthenticationManager(); AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository); @@ -3591,6 +3607,7 @@ protected void configure(ServerHttpSecurity http) { authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http)); authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler()); authenticationFilter.setSecurityContextRepository(this.securityContextRepository); + setDefaultEntryPoints(http); http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION); @@ -3737,6 +3754,13 @@ private ServerAuthorizationRequestRepository getAuth return this.authorizationRequestRepository; } + private ServerRedirectStrategy getAuthorizationRedirectStrategy() { + if (this.authorizationRedirectStrategy == null) { + this.authorizationRedirectStrategy = new DefaultServerRedirectStrategy(); + } + return this.authorizationRedirectStrategy; + } + private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() { ReactiveOAuth2AuthorizedClientService bean = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class); if (bean != null) { @@ -3759,6 +3783,8 @@ public final class OAuth2ClientSpec { private ServerAuthorizationRequestRepository authorizationRequestRepository; + private ServerRedirectStrategy authorizationRedirectStrategy; + private OAuth2ClientSpec() { } @@ -3851,6 +3877,23 @@ private ServerAuthorizationRequestRepository getAuth return this.authorizationRequestRepository; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + * @return the {@link OAuth2ClientSpec} for further configuration + */ + public OAuth2ClientSpec authorizationRedirectStrategy(ServerRedirectStrategy authorizationRedirectStrategy) { + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + return this; + } + + private ServerRedirectStrategy getAuthorizationRedirectStrategy() { + if (this.authorizationRedirectStrategy == null) { + this.authorizationRedirectStrategy = new DefaultServerRedirectStrategy(); + } + return this.authorizationRedirectStrategy; + } + /** * Allows method chaining to continue configuring the {@link ServerHttpSecurity} * @return the {@link ServerHttpSecurity} to continue configuring @@ -3870,12 +3913,15 @@ protected void configure(ServerHttpSecurity http) { if (http.requestCache != null) { codeGrantWebFilter.setRequestCache(http.requestCache.requestCache); } + OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter( clientRegistrationRepository); oauthRedirectFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + oauthRedirectFilter.setAuthorizationRedirectStrategy(getAuthorizationRedirectStrategy()); if (http.requestCache != null) { oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); } + http.addFilterAt(codeGrantWebFilter, SecurityWebFiltersOrder.OAUTH2_AUTHORIZATION_CODE); http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); } diff --git a/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDsl.kt b/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDsl.kt index 35356b09e12..735d75c0b39 100644 --- a/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDsl.kt @@ -23,6 +23,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest +import org.springframework.security.web.RedirectStrategy /** * A Kotlin DSL to configure OAuth 2.0 Authorization Code Grant. @@ -31,6 +32,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ * @since 5.3 * @property authorizationRequestResolver the resolver used for resolving [OAuth2AuthorizationRequest]'s. * @property authorizationRequestRepository the repository used for storing [OAuth2AuthorizationRequest]'s. + * @property authorizationRedirectStrategy the redirect strategy for Authorization Endpoint redirect URI. * @property accessTokenResponseClient the client used for requesting the access token credential * from the Token Endpoint. */ @@ -38,12 +40,14 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ class AuthorizationCodeGrantDsl { var authorizationRequestResolver: OAuth2AuthorizationRequestResolver? = null var authorizationRequestRepository: AuthorizationRequestRepository? = null + var authorizationRedirectStrategy: RedirectStrategy? = null var accessTokenResponseClient: OAuth2AccessTokenResponseClient? = null internal fun get(): (OAuth2ClientConfigurer.AuthorizationCodeGrantConfigurer) -> Unit { return { authorizationCodeGrant -> authorizationRequestResolver?.also { authorizationCodeGrant.authorizationRequestResolver(authorizationRequestResolver) } authorizationRequestRepository?.also { authorizationCodeGrant.authorizationRequestRepository(authorizationRequestRepository) } + authorizationRedirectStrategy?.also { authorizationCodeGrant.authorizationRedirectStrategy(authorizationRedirectStrategy) } accessTokenResponseClient?.also { authorizationCodeGrant.accessTokenResponseClient(accessTokenResponseClient) } } } diff --git a/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDsl.kt b/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDsl.kt index 96289fa825d..160efb90815 100644 --- a/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDsl.kt @@ -21,6 +21,7 @@ import org.springframework.security.config.annotation.web.configurers.oauth2.cli import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest +import org.springframework.security.web.RedirectStrategy /** * A Kotlin DSL to configure the Authorization Server's Authorization Endpoint using @@ -31,18 +32,21 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ * @property baseUri the base URI used for authorization requests. * @property authorizationRequestResolver the resolver used for resolving [OAuth2AuthorizationRequest]'s. * @property authorizationRequestRepository the repository used for storing [OAuth2AuthorizationRequest]'s. + * @property authorizationRedirectStrategy the redirect strategy for Authorization Endpoint redirect URI. */ @OAuth2LoginSecurityMarker class AuthorizationEndpointDsl { var baseUri: String? = null var authorizationRequestResolver: OAuth2AuthorizationRequestResolver? = null var authorizationRequestRepository: AuthorizationRequestRepository? = null + var authorizationRedirectStrategy: RedirectStrategy? = null internal fun get(): (OAuth2LoginConfigurer.AuthorizationEndpointConfig) -> Unit { return { authorizationEndpoint -> baseUri?.also { authorizationEndpoint.baseUri(baseUri) } authorizationRequestResolver?.also { authorizationEndpoint.authorizationRequestResolver(authorizationRequestResolver) } authorizationRequestRepository?.also { authorizationEndpoint.authorizationRequestRepository(authorizationRequestRepository) } + authorizationRedirectStrategy?.also { authorizationEndpoint.authorizationRedirectStrategy(authorizationRedirectStrategy) } } } } diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDsl.kt index 6751d242963..edd50e33457 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDsl.kt @@ -22,6 +22,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest +import org.springframework.security.web.server.ServerRedirectStrategy import org.springframework.security.web.server.authentication.ServerAuthenticationConverter import org.springframework.web.server.ServerWebExchange @@ -37,6 +38,7 @@ import org.springframework.web.server.ServerWebExchange * @property clientRegistrationRepository the repository of client registrations. * @property authorizedClientRepository the repository for authorized client(s). * @property authorizationRequestRepository the repository to use for storing [OAuth2AuthorizationRequest]s. + * @property authorizationRedirectStrategy the redirect strategy for Authorization Endpoint redirect URI. */ @ServerSecurityMarker class ServerOAuth2ClientDsl { @@ -45,6 +47,7 @@ class ServerOAuth2ClientDsl { var clientRegistrationRepository: ReactiveClientRegistrationRepository? = null var authorizedClientRepository: ServerOAuth2AuthorizedClientRepository? = null var authorizationRequestRepository: ServerAuthorizationRequestRepository? = null + var authorizationRedirectStrategy: ServerRedirectStrategy? = null internal fun get(): (ServerHttpSecurity.OAuth2ClientSpec) -> Unit { return { oauth2Client -> @@ -53,6 +56,7 @@ class ServerOAuth2ClientDsl { clientRegistrationRepository?.also { oauth2Client.clientRegistrationRepository(clientRegistrationRepository) } authorizedClientRepository?.also { oauth2Client.authorizedClientRepository(authorizedClientRepository) } authorizationRequestRepository?.also { oauth2Client.authorizationRequestRepository(authorizationRequestRepository) } + authorizationRedirectStrategy?.also { oauth2Client.authorizationRedirectStrategy(authorizationRedirectStrategy) } } } } diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDsl.kt index 0c24340fbb3..4ab8fcb0e45 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDsl.kt @@ -24,6 +24,7 @@ import org.springframework.security.oauth2.client.web.server.ServerAuthorization import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest +import org.springframework.security.web.server.ServerRedirectStrategy import org.springframework.security.web.server.authentication.ServerAuthenticationConverter import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler @@ -49,6 +50,7 @@ import org.springframework.web.server.ServerWebExchange * @property authorizedClientRepository the repository for authorized client(s). * @property authorizationRequestRepository the repository to use for storing [OAuth2AuthorizationRequest]s. * @property authorizationRequestResolver the resolver used for resolving [OAuth2AuthorizationRequest]s. + * @property authorizationRedirectStrategy the redirect strategy for Authorization Endpoint redirect URI. * @property authenticationMatcher the [ServerWebExchangeMatcher] used for determining if the request is an * authentication request. */ @@ -64,6 +66,7 @@ class ServerOAuth2LoginDsl { var authorizedClientRepository: ServerOAuth2AuthorizedClientRepository? = null var authorizationRequestRepository: ServerAuthorizationRequestRepository? = null var authorizationRequestResolver: ServerOAuth2AuthorizationRequestResolver? = null + var authorizationRedirectStrategy: ServerRedirectStrategy? = null var authenticationMatcher: ServerWebExchangeMatcher? = null internal fun get(): (ServerHttpSecurity.OAuth2LoginSpec) -> Unit { @@ -78,6 +81,7 @@ class ServerOAuth2LoginDsl { authorizedClientRepository?.also { oauth2Login.authorizedClientRepository(authorizedClientRepository) } authorizationRequestRepository?.also { oauth2Login.authorizationRequestRepository(authorizationRequestRepository) } authorizationRequestResolver?.also { oauth2Login.authorizationRequestResolver(authorizationRequestResolver) } + authorizationRedirectStrategy?.also { oauth2Login.authorizationRedirectStrategy(authorizationRedirectStrategy) } authenticationMatcher?.also { oauth2Login.authenticationMatcher(authenticationMatcher) } } } diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc index 8b70128c5f3..4620e16ac9d 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc +++ b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc @@ -495,6 +495,9 @@ oauth2-login.attlist &= oauth2-login.attlist &= ## Reference to the OAuth2AuthorizationRequestResolver attribute authorization-request-resolver-ref {xsd:token}? +oauth2-login.attlist &= + ## Reference to the authorization RedirectStrategy + attribute authorization-redirect-strategy-ref {xsd:token}? oauth2-login.attlist &= ## Reference to the OAuth2AccessTokenResponseClient attribute access-token-response-client-ref {xsd:token}? @@ -542,6 +545,9 @@ authorization-code-grant = authorization-code-grant.attlist &= ## Reference to the AuthorizationRequestRepository attribute authorization-request-repository-ref {xsd:token}? +authorization-code-grant.attlist &= + ## Reference to the authorization RedirectStrategy + attribute authorization-redirect-strategy-ref {xsd:token}? authorization-code-grant.attlist &= ## Reference to the OAuth2AuthorizationRequestResolver attribute authorization-request-resolver-ref {xsd:token}? diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd index 53191b5cb90..777e7a6f515 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd +++ b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd @@ -1603,6 +1603,12 @@ + + + Reference to the authorization RedirectStrategy + + + Reference to the OAuth2AccessTokenResponseClient @@ -1706,6 +1712,12 @@ + + + Reference to the authorization RedirectStrategy + + + Reference to the OAuth2AuthorizationRequestResolver diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java index f2e6a85f9a2..faeaceed2c0 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java @@ -58,6 +58,8 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.web.DefaultRedirectStrategy; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -68,6 +70,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -95,6 +98,8 @@ public class OAuth2ClientConfigurerTests { private static OAuth2AuthorizationRequestResolver authorizationRequestResolver; + private static RedirectStrategy authorizationRedirectStrategy; + private static OAuth2AccessTokenResponseClient accessTokenResponseClient; private static RequestCache requestCache; @@ -130,6 +135,7 @@ public void setup() { authorizedClientService); authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository, "/oauth2/authorization"); + authorizationRedirectStrategy = new DefaultRedirectStrategy(); OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234") .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(300).build(); accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); @@ -261,6 +267,19 @@ public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationR verify(authorizationRequestResolver).resolve(any()); } + @Test + public void configureWhenCustomAuthorizationRedirectStrategySetThenAuthorizationRedirectStrategyUsed() + throws Exception { + authorizationRedirectStrategy = mock(RedirectStrategy.class); + this.spring.register(OAuth2ClientConfig.class).autowire(); + // @formatter:off + this.mockMvc.perform(get("/oauth2/authorization/registration-1")) + .andExpect(status().isOk()) + .andReturn(); + // @formatter:on + verify(authorizationRedirectStrategy).sendRedirect(any(), any(), anyString()); + } + @EnableWebSecurity @EnableWebMvc static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter { @@ -278,6 +297,7 @@ protected void configure(HttpSecurity http) throws Exception { .oauth2Client() .authorizationCodeGrant() .authorizationRequestResolver(authorizationRequestResolver) + .authorizationRedirectStrategy(authorizationRedirectStrategy) .accessTokenResponseClient(accessTokenResponseClient); // @formatter:on } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index d591b9b1070..b0f5a73ae94 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -87,6 +87,7 @@ import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.web.FilterChainProxy; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.authentication.HttpStatusEntryPoint; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; @@ -98,7 +99,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.mock; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; @@ -357,6 +360,32 @@ public void requestWhenOauth2LoginWithCustomAuthorizationRequestParametersThenPa "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); } + @Test + public void oauth2LoginWithAuthorizationRedirectStrategyThenCustomAuthorizationRedirectStrategyUsed() + throws Exception { + loadConfig(OAuth2LoginConfigCustomAuthorizationRedirectStrategy.class); + RedirectStrategy redirectStrategy = this.context + .getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategy.class).redirectStrategy; + String requestUri = "/oauth2/authorization/google"; + this.request = new MockHttpServletRequest("GET", requestUri); + this.request.setServletPath(requestUri); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); + then(redirectStrategy).should().sendRedirect(any(), any(), anyString()); + } + + @Test + public void requestWhenOauth2LoginWithCustomAuthorizationRedirectStrategyThenCustomAuthorizationRedirectStrategyUsed() + throws Exception { + loadConfig(OAuth2LoginConfigCustomAuthorizationRedirectStrategyInLambda.class); + RedirectStrategy redirectStrategy = this.context + .getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategyInLambda.class).redirectStrategy; + String requestUri = "/oauth2/authorization/google"; + this.request = new MockHttpServletRequest("GET", requestUri); + this.request.setServletPath(requestUri); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); + then(redirectStrategy).should().sendRedirect(any(), any(), anyString()); + } + // gh-5347 @Test public void oauth2LoginWithOneClientConfiguredThenRedirectForAuthorization() throws Exception { @@ -858,6 +887,59 @@ protected void configure(HttpSecurity http) throws Exception { } + @EnableWebSecurity + static class OAuth2LoginConfigCustomAuthorizationRedirectStrategy extends CommonWebSecurityConfigurerAdapter { + + private final ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository( + GOOGLE_CLIENT_REGISTRATION); + + RedirectStrategy redirectStrategy = mock(RedirectStrategy.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2Login((oauth2Login) -> + oauth2Login + .clientRegistrationRepository(this.clientRegistrationRepository) + .authorizationEndpoint((authorizationEndpoint) -> + authorizationEndpoint + .authorizationRedirectStrategy(this.redirectStrategy) + ) + ); + // @formatter:on + super.configure(http); + } + + } + + @EnableWebSecurity + static class OAuth2LoginConfigCustomAuthorizationRedirectStrategyInLambda + extends CommonLambdaWebSecurityConfigurerAdapter { + + private final ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository( + GOOGLE_CLIENT_REGISTRATION); + + RedirectStrategy redirectStrategy = mock(RedirectStrategy.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2Login((oauth2Login) -> + oauth2Login + .clientRegistrationRepository(this.clientRegistrationRepository) + .authorizationEndpoint((authorizationEndpoint) -> + authorizationEndpoint + .authorizationRedirectStrategy(this.redirectStrategy) + ) + ); + // @formatter:on + super.configure(http); + } + + } + @EnableWebSecurity static class OAuth2LoginConfigMultipleClients extends CommonWebSecurityConfigurerAdapter { diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java index 0e9806118a2..f3dbab941cf 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java @@ -44,6 +44,7 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.support.WithMockUser; +import org.springframework.security.web.RedirectStrategy; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -55,6 +56,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -90,6 +92,9 @@ public class OAuth2ClientBeanDefinitionParserTests { @Autowired(required = false) private OAuth2AuthorizationRequestResolver authorizationRequestResolver; + @Autowired(required = false) + private RedirectStrategy authorizationRedirectStrategy; + @Autowired(required = false) private OAuth2AccessTokenResponseClient accessTokenResponseClient; @@ -148,6 +153,16 @@ public void requestWhenCustomAuthorizationRequestResolverThenCalled() throws Exc verify(this.authorizationRequestResolver).resolve(any()); } + @Test + public void requestWhenCustomAuthorizationRedirectStrategyThenCalled() throws Exception { + this.spring.configLocations(xml("CustomAuthorizationRedirectStrategy")).autowire(); + // @formatter:off + this.mvc.perform(get("/oauth2/authorization/google")) + .andExpect(status().isOk()); + // @formatter:on + verify(this.authorizationRedirectStrategy).sendRedirect(any(), any(), anyString()); + } + @Test public void requestWhenAuthorizationResponseMatchThenProcess() throws Exception { this.spring.configLocations(xml("CustomConfiguration")).autowire(); diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java index 38f43a0911f..8b98a9a9a01 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java @@ -63,6 +63,7 @@ import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.support.WithMockUser; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.savedrequest.RequestCache; @@ -77,6 +78,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -116,6 +118,9 @@ public class OAuth2LoginBeanDefinitionParserTests { @Autowired(required = false) private OAuth2AuthorizationRequestResolver authorizationRequestResolver; + @Autowired(required = false) + private RedirectStrategy authorizationRedirectStrategy; + @Autowired(required = false) private OAuth2AccessTokenResponseClient accessTokenResponseClient; @@ -373,6 +378,17 @@ public void requestWhenCustomAuthorizationRequestResolverThenCalled() throws Exc verify(this.authorizationRequestResolver).resolve(any()); } + @Test + public void requestWhenCustomAuthorizationRedirectStrategyThenCalled() throws Exception { + this.spring.configLocations(this.xml("SingleClientRegistration-WithCustomAuthorizationRedirectStrategy")) + .autowire(); + // @formatter:off + this.mvc.perform(get("/oauth2/authorization/google-login")) + .andExpect(status().isOk()); + // @formatter:on + verify(this.authorizationRedirectStrategy).sendRedirect(any(), any(), anyString()); + } + // gh-5347 @Test public void requestWhenMultiClientRegistrationThenRedirectDefaultLoginPage() throws Exception { diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index 3b8c6d97f83..be097e4b07c 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -39,14 +39,18 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter; import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; +import org.springframework.security.web.server.DefaultServerRedirectStrategy; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.ServerAuthenticationEntryPoint; +import org.springframework.security.web.server.ServerRedirectStrategy; import org.springframework.security.web.server.WebFilterChainProxy; import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; @@ -76,6 +80,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -531,6 +536,90 @@ public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() { verify(authorizationRequestRepository).removeAuthorizationRequest(any()); } + @Test + public void shouldUseDefaultAuthorizationRedirectStrategyForOAuth2Login() { + ReactiveClientRegistrationRepository clientRegistrationRepository = mock( + ReactiveClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(anyString())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); + + SecurityWebFilterChain securityFilterChain = this.http.oauth2Login() + .clientRegistrationRepository(clientRegistrationRepository).and().build(); + + WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); + client.get().uri("/oauth2/authorization/registration-id").exchange().expectStatus().is3xxRedirection(); + + OAuth2AuthorizationRequestRedirectWebFilter filter = getWebFilter(securityFilterChain, + OAuth2AuthorizationRequestRedirectWebFilter.class).get(); + assertThat(ReflectionTestUtils.getField(filter, "authorizationRedirectStrategy")) + .isInstanceOf(DefaultServerRedirectStrategy.class); + } + + @Test + public void shouldConfigureAuthorizationRedirectStrategyForOAuth2Login() { + ServerRedirectStrategy authorizationRedirectStrategy = mock(ServerRedirectStrategy.class); + ReactiveClientRegistrationRepository clientRegistrationRepository = mock( + ReactiveClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(anyString())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); + given(authorizationRedirectStrategy.sendRedirect(any(), any())).willReturn(Mono.empty()); + + SecurityWebFilterChain securityFilterChain = this.http.oauth2Login() + .clientRegistrationRepository(clientRegistrationRepository) + .authorizationRedirectStrategy(authorizationRedirectStrategy).and().build(); + + WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); + client.get().uri("/oauth2/authorization/registration-id").exchange(); + verify(authorizationRedirectStrategy).sendRedirect(any(), any()); + + OAuth2AuthorizationRequestRedirectWebFilter filter = getWebFilter(securityFilterChain, + OAuth2AuthorizationRequestRedirectWebFilter.class).get(); + assertThat(ReflectionTestUtils.getField(filter, "authorizationRedirectStrategy")) + .isSameAs(authorizationRedirectStrategy); + } + + @Test + public void shouldUseDefaultAuthorizationRedirectStrategyForOAuth2Client() { + ReactiveClientRegistrationRepository clientRegistrationRepository = mock( + ReactiveClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(anyString())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); + + SecurityWebFilterChain securityFilterChain = this.http.oauth2Client() + .clientRegistrationRepository(clientRegistrationRepository).and().build(); + + WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); + client.get().uri("/oauth2/authorization/registration-id").exchange().expectStatus().is3xxRedirection(); + + OAuth2AuthorizationRequestRedirectWebFilter filter = getWebFilter(securityFilterChain, + OAuth2AuthorizationRequestRedirectWebFilter.class).get(); + assertThat(ReflectionTestUtils.getField(filter, "authorizationRedirectStrategy")) + .isInstanceOf(DefaultServerRedirectStrategy.class); + } + + @Test + public void shouldConfigureAuthorizationRedirectStrategyForOAuth2Client() { + ServerRedirectStrategy authorizationRedirectStrategy = mock(ServerRedirectStrategy.class); + ReactiveClientRegistrationRepository clientRegistrationRepository = mock( + ReactiveClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(anyString())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); + given(authorizationRedirectStrategy.sendRedirect(any(), any())).willReturn(Mono.empty()); + + SecurityWebFilterChain securityFilterChain = this.http.oauth2Client() + .clientRegistrationRepository(clientRegistrationRepository) + .authorizationRedirectStrategy(authorizationRedirectStrategy).and().build(); + + WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); + client.get().uri("/oauth2/authorization/registration-id").exchange(); + verify(authorizationRedirectStrategy).sendRedirect(any(), any()); + + OAuth2AuthorizationRequestRedirectWebFilter filter = getWebFilter(securityFilterChain, + OAuth2AuthorizationRequestRedirectWebFilter.class).get(); + assertThat(ReflectionTestUtils.getField(filter, "authorizationRedirectStrategy")) + .isSameAs(authorizationRedirectStrategy); + } + private boolean isX509Filter(WebFilter filter) { try { Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter"); diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt index f4cb3ed6e59..8d2cec76c1b 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt @@ -43,6 +43,8 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames +import org.springframework.security.web.DefaultRedirectStrategy +import org.springframework.security.web.RedirectStrategy import org.springframework.security.web.SecurityFilterChain import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.get @@ -104,6 +106,40 @@ class AuthorizationCodeGrantDslTests { } } + @Test + fun `oauth2Client when custom authorization redirect strategy then redirect strategy used`() { + this.spring.register(RedirectStrategyConfig::class.java, ClientConfig::class.java).autowire() + mockkObject(RedirectStrategyConfig.REDIRECT_STRATEGY) + every { RedirectStrategyConfig.REDIRECT_STRATEGY.sendRedirect(any(), any(), any()) } + + this.mockMvc.get("/oauth2/authorization/registrationId") + + verify(exactly = 1) { RedirectStrategyConfig.REDIRECT_STRATEGY.sendRedirect(any(), any(), any()) } + } + + @EnableWebSecurity + open class RedirectStrategyConfig { + + companion object { + val REDIRECT_STRATEGY: RedirectStrategy = DefaultRedirectStrategy() + } + + @Bean + open fun securityFilterChain(http: HttpSecurity): SecurityFilterChain { + http { + oauth2Client { + authorizationCodeGrant { + authorizationRedirectStrategy = REDIRECT_STRATEGY + } + } + authorizeRequests { + authorize(anyRequest, authenticated) + } + } + return http.build() + } + } + @Test fun `oauth2Client when custom access token response client then client used`() { this.spring.register(AuthorizedClientConfig::class.java, ClientConfig::class.java).autowire() diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDslTests.kt index 5571688f5a2..1801f5d9546 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDslTests.kt @@ -37,6 +37,8 @@ import org.springframework.security.oauth2.client.web.AuthorizationRequestReposi import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest +import org.springframework.security.web.DefaultRedirectStrategy +import org.springframework.security.web.RedirectStrategy import org.springframework.security.web.SecurityFilterChain import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.get @@ -125,6 +127,37 @@ class AuthorizationEndpointDslTests { } } + @Test + fun `oauth2Login when custom authorization redirect strategy then redirect strategy used`() { + this.spring.register(RedirectStrategyConfig::class.java, ClientConfig::class.java).autowire() + mockkObject(RedirectStrategyConfig.REDIRECT_STRATEGY) + every { RedirectStrategyConfig.REDIRECT_STRATEGY.sendRedirect(any(), any(), any()) } + + this.mockMvc.get("/oauth2/authorization/google") + + verify(exactly = 1) { RedirectStrategyConfig.REDIRECT_STRATEGY.sendRedirect(any(), any(), any()) } + } + + @EnableWebSecurity + open class RedirectStrategyConfig { + + companion object { + val REDIRECT_STRATEGY: RedirectStrategy = DefaultRedirectStrategy() + } + + @Bean + open fun securityFilterChain(http: HttpSecurity): SecurityFilterChain { + http { + oauth2Login { + authorizationEndpoint { + authorizationRedirectStrategy = REDIRECT_STRATEGY + } + } + } + return http.build() + } + } + @Test fun `oauth2Login when custom authorization uri repository then uri used`() { this.spring.register(AuthorizationUriConfig::class.java, ClientConfig::class.java).autowire() diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDslTests.kt index e93816ebdb2..a1dc851c1c2 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDslTests.kt @@ -39,7 +39,9 @@ import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2Ser import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames import org.springframework.security.oauth2.server.resource.web.server.ServerBearerTokenAuthenticationConverter +import org.springframework.security.web.server.DefaultServerRedirectStrategy import org.springframework.security.web.server.SecurityWebFilterChain +import org.springframework.security.web.server.ServerRedirectStrategy import org.springframework.security.web.server.authentication.ServerAuthenticationConverter import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.web.reactive.config.EnableWebFlux @@ -130,6 +132,41 @@ class ServerOAuth2ClientDslTests { } } + @Test + fun `OAuth2 client when authorization redirect strategy configured then custom redirect strategy used`() { + this.spring.register(AuthorizationRedirectStrategyConfig::class.java, ClientConfig::class.java).autowire() + mockkObject(AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY) + every { + AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY.sendRedirect(any(), any()) + } returns Mono.empty() + + this.client.get() + .uri("/oauth2/authorization/google") + .exchange() + + verify(exactly = 1) { + AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY.sendRedirect(any(), any()) + } + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class AuthorizationRedirectStrategyConfig { + + companion object { + val AUTHORIZATION_REDIRECT_STRATEGY : ServerRedirectStrategy = DefaultServerRedirectStrategy() + } + + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + oauth2Client { + authorizationRedirectStrategy = AUTHORIZATION_REDIRECT_STRATEGY + } + } + } + } + @Test fun `OAuth2 client when authentication converter configured then custom converter used`() { this.spring.register(AuthenticationConverterConfig::class.java, ClientConfig::class.java).autowire() diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDslTests.kt index 8c73b263e21..5fd23b57aa7 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDslTests.kt @@ -35,7 +35,9 @@ import org.springframework.security.oauth2.client.web.server.ServerAuthorization import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest import org.springframework.security.oauth2.server.resource.web.server.ServerBearerTokenAuthenticationConverter +import org.springframework.security.web.server.DefaultServerRedirectStrategy import org.springframework.security.web.server.SecurityWebFilterChain +import org.springframework.security.web.server.ServerRedirectStrategy import org.springframework.security.web.server.authentication.ServerAuthenticationConverter import org.springframework.security.web.server.util.matcher.IpAddressServerWebExchangeMatcher import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher @@ -141,6 +143,38 @@ class ServerOAuth2LoginDslTests { } } + @Test + fun `OAuth2 login when authorization redirect strategy configured then custom redirect strategy used`() { + this.spring.register(AuthorizationRedirectStrategyConfig::class.java, ClientConfig::class.java).autowire() + mockkObject(AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY) + every { + AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY.sendRedirect(any(), any()) + } returns Mono.empty() + this.client.get() + .uri("/oauth2/authorization/google") + .exchange() + + verify(exactly = 1) { AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY.sendRedirect(any(), any()) } + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class AuthorizationRedirectStrategyConfig { + + companion object { + val AUTHORIZATION_REDIRECT_STRATEGY : ServerRedirectStrategy = DefaultServerRedirectStrategy() + } + + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + oauth2Login { + authorizationRedirectStrategy = AUTHORIZATION_REDIRECT_STRATEGY + } + } + } + } + @Test fun `OAuth2 login when authentication matcher configured then custom matcher used`() { this.spring.register(AuthenticationMatcherConfig::class.java, ClientConfig::class.java).autowire() diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-CustomAuthorizationRedirectStrategy.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-CustomAuthorizationRedirectStrategy.xml new file mode 100644 index 00000000000..d7ff4139094 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-CustomAuthorizationRedirectStrategy.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-SingleClientRegistration-WithCustomAuthorizationRedirectStrategy.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-SingleClientRegistration-WithCustomAuthorizationRedirectStrategy.xml new file mode 100644 index 00000000000..8454de28f19 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-SingleClientRegistration-WithCustomAuthorizationRedirectStrategy.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + diff --git a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc index 3b11c9872cc..3b2c950fd22 100644 --- a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc +++ b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc @@ -983,6 +983,11 @@ Reference to the `AuthorizationRequestRepository`. Reference to the `OAuth2AuthorizationRequestResolver`. +[[nsa-oauth2-login-authorization-redirect-strategy-ref]] +* **authorization-redirect-strategy-ref** +Reference to the authorization `RedirectStrategy`. + + [[nsa-oauth2-login-access-token-response-client-ref]] * **access-token-response-client-ref** Reference to the `OAuth2AccessTokenResponseClient`. @@ -1083,6 +1088,11 @@ Configures xref:servlet/oauth2/client/authorization-grants.adoc#oauth2Client-aut Reference to the `AuthorizationRequestRepository`. +[[nsa-authorization-code-grant-authorization-redirect-strategy-ref]] +* **authorization-redirect-strategy-ref** +Reference to the authorization `RedirectStrategy`. + + [[nsa-authorization-code-grant-authorization-request-resolver-ref]] * **authorization-request-resolver-ref** Reference to the `OAuth2AuthorizationRequestResolver`. diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java index 2bf35d43b63..e780c7f846f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java @@ -95,7 +95,7 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer(); - private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy(); + private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy(); private OAuth2AuthorizationRequestResolver authorizationRequestResolver; @@ -139,6 +139,15 @@ public OAuth2AuthorizationRequestRedirectFilter(OAuth2AuthorizationRequestResolv this.authorizationRequestResolver = authorizationRequestResolver; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + */ + public void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) { + Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be null"); + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + } + /** * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. * @param authorizationRequestRepository the repository used for storing diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java index 1b20821a561..f667c918ed8 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java @@ -75,7 +75,7 @@ */ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter { - private final ServerRedirectStrategy authorizationRedirectStrategy = new DefaultServerRedirectStrategy(); + private ServerRedirectStrategy authorizationRedirectStrategy = new DefaultServerRedirectStrategy(); private final ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver; @@ -105,6 +105,15 @@ public OAuth2AuthorizationRequestRedirectWebFilter( this.authorizationRequestResolver = authorizationRequestResolver; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + */ + public void setAuthorizationRedirectStrategy(ServerRedirectStrategy authorizationRedirectStrategy) { + Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be null"); + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + } + /** * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. * @param authorizationRequestRepository the repository used for storing diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java index 6684f5510b6..6250de1d745 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.client.web; import java.lang.reflect.Constructor; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -29,7 +30,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; @@ -39,6 +42,7 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.util.ClassUtils; import org.springframework.web.util.UriComponentsBuilder; @@ -115,6 +119,11 @@ public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryI assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null)); } + @Test + public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRedirectStrategy(null)); + } + @Test public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null)); @@ -332,4 +341,31 @@ public void doFilterWhenAuthorizationRequestAndCustomAuthorizationRequestUriSetT + "login_hint=user@provider\\.com"); } + @Test + public void doFilterWhenCustomAuthorizationRedirectStrategySetThenCustomAuthorizationRedirectStrategyUsed() + throws Exception { + String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + + this.registration1.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + RedirectStrategy customRedirectStrategy = (httpRequest, httpResponse, url) -> { + String redirectUrl = httpResponse.encodeRedirectURL(url); + httpResponse.setStatus(HttpStatus.OK.value()); + httpResponse.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE); + httpResponse.getWriter().write(redirectUrl); + httpResponse.getWriter().flush(); + }; + this.filter.setAuthorizationRedirectStrategy(customRedirectStrategy); + this.filter.doFilter(request, response, filterChain); + verifyZeroInteractions(filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + assertThat(response.getContentType()).isEqualTo(MediaType.TEXT_PLAIN_VALUE); + assertThat(response.getContentAsString(StandardCharsets.UTF_8)) + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/registration-id"); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java index d515c73e06d..1821cc140fa 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.client.web.server; import java.net.URI; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import org.junit.jupiter.api.BeforeEach; @@ -24,13 +25,20 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.web.server.ServerRedirectStrategy; import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.test.web.reactive.server.FluxExchangeResult; import org.springframework.test.web.reactive.server.WebTestClient; @@ -81,6 +89,11 @@ public void constructorWhenClientRegistrationRepositoryNullThenIllegalArgumentEx .isThrownBy(() -> new OAuth2AuthorizationRequestRedirectWebFilter(this.clientRepository)); } + @Test + public void setterWhenAuthorizationRedirectStrategyNullThenIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRedirectStrategy(null)); + } + @Test public void filterWhenDoesNotMatchThenClientRegistrationRepositoryNotSubscribed() { // @formatter:off @@ -195,4 +208,46 @@ public void filterWhenPathMatchesThenRequestSessionAttributeNotSaved() { verifyNoInteractions(this.requestCache); } + @Test + public void filterWhenCustomRedirectStrategySetThenRedirectUriInResponseBody() { + given(this.clientRepository.findByRegistrationId(this.registration.getRegistrationId())) + .willReturn(Mono.just(this.registration)); + given(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).willReturn(Mono.empty()); + ServerRedirectStrategy customRedirectStrategy = (exchange, location) -> { + ServerHttpResponse response = exchange.getResponse(); + response.setStatusCode(HttpStatus.OK); + response.getHeaders().setContentType(MediaType.TEXT_PLAIN); + DataBuffer buffer = exchange.getResponse().bufferFactory() + .wrap(location.toASCIIString().getBytes(StandardCharsets.UTF_8)); + + return exchange.getResponse().writeWith(Flux.just(buffer)); + }; + this.filter.setAuthorizationRedirectStrategy(customRedirectStrategy); + this.filter.setRequestCache(this.requestCache); + + FluxExchangeResult result = this.client.get() + .uri("https://example.com/oauth2/authorization/registration-id").exchange().expectHeader() + .contentType(MediaType.TEXT_PLAIN).expectStatus().isOk().returnResult(String.class); + + // @formatter:off + StepVerifier.create(result.getResponseBody()) + .assertNext((uri) -> { + URI location = URI.create(uri); + + assertThat(location) + .hasScheme("https") + .hasHost("example.com") + .hasPath("/login/oauth/authorize") + .hasParameter("response_type", "code") + .hasParameter("client_id", "client-id") + .hasParameter("scope", "read:user") + .hasParameter("state") + .hasParameter("redirect_uri", "https://example.com/login/oauth2/code/registration-id"); + }) + .verifyComplete(); + // @formatter:on + + verifyNoInteractions(this.requestCache); + } + }