Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow access token request parameters to override defaults #15339

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
Expand Down Expand Up @@ -75,7 +74,7 @@ public abstract class AbstractRestClientOAuth2AccessTokenResponseClient<T extend

private Converter<T, HttpHeaders> headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>();

private Converter<T, MultiValueMap<String, String>> parametersConverter = this::createParameters;
private Converter<T, MultiValueMap<String, String>> parametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();

AbstractRestClientOAuth2AccessTokenResponseClient() {
}
Expand Down Expand Up @@ -124,6 +123,11 @@ private void validateClientAuthenticationMethod(T grantRequest) {
}

private RequestHeadersSpec<?> populateRequest(T grantRequest) {
MultiValueMap<String, String> parameters = this.parametersConverter.convert(grantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}

return this.restClient.post()
.uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri())
.headers((headers) -> {
Expand All @@ -132,28 +136,7 @@ private RequestHeadersSpec<?> populateRequest(T grantRequest) {
headers.addAll(headersToAdd);
}
})
.body(this.parametersConverter.convert(grantRequest));
}

/**
* Returns a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body.
* @param grantRequest the authorization grant request
* @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body
*/
MultiValueMap<String, String> createParameters(T grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue());
if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC
.equals(clientRegistration.getClientAuthenticationMethod())) {
parameters.set(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) {
parameters.set(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
return parameters;
.body(parameters);
}

/**
Expand Down Expand Up @@ -216,7 +199,21 @@ public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter
*/
public final void setParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
this.parametersConverter = parametersConverter;
if (parametersConverter instanceof DefaultOAuth2TokenRequestParametersConverter) {
this.parametersConverter = parametersConverter;
}
else {
Converter<T, MultiValueMap<String, String>> defaultParametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();
this.parametersConverter = (authorizationGrantRequest) -> {
MultiValueMap<String, String> parameters = defaultParametersConverter
.convert(authorizationGrantRequest);
MultiValueMap<String, String> parametersToSet = parametersConverter.convert(authorizationGrantRequest);
if (parametersToSet != null) {
parameters.putAll(parametersToSet);
}
return parameters;
};
}
this.requestEntityConverter = this::populateRequest;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

package org.springframework.security.oauth2.client.endpoint;

import java.util.Collections;
import java.util.Set;

import reactor.core.publisher.Mono;

import org.springframework.core.convert.converter.Converter;
Expand All @@ -27,16 +24,12 @@
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec;

Expand All @@ -54,6 +47,7 @@
*
* @param <T> type of grant request
* @author Phil Clay
* @author Steve Riesenberg
* @since 5.3
* @see <a href="https://tools.ietf.org/html/rfc6749#section-3.2">RFC-6749 Token
* Endpoint</a>
Expand All @@ -72,7 +66,7 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T

private Converter<T, HttpHeaders> headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>();

private Converter<T, MultiValueMap<String, String>> parametersConverter = this::populateTokenRequestParameters;
private Converter<T, MultiValueMap<String, String>> parametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();

private BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors
.oauth2AccessTokenResponse();
Expand All @@ -86,18 +80,11 @@ public Mono<OAuth2AccessTokenResponse> getTokenResponse(T grantRequest) {
// @formatter:off
return Mono.defer(() -> this.requestEntityConverter.convert(grantRequest)
.exchange()
.flatMap((response) -> readTokenResponse(grantRequest, response))
.flatMap((response) -> response.body(this.bodyExtractor))
);
// @formatter:on
}

/**
* Returns the {@link ClientRegistration} for the given {@code grantRequest}.
* @param grantRequest the grant request
* @return the {@link ClientRegistration} for the given {@code grantRequest}.
*/
abstract ClientRegistration clientRegistration(T grantRequest);

private RequestHeadersSpec<?> validatingPopulateRequest(T grantRequest) {
validateClientAuthenticationMethod(grantRequest);
return populateRequest(grantRequest);
Expand All @@ -117,128 +104,20 @@ private void validateClientAuthenticationMethod(T grantRequest) {
}

private RequestHeadersSpec<?> populateRequest(T grantRequest) {
MultiValueMap<String, String> parameters = this.parametersConverter.convert(grantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}

return this.webClient.post()
.uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri())
.uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri())
.headers((headers) -> {
HttpHeaders headersToAdd = getHeadersConverter().convert(grantRequest);
HttpHeaders headersToAdd = this.headersConverter.convert(grantRequest);
if (headersToAdd != null) {
headers.addAll(headersToAdd);
}
})
.body(createTokenRequestBody(grantRequest));
}

/**
* Populates default parameters for the token request.
* @param grantRequest the grant request
* @return the parameters populated for the token request.
*/
private MultiValueMap<String, String> populateTokenRequestParameters(T grantRequest) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue());
return parameters;
}

/**
* Combine the results of {@code parametersConverter} and
* {@link #populateTokenRequestBody}.
*
* <p>
* This method pre-populates the body with some standard properties, and then
* delegates to
* {@link #populateTokenRequestBody(AbstractOAuth2AuthorizationGrantRequest, BodyInserters.FormInserter)}
* for subclasses to further populate the body before returning.
* </p>
* @param grantRequest the grant request
* @return the body for the token request.
*/
private BodyInserters.FormInserter<String> createTokenRequestBody(T grantRequest) {
MultiValueMap<String, String> parameters = getParametersConverter().convert(grantRequest);
return populateTokenRequestBody(grantRequest, BodyInserters.fromFormData(parameters));
}

/**
* Populates the body of the token request.
*
* <p>
* By default, populates properties that are common to all grant types. Subclasses can
* extend this method to populate grant type specific properties.
* </p>
* @param grantRequest the grant request
* @param body the body to populate
* @return the populated body
*/
BodyInserters.FormInserter<String> populateTokenRequestBody(T grantRequest,
BodyInserters.FormInserter<String> body) {
ClientRegistration clientRegistration = clientRegistration(grantRequest);
if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC
.equals(clientRegistration.getClientAuthenticationMethod())) {
body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) {
body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
Set<String> scopes = scopes(grantRequest);
if (!CollectionUtils.isEmpty(scopes)) {
body.with(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " "));
}
return body;
}

/**
* Returns the scopes to include as a property in the token request.
* @param grantRequest the grant request
* @return the scopes to include as a property in the token request.
*/
abstract Set<String> scopes(T grantRequest);

/**
* Returns the scopes to include in the response if the authorization server returned
* no scopes in the response.
*
* <p>
* As per <a href="https://tools.ietf.org/html/rfc6749#section-5.1">RFC-6749 Section
* 5.1 Successful Access Token Response</a>, if AccessTokenResponse.scope is empty,
* then default to the scope originally requested by the client in the Token Request.
* </p>
* @param grantRequest the grant request
* @return the scopes to include in the response if the authorization server returned
* no scopes.
*/
Set<String> defaultScopes(T grantRequest) {
return Collections.emptySet();
}

/**
* Reads the token response from the response body.
* @param grantRequest the request for which the response was received.
* @param response the client response from which to read
* @return the token response from the response body.
*/
private Mono<OAuth2AccessTokenResponse> readTokenResponse(T grantRequest, ClientResponse response) {
return response.body(this.bodyExtractor)
.map((tokenResponse) -> populateTokenResponse(grantRequest, tokenResponse));
}

/**
* Populates the given {@link OAuth2AccessTokenResponse} with additional details from
* the grant request.
* @param grantRequest the request for which the response was received.
* @param tokenResponse the original token response
* @return a token response optionally populated with additional details from the
* request.
*/
OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessTokenResponse tokenResponse) {
if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) {
Set<String> defaultScopes = defaultScopes(grantRequest);
// @formatter:off
tokenResponse = OAuth2AccessTokenResponse
.withResponse(tokenResponse)
.scopes(defaultScopes)
.build();
// @formatter:on
}
return tokenResponse;
.body(BodyInserters.fromFormData(parameters));
}

/**
Expand All @@ -247,22 +126,11 @@ OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessToke
* @param webClient the {@link WebClient} used when requesting the Access Token
* Response
*/
public void setWebClient(WebClient webClient) {
public final void setWebClient(WebClient webClient) {
Assert.notNull(webClient, "webClient cannot be null");
this.webClient = webClient;
}

/**
* Returns the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
* used in the OAuth 2.0 Access Token Request headers.
* @return the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders}
*/
final Converter<T, HttpHeaders> getHeadersConverter() {
return this.headersConverter;
}

/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
Expand Down Expand Up @@ -305,17 +173,6 @@ public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter
this.requestEntityConverter = this::populateRequest;
}

/**
* Returns the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
* used in the OAuth 2.0 Access Token Request body.
* @return the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link MultiValueMap}
*/
final Converter<T, MultiValueMap<String, String>> getParametersConverter() {
return this.parametersConverter;
}

/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
Expand All @@ -326,7 +183,21 @@ final Converter<T, MultiValueMap<String, String>> getParametersConverter() {
*/
public final void setParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
this.parametersConverter = parametersConverter;
if (parametersConverter instanceof DefaultOAuth2TokenRequestParametersConverter) {
this.parametersConverter = parametersConverter;
}
else {
Converter<T, MultiValueMap<String, String>> defaultParametersConverter = new DefaultOAuth2TokenRequestParametersConverter<>();
this.parametersConverter = (authorizationGrantRequest) -> {
MultiValueMap<String, String> parameters = defaultParametersConverter
.convert(authorizationGrantRequest);
MultiValueMap<String, String> parametersToSet = parametersConverter.convert(authorizationGrantRequest);
if (parametersToSet != null) {
parameters.putAll(parametersToSet);
}
return parameters;
};
}
this.requestEntityConverter = this::populateRequest;
}

Expand Down
Loading