diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractRestClientOAuth2AccessTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractRestClientOAuth2AccessTokenResponseClient.java index 94c19425206..597b05200db 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractRestClientOAuth2AccessTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractRestClientOAuth2AccessTokenResponseClient.java @@ -25,7 +25,6 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; @@ -75,7 +74,7 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>(); - private Converter> parametersConverter = this::createParameters; + private Converter> parametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>(); AbstractRestClientOAuth2AccessTokenResponseClient() { } @@ -124,6 +123,11 @@ private void validateClientAuthenticationMethod(T grantRequest) { } private RequestHeadersSpec populateRequest(T grantRequest) { + MultiValueMap parameters = this.parametersConverter.convert(grantRequest); + if (parameters == null) { + parameters = new LinkedMultiValueMap<>(); + } + return this.restClient.post() .uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri()) .headers((headers) -> { @@ -132,28 +136,7 @@ private RequestHeadersSpec populateRequest(T grantRequest) { headers.addAll(headersToAdd); } }) - .body(this.parametersConverter.convert(grantRequest)); - } - - /** - * Returns a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access - * Token Request body. - * @param grantRequest the authorization grant request - * @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access - * Token Request body - */ - MultiValueMap createParameters(T grantRequest) { - ClientRegistration clientRegistration = grantRequest.getClientRegistration(); - MultiValueMap parameters = new LinkedMultiValueMap<>(); - parameters.set(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue()); - if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC - .equals(clientRegistration.getClientAuthenticationMethod())) { - parameters.set(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); - } - if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) { - parameters.set(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); - } - return parameters; + .body(parameters); } /** @@ -216,7 +199,21 @@ public final void addHeadersConverter(Converter headersConverter */ public final void setParametersConverter(Converter> parametersConverter) { Assert.notNull(parametersConverter, "parametersConverter cannot be null"); - this.parametersConverter = parametersConverter; + if (parametersConverter instanceof DefaultOAuth2TokenRequestParametersConverter) { + this.parametersConverter = parametersConverter; + } + else { + Converter> defaultParametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>(); + this.parametersConverter = (authorizationGrantRequest) -> { + MultiValueMap parameters = defaultParametersConverter + .convert(authorizationGrantRequest); + MultiValueMap parametersToSet = parametersConverter.convert(authorizationGrantRequest); + if (parametersToSet != null) { + parameters.putAll(parametersToSet); + } + return parameters; + }; + } this.requestEntityConverter = this::populateRequest; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java index 14d403d9880..24241aabeca 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java @@ -16,9 +16,6 @@ package org.springframework.security.oauth2.client.endpoint; -import java.util.Collections; -import java.util.Set; - import reactor.core.publisher.Mono; import org.springframework.core.convert.converter.Converter; @@ -27,16 +24,12 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.reactive.function.BodyInserters; -import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec; @@ -54,6 +47,7 @@ * * @param type of grant request * @author Phil Clay + * @author Steve Riesenberg * @since 5.3 * @see RFC-6749 Token * Endpoint @@ -72,7 +66,7 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>(); - private Converter> parametersConverter = this::populateTokenRequestParameters; + private Converter> parametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>(); private BodyExtractor, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors .oauth2AccessTokenResponse(); @@ -86,18 +80,11 @@ public Mono getTokenResponse(T grantRequest) { // @formatter:off return Mono.defer(() -> this.requestEntityConverter.convert(grantRequest) .exchange() - .flatMap((response) -> readTokenResponse(grantRequest, response)) + .flatMap((response) -> response.body(this.bodyExtractor)) ); // @formatter:on } - /** - * Returns the {@link ClientRegistration} for the given {@code grantRequest}. - * @param grantRequest the grant request - * @return the {@link ClientRegistration} for the given {@code grantRequest}. - */ - abstract ClientRegistration clientRegistration(T grantRequest); - private RequestHeadersSpec validatingPopulateRequest(T grantRequest) { validateClientAuthenticationMethod(grantRequest); return populateRequest(grantRequest); @@ -117,128 +104,20 @@ private void validateClientAuthenticationMethod(T grantRequest) { } private RequestHeadersSpec populateRequest(T grantRequest) { + MultiValueMap parameters = this.parametersConverter.convert(grantRequest); + if (parameters == null) { + parameters = new LinkedMultiValueMap<>(); + } + return this.webClient.post() - .uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri()) + .uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri()) .headers((headers) -> { - HttpHeaders headersToAdd = getHeadersConverter().convert(grantRequest); + HttpHeaders headersToAdd = this.headersConverter.convert(grantRequest); if (headersToAdd != null) { headers.addAll(headersToAdd); } }) - .body(createTokenRequestBody(grantRequest)); - } - - /** - * Populates default parameters for the token request. - * @param grantRequest the grant request - * @return the parameters populated for the token request. - */ - private MultiValueMap populateTokenRequestParameters(T grantRequest) { - MultiValueMap parameters = new LinkedMultiValueMap<>(); - parameters.add(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue()); - return parameters; - } - - /** - * Combine the results of {@code parametersConverter} and - * {@link #populateTokenRequestBody}. - * - *

- * This method pre-populates the body with some standard properties, and then - * delegates to - * {@link #populateTokenRequestBody(AbstractOAuth2AuthorizationGrantRequest, BodyInserters.FormInserter)} - * for subclasses to further populate the body before returning. - *

- * @param grantRequest the grant request - * @return the body for the token request. - */ - private BodyInserters.FormInserter createTokenRequestBody(T grantRequest) { - MultiValueMap parameters = getParametersConverter().convert(grantRequest); - return populateTokenRequestBody(grantRequest, BodyInserters.fromFormData(parameters)); - } - - /** - * Populates the body of the token request. - * - *

- * By default, populates properties that are common to all grant types. Subclasses can - * extend this method to populate grant type specific properties. - *

- * @param grantRequest the grant request - * @param body the body to populate - * @return the populated body - */ - BodyInserters.FormInserter populateTokenRequestBody(T grantRequest, - BodyInserters.FormInserter body) { - ClientRegistration clientRegistration = clientRegistration(grantRequest); - if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC - .equals(clientRegistration.getClientAuthenticationMethod())) { - body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); - } - if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) { - body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); - } - Set scopes = scopes(grantRequest); - if (!CollectionUtils.isEmpty(scopes)) { - body.with(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " ")); - } - return body; - } - - /** - * Returns the scopes to include as a property in the token request. - * @param grantRequest the grant request - * @return the scopes to include as a property in the token request. - */ - abstract Set scopes(T grantRequest); - - /** - * Returns the scopes to include in the response if the authorization server returned - * no scopes in the response. - * - *

- * As per RFC-6749 Section - * 5.1 Successful Access Token Response, if AccessTokenResponse.scope is empty, - * then default to the scope originally requested by the client in the Token Request. - *

- * @param grantRequest the grant request - * @return the scopes to include in the response if the authorization server returned - * no scopes. - */ - Set defaultScopes(T grantRequest) { - return Collections.emptySet(); - } - - /** - * Reads the token response from the response body. - * @param grantRequest the request for which the response was received. - * @param response the client response from which to read - * @return the token response from the response body. - */ - private Mono readTokenResponse(T grantRequest, ClientResponse response) { - return response.body(this.bodyExtractor) - .map((tokenResponse) -> populateTokenResponse(grantRequest, tokenResponse)); - } - - /** - * Populates the given {@link OAuth2AccessTokenResponse} with additional details from - * the grant request. - * @param grantRequest the request for which the response was received. - * @param tokenResponse the original token response - * @return a token response optionally populated with additional details from the - * request. - */ - OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessTokenResponse tokenResponse) { - if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { - Set defaultScopes = defaultScopes(grantRequest); - // @formatter:off - tokenResponse = OAuth2AccessTokenResponse - .withResponse(tokenResponse) - .scopes(defaultScopes) - .build(); - // @formatter:on - } - return tokenResponse; + .body(BodyInserters.fromFormData(parameters)); } /** @@ -247,22 +126,11 @@ OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessToke * @param webClient the {@link WebClient} used when requesting the Access Token * Response */ - public void setWebClient(WebClient webClient) { + public final void setWebClient(WebClient webClient) { Assert.notNull(webClient, "webClient cannot be null"); this.webClient = webClient; } - /** - * Returns the {@link Converter} used for converting the - * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders} - * used in the OAuth 2.0 Access Token Request headers. - * @return the {@link Converter} used for converting the - * {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders} - */ - final Converter getHeadersConverter() { - return this.headersConverter; - } - /** * Sets the {@link Converter} used for converting the * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders} @@ -305,17 +173,6 @@ public final void addHeadersConverter(Converter headersConverter this.requestEntityConverter = this::populateRequest; } - /** - * Returns the {@link Converter} used for converting the - * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap} - * used in the OAuth 2.0 Access Token Request body. - * @return the {@link Converter} used for converting the - * {@link AbstractOAuth2AuthorizationGrantRequest} to {@link MultiValueMap} - */ - final Converter> getParametersConverter() { - return this.parametersConverter; - } - /** * Sets the {@link Converter} used for converting the * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap} @@ -326,7 +183,21 @@ final Converter> getParametersConverter() { */ public final void setParametersConverter(Converter> parametersConverter) { Assert.notNull(parametersConverter, "parametersConverter cannot be null"); - this.parametersConverter = parametersConverter; + if (parametersConverter instanceof DefaultOAuth2TokenRequestParametersConverter) { + this.parametersConverter = parametersConverter; + } + else { + Converter> defaultParametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>(); + this.parametersConverter = (authorizationGrantRequest) -> { + MultiValueMap parameters = defaultParametersConverter + .convert(authorizationGrantRequest); + MultiValueMap parametersToSet = parametersConverter.convert(authorizationGrantRequest); + if (parametersToSet != null) { + parameters.putAll(parametersToSet); + } + return parameters; + }; + } this.requestEntityConverter = this::populateRequest; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultOAuth2TokenRequestParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultOAuth2TokenRequestParametersConverter.java new file mode 100644 index 00000000000..07677a8f39e --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultOAuth2TokenRequestParametersConverter.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.util.function.Consumer; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * Default {@link Converter} used to convert an + * {@link AbstractOAuth2AuthorizationGrantRequest} to the default {@link MultiValueMap + * parameters} of an OAuth 2.0 Access Token Request. + *

+ * This implementation provides grant-type specific parameters for the following grant + * types: + * + *

    + *
  • {@code authorization_code}
  • + *
  • {@code refresh_token}
  • + *
  • {@code client_credentials}
  • + *
  • {@code password}
  • + *
  • {@code urn:ietf:params:oauth:grant-type:jwt-bearer}
  • + *
  • {@code urn:ietf:params:oauth:grant-type:token-exchange}
  • + *
+ * + * In addition, the following default parameters are provided: + * + *
    + *
  • {@code grant_type} - always provided
  • + *
  • {@code client_id} - provided unless the {@code clientAuthenticationMethod} is + * {@code client_secret_basic}
  • + *
  • {@code client_secret} - provided when the {@code clientAuthenticationMethod} is + * {@code client_secret_post}
  • + *
+ * + * @param type of grant request + * @author Steve Riesenberg + * @since 6.4 + * @see AbstractWebClientReactiveOAuth2AccessTokenResponseClient + * @see AbstractRestClientOAuth2AccessTokenResponseClient + */ +public final class DefaultOAuth2TokenRequestParametersConverter + implements Converter> { + + private final Converter> defaultParametersConverter = createDefaultParametersConverter(); + + private Consumer> parametersCustomizer = (parameters) -> { + }; + + /** + * Sets the {@link Consumer} used for customizing the OAuth 2.0 Access Token + * parameters, which allows for parameters to be added, overwritten or removed. + * @param parametersCustomizer the {@link Consumer} to customize the parameters + */ + public void setParametersCustomizer(Consumer> parametersCustomizer) { + Assert.notNull(parametersCustomizer, "parametersCustomizer cannot be null"); + this.parametersCustomizer = parametersCustomizer; + } + + @Override + public MultiValueMap convert(T grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue()); + if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC + .equals(clientRegistration.getClientAuthenticationMethod())) { + parameters.set(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + } + if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) { + parameters.set(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + } + + MultiValueMap defaultParameters = this.defaultParametersConverter.convert(grantRequest); + if (defaultParameters != null) { + parameters.addAll(defaultParameters); + } + + this.parametersCustomizer.accept(parameters); + return parameters; + } + + private static Converter> createDefaultParametersConverter() { + return (grantRequest) -> { + if (grantRequest instanceof OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { + return OAuth2AuthorizationCodeGrantRequest.defaultParameters(authorizationCodeGrantRequest); + } + else if (grantRequest instanceof OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { + return OAuth2ClientCredentialsGrantRequest.defaultParameters(clientCredentialsGrantRequest); + } + else if (grantRequest instanceof OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + return OAuth2RefreshTokenGrantRequest.defaultParameters(refreshTokenGrantRequest); + } + else if (grantRequest instanceof OAuth2PasswordGrantRequest passwordGrantRequest) { + return OAuth2PasswordGrantRequest.defaultParameters(passwordGrantRequest); + } + else if (grantRequest instanceof JwtBearerGrantRequest jwtBearerGrantRequest) { + return JwtBearerGrantRequest.defaultParameters(jwtBearerGrantRequest); + } + else if (grantRequest instanceof TokenExchangeGrantRequest tokenExchangeGrantRequest) { + return TokenExchangeGrantRequest.defaultParameters(tokenExchangeGrantRequest); + } + return null; + }; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtBearerGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtBearerGrantRequest.java index d5fed5d87a0..52ace9b200a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtBearerGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtBearerGrantRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,8 +18,13 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; /** * A JWT Bearer Grant request that holds a {@link Jwt} assertion. @@ -57,4 +62,21 @@ public Jwt getJwt() { return this.jwt; } + /** + * Populate default parameters for the JWT Bearer Grant. + * @param grantRequest the authorization grant request + * @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access + * Token Request body + */ + static MultiValueMap defaultParameters(JwtBearerGrantRequest grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + parameters.set(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + parameters.set(OAuth2ParameterNames.ASSERTION, grantRequest.getJwt().getTokenValue()); + return parameters; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtBearerGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtBearerGrantRequestEntityConverter.java index 6c8cf4fbc6e..cb187e1655d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtBearerGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtBearerGrantRequestEntityConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,9 @@ * @see RequestEntity * @see Section * 2.1 Using JWTs as Authorization Grants + * @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead */ +@Deprecated(since = "6.4") public class JwtBearerGrantRequestEntityConverter extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java index 698ebeec212..7b142fe38ac 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,11 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; /** * An OAuth 2.0 Authorization Code Grant request that holds an Authorization Code @@ -60,4 +64,26 @@ public OAuth2AuthorizationExchange getAuthorizationExchange() { return this.authorizationExchange; } + /** + * Populate default parameters for the Authorization Code Grant. + * @param grantRequest the authorization grant request + * @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access + * Token Request body + */ + static MultiValueMap defaultParameters(OAuth2AuthorizationCodeGrantRequest grantRequest) { + OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode()); + String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri(); + if (redirectUri != null) { + parameters.set(OAuth2ParameterNames.REDIRECT_URI, redirectUri); + } + String codeVerifier = authorizationExchange.getAuthorizationRequest() + .getAttribute(PkceParameterNames.CODE_VERIFIER); + if (codeVerifier != null) { + parameters.set(PkceParameterNames.CODE_VERIFIER, codeVerifier); + } + return parameters; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java index cbce66bcb60..52bab99b134 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,9 @@ * @see AbstractOAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2AuthorizationCodeGrantRequest * @see RequestEntity + * @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead */ +@Deprecated(since = "6.4") public class OAuth2AuthorizationCodeGrantRequestEntityConverter extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java index b1ab0f1f3b3..b37ed156329 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,12 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; /** * An OAuth 2.0 Client Credentials Grant request that holds the client's credentials in @@ -45,4 +50,20 @@ public OAuth2ClientCredentialsGrantRequest(ClientRegistration clientRegistration "clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS"); } + /** + * Populate default parameters for the Client Credentials Grant. + * @param grantRequest the authorization grant request + * @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access + * Token Request body + */ + static MultiValueMap defaultParameters(OAuth2ClientCredentialsGrantRequest grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + parameters.set(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + return parameters; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java index 3c8347c7077..4e246484861 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,9 @@ * @see AbstractOAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2ClientCredentialsGrantRequest * @see RequestEntity + * @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead */ +@Deprecated(since = "6.4") public class OAuth2ClientCredentialsGrantRequestEntityConverter extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java index 5710214f21d..f192d8b1eca 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,12 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; /** * An OAuth 2.0 Resource Owner Password Credentials Grant request that holds the resource @@ -74,4 +79,22 @@ public String getPassword() { return this.password; } + /** + * Populate default parameters for the Password Grant. + * @param grantRequest the authorization grant request + * @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access + * Token Request body + */ + static MultiValueMap defaultParameters(OAuth2PasswordGrantRequest grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + parameters.set(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + parameters.set(OAuth2ParameterNames.USERNAME, grantRequest.getUsername()); + parameters.set(OAuth2ParameterNames.PASSWORD, grantRequest.getPassword()); + return parameters; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java index aa649fdf101..9d25da6e7ef 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,9 @@ * @see AbstractOAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2PasswordGrantRequest * @see RequestEntity + * @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead */ +@Deprecated(since = "6.4") public class OAuth2PasswordGrantRequestEntityConverter extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java index bcd08b1c2de..2245f168093 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,12 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; /** * An OAuth 2.0 Refresh Token Grant request that holds the {@link OAuth2RefreshToken @@ -98,4 +103,20 @@ public Set getScopes() { return this.scopes; } + /** + * Populate default parameters for the Refresh Token Grant. + * @param grantRequest the authorization grant request + * @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access + * Token Request body + */ + static MultiValueMap defaultParameters(OAuth2RefreshTokenGrantRequest grantRequest) { + MultiValueMap parameters = new LinkedMultiValueMap<>(); + if (!CollectionUtils.isEmpty(grantRequest.getScopes())) { + parameters.set(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(grantRequest.getScopes(), " ")); + } + parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, grantRequest.getRefreshToken().getTokenValue()); + return parameters; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java index 07582583312..de08dd6fa9f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,9 @@ * @see AbstractOAuth2AuthorizationGrantRequestEntityConverter * @see OAuth2RefreshTokenGrantRequest * @see RequestEntity + * @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead */ +@Deprecated(since = "6.4") public class OAuth2RefreshTokenGrantRequestEntityConverter extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClient.java index a63d997a9a1..37964a6ff51 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClient.java @@ -17,10 +17,6 @@ package org.springframework.security.oauth2.client.endpoint; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; -import org.springframework.util.MultiValueMap; /** * An implementation of {@link OAuth2AccessTokenResponseClient} that "exchanges" @@ -43,21 +39,4 @@ public final class RestClientAuthorizationCodeTokenResponseClient extends AbstractRestClientOAuth2AccessTokenResponseClient { - @Override - MultiValueMap createParameters(OAuth2AuthorizationCodeGrantRequest grantRequest) { - OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange(); - MultiValueMap parameters = super.createParameters(grantRequest); - parameters.set(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode()); - String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri(); - if (redirectUri != null) { - parameters.set(OAuth2ParameterNames.REDIRECT_URI, redirectUri); - } - String codeVerifier = authorizationExchange.getAuthorizationRequest() - .getAttribute(PkceParameterNames.CODE_VERIFIER); - if (codeVerifier != null) { - parameters.set(PkceParameterNames.CODE_VERIFIER, codeVerifier); - } - return parameters; - } - } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClient.java index 7aa896e913a..6a47a612015 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClient.java @@ -16,12 +16,7 @@ package org.springframework.security.oauth2.client.endpoint; -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.util.CollectionUtils; -import org.springframework.util.MultiValueMap; -import org.springframework.util.StringUtils; /** * An implementation of {@link OAuth2AccessTokenResponseClient} that "exchanges" @@ -42,15 +37,4 @@ public final class RestClientClientCredentialsTokenResponseClient extends AbstractRestClientOAuth2AccessTokenResponseClient { - @Override - MultiValueMap createParameters(OAuth2ClientCredentialsGrantRequest grantRequest) { - ClientRegistration clientRegistration = grantRequest.getClientRegistration(); - MultiValueMap parameters = super.createParameters(grantRequest); - if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { - parameters.set(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); - } - return parameters; - } - } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClient.java index 65102410675..bf7cff5c730 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClient.java @@ -16,12 +16,7 @@ package org.springframework.security.oauth2.client.endpoint; -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.util.CollectionUtils; -import org.springframework.util.MultiValueMap; -import org.springframework.util.StringUtils; /** * An implementation of {@link OAuth2AccessTokenResponseClient} that "exchanges" @@ -40,16 +35,4 @@ public final class RestClientJwtBearerTokenResponseClient extends AbstractRestClientOAuth2AccessTokenResponseClient { - @Override - MultiValueMap createParameters(JwtBearerGrantRequest grantRequest) { - ClientRegistration clientRegistration = grantRequest.getClientRegistration(); - MultiValueMap parameters = super.createParameters(grantRequest); - if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { - parameters.set(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); - } - parameters.set(OAuth2ParameterNames.ASSERTION, grantRequest.getJwt().getTokenValue()); - return parameters; - } - } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClient.java index 02519ca8aa6..83a695a295c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClient.java @@ -17,10 +17,7 @@ package org.springframework.security.oauth2.client.endpoint; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.CollectionUtils; -import org.springframework.util.MultiValueMap; -import org.springframework.util.StringUtils; /** * An implementation of {@link OAuth2AccessTokenResponseClient} that "exchanges" @@ -43,17 +40,6 @@ public OAuth2AccessTokenResponse getTokenResponse(OAuth2RefreshTokenGrantRequest return populateTokenResponse(grantRequest, accessTokenResponse); } - @Override - MultiValueMap createParameters(OAuth2RefreshTokenGrantRequest grantRequest) { - MultiValueMap parameters = super.createParameters(grantRequest); - if (!CollectionUtils.isEmpty(grantRequest.getScopes())) { - parameters.set(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(grantRequest.getScopes(), " ")); - } - parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, grantRequest.getRefreshToken().getTokenValue()); - return parameters; - } - private OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest, OAuth2AccessTokenResponse accessTokenResponse) { if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes()) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClient.java index e0e6544ad94..4148e98ad65 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClient.java @@ -16,14 +16,7 @@ package org.springframework.security.oauth2.client.endpoint; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.util.CollectionUtils; -import org.springframework.util.MultiValueMap; -import org.springframework.util.StringUtils; /** * An implementation of {@link OAuth2AccessTokenResponseClient} that "exchanges" @@ -43,32 +36,4 @@ public final class RestClientTokenExchangeTokenResponseClient extends AbstractRestClientOAuth2AccessTokenResponseClient { - private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; - - private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; - - @Override - MultiValueMap createParameters(TokenExchangeGrantRequest grantRequest) { - ClientRegistration clientRegistration = grantRequest.getClientRegistration(); - MultiValueMap parameters = super.createParameters(grantRequest); - if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { - parameters.set(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); - } - parameters.set(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE); - OAuth2Token subjectToken = grantRequest.getSubjectToken(); - parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN, subjectToken.getTokenValue()); - parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, tokenType(subjectToken)); - OAuth2Token actorToken = grantRequest.getActorToken(); - if (actorToken != null) { - parameters.set(OAuth2ParameterNames.ACTOR_TOKEN, actorToken.getTokenValue()); - parameters.set(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, tokenType(actorToken)); - } - return parameters; - } - - private static String tokenType(OAuth2Token token) { - return (token instanceof Jwt) ? JWT_TOKEN_TYPE_VALUE : ACCESS_TOKEN_TYPE_VALUE; - } - } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequest.java index 0a026a56724..d225a5c91bc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequest.java @@ -19,7 +19,13 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; /** * A Token Exchange Grant request that holds the {@link OAuth2Token subject token} and @@ -39,6 +45,10 @@ */ public class TokenExchangeGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + private final OAuth2Token subjectToken; private final OAuth2Token actorToken; @@ -75,4 +85,33 @@ public OAuth2Token getActorToken() { return this.actorToken; } + /** + * Populate default parameters for the Token Exchange Grant. + * @param grantRequest the authorization grant request + * @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access + * Token Request body + */ + static MultiValueMap defaultParameters(TokenExchangeGrantRequest grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + parameters.set(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + parameters.set(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE); + OAuth2Token subjectToken = grantRequest.getSubjectToken(); + parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN, subjectToken.getTokenValue()); + parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, tokenType(subjectToken)); + OAuth2Token actorToken = grantRequest.getActorToken(); + if (actorToken != null) { + parameters.set(OAuth2ParameterNames.ACTOR_TOKEN, actorToken.getTokenValue()); + parameters.set(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, tokenType(actorToken)); + } + return parameters; + } + + private static String tokenType(OAuth2Token token) { + return (token instanceof Jwt) ? JWT_TOKEN_TYPE_VALUE : ACCESS_TOKEN_TYPE_VALUE; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverter.java index c8f72e4adb4..ea24499d30b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverter.java @@ -39,7 +39,9 @@ * @see RequestEntity * @see Section * 1.1 Delegation vs. Impersonation Semantics + * @deprecated Use {@link DefaultOAuth2TokenRequestParametersConverter} instead */ +@Deprecated(since = "6.4") public class TokenExchangeGrantRequestEntityConverter extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java index ba5ad0cf69a..4f713f8090d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,16 +16,7 @@ package org.springframework.security.oauth2.client.endpoint; -import java.util.Collections; -import java.util.Set; - -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; -import org.springframework.web.reactive.function.BodyInserters; /** * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} that @@ -55,33 +46,4 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClient extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { - @Override - ClientRegistration clientRegistration(OAuth2AuthorizationCodeGrantRequest grantRequest) { - return grantRequest.getClientRegistration(); - } - - @Override - Set scopes(OAuth2AuthorizationCodeGrantRequest grantRequest) { - return Collections.emptySet(); - } - - @Override - BodyInserters.FormInserter populateTokenRequestBody(OAuth2AuthorizationCodeGrantRequest grantRequest, - BodyInserters.FormInserter body) { - super.populateTokenRequestBody(grantRequest, body); - OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange(); - OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse(); - body.with(OAuth2ParameterNames.CODE, authorizationResponse.getCode()); - String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri(); - if (redirectUri != null) { - body.with(OAuth2ParameterNames.REDIRECT_URI, redirectUri); - } - String codeVerifier = authorizationExchange.getAuthorizationRequest() - .getAttribute(PkceParameterNames.CODE_VERIFIER); - if (codeVerifier != null) { - body.with(PkceParameterNames.CODE_VERIFIER, codeVerifier); - } - return body; - } - } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java index f7df261b6b2..b0252aeda72 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,6 @@ package org.springframework.security.oauth2.client.endpoint; -import java.util.Set; - -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; /** @@ -44,14 +41,4 @@ public class WebClientReactiveClientCredentialsTokenResponseClient extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { - @Override - ClientRegistration clientRegistration(OAuth2ClientCredentialsGrantRequest grantRequest) { - return grantRequest.getClientRegistration(); - } - - @Override - Set scopes(OAuth2ClientCredentialsGrantRequest grantRequest) { - return grantRequest.getClientRegistration().getScopes(); - } - } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClient.java index 157f00be510..50a8f0db4bd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,8 @@ package org.springframework.security.oauth2.client.endpoint; -import java.util.Set; - -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.WebClient; /** @@ -44,21 +39,4 @@ public final class WebClientReactiveJwtBearerTokenResponseClient extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { - @Override - ClientRegistration clientRegistration(JwtBearerGrantRequest grantRequest) { - return grantRequest.getClientRegistration(); - } - - @Override - Set scopes(JwtBearerGrantRequest grantRequest) { - return grantRequest.getClientRegistration().getScopes(); - } - - @Override - BodyInserters.FormInserter populateTokenRequestBody(JwtBearerGrantRequest grantRequest, - BodyInserters.FormInserter body) { - return super.populateTokenRequestBody(grantRequest, body).with(OAuth2ParameterNames.ASSERTION, - grantRequest.getJwt().getTokenValue()); - } - } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java index e175b3b37c8..1de84db719e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,8 @@ package org.springframework.security.oauth2.client.endpoint; -import java.util.Set; - -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.WebClient; /** @@ -51,22 +46,4 @@ public final class WebClientReactivePasswordTokenResponseClient extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { - @Override - ClientRegistration clientRegistration(OAuth2PasswordGrantRequest grantRequest) { - return grantRequest.getClientRegistration(); - } - - @Override - Set scopes(OAuth2PasswordGrantRequest grantRequest) { - return grantRequest.getClientRegistration().getScopes(); - } - - @Override - BodyInserters.FormInserter populateTokenRequestBody(OAuth2PasswordGrantRequest grantRequest, - BodyInserters.FormInserter body) { - return super.populateTokenRequestBody(grantRequest, body) - .with(OAuth2ParameterNames.USERNAME, grantRequest.getUsername()) - .with(OAuth2ParameterNames.PASSWORD, grantRequest.getPassword()); - } - } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java index 0c814a13e5c..7337c68d9b6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,14 +16,11 @@ package org.springframework.security.oauth2.client.endpoint; -import java.util.Set; +import reactor.core.publisher.Mono; -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.CollectionUtils; -import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.WebClient; /** @@ -44,29 +41,12 @@ public final class WebClientReactiveRefreshTokenTokenResponseClient extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { @Override - ClientRegistration clientRegistration(OAuth2RefreshTokenGrantRequest grantRequest) { - return grantRequest.getClientRegistration(); + public Mono getTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest) { + return super.getTokenResponse(grantRequest) + .map((accessTokenResponse) -> populateTokenResponse(grantRequest, accessTokenResponse)); } - @Override - Set scopes(OAuth2RefreshTokenGrantRequest grantRequest) { - return grantRequest.getScopes(); - } - - @Override - Set defaultScopes(OAuth2RefreshTokenGrantRequest grantRequest) { - return grantRequest.getAccessToken().getScopes(); - } - - @Override - BodyInserters.FormInserter populateTokenRequestBody(OAuth2RefreshTokenGrantRequest grantRequest, - BodyInserters.FormInserter body) { - return super.populateTokenRequestBody(grantRequest, body).with(OAuth2ParameterNames.REFRESH_TOKEN, - grantRequest.getRefreshToken().getTokenValue()); - } - - @Override - OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest, + private OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest, OAuth2AccessTokenResponse accessTokenResponse) { if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes()) && accessTokenResponse.getRefreshToken() != null) { @@ -75,7 +55,7 @@ OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantRequest g OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse .withResponse(accessTokenResponse); if (CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())) { - tokenResponseBuilder.scopes(defaultScopes(grantRequest)); + tokenResponseBuilder.scopes(grantRequest.getAccessToken().getScopes()); } if (accessTokenResponse.getRefreshToken() == null) { // Reuse existing refresh token diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java index abc9ad751b8..ce6bae6bc2c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java @@ -16,15 +16,8 @@ package org.springframework.security.oauth2.client.endpoint; -import java.util.Set; - -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2Token; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.WebClient; /** @@ -46,38 +39,4 @@ public final class WebClientReactiveTokenExchangeTokenResponseClient extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { - private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; - - private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; - - @Override - ClientRegistration clientRegistration(TokenExchangeGrantRequest grantRequest) { - return grantRequest.getClientRegistration(); - } - - @Override - Set scopes(TokenExchangeGrantRequest grantRequest) { - return grantRequest.getClientRegistration().getScopes(); - } - - @Override - BodyInserters.FormInserter populateTokenRequestBody(TokenExchangeGrantRequest grantRequest, - BodyInserters.FormInserter body) { - super.populateTokenRequestBody(grantRequest, body); - body.with(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE); - OAuth2Token subjectToken = grantRequest.getSubjectToken(); - body.with(OAuth2ParameterNames.SUBJECT_TOKEN, subjectToken.getTokenValue()); - body.with(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, tokenType(subjectToken)); - OAuth2Token actorToken = grantRequest.getActorToken(); - if (actorToken != null) { - body.with(OAuth2ParameterNames.ACTOR_TOKEN, actorToken.getTokenValue()); - body.with(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, tokenType(actorToken)); - } - return body; - } - - private static String tokenType(OAuth2Token token) { - return (token instanceof Jwt) ? JWT_TOKEN_TYPE_VALUE : ACCESS_TOKEN_TYPE_VALUE; - } - } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/MockResponses.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/MockResponses.java new file mode 100644 index 00000000000..ba17268d343 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/MockResponses.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import okhttp3.mockwebserver.MockResponse; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; + +/** + * @author Steve Riesenberg + */ +public final class MockResponses { + + private MockResponses() { + } + + public static MockResponse json(String path) { + try { + String json = new ClassPathResource(path).getContentAsString(StandardCharsets.UTF_8); + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } + catch (IOException ex) { + throw new RuntimeException("Unable to read %s as a classpath resource".formatted(path), ex); + } + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultOAuth2TokenRequestParametersConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultOAuth2TokenRequestParametersConverterTests.java new file mode 100644 index 00000000000..541c0a80862 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultOAuth2TokenRequestParametersConverterTests.java @@ -0,0 +1,228 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.util.Map; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link DefaultOAuth2TokenRequestParametersConverter}. + * + * @author Steve Riesenberg + */ +public class DefaultOAuth2TokenRequestParametersConverterTests { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + private ClientRegistration.Builder clientRegistration; + + @BeforeEach + public void setUp() { + this.clientRegistration = TestClientRegistrations.clientRegistration() + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .clientId("client-1") + .clientSecret("secret") + .scope("read", "write"); + } + + @Test + public void convertWhenGrantRequestIsAuthorizationCodeThenParametersProvided() { + ClientRegistration clientRegistration = this.clientRegistration + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .build(); + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .clientId("client-1") + .state("state") + .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) + .redirectUri(clientRegistration.getRedirectUri()) + .attributes(Map.of(PkceParameterNames.CODE_VERIFIER, "code-verifier")) + .scopes(clientRegistration.getScopes()) + .build(); + OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse.success("code") + .state("state") + .redirectUri(clientRegistration.getRedirectUri()) + .build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + authorizationResponse); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + authorizationExchange); + // @formatter:off + DefaultOAuth2TokenRequestParametersConverter parametersConverter = + new DefaultOAuth2TokenRequestParametersConverter<>(); + // @formatter:on + MultiValueMap parameters = parametersConverter.convert(grantRequest); + assertThat(parameters).hasSize(6); + assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE)) + .containsExactly(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET)) + .containsExactly(clientRegistration.getClientSecret()); + assertThat(parameters.get(OAuth2ParameterNames.CODE)).containsExactly(authorizationResponse.getCode()); + assertThat(parameters.get(OAuth2ParameterNames.REDIRECT_URI)) + .containsExactly(clientRegistration.getRedirectUri()); + assertThat(parameters.get(PkceParameterNames.CODE_VERIFIER)) + .containsExactly(authorizationRequest.getAttribute(PkceParameterNames.CODE_VERIFIER)); + } + + @Test + public void convertWhenGrantRequestIsClientCredentialsThenParametersProvided() { + ClientRegistration clientRegistration = this.clientRegistration + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + // @formatter:off + DefaultOAuth2TokenRequestParametersConverter parametersConverter = + new DefaultOAuth2TokenRequestParametersConverter<>(); + // @formatter:on + MultiValueMap parameters = parametersConverter.convert(grantRequest); + assertThat(parameters).hasSize(4); + assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE)) + .containsExactly(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET)) + .containsExactly(clientRegistration.getClientSecret()); + assertThat(parameters.get(OAuth2ParameterNames.SCOPE)) + .containsExactly(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + + @Test + public void convertWhenGrantRequestIsRefreshTokenThenParametersProvided() { + ClientRegistration clientRegistration = this.clientRegistration + .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN) + .build(); + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + accessToken, refreshToken, clientRegistration.getScopes()); + // @formatter:off + DefaultOAuth2TokenRequestParametersConverter parametersConverter = + new DefaultOAuth2TokenRequestParametersConverter<>(); + // @formatter:on + MultiValueMap parameters = parametersConverter.convert(grantRequest); + assertThat(parameters).hasSize(5); + assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE)) + .containsExactly(AuthorizationGrantType.REFRESH_TOKEN.getValue()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET)) + .containsExactly(clientRegistration.getClientSecret()); + assertThat(parameters.get(OAuth2ParameterNames.REFRESH_TOKEN)).containsExactly(refreshToken.getTokenValue()); + assertThat(parameters.get(OAuth2ParameterNames.SCOPE)) + .containsExactly(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + + @Test + public void convertWhenGrantRequestIsPasswordThenParametersProvided() { + ClientRegistration clientRegistration = this.clientRegistration + .authorizationGrantType(AuthorizationGrantType.PASSWORD) + .build(); + OAuth2PasswordGrantRequest grantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user", + "password"); + // @formatter:off + DefaultOAuth2TokenRequestParametersConverter parametersConverter = + new DefaultOAuth2TokenRequestParametersConverter<>(); + // @formatter:on + MultiValueMap parameters = parametersConverter.convert(grantRequest); + assertThat(parameters).hasSize(6); + assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE)) + .containsExactly(AuthorizationGrantType.PASSWORD.getValue()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET)) + .containsExactly(clientRegistration.getClientSecret()); + assertThat(parameters.get(OAuth2ParameterNames.USERNAME)).containsExactly(grantRequest.getUsername()); + assertThat(parameters.get(OAuth2ParameterNames.PASSWORD)).containsExactly(grantRequest.getPassword()); + assertThat(parameters.get(OAuth2ParameterNames.SCOPE)) + .containsExactly(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + + @Test + public void convertWhenGrantRequestIsJwtBearerThenParametersProvided() { + ClientRegistration clientRegistration = this.clientRegistration + .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) + .build(); + Jwt jwt = TestJwts.jwt().build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, jwt); + // @formatter:off + DefaultOAuth2TokenRequestParametersConverter parametersConverter = + new DefaultOAuth2TokenRequestParametersConverter<>(); + // @formatter:on + MultiValueMap parameters = parametersConverter.convert(grantRequest); + assertThat(parameters).hasSize(5); + assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE)) + .containsExactly(AuthorizationGrantType.JWT_BEARER.getValue()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET)) + .containsExactly(clientRegistration.getClientSecret()); + assertThat(parameters.get(OAuth2ParameterNames.ASSERTION)).containsExactly(jwt.getTokenValue()); + assertThat(parameters.get(OAuth2ParameterNames.SCOPE)) + .containsExactly(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + + @Test + public void convertWhenGrantRequestIsTokenExchangeThenParametersProvided() { + ClientRegistration clientRegistration = this.clientRegistration + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .build(); + OAuth2Token subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + OAuth2Token actorToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, subjectToken, + actorToken); + // @formatter:off + DefaultOAuth2TokenRequestParametersConverter parametersConverter = + new DefaultOAuth2TokenRequestParametersConverter<>(); + // @formatter:on + MultiValueMap parameters = parametersConverter.convert(grantRequest); + assertThat(parameters).hasSize(9); + assertThat(parameters.get(OAuth2ParameterNames.GRANT_TYPE)) + .containsExactly(AuthorizationGrantType.TOKEN_EXCHANGE.getValue()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_ID)).containsExactly(clientRegistration.getClientId()); + assertThat(parameters.get(OAuth2ParameterNames.CLIENT_SECRET)) + .containsExactly(clientRegistration.getClientSecret()); + assertThat(parameters.get(OAuth2ParameterNames.SCOPE)) + .containsExactly(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + assertThat(parameters.get(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)).containsExactly(ACCESS_TOKEN_TYPE_VALUE); + assertThat(parameters.get(OAuth2ParameterNames.SUBJECT_TOKEN)).containsExactly(subjectToken.getTokenValue()); + assertThat(parameters.get(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).containsExactly(ACCESS_TOKEN_TYPE_VALUE); + assertThat(parameters.get(OAuth2ParameterNames.ACTOR_TOKEN)).containsExactly(actorToken.getTokenValue()); + assertThat(parameters.get(OAuth2ParameterNames.ACTOR_TOKEN_TYPE)).containsExactly(JWT_TOKEN_TYPE_VALUE); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClientTests.java index 95d6bb188e7..226ca94b0d1 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientAuthorizationCodeTokenResponseClientTests.java @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collections; +import java.util.function.Consumer; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; @@ -34,6 +35,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.security.oauth2.client.MockResponses; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -54,6 +56,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -80,13 +83,12 @@ public void setUp() throws IOException { this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); - // @formatter:off this.clientRegistration = TestClientRegistrations.clientRegistration() - .clientId("client-1") - .clientSecret("secret") - .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .tokenUri(tokenUri) - .scope("read", "write"); + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .tokenUri(tokenUri) + .scope("read", "write"); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .clientId("client-1") @@ -99,7 +101,6 @@ public void setUp() throws IOException { .state("state") .redirectUri(clientRegistration.getRedirectUri()) .build(); - // @formatter:on this.authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); } @@ -164,15 +165,7 @@ public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentExcept @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, @@ -201,14 +194,7 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t @Test public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, this.authorizationExchange); @@ -219,14 +205,7 @@ public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorization @Test public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .build(); @@ -235,19 +214,17 @@ public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParameters this.tokenResponseClient.getTokenResponse(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); - assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), + param(OAuth2ParameterNames.CLIENT_SECRET, "secret") + ); + // @formatter:on } @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"not-bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("invalid-token-type-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, this.authorizationExchange); @@ -262,15 +239,7 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, this.authorizationExchange); @@ -281,14 +250,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasRe @Test public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, this.authorizationExchange); @@ -313,8 +275,7 @@ public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationExcep @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500)); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest request = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, this.authorizationExchange); @@ -328,8 +289,7 @@ public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationE @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400)); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest request = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, this.authorizationExchange); @@ -371,18 +331,11 @@ public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegal @Test public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, this.authorizationExchange); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -396,18 +349,11 @@ public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Excepti @Test public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, this.authorizationExchange); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -421,19 +367,11 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception @Test public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, this.authorizationExchange); - Converter> parametersConverter = mock( - Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -442,20 +380,13 @@ public void getTokenResponseWhenParametersConverterSetThenCalled() throws Except verify(parametersConverter).convert(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); - assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + assertThat(formParameters).contains(param("custom-parameter-name", "custom-parameter-value")); } @Test public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception { this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, this.authorizationExchange); @@ -463,7 +394,6 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom"); parameters.set(OAuth2ParameterNames.CODE, "custom-code"); parameters.set(OAuth2ParameterNames.REDIRECT_URI, "custom-uri"); - // The client_id parameter is omitted for testing purposes this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters); this.tokenResponseClient.getTokenResponse(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); @@ -471,27 +401,20 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP // @formatter:off assertThat(formParameters).contains( param(OAuth2ParameterNames.GRANT_TYPE, "custom"), + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), param(OAuth2ParameterNames.CODE, "custom-code"), - param(OAuth2ParameterNames.REDIRECT_URI, "custom-uri")); + param(OAuth2ParameterNames.REDIRECT_URI, "custom-uri") + ); // @formatter:on - assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID); } @Test public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, this.authorizationExchange); - Converter> parametersConverter = mock( - Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -510,15 +433,25 @@ public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exce } @Test - public void getTokenResponseWhenRestClientSetThenCalled() { + public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception { + this.server.enqueue(MockResponses.json("access-token-response.json")); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration, + this.authorizationExchange); + Consumer> parametersCustomizer = mock(); // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; + DefaultOAuth2TokenRequestParametersConverter parametersConverter = + new DefaultOAuth2TokenRequestParametersConverter<>(); // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + parametersConverter.setParametersCustomizer(parametersCustomizer); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersCustomizer).accept(any()); + } + + @Test + public void getTokenResponseWhenRestClientSetThenCalled() { + this.server.enqueue(MockResponses.json("access-token-response.json")); RestClient restClient = RestClient.builder().messageConverters((messageConverters) -> { messageConverters.add(0, new FormHttpMessageConverter()); messageConverters.add(1, new OAuth2AccessTokenResponseHttpMessageConverter()); @@ -532,10 +465,6 @@ public void getTokenResponseWhenRestClientSetThenCalled() { verify(customClient).post(); } - private static MockResponse jsonResponse(String json) { - return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); - } - private static String param(String parameterName, String parameterValue) { return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClientTests.java index bd9fd031139..f1936fcce62 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientClientCredentialsTokenResponseClientTests.java @@ -22,6 +22,7 @@ import java.time.Instant; import java.util.Collections; import java.util.Set; +import java.util.function.Consumer; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; @@ -35,6 +36,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.security.oauth2.client.MockResponses; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -53,6 +55,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -77,14 +80,12 @@ public void setUp() throws IOException { this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); - // @formatter:off this.clientRegistration = TestClientRegistrations.clientCredentials() - .clientId("client-1") - .clientSecret("secret") - .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .tokenUri(tokenUri) - .scope("read", "write"); - // @formatter:on + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .tokenUri(tokenUri) + .scope("read", "write"); } @AfterEach @@ -148,15 +149,7 @@ public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentExcept @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); ClientRegistration clientRegistration = this.clientRegistration.build(); Set scopes = clientRegistration.getScopes(); @@ -185,14 +178,7 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t @Test public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); this.tokenResponseClient.getTokenResponse(grantRequest); @@ -202,14 +188,7 @@ public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorization @Test public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .build(); @@ -217,19 +196,17 @@ public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParameters this.tokenResponseClient.getTokenResponse(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); - assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), + param(OAuth2ParameterNames.CLIENT_SECRET, "secret") + ); + // @formatter:on } @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"not-bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("invalid-token-type-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); // @formatter:off @@ -243,15 +220,7 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); @@ -261,14 +230,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasRe @Test public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); @@ -278,14 +240,7 @@ public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessToke @Test public void getTokenResponseWhenRequestDoesNotIncludeScopeThenAccessTokenHasNoScope() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); // @formatter:off ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("no-scope") .clientId("client-1") @@ -328,8 +283,7 @@ public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationExcep @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500)); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(clientRegistration); // @formatter:off @@ -342,8 +296,7 @@ public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationE @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400)); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(clientRegistration); // @formatter:off @@ -382,17 +335,10 @@ public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegal @Test public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -406,17 +352,10 @@ public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Excepti @Test public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -430,18 +369,10 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception @Test public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); - Converter> parametersConverter = mock( - Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -450,28 +381,19 @@ public void getTokenResponseWhenParametersConverterSetThenCalled() throws Except verify(parametersConverter).convert(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); - assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + assertThat(formParameters).contains(param("custom-parameter-name", "custom-parameter-value")); } @Test public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception { this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); - Converter> parametersConverter = mock( - Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom"); parameters.set(OAuth2ParameterNames.SCOPE, "one two"); - // The client_id parameter is omitted for testing purposes given(parametersConverter.convert(grantRequest)).willReturn(parameters); this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters); this.tokenResponseClient.getTokenResponse(grantRequest); @@ -480,26 +402,19 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP // @formatter:off assertThat(formParameters).contains( param(OAuth2ParameterNames.GRANT_TYPE, "custom"), - param(OAuth2ParameterNames.SCOPE, "one two")); + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), + param(OAuth2ParameterNames.SCOPE, "one two") + ); // @formatter:on - assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID); } @Test public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); Set scopes = clientRegistration.getScopes(); OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); - Converter> parametersConverter = mock( - Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -518,15 +433,24 @@ public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exce } @Test - public void getTokenResponseWhenRestClientSetThenCalled() { + public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception { + this.server.enqueue(MockResponses.json("access-token-response.json")); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + Consumer> parametersCustomizer = mock(); // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; + DefaultOAuth2TokenRequestParametersConverter parametersConverter = + new DefaultOAuth2TokenRequestParametersConverter<>(); // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + parametersConverter.setParametersCustomizer(parametersCustomizer); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersCustomizer).accept(any()); + } + + @Test + public void getTokenResponseWhenRestClientSetThenCalled() { + this.server.enqueue(MockResponses.json("access-token-response.json")); RestClient restClient = RestClient.builder().messageConverters((messageConverters) -> { messageConverters.add(0, new FormHttpMessageConverter()); messageConverters.add(1, new OAuth2AccessTokenResponseHttpMessageConverter()); @@ -539,10 +463,6 @@ public void getTokenResponseWhenRestClientSetThenCalled() { verify(customClient).post(); } - private static MockResponse jsonResponse(String json) { - return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); - } - private static String param(String parameterName, String parameterValue) { return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClientTests.java index 91d2649942b..77a00712e09 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientJwtBearerTokenResponseClientTests.java @@ -22,6 +22,7 @@ import java.time.Instant; import java.util.Collections; import java.util.Set; +import java.util.function.Consumer; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; @@ -34,6 +35,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.MockResponses; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -53,6 +55,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -78,14 +81,12 @@ public void setUp() throws IOException { this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); - // @formatter:off this.clientRegistration = TestClientRegistrations.clientCredentials() - .clientId("client-1") - .clientSecret("secret") - .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) - .tokenUri(tokenUri) - .scope("read", "write"); - // @formatter:on + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) + .tokenUri(tokenUri) + .scope("read", "write"); this.jwtAssertion = TestJwts.jwt().build(); } @@ -150,15 +151,7 @@ public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentExcept @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); ClientRegistration clientRegistration = this.clientRegistration.build(); Set scopes = clientRegistration.getScopes(); @@ -188,14 +181,7 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t @Test public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); this.tokenResponseClient.getTokenResponse(grantRequest); @@ -205,14 +191,7 @@ public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorization @Test public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .build(); @@ -220,19 +199,17 @@ public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParameters this.tokenResponseClient.getTokenResponse(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); - assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), + param(OAuth2ParameterNames.CLIENT_SECRET, "secret") + ); + // @formatter:on } @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"not-bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("invalid-token-type-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); // @formatter:off @@ -246,15 +223,7 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); @@ -264,14 +233,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasRe @Test public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); @@ -294,8 +256,7 @@ public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationExcep @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500)); ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); // @formatter:off @@ -308,8 +269,7 @@ public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationE @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400)); ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); // @formatter:off @@ -348,17 +308,10 @@ public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegal @Test public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -372,17 +325,10 @@ public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Excepti @Test public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -397,14 +343,7 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception @Test public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); MultiValueMap parameters = new LinkedMultiValueMap<>(); @@ -419,26 +358,20 @@ public void getTokenResponseWhenParametersConverterSetThenCalled() throws Except // @formatter:off assertThat(formParameters).contains( param(OAuth2ParameterNames.GRANT_TYPE, "custom"), - param(OAuth2ParameterNames.ASSERTION, "custom-assertion"), - param(OAuth2ParameterNames.SCOPE, "one two")); + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), + param(OAuth2ParameterNames.SCOPE, "one two"), + param(OAuth2ParameterNames.ASSERTION, "custom-assertion") + ); // @formatter:on - assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID); } @Test public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception { this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - Converter> parametersConverter = mock(Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -447,23 +380,16 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP verify(parametersConverter).convert(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); - assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + assertThat(formParameters).contains(param("custom-parameter-name", "custom-parameter-value")); } @Test public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); Set scopes = clientRegistration.getScopes(); JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - Converter> parametersConverter = mock(Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -483,16 +409,25 @@ public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exce } @Test - public void getTokenResponseWhenRestClientSetThenCalled() { + public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception { + this.server.enqueue(MockResponses.json("access-token-response.json")); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest grantRequest = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + Consumer> parametersCustomizer = mock(); // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; + DefaultOAuth2TokenRequestParametersConverter parametersConverter = + new DefaultOAuth2TokenRequestParametersConverter<>(); // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - RestClient customClient = mock(RestClient.class); + parametersConverter.setParametersCustomizer(parametersCustomizer); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersCustomizer).accept(any()); + } + + @Test + public void getTokenResponseWhenRestClientSetThenCalled() { + this.server.enqueue(MockResponses.json("access-token-response.json")); + RestClient customClient = mock(); given(customClient.post()).willReturn(RestClient.builder().build().post()); this.tokenResponseClient.setRestClient(customClient); ClientRegistration clientRegistration = this.clientRegistration.build(); @@ -501,10 +436,6 @@ public void getTokenResponseWhenRestClientSetThenCalled() { verify(customClient).post(); } - private static MockResponse jsonResponse(String json) { - return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); - } - private static String param(String parameterName, String parameterValue) { return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClientTests.java index 3a2e0cf5f71..5d55f5160d8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientRefreshTokenTokenResponseClientTests.java @@ -22,6 +22,7 @@ import java.time.Instant; import java.util.Collections; import java.util.Set; +import java.util.function.Consumer; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; @@ -35,6 +36,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.security.oauth2.client.MockResponses; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -56,6 +58,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -84,14 +87,12 @@ public void setUp() throws IOException { this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); - // @formatter:off this.clientRegistration = TestClientRegistrations.clientCredentials() - .clientId("client-1") - .clientSecret("secret") - .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN) - .tokenUri(tokenUri) - .scope("read", "write"); - // @formatter:on + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN) + .tokenUri(tokenUri) + .scope("read", "write"); this.accessToken = TestOAuth2AccessTokens.scopes("read", "write"); this.refreshToken = TestOAuth2RefreshTokens.refreshToken(); } @@ -157,15 +158,7 @@ public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentExcept @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); ClientRegistration clientRegistration = this.clientRegistration.build(); Set scopes = clientRegistration.getScopes(); @@ -196,14 +189,7 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t @Test public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); @@ -214,14 +200,7 @@ public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorization @Test public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .build(); @@ -230,19 +209,17 @@ public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParameters this.tokenResponseClient.getTokenResponse(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); - assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), + param(OAuth2ParameterNames.CLIENT_SECRET, "secret") + ); + // @formatter:on } @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"not-bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("invalid-token-type-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); @@ -257,15 +234,7 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); @@ -276,14 +245,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasRe @Test public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasRequestedScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); Set scopes = clientRegistration.getScopes(); OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, @@ -295,15 +257,7 @@ public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessToke @Test public void getTokenResponseWhenRequestDoesNotIncludeScopeThenAccessTokenHasResponseScope() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); @@ -341,8 +295,7 @@ public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationExcep @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500)); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); @@ -356,8 +309,7 @@ public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationE @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400)); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); @@ -399,18 +351,11 @@ public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegal @Test public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -424,18 +369,11 @@ public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Excepti @Test public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -449,19 +387,11 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception @Test public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); - Converter> parametersConverter = mock( - Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -470,20 +400,13 @@ public void getTokenResponseWhenParametersConverterSetThenCalled() throws Except verify(parametersConverter).convert(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); - assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + assertThat(formParameters).contains(param("custom-parameter-name", "custom-parameter-value")); } @Test public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception { this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); @@ -491,7 +414,6 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom"); parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, "custom-token"); parameters.set(OAuth2ParameterNames.SCOPE, "one two"); - // The client_id parameter is omitted for testing purposes this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters); this.tokenResponseClient.getTokenResponse(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); @@ -499,28 +421,21 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP // @formatter:off assertThat(formParameters).contains( param(OAuth2ParameterNames.GRANT_TYPE, "custom"), + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), param(OAuth2ParameterNames.REFRESH_TOKEN, "custom-token"), - param(OAuth2ParameterNames.SCOPE, "one two")); + param(OAuth2ParameterNames.SCOPE, "one two") + ); // @formatter:on - assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID); } @Test public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); Set scopes = clientRegistration.getScopes(); OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken, scopes); - Converter> parametersConverter = mock( - Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -540,15 +455,25 @@ public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exce } @Test - public void getTokenResponseWhenRestClientSetThenCalled() { + public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception { + this.server.enqueue(MockResponses.json("access-token-response.json")); + ClientRegistration clientRegistration = this.clientRegistration.build(); + OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + Consumer> parametersCustomizer = mock(); // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; + DefaultOAuth2TokenRequestParametersConverter parametersConverter = + new DefaultOAuth2TokenRequestParametersConverter<>(); // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + parametersConverter.setParametersCustomizer(parametersCustomizer); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersCustomizer).accept(any()); + } + + @Test + public void getTokenResponseWhenRestClientSetThenCalled() { + this.server.enqueue(MockResponses.json("access-token-response.json")); RestClient restClient = RestClient.builder().messageConverters((messageConverters) -> { messageConverters.add(0, new FormHttpMessageConverter()); messageConverters.add(1, new OAuth2AccessTokenResponseHttpMessageConverter()); @@ -562,10 +487,6 @@ public void getTokenResponseWhenRestClientSetThenCalled() { verify(customClient).post(); } - private static MockResponse jsonResponse(String json) { - return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); - } - private static String param(String parameterName, String parameterValue) { return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClientTests.java index 1792a57e59f..23fe4828a78 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/RestClientTokenExchangeTokenResponseClientTests.java @@ -22,6 +22,7 @@ import java.time.Instant; import java.util.Collections; import java.util.Set; +import java.util.function.Consumer; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; @@ -34,6 +35,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.MockResponses; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -54,6 +56,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -85,14 +88,12 @@ public void setUp() throws IOException { this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); - // @formatter:off this.clientRegistration = TestClientRegistrations.clientCredentials() - .clientId("client-1") - .clientSecret("secret") - .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) - .tokenUri(tokenUri) - .scope("read", "write"); - // @formatter:on + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .tokenUri(tokenUri) + .scope("read", "write"); this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); this.actorToken = null; } @@ -158,15 +159,7 @@ public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentExcept @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); ClientRegistration clientRegistration = this.clientRegistration.build(); Set scopes = clientRegistration.getScopes(); @@ -199,15 +192,7 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t @Test public void getTokenResponseWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); this.subjectToken = TestJwts.jwt().build(); ClientRegistration clientRegistration = this.clientRegistration.build(); @@ -241,15 +226,7 @@ public void getTokenResponseWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() thr @Test public void getTokenResponseWhenActorTokenIsNotNullThenActorParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); this.actorToken = TestOAuth2AccessTokens.noScopes(); ClientRegistration clientRegistration = this.clientRegistration.build(); @@ -285,15 +262,7 @@ public void getTokenResponseWhenActorTokenIsNotNullThenActorParametersAreSent() @Test public void getTokenResponseWhenActorTokenIsJwtThenActorTokenTypeIsJwt() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); this.actorToken = TestJwts.jwt().build(); ClientRegistration clientRegistration = this.clientRegistration.build(); @@ -329,14 +298,7 @@ public void getTokenResponseWhenActorTokenIsJwtThenActorTokenTypeIsJwt() throws @Test public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, this.actorToken); @@ -347,14 +309,7 @@ public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorization @Test public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .build(); @@ -363,19 +318,17 @@ public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParameters this.tokenResponseClient.getTokenResponse(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); - assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), + param(OAuth2ParameterNames.CLIENT_SECRET, "secret") + ); + // @formatter:on } @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"not-bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("invalid-token-type-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, this.actorToken); @@ -390,15 +343,7 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, this.actorToken); @@ -409,14 +354,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasRe @Test public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, this.actorToken); @@ -440,8 +378,7 @@ public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationExcep @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500)); TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); // @formatter:off @@ -454,8 +391,7 @@ public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationE @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400)); TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); // @formatter:off @@ -496,18 +432,11 @@ public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegal @Test public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, this.actorToken); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -521,18 +450,11 @@ public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Excepti @Test public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, this.actorToken); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -546,18 +468,11 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception @Test public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, this.actorToken); - Converter> parametersConverter = mock(Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -566,20 +481,13 @@ public void getTokenResponseWhenParametersConverterSetThenCalled() throws Except verify(parametersConverter).convert(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); - assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + assertThat(formParameters).contains(param("custom-parameter-name", "custom-parameter-value")); } @Test public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception { this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, this.actorToken); @@ -587,7 +495,6 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom"); parameters.set(OAuth2ParameterNames.SCOPE, "one two"); parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN, "custom-token"); - // The client_id parameter is omitted for testing purposes this.tokenResponseClient.setParametersConverter((authorizationGrantRequest) -> parameters); this.tokenResponseClient.getTokenResponse(grantRequest); RecordedRequest recordedRequest = this.server.takeRequest(); @@ -595,27 +502,22 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP // @formatter:off assertThat(formParameters).contains( param(OAuth2ParameterNames.GRANT_TYPE, "custom"), - param(OAuth2ParameterNames.SCOPE, "one two"), - param(OAuth2ParameterNames.SUBJECT_TOKEN, "custom-token")); + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), + param(OAuth2ParameterNames.SUBJECT_TOKEN, "custom-token"), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, "one two") + ); // @formatter:on - assertThat(formParameters).doesNotContain(OAuth2ParameterNames.CLIENT_ID); } @Test public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration.build(); Set scopes = clientRegistration.getScopes(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, this.actorToken); - Converter> parametersConverter = mock(Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -637,16 +539,26 @@ public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exce } @Test - public void getTokenResponseWhenRestClientSetThenCalled() { + public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception { + this.server.enqueue(MockResponses.json("access-token-response.json")); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + Consumer> parametersCustomizer = mock(); // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; + DefaultOAuth2TokenRequestParametersConverter parametersConverter = + new DefaultOAuth2TokenRequestParametersConverter<>(); // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - RestClient customClient = mock(RestClient.class); + parametersConverter.setParametersCustomizer(parametersCustomizer); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(parametersCustomizer).accept(any()); + } + + @Test + public void getTokenResponseWhenRestClientSetThenCalled() { + this.server.enqueue(MockResponses.json("access-token-response.json")); + RestClient customClient = mock(); given(customClient.post()).willReturn(RestClient.builder().build().post()); this.tokenResponseClient.setRestClient(customClient); ClientRegistration clientRegistration = this.clientRegistration.build(); @@ -656,10 +568,6 @@ public void getTokenResponseWhenRestClientSetThenCalled() { verify(customClient).post(); } - private static MockResponse jsonResponse(String json) { - return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); - } - private static String param(String parameterName, String parameterValue) { return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java index cd7c1d31f58..5da4552bdfc 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.client.endpoint; +import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collections; @@ -26,7 +27,6 @@ import javax.crypto.spec.SecretKeySpec; import com.nimbusds.jose.jwk.JWK; -import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; import org.junit.jupiter.api.AfterEach; @@ -36,9 +36,8 @@ import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.security.oauth2.client.MockResponses; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; @@ -48,6 +47,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.jose.TestJwks; @@ -93,18 +93,7 @@ public void cleanup() throws Exception { @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"openid profile\",\n" - + " \"refresh_token\": \"refresh-token-1234\",\n" - + " \"custom_parameter_1\": \"custom-value-1\",\n" - + " \"custom_parameter_2\": \"custom-value-2\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-openid-profile-2.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient .getTokenResponse(authorizationCodeGrantRequest()) @@ -125,14 +114,7 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t @Test public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); // @formatter:off ClientRegistration clientRegistration = this.clientRegistration @@ -158,14 +140,7 @@ public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersA @Test public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); // @formatter:off ClientRegistration clientRegistration = this.clientRegistration @@ -194,9 +169,7 @@ private void configureJwtClientAuthenticationConverter(Function this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("unauthorized_client")) @@ -206,9 +179,7 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti // gh-5594 @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{}"; - this.server - .enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value())); + this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500)); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) .withMessageContaining("server_error"); @@ -216,14 +187,7 @@ public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationE @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + "\"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"not-bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("invalid-token-type-response.json")); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) .withMessageContaining("invalid_token_response"); @@ -231,15 +195,7 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + "\"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"openid profile\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-openid-profile.json")); this.clientRegistration.scope("openid", "profile", "email", "address"); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient .getTokenResponse(authorizationCodeGrantRequest()) @@ -249,14 +205,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessToke @Test public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseWithNoScopes() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.clientRegistration.scope("openid", "profile", "email", "address"); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient .getTokenResponse(authorizationCodeGrantRequest()) @@ -285,10 +234,6 @@ private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest(Client return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange); } - private MockResponse jsonResponse(String json) { - return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); - } - @Test public void setWebClientNullThenIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setWebClient(null)); @@ -296,18 +241,10 @@ public void setWebClientNullThenIllegalArgumentException() { @Test public void setCustomWebClientThenCustomWebClientIsUsed() { - WebClient customClient = mock(WebClient.class); + WebClient customClient = mock(); given(customClient.post()).willReturn(WebClient.builder().build().post()); this.tokenResponseClient.setWebClient(customClient); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"openid profile\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.clientRegistration.scope("openid", "profile", "email", "address"); OAuth2AccessTokenResponse response = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()) .block(); @@ -317,14 +254,7 @@ public void setCustomWebClientThenCustomWebClientIsUsed() { @Test public void getTokenResponseWhenOAuth2AuthorizationRequestContainsPkceParametersThenTokenRequestBodyShouldContainCodeVerifier() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(pkceAuthorizationCodeGrantRequest()).block(); String body = this.server.takeRequest().getBody().readUtf8(); assertThat(body).isEqualTo( @@ -379,20 +309,12 @@ public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { @Test public void convertWhenHeadersConverterAddedThenCalled() throws Exception { OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest(); - Converter addedHeadersConverter = mock(Converter.class); + Converter addedHeadersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(addedHeadersConverter.convert(request)).willReturn(headers); this.tokenResponseClient.addHeadersConverter(addedHeadersConverter); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"openid profile\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(addedHeadersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -406,20 +328,12 @@ public void convertWhenHeadersConverterAddedThenCalled() throws Exception { public void convertWhenHeadersConverterSetThenCalled() throws Exception { OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest(); ClientRegistration clientRegistration = request.getClientRegistration(); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); given(headersConverter.convert(request)).willReturn(headers); this.tokenResponseClient.setHeadersConverter(headersConverter); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"openid profile\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(headersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -440,23 +354,14 @@ public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { } @Test - public void convertWhenParametersConverterAddedThenCalled() throws Exception { + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest(); - Converter> addedParametersConverter = mock( - Converter.class); + Converter> addedParametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(addedParametersConverter.convert(request)).willReturn(parameters); this.tokenResponseClient.addParametersConverter(addedParametersConverter); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(addedParametersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -465,44 +370,55 @@ public void convertWhenParametersConverterAddedThenCalled() throws Exception { } @Test - public void convertWhenParametersConverterSetThenCalled() throws Exception { + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest(); - Converter> parametersConverter = mock( - Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(request)).willReturn(parameters); this.tokenResponseClient.setParametersConverter(parametersConverter); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(parametersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value"); } + @Test + public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception { + this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); + OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest(); + this.server.enqueue(MockResponses.json("access-token-response.json")); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom"); + parameters.set(OAuth2ParameterNames.CODE, "custom-code"); + parameters.set(OAuth2ParameterNames.REDIRECT_URI, "custom-uri"); + this.tokenResponseClient.setParametersConverter((grantRequest) -> parameters); + this.tokenResponseClient.getTokenResponse(request).block(); + String formParameters = this.server.takeRequest().getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, "custom"), + param(OAuth2ParameterNames.CLIENT_ID, "client-id"), + param(OAuth2ParameterNames.CODE, "custom-code"), + param(OAuth2ParameterNames.REDIRECT_URI, "custom-uri") + ); + // @formatter:on + } + // gh-10260 @Test public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() { - String accessTokenSuccessResponse = "{}"; - WebClientReactiveAuthorizationCodeTokenResponseClient customClient = new WebClientReactiveAuthorizationCodeTokenResponseClient(); - BodyExtractor, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class); + BodyExtractor, ReactiveHttpInputMessage> extractor = mock(); OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); given(extractor.extract(any(), any())).willReturn(Mono.just(response)); customClient.setBodyExtractor(extractor); - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(authorizationCodeGrantRequest()) .block(); assertThat(accessTokenResponse.getAccessToken()).isNotNull(); @@ -533,4 +449,8 @@ public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegal .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest).block()); } + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java index 380240b1892..994a375f56b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,11 +37,13 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.security.oauth2.client.MockResponses; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestKeys; @@ -88,15 +90,7 @@ public void cleanup() throws Exception { @Test public void getTokenResponseWhenHeaderThenSuccess() throws Exception { - // @formatter:off - enqueueJson("{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n" - + " \"scope\":\"create\"\n" - + "}"); - // @formatter:on + this.server.enqueue(MockResponses.json("access-token-response-create.json")); OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest( this.clientRegistration.build()); OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); @@ -112,15 +106,7 @@ public void getTokenResponseWhenHeaderThenSuccess() throws Exception { // gh-9610 @Test public void getTokenResponseWhenSpecialCharactersThenSuccessWithEncodedClientCredentials() throws Exception { - // @formatter:off - enqueueJson("{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n" - + " \"scope\":\"create\"\n" - + "}"); - // @formatter:on + this.server.enqueue(MockResponses.json("access-token-response-create.json")); String clientCredentialWithAnsiKeyboardSpecialCharacters = "~!@#$%^&*()_+{}|:\"<>?`-=[]\\;',./ "; OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest( this.clientRegistration.clientId(clientCredentialWithAnsiKeyboardSpecialCharacters) @@ -145,15 +131,7 @@ public void getTokenResponseWhenPostThenSuccess() throws Exception { ClientRegistration registration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .build(); - // @formatter:off - enqueueJson("{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n" - + " \"scope\":\"create\"\n" - + "}"); - // @formatter:on + this.server.enqueue(MockResponses.json("access-token-response-create.json")); OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); RecordedRequest actualRequest = this.server.takeRequest(); @@ -167,13 +145,7 @@ public void getTokenResponseWhenPostThenSuccess() throws Exception { @Test public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { - // @formatter:off - enqueueJson("{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}"); - // @formatter:on + this.server.enqueue(MockResponses.json("access-token-response.json")); // @formatter:off ClientRegistration clientRegistration = this.clientRegistration @@ -200,13 +172,7 @@ public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersA @Test public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { - // @formatter:off - enqueueJson("{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}"); - // @formatter:on + this.server.enqueue(MockResponses.json("access-token-response.json")); // @formatter:off ClientRegistration clientRegistration = this.clientRegistration @@ -237,14 +203,7 @@ private void configureJwtClientAuthenticationConverter(Function addedHeadersConverter = mock(Converter.class); + Converter addedHeadersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(addedHeadersConverter.convert(request)).willReturn(headers); this.client.addHeadersConverter(addedHeadersConverter); - // @formatter:off - enqueueJson("{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"); - // @formatter:on + this.server.enqueue(MockResponses.json("access-token-response.json")); this.client.getTokenResponse(request).block(); verify(addedHeadersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -347,19 +286,12 @@ public void convertWhenHeadersConverterSetThenCalled() throws Exception { OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest( this.clientRegistration.build()); ClientRegistration clientRegistration = request.getClientRegistration(); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); given(headersConverter.convert(request)).willReturn(headers); this.client.setHeadersConverter(headersConverter); - // @formatter:off - enqueueJson("{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"); - // @formatter:on + this.server.enqueue(MockResponses.json("access-token-response.json")); this.client.getTokenResponse(request).block(); verify(headersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -380,23 +312,15 @@ public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { } @Test - public void convertWhenParametersConverterAddedThenCalled() throws Exception { + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest( this.clientRegistration.build()); - Converter> addedParametersConverter = mock( - Converter.class); + Converter> addedParametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(addedParametersConverter.convert(request)).willReturn(parameters); this.client.addParametersConverter(addedParametersConverter); - // @formatter:off - enqueueJson("{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"); - // @formatter:on + this.server.enqueue(MockResponses.json("access-token-response.json")); this.client.getTokenResponse(request).block(); verify(addedParametersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -405,38 +329,51 @@ public void convertWhenParametersConverterAddedThenCalled() throws Exception { } @Test - public void convertWhenParametersConverterSetThenCalled() throws Exception { + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest( this.clientRegistration.build()); - Converter> parametersConverter = mock( - Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(request)).willReturn(parameters); this.client.setParametersConverter(parametersConverter); - // @formatter:off - enqueueJson("{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"); - // @formatter:on + this.server.enqueue(MockResponses.json("access-token-response.json")); this.client.getTokenResponse(request).block(); verify(parametersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value"); } + @Test + public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception { + this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration.build()); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom"); + parameters.set(OAuth2ParameterNames.SCOPE, "one two"); + this.client.setParametersConverter((grantRequest) -> parameters); + this.server.enqueue(MockResponses.json("access-token-response.json")); + this.client.getTokenResponse(request).block(); + String formParameters = this.server.takeRequest().getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, "custom"), + param(OAuth2ParameterNames.CLIENT_ID, "client-id"), + param(OAuth2ParameterNames.SCOPE, "one two") + ); + // @formatter:on + } + // gh-10260 @Test public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() { - enqueueJson("{}"); + this.server.enqueue(MockResponses.json("access-token-response.json")); WebClientReactiveClientCredentialsTokenResponseClient customClient = new WebClientReactiveClientCredentialsTokenResponseClient(); - BodyExtractor, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class); + BodyExtractor, ReactiveHttpInputMessage> extractor = mock(); OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); given(extractor.extract(any(), any())).willReturn(Mono.just(response)); @@ -474,4 +411,8 @@ public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegal .isThrownBy(() -> this.client.getTokenResponse(clientCredentialsGrantRequest).block()); } + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClientTests.java index b615c4924b7..21b7d2154e3 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.security.oauth2.client.endpoint; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; import java.util.Collections; import okhttp3.mockwebserver.MockResponse; @@ -30,6 +32,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.security.oauth2.client.MockResponses; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -37,6 +40,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.TestJwts; @@ -82,12 +86,10 @@ public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); - // @formatter:off this.clientRegistration = TestClientRegistrations.clientCredentials() - .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) - .tokenUri(tokenUri) - .scope("read", "write"); - // @formatter:on + .authorizationGrantType(AuthorizationGrantType.JWT_BEARER) + .tokenUri(tokenUri) + .scope("read", "write"); this.jwtAssertion = TestJwts.jwt().build(); } @@ -150,13 +152,8 @@ public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationE @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - // @formatter:off - String accessTokenResponse = "{\n" - + " \"error\": \"invalid_grant\"\n" - + "}\n"; - // @formatter:on ClientRegistration registration = this.clientRegistration.build(); - enqueueJson(accessTokenResponse); + this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400)); JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.client.getTokenResponse(request).block()) @@ -166,15 +163,8 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti @Test public void getTokenResponseWhenResponseIsNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - // @formatter:off - String accessTokenResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"not-bearer\",\n" - + " \"expires_in\": 3600\n" - + "}\n"; - // @formatter:on ClientRegistration registration = this.clientRegistration.build(); - enqueueJson(accessTokenResponse); + this.server.enqueue(MockResponses.json("invalid-token-type-response.json")); JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion); assertThatExceptionOfType(OAuth2AuthorizationException.class) .isThrownBy(() -> this.client.getTokenResponse(request).block()) @@ -185,10 +175,10 @@ public void getTokenResponseWhenResponseIsNotBearerTokenTypeThenThrowOAuth2Autho @Test public void getTokenResponseWhenWebClientSetThenCalled() { - WebClient customClient = mock(WebClient.class); + WebClient customClient = mock(); given(customClient.post()).willReturn(WebClient.builder().build().post()); this.client.setWebClient(customClient); - enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration registration = this.clientRegistration.build(); JwtBearerGrantRequest request = new JwtBearerGrantRequest(registration, this.jwtAssertion); this.client.getTokenResponse(request).block(); @@ -199,12 +189,12 @@ public void getTokenResponseWhenWebClientSetThenCalled() { public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); given(headersConverter.convert(request)).willReturn(headers); this.client.setHeadersConverter(headersConverter); - enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.client.getTokenResponse(request).block(); verify(headersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -216,12 +206,12 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - Converter addedHeadersConverter = mock(Converter.class); + Converter addedHeadersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(addedHeadersConverter.convert(request)).willReturn(headers); this.client.addHeadersConverter(addedHeadersConverter); - enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.client.getTokenResponse(request).block(); verify(addedHeadersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -243,16 +233,15 @@ public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { } @Test - public void convertWhenParametersConverterAddedThenCalled() throws Exception { + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - Converter> addedParametersConverter = mock( - Converter.class); + Converter> addedParametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(addedParametersConverter.convert(request)).willReturn(parameters); this.client.addParametersConverter(addedParametersConverter); - enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.client.getTokenResponse(request).block(); verify(addedParametersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -262,48 +251,62 @@ public void convertWhenParametersConverterAddedThenCalled() throws Exception { } @Test - public void convertWhenParametersConverterSetThenCalled() throws Exception { + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - Converter> parametersConverter = mock(Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(request)).willReturn(parameters); this.client.setParametersConverter(parametersConverter); - enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.client.getTokenResponse(request).block(); verify(parametersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value"); } + @Test + public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception { + this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom"); + parameters.set(OAuth2ParameterNames.ASSERTION, "custom-assertion"); + parameters.set(OAuth2ParameterNames.SCOPE, "one two"); + this.client.setParametersConverter((grantRequest) -> parameters); + this.server.enqueue(MockResponses.json("access-token-response.json")); + this.client.getTokenResponse(request).block(); + String formParameters = this.server.takeRequest().getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, "custom"), + param(OAuth2ParameterNames.CLIENT_ID, "client-id"), + param(OAuth2ParameterNames.SCOPE, "one two"), + param(OAuth2ParameterNames.ASSERTION, "custom-assertion") + ); + // @formatter:on + } + @Test public void getTokenResponseWhenBodyExtractorSetThenCalled() { - BodyExtractor, ReactiveHttpInputMessage> bodyExtractor = mock( - BodyExtractor.class); + BodyExtractor, ReactiveHttpInputMessage> bodyExtractor = mock(); OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); given(bodyExtractor.extract(any(), any())).willReturn(Mono.just(response)); ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); this.client.setBodyExtractor(bodyExtractor); - enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.client.getTokenResponse(request).block(); verify(bodyExtractor).extract(any(), any()); } @Test public void getTokenResponseWhenClientSecretBasicThenSuccess() throws Exception { - // @formatter:off - String accessTokenResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": 3600,\n" - + " \"scope\": \"read write\"" - + "}\n"; - // @formatter:on ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - enqueueJson(accessTokenResponse); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); assertThat(response).isNotNull(); assertThat(response.getAccessToken().getScopes()).containsExactly("read", "write"); @@ -317,18 +320,12 @@ public void getTokenResponseWhenClientSecretBasicThenSuccess() throws Exception @Test public void getTokenResponseWhenClientSecretPostThenSuccess() throws Exception { // @formatter:off - String accessTokenResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": 3600,\n" - + " \"scope\": \"read write\"" - + "}\n"; ClientRegistration clientRegistration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .build(); // @formatter:on JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - enqueueJson(accessTokenResponse); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); assertThat(response).isNotNull(); assertThat(response.getAccessToken().getScopes()).containsExactly("read", "write"); @@ -340,17 +337,9 @@ public void getTokenResponseWhenClientSecretPostThenSuccess() throws Exception { @Test public void getTokenResponseWhenResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { - // @formatter:off - String accessTokenResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": 3600,\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - enqueueJson(accessTokenResponse); + this.server.enqueue(MockResponses.json("access-token-response-read.json")); OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); assertThat(response).isNotNull(); assertThat(response.getAccessToken().getScopes()).containsExactly("read"); @@ -361,7 +350,7 @@ public void getTokenResponseWhenResponseDoesNotIncludeScopeThenReturnAccessToken throws Exception { ClientRegistration clientRegistration = this.clientRegistration.build(); JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); - enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.server.enqueue(MockResponses.json("access-token-response.json")); OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); assertThat(response).isNotNull(); assertThat(response.getAccessToken().getScopes()).isEmpty(); @@ -389,12 +378,6 @@ public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegal .isThrownBy(() -> this.client.getTokenResponse(jwtBearerGrantRequest).block()); } - private void enqueueJson(String body) { - MockResponse response = new MockResponse().setBody(body) - .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); - this.server.enqueue(response); - } - private void enqueueUnexpectedResponse() { // @formatter:off MockResponse response = new MockResponse() @@ -414,4 +397,8 @@ private void enqueueServerErrorResponse() { this.server.enqueue(response); } + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java index da8b2353193..e4a970f8e51 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.client.endpoint; +import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collections; @@ -37,12 +38,14 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.security.oauth2.client.MockResponses; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestKeys; @@ -101,14 +104,7 @@ public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() @Test public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseWithNoScope() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, @@ -135,15 +131,7 @@ public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAcce @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponse() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, @@ -171,14 +159,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessToke @Test public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistrationBuilder .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .build(); @@ -194,14 +175,7 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen @Test public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); // @formatter:off ClientRegistration clientRegistration = this.clientRegistrationBuilder @@ -229,14 +203,7 @@ public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersA @Test public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); // @formatter:off ClientRegistration clientRegistration = this.clientRegistrationBuilder @@ -267,14 +234,7 @@ private void configureJwtClientAuthenticationConverter(Function addedHeadersConverter = mock(Converter.class); + Converter addedHeadersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(addedHeadersConverter.convert(request)).willReturn(headers); this.tokenResponseClient.addHeadersConverter(addedHeadersConverter); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(addedHeadersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -389,20 +320,12 @@ public void convertWhenHeadersConverterSetThenCalled() throws Exception { OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(), this.username, this.password); ClientRegistration clientRegistration = request.getClientRegistration(); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); given(headersConverter.convert(request)).willReturn(headers); this.tokenResponseClient.setHeadersConverter(headersConverter); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(headersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -423,65 +346,75 @@ public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { } @Test - public void convertWhenParametersConverterAddedThenCalled() throws Exception { + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(), this.username, this.password); - Converter> addedParametersConverter = mock( - Converter.class); + Converter> addedParametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(addedParametersConverter.convert(request)).willReturn(parameters); this.tokenResponseClient.addParametersConverter(addedParametersConverter); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(addedParametersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); - assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=password", - "custom-parameter-name=custom-parameter-value"); + String formParameters = actualRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, "password"), + param("custom-parameter-name", "custom-parameter-value") + ); + // @formatter:on } @Test - public void convertWhenParametersConverterSetThenCalled() throws Exception { + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(), this.username, this.password); - Converter> parametersConverter = mock( - Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(request)).willReturn(parameters); this.tokenResponseClient.setParametersConverter(parametersConverter); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(parametersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value"); } + @Test + public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception { + this.clientRegistrationBuilder.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); + OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(), + this.username, this.password); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom"); + parameters.set(OAuth2ParameterNames.USERNAME, "user"); + parameters.set(OAuth2ParameterNames.PASSWORD, "password"); + parameters.set(OAuth2ParameterNames.SCOPE, "one two"); + this.tokenResponseClient.setParametersConverter((grantRequest) -> parameters); + this.server.enqueue(MockResponses.json("access-token-response.json")); + this.tokenResponseClient.getTokenResponse(request).block(); + String formParameters = this.server.takeRequest().getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, "custom"), + param(OAuth2ParameterNames.CLIENT_ID, "client-id"), + param(OAuth2ParameterNames.SCOPE, "one two"), + param(OAuth2ParameterNames.USERNAME, "user"), + param(OAuth2ParameterNames.PASSWORD, "password") + ); + // @formatter:on + } + // gh-10260 @Test public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() { - String accessTokenSuccessResponse = "{}"; - WebClientReactivePasswordTokenResponseClient customClient = new WebClientReactivePasswordTokenResponseClient(); - BodyExtractor, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class); + BodyExtractor, ReactiveHttpInputMessage> extractor = mock(); OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); given(extractor.extract(any(), any())).willReturn(Mono.just(response)); @@ -491,11 +424,15 @@ public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenRespon OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, this.username, this.password); - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(passwordGrantRequest).block(); assertThat(accessTokenResponse.getAccessToken()).isNotNull(); } + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java index 204080be82a..068a27d6e1e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.client.endpoint; +import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collections; @@ -37,6 +38,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.security.oauth2.client.MockResponses; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; @@ -46,6 +48,7 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestKeys; @@ -105,14 +108,7 @@ public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); @@ -139,14 +135,7 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t @Test public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistrationBuilder .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .build(); @@ -162,14 +151,7 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen @Test public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); // @formatter:off ClientRegistration clientRegistration = this.clientRegistrationBuilder @@ -197,14 +179,7 @@ public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersA @Test public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); // @formatter:off ClientRegistration clientRegistration = this.clientRegistrationBuilder @@ -235,14 +210,7 @@ private void configureJwtClientAuthenticationConverter(Function addedHeadersConverter = mock(Converter.class); + Converter addedHeadersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(addedHeadersConverter.convert(request)).willReturn(headers); this.tokenResponseClient.addHeadersConverter(addedHeadersConverter); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(addedHeadersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -354,24 +293,16 @@ public void convertWhenHeadersConverterAddedThenCalled() throws Exception { // gh-10130 @Test - public void convertWhenHeadersConverterSetThenCalled() throws Exception { + public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); ClientRegistration clientRegistration = request.getClientRegistration(); - Converter headersConverter1 = mock(Converter.class); + Converter headersConverter1 = mock(); HttpHeaders headers = new HttpHeaders(); headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); given(headersConverter1.convert(request)).willReturn(headers); this.tokenResponseClient.setHeadersConverter(headersConverter1); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(headersConverter1).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -392,24 +323,15 @@ public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { } @Test - public void convertWhenParametersConverterAddedThenCalled() throws Exception { + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); - Converter> addedParametersConverter = mock( - Converter.class); + Converter> addedParametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(addedParametersConverter.convert(request)).willReturn(parameters); this.tokenResponseClient.addParametersConverter(addedParametersConverter); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(addedParametersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); @@ -418,39 +340,51 @@ public void convertWhenParametersConverterAddedThenCalled() throws Exception { } @Test - public void convertWhenParametersConverterSetThenCalled() throws Exception { + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); - Converter> parametersConverter = mock( - Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(request)).willReturn(parameters); this.tokenResponseClient.setParametersConverter(parametersConverter); - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" - + " \"token_type\":\"bearer\",\n" - + " \"expires_in\":3600,\n" - + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" - + "}"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); this.tokenResponseClient.getTokenResponse(request).block(); verify(parametersConverter).convert(request); RecordedRequest actualRequest = this.server.takeRequest(); assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value"); } + @Test + public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception { + this.clientRegistrationBuilder.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); + OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom"); + parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, "custom-token"); + parameters.set(OAuth2ParameterNames.SCOPE, "one two"); + this.tokenResponseClient.setParametersConverter((grantRequest) -> parameters); + this.server.enqueue(MockResponses.json("access-token-response.json")); + this.tokenResponseClient.getTokenResponse(request).block(); + String formParameters = this.server.takeRequest().getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, "custom"), + param(OAuth2ParameterNames.CLIENT_ID, "client-id"), + param(OAuth2ParameterNames.REFRESH_TOKEN, "custom-token"), + param(OAuth2ParameterNames.SCOPE, "one two") + ); + // @formatter:on + } + // gh-10260 @Test public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() { - String accessTokenSuccessResponse = "{}"; - WebClientReactiveRefreshTokenTokenResponseClient customClient = new WebClientReactiveRefreshTokenTokenResponseClient(); - BodyExtractor, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class); + BodyExtractor, ReactiveHttpInputMessage> extractor = mock(); OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); given(extractor.extract(any(), any())).willReturn(Mono.just(response)); @@ -459,7 +393,7 @@ public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenRespon OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(refreshTokenGrantRequest).block(); assertThat(accessTokenResponse.getAccessToken()).isNotNull(); @@ -489,4 +423,8 @@ public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegal .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()); } + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java index 2e3d32bb170..59e9e94df48 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java @@ -35,6 +35,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.security.oauth2.client.MockResponses; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -89,14 +90,12 @@ public void setUp() throws IOException { this.server = new MockWebServer(); this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); - // @formatter:off this.clientRegistration = TestClientRegistrations.clientCredentials() - .clientId("client-1") - .clientSecret("secret") - .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) - .tokenUri(tokenUri) - .scope("read", "write"); - // @formatter:on + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .tokenUri(tokenUri) + .scope("read", "write"); this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); this.actorToken = null; } @@ -171,15 +170,7 @@ public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentExcept @Test public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); @@ -210,15 +201,7 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t @Test public void getTokenResponseWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); this.subjectToken = TestJwts.jwt().build(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), @@ -250,15 +233,7 @@ public void getTokenResponseWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() thr @Test public void getTokenResponseWhenActorTokenIsNotNullThenActorParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); this.actorToken = TestOAuth2AccessTokens.noScopes(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), @@ -292,15 +267,7 @@ public void getTokenResponseWhenActorTokenIsNotNullThenActorParametersAreSent() @Test public void getTokenResponseWhenActorTokenIsJwtThenActorTokenTypeIsJwt() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read write\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read-write.json")); Instant expiresAtBefore = Instant.now().plusSeconds(3600); this.actorToken = TestJwts.jwt().build(); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), @@ -334,14 +301,7 @@ public void getTokenResponseWhenActorTokenIsJwtThenActorTokenTypeIsJwt() throws @Test public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); this.tokenResponseClient.getTokenResponse(grantRequest).block(); @@ -351,14 +311,7 @@ public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorization @Test public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); ClientRegistration clientRegistration = this.clientRegistration .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .build(); @@ -367,19 +320,17 @@ public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParameters this.tokenResponseClient.getTokenResponse(grantRequest).block(); RecordedRequest recordedRequest = this.server.takeRequest(); String formParameters = recordedRequest.getBody().readUtf8(); - assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), + param(OAuth2ParameterNames.CLIENT_SECRET, "secret") + ); + // @formatter:on } @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"not-bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("invalid-token-type-response.json")); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); // @formatter:off @@ -393,15 +344,7 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu @Test public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\",\n" - + " \"scope\": \"read\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response-read.json")); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); @@ -411,14 +354,7 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasRe @Test public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); @@ -441,8 +377,7 @@ public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationExcep @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + this.server.enqueue(MockResponses.json("server-error-response.json").setResponseCode(500)); TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); // @formatter:off @@ -455,8 +390,7 @@ public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationE @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { - String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; - this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + this.server.enqueue(MockResponses.json("invalid-grant-response.json").setResponseCode(400)); TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); // @formatter:off @@ -497,17 +431,10 @@ public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegal @Test public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -521,17 +448,10 @@ public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Excepti @Test public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); - Converter headersConverter = mock(Converter.class); + Converter headersConverter = mock(); HttpHeaders headers = new HttpHeaders(); headers.put("custom-header-name", Collections.singletonList("custom-header-value")); given(headersConverter.convert(grantRequest)).willReturn(headers); @@ -545,17 +465,10 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception @Test public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.server.enqueue(MockResponses.json("access-token-response.json")); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); - Converter> parametersConverter = mock(Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -568,18 +481,34 @@ public void getTokenResponseWhenParametersConverterSetThenCalled() throws Except } @Test - public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { + public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultParameters() throws Exception { + this.clientRegistration.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom"); + parameters.set(OAuth2ParameterNames.SCOPE, "one two"); + parameters.set(OAuth2ParameterNames.SUBJECT_TOKEN, "custom-token"); + this.tokenResponseClient.setParametersConverter((request) -> parameters); + this.server.enqueue(MockResponses.json("access-token-response.json")); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + String formParameters = this.server.takeRequest().getBody().readUtf8(); // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, "custom"), + param(OAuth2ParameterNames.CLIENT_ID, "client-1"), + param(OAuth2ParameterNames.SCOPE, "one two"), + param(OAuth2ParameterNames.SUBJECT_TOKEN, "custom-token") + ); // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + } + + @Test + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { + this.server.enqueue(MockResponses.json("access-token-response.json")); TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), this.subjectToken, this.actorToken); - Converter> parametersConverter = mock(Converter.class); + Converter> parametersConverter = mock(); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.add("custom-parameter-name", "custom-parameter-value"); given(parametersConverter.convert(grantRequest)).willReturn(parameters); @@ -602,16 +531,8 @@ public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exce @Test public void getTokenResponseWhenBodyExtractorSetThenCalled() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - BodyExtractor, ReactiveHttpInputMessage> bodyExtractor = mock( - BodyExtractor.class); + this.server.enqueue(MockResponses.json("access-token-response.json")); + BodyExtractor, ReactiveHttpInputMessage> bodyExtractor = mock(); OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); given(bodyExtractor.extract(any(ReactiveHttpInputMessage.class), any(BodyExtractor.Context.class))) .willReturn(Mono.just(response)); @@ -625,15 +546,8 @@ public void getTokenResponseWhenBodyExtractorSetThenCalled() { @Test public void getTokenResponseWhenWebClientSetThenCalled() { - // @formatter:off - String accessTokenSuccessResponse = "{\n" - + " \"access_token\": \"access-token-1234\",\n" - + " \"token_type\": \"bearer\",\n" - + " \"expires_in\": \"3600\"\n" - + "}\n"; - // @formatter:on - this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - WebClient customClient = mock(WebClient.class); + this.server.enqueue(MockResponses.json("access-token-response.json")); + WebClient customClient = mock(); given(customClient.post()).willReturn(WebClient.builder().build().post()); this.tokenResponseClient.setWebClient(customClient); ClientRegistration clientRegistration = this.clientRegistration.build(); @@ -643,10 +557,6 @@ public void getTokenResponseWhenWebClientSetThenCalled() { verify(customClient).post(); } - private MockResponse jsonResponse(String json) { - return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); - } - private static String param(String parameterName, String parameterValue) { return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); } diff --git a/oauth2/oauth2-client/src/test/resources/access-token-response-create.json b/oauth2/oauth2-client/src/test/resources/access-token-response-create.json new file mode 100644 index 00000000000..55080cbecbc --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/access-token-response-create.json @@ -0,0 +1,6 @@ +{ + "access_token": "access-token-1234", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "create" +} \ No newline at end of file diff --git a/oauth2/oauth2-client/src/test/resources/access-token-response-openid-profile-2.json b/oauth2/oauth2-client/src/test/resources/access-token-response-openid-profile-2.json new file mode 100644 index 00000000000..d836b7eff59 --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/access-token-response-openid-profile-2.json @@ -0,0 +1,9 @@ +{ + "access_token": "access-token-1234", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "openid profile", + "refresh_token": "refresh-token-1234", + "custom_parameter_1": "custom-value-1", + "custom_parameter_2": "custom-value-2" +} \ No newline at end of file diff --git a/oauth2/oauth2-client/src/test/resources/access-token-response-openid-profile.json b/oauth2/oauth2-client/src/test/resources/access-token-response-openid-profile.json new file mode 100644 index 00000000000..992439eefc5 --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/access-token-response-openid-profile.json @@ -0,0 +1,6 @@ +{ + "access_token": "access-token-1234", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "openid profile" +} \ No newline at end of file diff --git a/oauth2/oauth2-client/src/test/resources/access-token-response-read-write.json b/oauth2/oauth2-client/src/test/resources/access-token-response-read-write.json new file mode 100644 index 00000000000..840de942adc --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/access-token-response-read-write.json @@ -0,0 +1,6 @@ +{ + "access_token": "access-token-1234", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "read write" +} \ No newline at end of file diff --git a/oauth2/oauth2-client/src/test/resources/access-token-response-read.json b/oauth2/oauth2-client/src/test/resources/access-token-response-read.json new file mode 100644 index 00000000000..fc2219f16a1 --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/access-token-response-read.json @@ -0,0 +1,6 @@ +{ + "access_token": "access-token-1234", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "read" +} \ No newline at end of file diff --git a/oauth2/oauth2-client/src/test/resources/access-token-response.json b/oauth2/oauth2-client/src/test/resources/access-token-response.json index 78c6dbd77b8..934c77b1b7a 100644 --- a/oauth2/oauth2-client/src/test/resources/access-token-response.json +++ b/oauth2/oauth2-client/src/test/resources/access-token-response.json @@ -1,5 +1,5 @@ { - "access_token": "token", + "access_token": "access-token-1234", "token_type": "Bearer", "expires_in": 3600 } \ No newline at end of file diff --git a/oauth2/oauth2-client/src/test/resources/invalid-grant-response.json b/oauth2/oauth2-client/src/test/resources/invalid-grant-response.json new file mode 100644 index 00000000000..5e2adcd35f0 --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/invalid-grant-response.json @@ -0,0 +1,4 @@ +{ + "error": "invalid_grant", + "error_description": "Invalid grant" +} \ No newline at end of file diff --git a/oauth2/oauth2-client/src/test/resources/invalid-token-type-response.json b/oauth2/oauth2-client/src/test/resources/invalid-token-type-response.json new file mode 100644 index 00000000000..c1fda00390f --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/invalid-token-type-response.json @@ -0,0 +1,5 @@ +{ + "access_token": "access-token-1234", + "token_type": "not-bearer", + "expires_in": 3600 +} \ No newline at end of file diff --git a/oauth2/oauth2-client/src/test/resources/server-error-response.json b/oauth2/oauth2-client/src/test/resources/server-error-response.json new file mode 100644 index 00000000000..6a3c3263741 --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/server-error-response.json @@ -0,0 +1,4 @@ +{ + "error": "server_error", + "error_description": "A server error occurred" +} \ No newline at end of file diff --git a/oauth2/oauth2-client/src/test/resources/unauthorized-client-response.json b/oauth2/oauth2-client/src/test/resources/unauthorized-client-response.json new file mode 100644 index 00000000000..e955ec4c4c3 --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/unauthorized-client-response.json @@ -0,0 +1,4 @@ +{ + "error": "unauthorized_client", + "error_description": "Unauthorized client" +} \ No newline at end of file