From bcc1cfc28afb92597a2273503f063aecfaadbd35 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Wed, 31 May 2023 15:04:08 -0600 Subject: [PATCH] Restore OAuth2AuthorizedClientRepository Test Instrumentation Closes gh-13113 --- .../server/SecurityMockServerConfigurers.java | 163 +++++++++++++----- .../SecurityMockMvcRequestPostProcessors.java | 140 +++++++++++---- ...ockServerConfigurersOAuth2ClientTests.java | 31 +++- ...equestPostProcessorsOAuth2ClientTests.java | 25 ++- 4 files changed, 281 insertions(+), 78 deletions(-) diff --git a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java index 174a3188799..d8a87eccda2 100644 --- a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java +++ b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,10 +46,11 @@ import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository; @@ -214,8 +215,8 @@ public static OidcLoginMutator mockOidcLogin() { * tokens to be valid. * *

- * The support works by associating the authorized client to the ServerWebExchange via - * the {@link WebSessionServerOAuth2AuthorizedClientRepository} + * The support works by associating the authorized client to the ServerWebExchange + * using a {@link ServerOAuth2AuthorizedClientRepository} *

* @return the {@link OAuth2ClientMutator} to further configure or use * @since 5.3 @@ -230,8 +231,8 @@ public static OAuth2ClientMutator mockOAuth2Client() { * tokens to be valid. * *

- * The support works by associating the authorized client to the ServerWebExchange via - * the {@link WebSessionServerOAuth2AuthorizedClientRepository} + * The support works by associating the authorized client to the ServerWebExchange + * using a {@link ServerOAuth2AuthorizedClientRepository} *

* @param registrationId The registration id associated with the * {@link OAuth2AuthorizedClient} @@ -715,8 +716,6 @@ public static final class OAuth2LoginMutator implements WebTestClientConfigurer, private Supplier oauth2User = this::defaultPrincipal; - private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository = new WebSessionServerOAuth2AuthorizedClientRepository(); - private OAuth2LoginMutator(OAuth2AccessToken accessToken) { this.accessToken = accessToken; this.clientRegistration = clientRegistrationBuilder().build(); @@ -776,12 +775,8 @@ public OAuth2LoginMutator oauth2User(OAuth2User oauth2User) { /** * Use the provided {@link ClientRegistration} as the client to authorize. *

- * The supplied {@link ClientRegistration} will be registered into an - * {@link WebSessionServerOAuth2AuthorizedClientRepository}. Tests relying on - * {@link org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient} - * annotations should register an - * {@link WebSessionServerOAuth2AuthorizedClientRepository} bean to the - * application context. + * The supplied {@link ClientRegistration} will be registered into a + * {@link ServerOAuth2AuthorizedClientRepository}. * @param clientRegistration the {@link ClientRegistration} to use * @return the {@link OAuth2LoginMutator} for further configuration */ @@ -866,8 +861,6 @@ public static final class OidcLoginMutator implements WebTestClientConfigurer, M private Collection authorities; - ServerOAuth2AuthorizedClientRepository authorizedClientRepository = new WebSessionServerOAuth2AuthorizedClientRepository(); - private OidcLoginMutator(OAuth2AccessToken accessToken) { this.accessToken = accessToken; this.clientRegistration = clientRegistrationBuilder().build(); @@ -942,12 +935,8 @@ public OidcLoginMutator oidcUser(OidcUser oidcUser) { /** * Use the provided {@link ClientRegistration} as the client to authorize. *

- * The supplied {@link ClientRegistration} will be registered into an - * {@link WebSessionServerOAuth2AuthorizedClientRepository}. Tests relying on - * {@link org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient} - * annotations should register an - * {@link WebSessionServerOAuth2AuthorizedClientRepository} bean to the - * application context. + * The supplied {@link ClientRegistration} will be registered into a + * {@link ServerOAuth2AuthorizedClientRepository}. * @param clientRegistration the {@link ClientRegistration} to use * @return the {@link OidcLoginMutator} for further configuration */ @@ -1037,8 +1026,6 @@ public static final class OAuth2ClientMutator implements WebTestClientConfigurer private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token", null, null, Collections.singleton("read")); - private ServerOAuth2AuthorizedClientRepository authorizedClientRepository = new WebSessionServerOAuth2AuthorizedClientRepository(); - private OAuth2ClientMutator() { } @@ -1116,16 +1103,15 @@ public void afterConfigurerAdded(WebTestClient.Builder builder, private Consumer> addAuthorizedClientFilter() { OAuth2AuthorizedClient client = getClient(); return (filters) -> filters.add(0, (exchange, chain) -> { - ReactiveOAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServerTestUtils - .getOAuth2AuthorizedClientManager(exchange); - if (!(authorizationClientManager instanceof TestReactiveOAuth2AuthorizedClientManager)) { - authorizationClientManager = new TestReactiveOAuth2AuthorizedClientManager( - authorizationClientManager); - OAuth2ClientServerTestUtils.setOAuth2AuthorizedClientManager(exchange, authorizationClientManager); + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = OAuth2ClientServerTestUtils + .getAuthorizedClientRepository(exchange); + if (!(authorizedClientRepository instanceof TestOAuth2AuthorizedClientRepository)) { + authorizedClientRepository = new TestOAuth2AuthorizedClientRepository(authorizedClientRepository); + OAuth2ClientServerTestUtils.setAuthorizedClientRepository(exchange, authorizedClientRepository); } - TestReactiveOAuth2AuthorizedClientManager.enable(exchange); - exchange.getAttributes().put(TestReactiveOAuth2AuthorizedClientManager.TOKEN_ATTR_NAME, client); - return chain.filter(exchange); + TestOAuth2AuthorizedClientRepository.enable(exchange); + return authorizedClientRepository.saveAuthorizedClient(client, null, exchange) + .then(chain.filter(exchange)); }); } @@ -1142,21 +1128,19 @@ private ClientRegistration.Builder clientRegistrationBuilder() { } /** - * Used to wrap the {@link OAuth2AuthorizedClientManager} to provide support for - * testing when the request is wrapped + * Used to wrap the {@link OAuth2AuthorizedClientRepository} to provide support + * for testing when the request is wrapped */ - private static final class TestReactiveOAuth2AuthorizedClientManager - implements ReactiveOAuth2AuthorizedClientManager { + private static final class TestOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager { - static final String TOKEN_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class.getName() - .concat(".TOKEN"); - - static final String ENABLED_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class.getName() + static final String ENABLED_ATTR_NAME = TestOAuth2AuthorizedClientManager.class.getName() .concat(".ENABLED"); private final ReactiveOAuth2AuthorizedClientManager delegate; - private TestReactiveOAuth2AuthorizedClientManager(ReactiveOAuth2AuthorizedClientManager delegate) { + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + TestOAuth2AuthorizedClientManager(ReactiveOAuth2AuthorizedClientManager delegate) { this.delegate = delegate; } @@ -1164,8 +1148,8 @@ private TestReactiveOAuth2AuthorizedClientManager(ReactiveOAuth2AuthorizedClient public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { ServerWebExchange exchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); if (isEnabled(exchange)) { - OAuth2AuthorizedClient client = exchange.getAttribute(TOKEN_ATTR_NAME); - return Mono.just(client); + return this.authorizedClientRepository.loadAuthorizedClient( + authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), exchange); } return this.delegate.authorize(authorizeRequest); } @@ -1180,6 +1164,62 @@ boolean isEnabled(ServerWebExchange exchange) { } + /** + * Used to wrap the {@link OAuth2AuthorizedClientRepository} to provide support + * for testing when the request is wrapped + */ + static final class TestOAuth2AuthorizedClientRepository implements ServerOAuth2AuthorizedClientRepository { + + static final String TOKEN_ATTR_NAME = TestOAuth2AuthorizedClientRepository.class.getName().concat(".TOKEN"); + + static final String ENABLED_ATTR_NAME = TestOAuth2AuthorizedClientRepository.class.getName() + .concat(".ENABLED"); + + private final ServerOAuth2AuthorizedClientRepository delegate; + + TestOAuth2AuthorizedClientRepository(ServerOAuth2AuthorizedClientRepository delegate) { + this.delegate = delegate; + } + + @Override + public Mono loadAuthorizedClient(String clientRegistrationId, + Authentication principal, ServerWebExchange exchange) { + if (isEnabled(exchange)) { + return Mono.just(exchange.getAttribute(TOKEN_ATTR_NAME)); + } + return this.delegate.loadAuthorizedClient(clientRegistrationId, principal, exchange); + } + + @Override + public Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, + ServerWebExchange exchange) { + if (isEnabled(exchange)) { + exchange.getAttributes().put(TOKEN_ATTR_NAME, authorizedClient); + return Mono.empty(); + } + return this.delegate.saveAuthorizedClient(authorizedClient, principal, exchange); + } + + @Override + public Mono removeAuthorizedClient(String clientRegistrationId, Authentication principal, + ServerWebExchange exchange) { + if (isEnabled(exchange)) { + exchange.getAttributes().remove(TOKEN_ATTR_NAME); + return Mono.empty(); + } + return this.delegate.removeAuthorizedClient(clientRegistrationId, principal, exchange); + } + + static void enable(ServerWebExchange exchange) { + exchange.getAttributes().put(ENABLED_ATTR_NAME, Boolean.TRUE); + } + + boolean isEnabled(ServerWebExchange exchange) { + return Boolean.TRUE.equals(exchange.getAttribute(ENABLED_ATTR_NAME)); + } + + } + private static final class OAuth2ClientServerTestUtils { private static final ServerOAuth2AuthorizedClientRepository DEFAULT_CLIENT_REPO = new WebSessionServerOAuth2AuthorizedClientRepository(); @@ -1188,7 +1228,7 @@ private OAuth2ClientServerTestUtils() { } /** - * Gets the {@link ReactiveOAuth2AuthorizedClientManager} for the specified + * Gets the {@link ServerOAuth2AuthorizedClientRepository} for the specified * {@link ServerWebExchange}. If one is not found, one based off of * {@link WebSessionServerOAuth2AuthorizedClientRepository} is used. * @param exchange the {@link ServerWebExchange} to obtain the @@ -1196,6 +1236,39 @@ private OAuth2ClientServerTestUtils() { * @return the {@link ReactiveOAuth2AuthorizedClientManager} for the specified * {@link ServerWebExchange} */ + static ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository(ServerWebExchange exchange) { + ReactiveOAuth2AuthorizedClientManager manager = getOAuth2AuthorizedClientManager(exchange); + if (manager == null) { + return DEFAULT_CLIENT_REPO; + } + if (manager instanceof DefaultReactiveOAuth2AuthorizedClientManager) { + return (ServerOAuth2AuthorizedClientRepository) ReflectionTestUtils.getField(manager, + "authorizedClientRepository"); + } + if (manager instanceof TestOAuth2AuthorizedClientManager) { + return ((TestOAuth2AuthorizedClientManager) manager).authorizedClientRepository; + } + return DEFAULT_CLIENT_REPO; + } + + static void setAuthorizedClientRepository(ServerWebExchange exchange, + ServerOAuth2AuthorizedClientRepository repository) { + ReactiveOAuth2AuthorizedClientManager manager = getOAuth2AuthorizedClientManager(exchange); + if (manager == null) { + return; + } + if (manager instanceof DefaultReactiveOAuth2AuthorizedClientManager) { + ReflectionTestUtils.setField(manager, "authorizedClientRepository", repository); + return; + } + if (!(manager instanceof TestOAuth2AuthorizedClientManager)) { + manager = new TestOAuth2AuthorizedClientManager(manager); + setOAuth2AuthorizedClientManager(exchange, manager); + } + TestOAuth2AuthorizedClientManager.enable(exchange); + ((TestOAuth2AuthorizedClientManager) manager).authorizedClientRepository = repository; + } + static ReactiveOAuth2AuthorizedClientManager getOAuth2AuthorizedClientManager(ServerWebExchange exchange) { OAuth2AuthorizedClientArgumentResolver resolver = findResolver(exchange, OAuth2AuthorizedClientArgumentResolver.class); diff --git a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java index 33c2db2066c..dd419b53554 100644 --- a/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java +++ b/test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,6 +62,7 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver; @@ -430,7 +431,7 @@ public static OidcLoginRequestPostProcessor oidcLogin() { * *

* The support works by associating the authorized client to the HttpServletRequest - * via the {@link HttpSessionOAuth2AuthorizedClientRepository} + * using an {@link OAuth2AuthorizedClientRepository} *

* @return the {@link OAuth2ClientRequestPostProcessor} for additional customization * @since 5.3 @@ -445,7 +446,7 @@ public static OAuth2ClientRequestPostProcessor oauth2Client() { * *

* The support works by associating the authorized client to the HttpServletRequest - * via the {@link HttpSessionOAuth2AuthorizedClientRepository} + * using an {@link OAuth2AuthorizedClientRepository} *

* @param registrationId The registration id for the {@link OAuth2AuthorizedClient} * @return the {@link OAuth2ClientRequestPostProcessor} for additional customization @@ -1317,11 +1318,7 @@ public OAuth2LoginRequestPostProcessor oauth2User(OAuth2User oauth2User) { * Use the provided {@link ClientRegistration} as the client to authorize. * * The supplied {@link ClientRegistration} will be registered into an - * {@link HttpSessionOAuth2AuthorizedClientRepository}. Tests relying on - * {@link org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient} - * annotations should register an - * {@link HttpSessionOAuth2AuthorizedClientRepository} bean to the application - * context. + * {@link OAuth2AuthorizedClientRepository}. * @param clientRegistration the {@link ClientRegistration} to use * @return the {@link OAuth2LoginRequestPostProcessor} for further configuration */ @@ -1456,11 +1453,7 @@ public OidcLoginRequestPostProcessor oidcUser(OidcUser oidcUser) { * Use the provided {@link ClientRegistration} as the client to authorize. * * The supplied {@link ClientRegistration} will be registered into an - * {@link HttpSessionOAuth2AuthorizedClientRepository}. Tests relying on - * {@link org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient} - * annotations should register an - * {@link HttpSessionOAuth2AuthorizedClientRepository} bean to the application - * context. + * {@link HttpSessionOAuth2AuthorizedClientRepository}. * @param clientRegistration the {@link ClientRegistration} to use * @return the {@link OidcLoginRequestPostProcessor} for further configuration */ @@ -1586,14 +1579,14 @@ public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) } OAuth2AuthorizedClient client = new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, this.accessToken); - OAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServletTestUtils - .getOAuth2AuthorizedClientManager(request); - if (!(authorizationClientManager instanceof TestOAuth2AuthorizedClientManager)) { - authorizationClientManager = new TestOAuth2AuthorizedClientManager(authorizationClientManager); - OAuth2ClientServletTestUtils.setOAuth2AuthorizedClientManager(request, authorizationClientManager); + OAuth2AuthorizedClientRepository authorizedClientRepository = OAuth2ClientServletTestUtils + .getAuthorizedClientRepository(request); + if (!(authorizedClientRepository instanceof TestOAuth2AuthorizedClientRepository)) { + authorizedClientRepository = new TestOAuth2AuthorizedClientRepository(authorizedClientRepository); + OAuth2ClientServletTestUtils.setAuthorizedClientRepository(request, authorizedClientRepository); } - TestOAuth2AuthorizedClientManager.enable(request); - request.setAttribute(TestOAuth2AuthorizedClientManager.TOKEN_ATTR_NAME, client); + TestOAuth2AuthorizedClientRepository.enable(request); + authorizedClientRepository.saveAuthorizedClient(client, null, request, new MockHttpServletResponse()); return request; } @@ -1604,19 +1597,19 @@ private ClientRegistration.Builder clientRegistrationBuilder() { } /** - * Used to wrap the {@link OAuth2AuthorizedClientManager} to provide support for - * testing when the request is wrapped + * Used to wrap the {@link OAuth2AuthorizedClientRepository} to provide support + * for testing when the request is wrapped */ private static final class TestOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager { - static final String TOKEN_ATTR_NAME = TestOAuth2AuthorizedClientManager.class.getName().concat(".TOKEN"); - static final String ENABLED_ATTR_NAME = TestOAuth2AuthorizedClientManager.class.getName() .concat(".ENABLED"); private final OAuth2AuthorizedClientManager delegate; - private TestOAuth2AuthorizedClientManager(OAuth2AuthorizedClientManager delegate) { + private OAuth2AuthorizedClientRepository authorizedClientRepository; + + TestOAuth2AuthorizedClientManager(OAuth2AuthorizedClientManager delegate) { this.delegate = delegate; } @@ -1624,7 +1617,8 @@ private TestOAuth2AuthorizedClientManager(OAuth2AuthorizedClientManager delegate public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) { HttpServletRequest request = authorizeRequest.getAttribute(HttpServletRequest.class.getName()); if (isEnabled(request)) { - return (OAuth2AuthorizedClient) request.getAttribute(TOKEN_ATTR_NAME); + return this.authorizedClientRepository.loadAuthorizedClient( + authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), request); } return this.delegate.authorize(authorizeRequest); } @@ -1639,6 +1633,62 @@ boolean isEnabled(HttpServletRequest request) { } + /** + * Used to wrap the {@link OAuth2AuthorizedClientRepository} to provide support + * for testing when the request is wrapped + */ + static final class TestOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository { + + static final String TOKEN_ATTR_NAME = TestOAuth2AuthorizedClientRepository.class.getName().concat(".TOKEN"); + + static final String ENABLED_ATTR_NAME = TestOAuth2AuthorizedClientRepository.class.getName() + .concat(".ENABLED"); + + private final OAuth2AuthorizedClientRepository delegate; + + TestOAuth2AuthorizedClientRepository(OAuth2AuthorizedClientRepository delegate) { + this.delegate = delegate; + } + + @Override + public T loadAuthorizedClient(String clientRegistrationId, + Authentication principal, HttpServletRequest request) { + if (isEnabled(request)) { + return (T) request.getAttribute(TOKEN_ATTR_NAME); + } + return this.delegate.loadAuthorizedClient(clientRegistrationId, principal, request); + } + + @Override + public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, + HttpServletRequest request, HttpServletResponse response) { + if (isEnabled(request)) { + request.setAttribute(TOKEN_ATTR_NAME, authorizedClient); + return; + } + this.delegate.saveAuthorizedClient(authorizedClient, principal, request, response); + } + + @Override + public void removeAuthorizedClient(String clientRegistrationId, Authentication principal, + HttpServletRequest request, HttpServletResponse response) { + if (isEnabled(request)) { + request.removeAttribute(TOKEN_ATTR_NAME); + return; + } + this.delegate.removeAuthorizedClient(clientRegistrationId, principal, request, response); + } + + static void enable(HttpServletRequest request) { + request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE); + } + + boolean isEnabled(HttpServletRequest request) { + return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME)); + } + + } + private static final class OAuth2ClientServletTestUtils { private static final OAuth2AuthorizedClientRepository DEFAULT_CLIENT_REPO = new HttpSessionOAuth2AuthorizedClientRepository(); @@ -1647,7 +1697,7 @@ private OAuth2ClientServletTestUtils() { } /** - * Gets the {@link OAuth2AuthorizedClientManager} for the specified + * Gets the {@link OAuth2AuthorizedClientRepository} for the specified * {@link HttpServletRequest}. If one is not found, one based off of * {@link HttpSessionOAuth2AuthorizedClientRepository} is used. * @param request the {@link HttpServletRequest} to obtain the @@ -1655,12 +1705,44 @@ private OAuth2ClientServletTestUtils() { * @return the {@link OAuth2AuthorizedClientManager} for the specified * {@link HttpServletRequest} */ + static OAuth2AuthorizedClientRepository getAuthorizedClientRepository(HttpServletRequest request) { + OAuth2AuthorizedClientManager manager = getOAuth2AuthorizedClientManager(request); + if (manager == null) { + return DEFAULT_CLIENT_REPO; + } + if (manager instanceof DefaultOAuth2AuthorizedClientManager) { + return (OAuth2AuthorizedClientRepository) ReflectionTestUtils.getField(manager, + "authorizedClientRepository"); + } + if (manager instanceof TestOAuth2AuthorizedClientManager) { + return ((TestOAuth2AuthorizedClientManager) manager).authorizedClientRepository; + } + return DEFAULT_CLIENT_REPO; + } + + static void setAuthorizedClientRepository(HttpServletRequest request, + OAuth2AuthorizedClientRepository repository) { + OAuth2AuthorizedClientManager manager = getOAuth2AuthorizedClientManager(request); + if (manager == null) { + return; + } + if (manager instanceof DefaultOAuth2AuthorizedClientManager) { + ReflectionTestUtils.setField(manager, "authorizedClientRepository", repository); + return; + } + if (!(manager instanceof TestOAuth2AuthorizedClientManager)) { + manager = new TestOAuth2AuthorizedClientManager(manager); + setOAuth2AuthorizedClientManager(request, manager); + } + TestOAuth2AuthorizedClientManager.enable(request); + ((TestOAuth2AuthorizedClientManager) manager).authorizedClientRepository = repository; + } + static OAuth2AuthorizedClientManager getOAuth2AuthorizedClientManager(HttpServletRequest request) { OAuth2AuthorizedClientArgumentResolver resolver = findResolver(request, OAuth2AuthorizedClientArgumentResolver.class); if (resolver == null) { - return (authorizeRequest) -> DEFAULT_CLIENT_REPO.loadAuthorizedClient( - authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), request); + return null; } return (OAuth2AuthorizedClientManager) ReflectionTestUtils.getField(resolver, "authorizedClientManager"); diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java index c7910c7ebff..c925cb3c964 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,15 +27,19 @@ import org.springframework.http.MediaType; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.OAuth2ClientMutator.TestOAuth2AuthorizedClientRepository; import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; +import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; @@ -61,16 +65,21 @@ public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMock @Mock private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private ReactiveOAuth2AuthorizedClientManager authorizedClientManager; + private WebTestClient client; @BeforeEach public void setup() { + this.authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); this.client = WebTestClient.bindToController(this.controller) - .argumentResolvers((c) -> c.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver( - this.clientRegistrationRepository, this.authorizedClientRepository))) + .argumentResolvers((c) -> c + .addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientManager))) .webFilter(new SecurityContextServerWebExchangeWebFilter()) .apply(SecurityMockServerConfigurers.springSecurity()).configureClient() .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE).build(); + } @Test @@ -160,6 +169,22 @@ public void oauth2ClientWhenUsedOnceThenDoesNotAffectRemainingTests() throws Exc any(ServerWebExchange.class)); } + // gh-13113 + @Test + public void oauth2ClientWhenUsedThenSetsClientToRepository() { + this.client.mutateWith(SecurityMockServerConfigurers.mockOAuth2Client("registration-id")) + .mutateWith((clientBuilder, httpBuilder, connector) -> httpBuilder + .filters((filters) -> filters.add((exchange, chain) -> { + ServerOAuth2AuthorizedClientRepository repository = (ServerOAuth2AuthorizedClientRepository) ReflectionTestUtils + .getField(this.authorizedClientManager, "authorizedClientRepository"); + assertThat(repository).isInstanceOf(TestOAuth2AuthorizedClientRepository.class); + return repository.loadAuthorizedClient("registration-id", null, exchange) + .switchIfEmpty(Mono.error(new AssertionError("no authorized client found"))) + .then(chain.filter(exchange)); + }))) + .get().uri("/client").exchange().expectStatus().isOk(); + } + @RestController static class OAuth2LoginController { diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2ClientTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2ClientTests.java index 2d556581778..e93277036d8 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2ClientTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2ClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,17 +31,21 @@ import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.test.context.TestSecurityContextHolder; +import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor.TestOAuth2AuthorizedClientRepository; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.test.context.web.WebAppConfiguration; +import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.bind.annotation.GetMapping; @@ -49,6 +53,7 @@ import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -148,6 +153,18 @@ public void oauth2ClientWhenUsedOnceThenDoesNotAffectRemainingTests() throws Exc any(HttpServletRequest.class)); } + // gh-13113 + @Test + public void oauth2ClientWhenUsedThenSetsClientToRepository() throws Exception { + HttpServletRequest request = this.mvc.perform(get("/client-id").with(oauth2Client("registration-id"))) + .andExpect(content().string("test-client")).andReturn().getRequest(); + OAuth2AuthorizedClientManager manager = this.context.getBean(OAuth2AuthorizedClientManager.class); + OAuth2AuthorizedClientRepository repository = (OAuth2AuthorizedClientRepository) ReflectionTestUtils + .getField(manager, "authorizedClientRepository"); + assertThat(repository).isInstanceOf(TestOAuth2AuthorizedClientRepository.class); + assertThat((OAuth2AuthorizedClient) repository.loadAuthorizedClient("id", null, request)).isNotNull(); + } + @EnableWebSecurity @EnableWebMvc static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter { @@ -163,6 +180,12 @@ protected void configure(HttpSecurity http) throws Exception { // @formatter:on } + @Bean + OAuth2AuthorizedClientManager authorizedClientManager(ClientRegistrationRepository clients, + OAuth2AuthorizedClientRepository authorizedClients) { + return new DefaultOAuth2AuthorizedClientManager(clients, authorizedClients); + } + @Bean ClientRegistrationRepository clientRegistrationRepository() { return mock(ClientRegistrationRepository.class);