Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement reactive support for JWT as an Authorization Grant #10327

Merged
merged 1 commit into from
Oct 5, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Implement reactive support for JWT as an Authorization Grant
Closes gh-10147
sjohnr committed Oct 5, 2021
commit 4fe769eca1b106c53b6b499124fa7bfff33bedd9
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;

import reactor.core.publisher.Mono;

import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveJwtBearerTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.util.Assert;

/**
* An implementation of an {@link ReactiveOAuth2AuthorizedClientProvider} for the
* {@link AuthorizationGrantType#JWT_BEARER jwt-bearer} grant.
*
* @author Steve Riesenberg
* @since 5.6
* @see ReactiveOAuth2AuthorizedClientProvider
* @see WebClientReactiveJwtBearerTokenResponseClient
*/
public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements ReactiveOAuth2AuthorizedClientProvider {

private ReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new WebClientReactiveJwtBearerTokenResponseClient();

private Duration clockSkew = Duration.ofSeconds(60);

private Clock clock = Clock.systemUTC();

/**
* Attempt to authorize (or re-authorize) the
* {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided
* {@code context}. Returns an empty {@code Mono} if authorization (or
* re-authorization) is not supported, e.g. the client's
* {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} is
* not {@link AuthorizationGrantType#JWT_BEARER jwt-bearer} OR the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired.
* @param context the context that holds authorization-specific state for the client
* @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if
* authorization is not supported
*/
@Override
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context) {
Assert.notNull(context, "context cannot be null");
ClientRegistration clientRegistration = context.getClientRegistration();
if (!AuthorizationGrantType.JWT_BEARER.equals(clientRegistration.getAuthorizationGrantType())) {
return Mono.empty();
}
OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient();
if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) {
// If client is already authorized but access token is NOT expired than no
// need for re-authorization
return Mono.empty();
}
if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
return Mono.empty();
}
Jwt jwt = (Jwt) context.getPrincipal().getPrincipal();
// As per spec, in section 4.1 Using Assertions as Authorization Grants
// https://tools.ietf.org/html/rfc7521#section-4.1
//
// An assertion used in this context is generally a short-lived
// representation of the authorization grant, and authorization servers
// SHOULD NOT issue access tokens with a lifetime that exceeds the
// validity period of the assertion by a significant period. In
// practice, that will usually mean that refresh tokens are not issued
// in response to assertion grant requests, and access tokens will be
// issued with a reasonably short lifetime. Clients can refresh an
// expired access token by requesting a new one using the same
// assertion, if it is still valid, or with a new assertion.
return Mono.just(new JwtBearerGrantRequest(clientRegistration, jwt))
.flatMap(this.accessTokenResponseClient::getTokenResponse)
.onErrorMap(OAuth2AuthorizationException.class,
(ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(),
ex))
.map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
tokenResponse.getAccessToken()));
}

private boolean hasTokenExpired(OAuth2Token token) {
return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew));
}

/**
* Sets the client used when requesting an access token credential at the Token
* Endpoint for the {@code jwt-bearer} grant.
* @param accessTokenResponseClient the client used when requesting an access token
* credential at the Token Endpoint for the {@code jwt-bearer} grant
*/
public void setAccessTokenResponseClient(
ReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient) {
Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null");
this.accessTokenResponseClient = accessTokenResponseClient;
}

/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is
* 60 seconds.
*
* <p>
* An access token is considered expired if
* {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time
* {@code clock#instant()}.
* @param clockSkew the maximum acceptable clock skew
*/
public void setClockSkew(Duration clockSkew) {
Assert.notNull(clockSkew, "clockSkew cannot be null");
Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0");
this.clockSkew = clockSkew;
}

/**
* Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access
* token expiry.
* @param clock the clock
*/
public void setClock(Clock clock) {
Assert.notNull(clock, "clock cannot be null");
this.clock = clock;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.endpoint;

import java.util.Set;

import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;

/**
* The default implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} for
* the {@link AuthorizationGrantType#JWT_BEARER jwt-bearer} grant. This implementation
* uses {@link WebClient} when requesting an access token credential at the Authorization
* Server's Token Endpoint.
*
* @author Steve Riesenberg
* @since 5.6
* @see ReactiveOAuth2AccessTokenResponseClient
* @see JwtBearerGrantRequest
* @see OAuth2AccessToken
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7523#section-2.1">Section
* 2.1 Using JWTs as Authorization Grants</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7521#section-4.1">Section
* 4.1 Using Assertions as Authorization Grants</a>
*/
public final class WebClientReactiveJwtBearerTokenResponseClient
extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> {

@Override
ClientRegistration clientRegistration(JwtBearerGrantRequest grantRequest) {
return grantRequest.getClientRegistration();
}

@Override
Set<String> scopes(JwtBearerGrantRequest grantRequest) {
return grantRequest.getClientRegistration().getScopes();
}

@Override
BodyInserters.FormInserter<String> populateTokenRequestBody(JwtBearerGrantRequest grantRequest,
BodyInserters.FormInserter<String> body) {
return super.populateTokenRequestBody(grantRequest, body).with(OAuth2ParameterNames.ASSERTION,
grantRequest.getJwt().getTokenValue());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.TestJwts;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

/**
* Tests for {@link JwtBearerReactiveOAuth2AuthorizedClientProvider}.
*
* @author Steve Riesenberg
*/
public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests {

private JwtBearerReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;

private ReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient;

private ClientRegistration clientRegistration;

private Jwt jwtAssertion;

private Authentication principal;

@BeforeEach
public void setup() {
this.authorizedClientProvider = new JwtBearerReactiveOAuth2AuthorizedClientProvider();
this.accessTokenResponseClient = mock(ReactiveOAuth2AccessTokenResponseClient.class);
this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
// @formatter:off
this.clientRegistration = ClientRegistration.withRegistrationId("jwt-bearer")
.clientId("client-id")
.clientSecret("client-secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.scope("read", "write")
.tokenUri("https://example.com/oauth2/token")
.build();
// @formatter:on
this.jwtAssertion = TestJwts.jwt().build();
this.principal = new TestingAuthenticationToken(this.jwtAssertion, this.jwtAssertion);
}

@Test
public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null))
.withMessage("accessTokenResponseClient cannot be null");
}

@Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null))
.withMessage("clockSkew cannot be null");
// @formatter:on
}

@Test
public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1)))
.withMessage("clockSkew must be >= 0");
// @formatter:on
}

@Test
public void setClockWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClock(null))
.withMessage("clock cannot be null");
// @formatter:on
}

@Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.authorize(null).block())
.withMessage("context cannot be null");
// @formatter:on
}

@Test
public void authorizeWhenNotJwtBearerThenUnableToAuthorize() {
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build();
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(clientRegistration)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
}

@Test
public void authorizeWhenJwtBearerAndTokenNotExpiredThenNotReauthorize() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), TestOAuth2AccessTokens.noScopes());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
}

@Test
public void authorizeWhenJwtBearerAndTokenExpiredThenReauthorize() {
Instant now = Instant.now();
Instant issuedAt = now.minus(Duration.ofMinutes(60));
Instant expiresAt = now.minus(Duration.ofMinutes(30));
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234",
issuedAt, expiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), accessToken);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}

@Test
public void authorizeWhenClockSetThenCalled() {
Clock clock = mock(Clock.class);
given(clock.instant()).willReturn(Instant.now());
this.authorizedClientProvider.setClock(clock);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), TestOAuth2AccessTokens.noScopes());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
verify(clock).instant();
}

@Test
public void authorizeWhenJwtBearerAndTokenNotExpiredButClockSkewForcesExpiryThenReauthorize() {
Instant now = Instant.now();
Instant issuedAt = now.minus(Duration.ofMinutes(60));
Instant expiresAt = now.plus(Duration.ofMinutes(1));
OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"access-token-1234", issuedAt, expiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), expiresInOneMinAccessToken);
// Shorten the lifespan of the access token by 90 seconds, which will ultimately
// force it to expire on the client
this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90));
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext)
.block();
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}

@Test
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration)
.principal(new TestingAuthenticationToken("user", "password"))
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
}

@Test
public void authorizeWhenInvalidRequestThenThrowClientAuthorizationException() {
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(
Mono.error(new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST))));
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration)
.principal(this.principal)
.build();
// @formatter:on

// @formatter:off
assertThatExceptionOfType(ClientAuthorizationException.class)
.isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block())
.withMessageContaining(OAuth2ErrorCodes.INVALID_REQUEST);
// @formatter:on
}

@Test
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration)
.principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.endpoint;

import java.util.Collections;

import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpInputMessage;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.client.WebClient;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

/**
* Tests for {@link WebClientReactiveJwtBearerTokenResponseClient}.
*
* @author Steve Riesenberg
*/
public class WebClientReactiveJwtBearerTokenResponseClientTests {

// @formatter:off
private static final String DEFAULT_ACCESS_TOKEN_RESPONSE = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": 3600\n"
+ "}\n";
// @formatter:on

private WebClientReactiveJwtBearerTokenResponseClient client;

private MockWebServer server;

private ClientRegistration.Builder clientRegistration;

private Jwt jwtAssertion;

@BeforeEach
public void setup() throws Exception {
this.client = new WebClientReactiveJwtBearerTokenResponseClient();
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
// @formatter:off
this.clientRegistration = TestClientRegistrations.clientCredentials()
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.tokenUri(tokenUri)
.scope("read", "write");
// @formatter:on
this.jwtAssertion = TestJwts.jwt().build();
}

@AfterEach
public void cleanup() throws Exception {
this.server.shutdown();
}

@Test
public void setWebClientWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.client.setWebClient(null))
.withMessage("webClient cannot be null");
}

@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.client.setHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}

@Test
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.client.addHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}

@Test
public void setBodyExtractorWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.client.setBodyExtractor(null))
.withMessage("bodyExtractor cannot be null");
}

@Test
public void getTokenResponseWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.client.getTokenResponse(null))
.withMessage("grantRequest cannot be null");
}

@Test
public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationException() {
ClientRegistration registration = this.clientRegistration.build();
enqueueUnexpectedResponse();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.client.getTokenResponse(request).block())
.satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response"))
.withMessage("[invalid_token_response] Empty OAuth 2.0 Access Token Response");
}

@Test
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
ClientRegistration registration = this.clientRegistration.build();
enqueueServerErrorResponse();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.client.getTokenResponse(request).block())
.satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR))
.withMessageContaining("[server_error]");
}

@Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenResponse = "{\n"
+ " \"error\": \"invalid_grant\"\n"
+ "}\n";
// @formatter:on
ClientRegistration registration = this.clientRegistration.build();
enqueueJson(accessTokenResponse);
JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.client.getTokenResponse(request).block())
.satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT))
.withMessageContaining("[invalid_grant]");
}

@Test
public void getTokenResponseWhenResponseIsNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
String accessTokenResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": 3600\n"
+ "}\n";
// @formatter:on
ClientRegistration registration = this.clientRegistration.build();
enqueueJson(accessTokenResponse);
JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.client.getTokenResponse(request).block())
.satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response"))
.withMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response")
.withMessageContaining("Unsupported token_type: not-bearer");
}

@Test
public void getTokenResponseWhenWebClientSetThenCalled() {
WebClient customClient = mock(WebClient.class);
given(customClient.post()).willReturn(WebClient.builder().build().post());
this.client.setWebClient(customClient);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
ClientRegistration registration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion);
this.client.getTokenResponse(request).block();
verify(customClient).post();
}

@Test
public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception {
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
HttpHeaders headers = new HttpHeaders();
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
given(headersConverter.convert(request)).willReturn(headers);
this.client.setHeadersConverter(headersConverter);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.client.getTokenResponse(request).block();
verify(headersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}

@Test
public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception {
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
Converter<JwtBearerGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
HttpHeaders headers = new HttpHeaders();
headers.put("custom-header-name", Collections.singletonList("custom-header-value"));
given(addedHeadersConverter.convert(request)).willReturn(headers);
this.client.addHeadersConverter(addedHeadersConverter);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.client.getTokenResponse(request).block();
verify(addedHeadersConverter).convert(request);
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
assertThat(actualRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value");
}

@Test
public void getTokenResponseWhenBodyExtractorSetThenCalled() {
BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = mock(
BodyExtractor.class);
OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(bodyExtractor.extract(any(), any())).willReturn(Mono.just(response));
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
this.client.setBodyExtractor(bodyExtractor);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
this.client.getTokenResponse(request).block();
verify(bodyExtractor).extract(any(), any());
}

@Test
public void getTokenResponseWhenClientSecretBasicThenSuccess() throws Exception {
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
assertThat(response).isNotNull();
assertThat(response.getAccessToken().getScopes()).containsExactly("read", "write");
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
assertThat(actualRequest.getBody().readUtf8()).isEqualTo(
"grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer&scope=read+write&assertion=token");
}

@Test
public void getTokenResponseWhenClientSecretPostThenSuccess() throws Exception {
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
.build();
// @formatter:on
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
assertThat(response).isNotNull();
assertThat(response.getAccessToken().getScopes()).containsExactly("read", "write");
RecordedRequest actualRequest = this.server.takeRequest();
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
assertThat(actualRequest.getBody().readUtf8()).isEqualTo(
"grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer&client_id=client-id&client_secret=client-secret&scope=read+write&assertion=token");
}

@Test
public void getTokenResponseWhenResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception {
// @formatter:off
String accessTokenResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": 3600,\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
ClientRegistration clientRegistration = this.clientRegistration.build();
JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion);
enqueueJson(accessTokenResponse);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
assertThat(response).isNotNull();
assertThat(response.getAccessToken().getScopes()).containsExactly("read");
}

private void enqueueJson(String body) {
MockResponse response = new MockResponse().setBody(body).setHeader(HttpHeaders.CONTENT_TYPE,
MediaType.APPLICATION_JSON_VALUE);
this.server.enqueue(response);
}

private void enqueueUnexpectedResponse() {
// @formatter:off
MockResponse response = new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setResponseCode(301);
// @formatter:on
this.server.enqueue(response);
}

private void enqueueServerErrorResponse() {
// @formatter:off
MockResponse response = new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setResponseCode(500)
.setBody("{}");
// @formatter:on
this.server.enqueue(response);
}

}
Original file line number Diff line number Diff line change
@@ -61,12 +61,14 @@
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.JwtBearerReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
@@ -78,6 +80,8 @@
import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
@@ -87,6 +91,8 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.client.ClientRequest;
@@ -131,6 +137,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ReactiveOAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> passwordTokenResponseClient;

@Mock
private ReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> jwtBearerTokenResponseClient;

@Mock
private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler;

@@ -162,6 +171,8 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@BeforeEach
public void setup() {
// @formatter:off
JwtBearerReactiveOAuth2AuthorizedClientProvider jwtBearerAuthorizedClientProvider = new JwtBearerReactiveOAuth2AuthorizedClientProvider();
jwtBearerAuthorizedClientProvider.setAccessTokenResponseClient(this.jwtBearerTokenResponseClient);
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder
.builder()
.authorizationCode()
@@ -170,6 +181,7 @@ public void setup() {
.clientCredentials(
(configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient))
.password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient))
.provider(jwtBearerAuthorizedClientProvider)
.build();
// @formatter:on
this.authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(
@@ -396,6 +408,49 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
assertThat(getBody(request0)).isEmpty();
}

@Test
public void filterWhenJwtBearerClientNotAuthorizedThenExchangeToken() {
setupMocks();
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("exchanged-token")
.tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(360).build();
given(this.jwtBearerTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
ClientRegistration registration = ClientRegistration.withRegistrationId("jwt-bearer")
.clientId("client-id")
.clientSecret("client-secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
.scope("read", "write")
.tokenUri("https://example.com/oauth/token")
.build();
// @formatter:on
given(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId())))
.willReturn(Mono.just(registration));
Jwt jwtAssertion = TestJwts.jwt().build();
Authentication jwtAuthentication = new TestingAuthenticationToken(jwtAssertion, jwtAssertion);
given(this.authorizedClientRepository.loadAuthorizedClient(eq(registration.getRegistrationId()),
eq(jwtAuthentication), any())).willReturn(Mono.empty());
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId(registration.getRegistrationId()))
.build();
// @formatter:on
this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(jwtAuthentication))
.subscriberContext(serverWebExchange()).block();
verify(this.jwtBearerTokenResponseClient).getTokenResponse(any());
verify(this.authorizedClientRepository).loadAuthorizedClient(eq(registration.getRegistrationId()),
eq(jwtAuthentication), any());
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(jwtAuthentication), any());
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer exchanged-token");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}

@Test
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",