Skip to content

Commit

Permalink
Align RestClient parameters with WebClient
Browse files Browse the repository at this point in the history
  • Loading branch information
sjohnr committed Oct 1, 2024
1 parent b4d5ba7 commit e10cda0
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 278 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package org.springframework.security.oauth2.client.endpoint;

import java.util.function.Consumer;

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.converter.FormHttpMessageConverter;
Expand All @@ -27,7 +25,6 @@
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
Expand Down Expand Up @@ -77,13 +74,7 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient<T extend

private Converter<T, HttpHeaders> headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>();

private Consumer<HttpHeaders> headersCustomizer = (headers) -> {
};

private Converter<T, MultiValueMap<String, String>> parametersConverter = this::createParameters;

private Consumer<MultiValueMap<String, String>> parametersCustomizer = (parameters) -> {
};
private Converter<T, MultiValueMap<String, String>> parametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();

AbstractRestClientOAuth2AccessTokenResponseClient() {
}
Expand Down Expand Up @@ -136,40 +127,18 @@ private RequestHeadersSpec<?> populateRequest(T grantRequest) {
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
this.parametersCustomizer.accept(parameters);

return this.restClient.post()
.uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri())
.headers((headers) -> {
HttpHeaders headersToAdd = this.headersConverter.convert(grantRequest);
if (headersToAdd != null) {
headers.addAll(headersToAdd);
}
this.headersCustomizer.accept(headers);
})
.body(parameters);
}

/**
* Returns a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body.
* @param grantRequest the authorization grant request
* @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body
*/
MultiValueMap<String, String> createParameters(T grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue());
if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC
.equals(clientRegistration.getClientAuthenticationMethod())) {
parameters.set(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) {
parameters.set(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
return parameters;
}

/**
* Sets the {@link RestClient} used when requesting the OAuth 2.0 Access Token
* Response.
Expand Down Expand Up @@ -221,17 +190,6 @@ public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter
this.requestEntityConverter = this::populateRequest;
}

/**
* Sets the {@link Consumer} used for customizing all of the OAuth 2.0 Access Token
* headers, which allows for headers to be added, overwritten or removed.
* @param headersCustomizer the {@link Consumer} to customize the headers
* @since 6.4
*/
public final void setHeadersCustomizer(Consumer<HttpHeaders> headersCustomizer) {
Assert.notNull(headersCustomizer, "headersCustomizer cannot be null");
this.headersCustomizer = headersCustomizer;
}

/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
Expand All @@ -241,18 +199,24 @@ public final void setHeadersCustomizer(Consumer<HttpHeaders> headersCustomizer)
*/
public final void setParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
Converter<T, MultiValueMap<String, String>> defaultParametersConverter = this::createParameters;
this.parametersConverter = (authorizationGrantRequest) -> {
MultiValueMap<String, String> parameters = defaultParametersConverter.convert(authorizationGrantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
MultiValueMap<String, String> parametersToSet = parametersConverter.convert(authorizationGrantRequest);
if (parametersToSet != null) {
parameters.putAll(parametersToSet);
}
return parameters;
};
if (parametersConverter instanceof DefaultOAuth2TokenRequestParametersConverter) {
this.parametersConverter = parametersConverter;
}
else {
Converter<T, MultiValueMap<String, String>> defaultParametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();
this.parametersConverter = (authorizationGrantRequest) -> {
MultiValueMap<String, String> parameters = defaultParametersConverter
.convert(authorizationGrantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
MultiValueMap<String, String> parametersToSet = parametersConverter.convert(authorizationGrantRequest);
if (parametersToSet != null) {
parameters.putAll(parametersToSet);
}
return parameters;
};
}
this.requestEntityConverter = this::populateRequest;
}

Expand Down Expand Up @@ -282,15 +246,4 @@ public final void addParametersConverter(Converter<T, MultiValueMap<String, Stri
this.requestEntityConverter = this::populateRequest;
}

/**
* Sets the {@link Consumer} used for customizing all of the OAuth 2.0 Access Token
* parameters, which allows for parameters to be added, overwritten or removed.
* @param parametersCustomizer the {@link Consumer} to customize the parameters
* @since 6.4
*/
public final void setParametersCustomizer(Consumer<MultiValueMap<String, String>> parametersCustomizer) {
Assert.notNull(parametersCustomizer, "parametersCustomizer cannot be null");
this.parametersCustomizer = parametersCustomizer;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

/**
Expand All @@ -43,10 +44,9 @@
public final class RestClientAuthorizationCodeTokenResponseClient
extends AbstractRestClientOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {

@Override
MultiValueMap<String, String> createParameters(OAuth2AuthorizationCodeGrantRequest grantRequest) {
static MultiValueMap<String, String> createParameters(OAuth2AuthorizationCodeGrantRequest grantRequest) {
OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange();
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode());
String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri();
if (redirectUri != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

Expand All @@ -42,10 +43,9 @@
public final class RestClientClientCredentialsTokenResponseClient
extends AbstractRestClientOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {

@Override
MultiValueMap<String, String> createParameters(OAuth2ClientCredentialsGrantRequest grantRequest) {
static MultiValueMap<String, String> createParameters(OAuth2ClientCredentialsGrantRequest grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

Expand All @@ -40,10 +41,9 @@
public final class RestClientJwtBearerTokenResponseClient
extends AbstractRestClientOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> {

@Override
MultiValueMap<String, String> createParameters(JwtBearerGrantRequest grantRequest) {
static MultiValueMap<String, String> createParameters(JwtBearerGrantRequest grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

Expand All @@ -43,17 +44,6 @@ public OAuth2AccessTokenResponse getTokenResponse(OAuth2RefreshTokenGrantRequest
return populateTokenResponse(grantRequest, accessTokenResponse);
}

@Override
MultiValueMap<String, String> createParameters(OAuth2RefreshTokenGrantRequest grantRequest) {
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
if (!CollectionUtils.isEmpty(grantRequest.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(grantRequest.getScopes(), " "));
}
parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, grantRequest.getRefreshToken().getTokenValue());
return parameters;
}

private OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest,
OAuth2AccessTokenResponse accessTokenResponse) {
if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())
Expand All @@ -72,4 +62,14 @@ private OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantR
return tokenResponseBuilder.build();
}

static MultiValueMap<String, String> createParameters(OAuth2RefreshTokenGrantRequest grantRequest) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
if (!CollectionUtils.isEmpty(grantRequest.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(grantRequest.getScopes(), " "));
}
parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, grantRequest.getRefreshToken().getTokenValue());
return parameters;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

Expand All @@ -47,10 +48,9 @@ public final class RestClientTokenExchangeTokenResponseClient

private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt";

@Override
MultiValueMap<String, String> createParameters(TokenExchangeGrantRequest grantRequest) {
static MultiValueMap<String, String> createParameters(TokenExchangeGrantRequest grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,6 @@ public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
// @formatter:on
}

@Test
public void setHeadersCustomizerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setHeadersCustomizer(null))
.withMessage("headersCustomizer cannot be null");
// @formatter:on
}

@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
Expand All @@ -164,15 +155,6 @@ public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
// @formatter:on
}

@Test
public void setParametersCustomizerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setParametersCustomizer(null))
.withMessage("parametersCustomizer cannot be null");
// @formatter:on
}

@Test
public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() {
// @formatter:off
Expand Down Expand Up @@ -439,25 +421,6 @@ public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception
assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value");
}

@Test
public void getTokenResponseWhenHeadersCustomizerSetThenCalled() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
Consumer<HttpHeaders> headersCustomizer = mock(Consumer.class);
this.tokenResponseClient.setHeadersCustomizer(headersCustomizer);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(headersCustomizer).accept(any(HttpHeaders.class));
}

@Test
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
// @formatter:off
Expand Down Expand Up @@ -562,7 +525,9 @@ public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Excep
OAuth2AuthorizationCodeGrantRequest grantRequest = new OAuth2AuthorizationCodeGrantRequest(clientRegistration,
this.authorizationExchange);
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock(Consumer.class);
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
DefaultOAuth2TokenRequestParametersConverter<OAuth2AuthorizationCodeGrantRequest> parametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();
parametersConverter.setParametersCustomizer(parametersCustomizer);
this.tokenResponseClient.setParametersConverter(parametersConverter);
this.tokenResponseClient.getTokenResponse(grantRequest);
verify(parametersCustomizer).accept(any());
}
Expand Down
Loading

0 comments on commit e10cda0

Please sign in to comment.