Skip to content

Commit

Permalink
Add headersCustomizer and parametersCustomizer
Browse files Browse the repository at this point in the history
  • Loading branch information
sjohnr committed Sep 27, 2024
1 parent 6c67d10 commit ea94167
Show file tree
Hide file tree
Showing 8 changed files with 377 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.security.oauth2.client.endpoint;

import java.util.function.Consumer;

import reactor.core.publisher.Mono;

import org.springframework.core.convert.converter.Converter;
Expand All @@ -24,6 +26,7 @@
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
Expand Down Expand Up @@ -66,10 +69,14 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T

private Converter<T, HttpHeaders> headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>();

private final Converter<T, MultiValueMap<String, String>> defaultParametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();
private Consumer<HttpHeaders> headersCustomizer = (headers) -> {
};

private Converter<T, MultiValueMap<String, String>> parametersConverter = this::createParameters;

private Consumer<MultiValueMap<String, String>> parametersCustomizer = (parameters) -> {
};

private BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors
.oauth2AccessTokenResponse();

Expand Down Expand Up @@ -107,13 +114,18 @@ private void validateClientAuthenticationMethod(T grantRequest) {

private RequestHeadersSpec<?> populateRequest(T grantRequest) {
MultiValueMap<String, String> parameters = this.parametersConverter.convert(grantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
this.parametersCustomizer.accept(parameters);
return this.webClient.post()
.uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri())
.headers((headers) -> {
HttpHeaders headersToAdd = this.headersConverter.convert(grantRequest);
if (headersToAdd != null) {
headers.addAll(headersToAdd);
}
this.headersCustomizer.accept(headers);
})
.body(BodyInserters.fromFormData(parameters));
}
Expand All @@ -126,7 +138,17 @@ private RequestHeadersSpec<?> populateRequest(T grantRequest) {
* Token Request body
*/
MultiValueMap<String, String> createParameters(T grantRequest) {
return this.defaultParametersConverter.convert(grantRequest);
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue());
if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC
.equals(clientRegistration.getClientAuthenticationMethod())) {
parameters.set(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) {
parameters.set(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
return parameters;
}

/**
Expand Down Expand Up @@ -182,48 +204,39 @@ public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter
this.requestEntityConverter = this::populateRequest;
}

/**
* Sets the {@link Consumer} used for customizing all of the OAuth 2.0 Access Token
* headers, which allows for headers to be added, overwritten or removed.
* @param headersCustomizer the {@link Consumer} to customize the headers
* @since 6.4
*/
public final void setHeadersCustomizer(Consumer<HttpHeaders> headersCustomizer) {
Assert.notNull(headersCustomizer, "headersCustomizer cannot be null");
this.headersCustomizer = headersCustomizer;
}

/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
* used in the OAuth 2.0 Access Token Request body.
* <p>
* For backwards compatibility with Spring Security 6.3 (and earlier), this method
* ensures that default parameters for this particular grant type are provided if the
* given parameters converter does not supply them. In order to fully override or omit
* parameters, supply this method with an instance of
* {@link DefaultOAuth2TokenRequestParametersConverter} via
* {@link DefaultOAuth2TokenRequestParametersConverter#of(Converter)} and only the
* returned parameters will be provided.
* @param parametersConverter the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link MultiValueMap}
* @since 5.6
* @see DefaultOAuth2TokenRequestParametersConverter#of(Converter)
*/
public final void setParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
// Allow opting into new behavior of fully overriding parameter values when
// user provides instance of DefaultOAuth2TokenRequestParametersConverter.
if (parametersConverter instanceof DefaultOAuth2TokenRequestParametersConverter) {
this.parametersConverter = parametersConverter;
}
else {
// For backwards compatibility with 6.3, ensure default parameters are always
// populated but allow parameter values to be overridden if provided.
// TODO: Remove in Spring Security 7
Converter<T, MultiValueMap<String, String>> defaultParametersConverter = this::createParameters;
this.parametersConverter = (authorizationGrantRequest) -> {
MultiValueMap<String, String> parameters = defaultParametersConverter
.convert(authorizationGrantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
MultiValueMap<String, String> parametersToSet = parametersConverter.convert(authorizationGrantRequest);
if (parametersToSet != null) {
parameters.putAll(parametersToSet);
}
return parameters;
};
}
Converter<T, MultiValueMap<String, String>> defaultParametersConverter = this::createParameters;
this.parametersConverter = (authorizationGrantRequest) -> {
MultiValueMap<String, String> parameters = defaultParametersConverter.convert(authorizationGrantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
MultiValueMap<String, String> parametersToSet = parametersConverter.convert(authorizationGrantRequest);
if (parametersToSet != null) {
parameters.putAll(parametersToSet);
}
return parameters;
};
this.requestEntityConverter = this::populateRequest;
}

Expand Down Expand Up @@ -254,6 +267,17 @@ public final void addParametersConverter(Converter<T, MultiValueMap<String, Stri
this.requestEntityConverter = this::populateRequest;
}

/**
* Sets the {@link Consumer} used for customizing all of the OAuth 2.0 Access Token
* parameters, which allows for parameters to be added, overwritten or removed.
* @param parametersCustomizer the {@link Consumer} to customize the parameters
* @since 6.4
*/
public final void setParametersCustomizer(Consumer<MultiValueMap<String, String>> parametersCustomizer) {
Assert.notNull(parametersCustomizer, "parametersCustomizer cannot be null");
this.parametersCustomizer = parametersCustomizer;
}

/**
* Sets the {@link BodyExtractor} that will be used to decode the
* {@link OAuth2AccessTokenResponse}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;

import javax.crypto.spec.SecretKeySpec;
Expand Down Expand Up @@ -428,6 +429,30 @@ public void convertWhenHeadersConverterSetThenCalled() throws Exception {
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
}

@Test
public void setHeadersCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setHeadersCustomizer(null))
.withMessage("headersCustomizer cannot be null");
}

@Test
public void getTokenResponseWhenHeadersCustomizerSetThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Consumer<HttpHeaders> headersCustomizer = mock(Consumer.class);
this.tokenResponseClient.setHeadersCustomizer(headersCustomizer);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(headersCustomizer).accept(any(HttpHeaders.class));
}

@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null))
Expand All @@ -441,7 +466,7 @@ public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
}

@Test
public void convertWhenParametersConverterAddedThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> addedParametersConverter = mock(
Converter.class);
Expand All @@ -466,7 +491,7 @@ public void convertWhenParametersConverterAddedThenCalled() throws Exception {
}

@Test
public void convertWhenParametersConverterSetThenCalled() throws Exception {
public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> parametersConverter = mock(
Converter.class);
Expand Down Expand Up @@ -506,11 +531,35 @@ public void getTokenResponseWhenParametersConverterSetThenAbleToOverrideDefaultP
parameters.set(OAuth2ParameterNames.GRANT_TYPE, "custom");
parameters.set(OAuth2ParameterNames.CODE, "custom-code");
parameters.set(OAuth2ParameterNames.REDIRECT_URI, "custom-uri");
// The client_id parameter is omitted for testing purposes
this.tokenResponseClient.setParametersConverter((grantRequest) -> parameters);
this.tokenResponseClient.getTokenResponse(request).block();
String body = this.server.takeRequest().getBody().readUtf8();
assertThat(body).contains("grant_type=custom", "code=custom-code", "redirect_uri=custom-uri");
assertThat(body).contains("grant_type=custom", "client_id=client-id", "code=custom-code",
"redirect_uri=custom-uri");
}

@Test
public void setParametersCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setParametersCustomizer(null))
.withMessage("parametersCustomizer cannot be null");
}

@Test
public void getTokenResponseWhenParametersCustomizerSetThenCalled() throws Exception {
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
Consumer<MultiValueMap<String, String>> parametersCustomizer = mock(Consumer.class);
this.tokenResponseClient.setParametersCustomizer(parametersCustomizer);
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(request).block();
verify(parametersCustomizer).accept(any());
}

// gh-10260
Expand Down
Loading

0 comments on commit ea94167

Please sign in to comment.