diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..eb60c3c4bb8 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProvider.java @@ -0,0 +1,145 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; + +import reactor.core.publisher.Mono; + +import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveJwtBearerTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.util.Assert; + +/** + * An implementation of an {@link ReactiveOAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#JWT_BEARER jwt-bearer} grant. + * + * @author Steve Riesenberg + * @since 5.6 + * @see ReactiveOAuth2AuthorizedClientProvider + * @see WebClientReactiveJwtBearerTokenResponseClient + */ +public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider { + + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = new WebClientReactiveJwtBearerTokenResponseClient(); + + private Duration clockSkew = Duration.ofSeconds(60); + + private Clock clock = Clock.systemUTC(); + + /** + * Attempt to authorize (or re-authorize) the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * {@code context}. Returns an empty {@code Mono} if authorization (or + * re-authorization) is not supported, e.g. the client's + * {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} is + * not {@link AuthorizationGrantType#JWT_BEARER jwt-bearer} OR the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if + * authorization is not supported + */ + @Override + public Mono authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + ClientRegistration clientRegistration = context.getClientRegistration(); + if (!AuthorizationGrantType.JWT_BEARER.equals(clientRegistration.getAuthorizationGrantType())) { + return Mono.empty(); + } + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); + if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { + // If client is already authorized but access token is NOT expired than no + // need for re-authorization + return Mono.empty(); + } + if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) { + return Mono.empty(); + } + Jwt jwt = (Jwt) context.getPrincipal().getPrincipal(); + // As per spec, in section 4.1 Using Assertions as Authorization Grants + // https://tools.ietf.org/html/rfc7521#section-4.1 + // + // An assertion used in this context is generally a short-lived + // representation of the authorization grant, and authorization servers + // SHOULD NOT issue access tokens with a lifetime that exceeds the + // validity period of the assertion by a significant period. In + // practice, that will usually mean that refresh tokens are not issued + // in response to assertion grant requests, and access tokens will be + // issued with a reasonably short lifetime. Clients can refresh an + // expired access token by requesting a new one using the same + // assertion, if it is still valid, or with a new assertion. + return Mono.just(new JwtBearerGrantRequest(clientRegistration, jwt)) + .flatMap(this.accessTokenResponseClient::getTokenResponse) + .onErrorMap(OAuth2AuthorizationException.class, + (ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), + ex)) + .map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + tokenResponse.getAccessToken())); + } + + private boolean hasTokenExpired(OAuth2Token token) { + return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew)); + } + + /** + * Sets the client used when requesting an access token credential at the Token + * Endpoint for the {@code jwt-bearer} grant. + * @param accessTokenResponseClient the client used when requesting an access token + * credential at the Token Endpoint for the {@code jwt-bearer} grant + */ + public void setAccessTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is + * 60 seconds. + * + *

+ * An access token is considered expired if + * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time + * {@code clock#instant()}. + * @param clockSkew the maximum acceptable clock skew + */ + public void setClockSkew(Duration clockSkew) { + Assert.notNull(clockSkew, "clockSkew cannot be null"); + Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } + + /** + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access + * token expiry. + * @param clock the clock + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock cannot be null"); + this.clock = clock; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClient.java new file mode 100644 index 00000000000..157f00be510 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClient.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.util.Set; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * The default implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} for + * the {@link AuthorizationGrantType#JWT_BEARER jwt-bearer} grant. This implementation + * uses {@link WebClient} when requesting an access token credential at the Authorization + * Server's Token Endpoint. + * + * @author Steve Riesenberg + * @since 5.6 + * @see ReactiveOAuth2AccessTokenResponseClient + * @see JwtBearerGrantRequest + * @see OAuth2AccessToken + * @see Section + * 2.1 Using JWTs as Authorization Grants + * @see Section + * 4.1 Using Assertions as Authorization Grants + */ +public final class WebClientReactiveJwtBearerTokenResponseClient + extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { + + @Override + ClientRegistration clientRegistration(JwtBearerGrantRequest grantRequest) { + return grantRequest.getClientRegistration(); + } + + @Override + Set scopes(JwtBearerGrantRequest grantRequest) { + return grantRequest.getClientRegistration().getScopes(); + } + + @Override + BodyInserters.FormInserter populateTokenRequestBody(JwtBearerGrantRequest grantRequest, + BodyInserters.FormInserter body) { + return super.populateTokenRequestBody(grantRequest, body).with(OAuth2ParameterNames.ASSERTION, + grantRequest.getJwt().getTokenValue()); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..33279c6f947 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,269 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.TestJwts; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link JwtBearerReactiveOAuth2AuthorizedClientProvider}. + * + * @author Steve Riesenberg + */ +public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests { + + private JwtBearerReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + + private ClientRegistration clientRegistration; + + private Jwt jwtAssertion; + + private Authentication principal; + + @BeforeEach + public void setup() { + this.authorizedClientProvider = new JwtBearerReactiveOAuth2AuthorizedClientProvider(); + this.accessTokenResponseClient = mock(ReactiveOAuth2AccessTokenResponseClient.class); + this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + // @formatter:off + this.clientRegistration = ClientRegistration.withRegistrationId("jwt-bearer") + .clientId("client-id") + .clientSecret("client-secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) + .scope("read", "write") + .tokenUri("https://example.com/oauth2/token") + .build(); + // @formatter:on + this.jwtAssertion = TestJwts.jwt().build(); + this.principal = new TestingAuthenticationToken(this.jwtAssertion, this.jwtAssertion); + } + + @Test + public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .withMessage("accessTokenResponseClient cannot be null"); + } + + @Test + public void setClockSkewWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + // @formatter:on + } + + @Test + public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .withMessage("clockSkew must be >= 0"); + // @formatter:on + } + + @Test + public void setClockWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClock(null)) + .withMessage("clock cannot be null"); + // @formatter:on + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) + .withMessage("context cannot be null"); + // @formatter:on + } + + @Test + public void authorizeWhenNotJwtBearerThenUnableToAuthorize() { + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + } + + @Test + public void authorizeWhenJwtBearerAndTokenNotExpiredThenNotReauthorize() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + } + + @Test + public void authorizeWhenJwtBearerAndTokenExpiredThenReauthorize() { + Instant now = Instant.now(); + Instant issuedAt = now.minus(Duration.ofMinutes(60)); + Instant expiresAt = now.minus(Duration.ofMinutes(30)); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", + issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), accessToken); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + } + + @Test + public void authorizeWhenClockSetThenCalled() { + Clock clock = mock(Clock.class); + given(clock.instant()).willReturn(Instant.now()); + this.authorizedClientProvider.setClock(clock); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + verify(clock).instant(); + } + + @Test + public void authorizeWhenJwtBearerAndTokenNotExpiredButClockSkewForcesExpiryThenReauthorize() { + Instant now = Instant.now(); + Instant issuedAt = now.minus(Duration.ofMinutes(60)); + Instant expiresAt = now.plus(Duration.ofMinutes(1)); + OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), expiresInOneMinAccessToken); + // Shorten the lifespan of the access token by 90 seconds, which will ultimately + // force it to expire on the client + this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext) + .block(); + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + } + + @Test + public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() { + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(new TestingAuthenticationToken("user", "password")) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + } + + @Test + public void authorizeWhenInvalidRequestThenThrowClientAuthorizationException() { + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn( + Mono.error(new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)))); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + + // @formatter:off + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block()) + .withMessageContaining(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + } + + @Test + public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClientTests.java new file mode 100644 index 00000000000..9470556046b --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClientTests.java @@ -0,0 +1,320 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.util.Collections; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.reactive.function.client.WebClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link WebClientReactiveJwtBearerTokenResponseClient}. + * + * @author Steve Riesenberg + */ +public class WebClientReactiveJwtBearerTokenResponseClientTests { + + // @formatter:off + private static final String DEFAULT_ACCESS_TOKEN_RESPONSE = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": 3600\n" + + "}\n"; + // @formatter:on + + private WebClientReactiveJwtBearerTokenResponseClient client; + + private MockWebServer server; + + private ClientRegistration.Builder clientRegistration; + + private Jwt jwtAssertion; + + @BeforeEach + public void setup() throws Exception { + this.client = new WebClientReactiveJwtBearerTokenResponseClient(); + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + // @formatter:off + this.clientRegistration = TestClientRegistrations.clientCredentials() + .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) + .tokenUri(tokenUri) + .scope("read", "write"); + // @formatter:on + this.jwtAssertion = TestJwts.jwt().build(); + } + + @AfterEach + public void cleanup() throws Exception { + this.server.shutdown(); + } + + @Test + public void setWebClientWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.client.setWebClient(null)) + .withMessage("webClient cannot be null"); + } + + @Test + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.client.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + } + + @Test + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.client.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + } + + @Test + public void setBodyExtractorWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.client.setBodyExtractor(null)) + .withMessage("bodyExtractor cannot be null"); + } + + @Test + public void getTokenResponseWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.client.getTokenResponse(null)) + .withMessage("grantRequest cannot be null"); + } + + @Test + public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationException() { + ClientRegistration registration = this.clientRegistration.build(); + enqueueUnexpectedResponse(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.client.getTokenResponse(request).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessage("[invalid_token_response] Empty OAuth 2.0 Access Token Response"); + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + ClientRegistration registration = this.clientRegistration.build(); + enqueueServerErrorResponse(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.client.getTokenResponse(request).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR)) + .withMessageContaining("[server_error]"); + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + // @formatter:off + String accessTokenResponse = "{\n" + + " \"error\": \"invalid_grant\"\n" + + "}\n"; + // @formatter:on + ClientRegistration registration = this.clientRegistration.build(); + enqueueJson(accessTokenResponse); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.client.getTokenResponse(request).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT)) + .withMessageContaining("[invalid_grant]"); + } + + @Test + public void getTokenResponseWhenResponseIsNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + // @formatter:off + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": 3600\n" + + "}\n"; + // @formatter:on + ClientRegistration registration = this.clientRegistration.build(); + enqueueJson(accessTokenResponse); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.client.getTokenResponse(request).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response") + .withMessageContaining("Unsupported token_type: not-bearer"); + } + + @Test + public void getTokenResponseWhenWebClientSetThenCalled() { + WebClient customClient = mock(WebClient.class); + given(customClient.post()).willReturn(WebClient.builder().build().post()); + this.client.setWebClient(customClient); + enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + ClientRegistration registration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion); + this.client.getTokenResponse(request).block(); + verify(customClient).post(); + } + + @Test + public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); + given(headersConverter.convert(request)).willReturn(headers); + this.client.setHeadersConverter(headersConverter); + enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.client.getTokenResponse(request).block(); + verify(headersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); + } + + @Test + public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + Converter addedHeadersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(addedHeadersConverter.convert(request)).willReturn(headers); + this.client.addHeadersConverter(addedHeadersConverter); + enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.client.getTokenResponse(request).block(); + verify(addedHeadersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); + assertThat(actualRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenBodyExtractorSetThenCalled() { + BodyExtractor, ReactiveHttpInputMessage> bodyExtractor = mock( + BodyExtractor.class); + OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(bodyExtractor.extract(any(), any())).willReturn(Mono.just(response)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + this.client.setBodyExtractor(bodyExtractor); + enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.client.getTokenResponse(request).block(); + verify(bodyExtractor).extract(any(), any()); + } + + @Test + public void getTokenResponseWhenClientSecretBasicThenSuccess() throws Exception { + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); + assertThat(response).isNotNull(); + assertThat(response.getAccessToken().getScopes()).containsExactly("read", "write"); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)) + .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); + assertThat(actualRequest.getBody().readUtf8()).isEqualTo( + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer&scope=read+write&assertion=token"); + } + + @Test + public void getTokenResponseWhenClientSecretPostThenSuccess() throws Exception { + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .build(); + // @formatter:on + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); + assertThat(response).isNotNull(); + assertThat(response.getAccessToken().getScopes()).containsExactly("read", "write"); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(actualRequest.getBody().readUtf8()).isEqualTo( + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer&client_id=client-id&client_secret=client-secret&scope=read+write&assertion=token"); + } + + @Test + public void getTokenResponseWhenResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { + // @formatter:off + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": 3600,\n" + + " \"scope\": \"read\"\n" + + "}\n"; + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + enqueueJson(accessTokenResponse); + OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); + assertThat(response).isNotNull(); + assertThat(response.getAccessToken().getScopes()).containsExactly("read"); + } + + private void enqueueJson(String body) { + MockResponse response = new MockResponse().setBody(body).setHeader(HttpHeaders.CONTENT_TYPE, + MediaType.APPLICATION_JSON_VALUE); + this.server.enqueue(response); + } + + private void enqueueUnexpectedResponse() { + // @formatter:off + MockResponse response = new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setResponseCode(301); + // @formatter:on + this.server.enqueue(response); + } + + private void enqueueServerErrorResponse() { + // @formatter:off + MockResponse response = new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setResponseCode(500) + .setBody("{}"); + // @formatter:on + this.server.enqueue(response); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 343a0f47ef4..6faf6657e3c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -61,12 +61,14 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.oauth2.client.ClientAuthorizationException; +import org.springframework.security.oauth2.client.JwtBearerReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; @@ -78,6 +80,8 @@ import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -87,6 +91,8 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientRequest; @@ -131,6 +137,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock private ReactiveOAuth2AccessTokenResponseClient passwordTokenResponseClient; + @Mock + private ReactiveOAuth2AccessTokenResponseClient jwtBearerTokenResponseClient; + @Mock private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler; @@ -162,6 +171,8 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @BeforeEach public void setup() { // @formatter:off + JwtBearerReactiveOAuth2AuthorizedClientProvider jwtBearerAuthorizedClientProvider = new JwtBearerReactiveOAuth2AuthorizedClientProvider(); + jwtBearerAuthorizedClientProvider.setAccessTokenResponseClient(this.jwtBearerTokenResponseClient); ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder .builder() .authorizationCode() @@ -170,6 +181,7 @@ public void setup() { .clientCredentials( (configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) .password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) + .provider(jwtBearerAuthorizedClientProvider) .build(); // @formatter:on this.authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( @@ -396,6 +408,49 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() assertThat(getBody(request0)).isEmpty(); } + @Test + public void filterWhenJwtBearerClientNotAuthorizedThenExchangeToken() { + setupMocks(); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("exchanged-token") + .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(360).build(); + given(this.jwtBearerTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + ClientRegistration registration = ClientRegistration.withRegistrationId("jwt-bearer") + .clientId("client-id") + .clientSecret("client-secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) + .scope("read", "write") + .tokenUri("https://example.com/oauth/token") + .build(); + // @formatter:on + given(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))) + .willReturn(Mono.just(registration)); + Jwt jwtAssertion = TestJwts.jwt().build(); + Authentication jwtAuthentication = new TestingAuthenticationToken(jwtAssertion, jwtAssertion); + given(this.authorizedClientRepository.loadAuthorizedClient(eq(registration.getRegistrationId()), + eq(jwtAuthentication), any())).willReturn(Mono.empty()); + // @formatter:off + ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId(registration.getRegistrationId())) + .build(); + // @formatter:on + this.function.filter(request, this.exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(jwtAuthentication)) + .subscriberContext(serverWebExchange()).block(); + verify(this.jwtBearerTokenResponseClient).getTokenResponse(any()); + verify(this.authorizedClientRepository).loadAuthorizedClient(eq(registration.getRegistrationId()), + eq(jwtAuthentication), any()); + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(jwtAuthentication), any()); + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + ClientRequest request1 = requests.get(0); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer exchanged-token"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + @Test public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",