Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose RestOperations in DefaultOAuth2UserService and CustomUserTypesOAuth2UserService #5641

Merged
merged 3 commits into from
Sep 5, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
*/
package org.springframework.security.oauth2.client.http;

import com.nimbusds.oauth2.sdk.token.BearerTokenError;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.util.StringUtils;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.ResponseErrorHandler;

Expand All @@ -44,10 +48,39 @@ public boolean hasError(ClientHttpResponse response) throws IOException {

@Override
public void handleError(ClientHttpResponse response) throws IOException {
if (HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) {
OAuth2Error oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
if (!HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) {
this.defaultErrorHandler.handleError(response);
}
this.defaultErrorHandler.handleError(response);

// A Bearer Token Error may be in the WWW-Authenticate response header
// See https://tools.ietf.org/html/rfc6750#section-3
OAuth2Error oauth2Error = this.readErrorFromWwwAuthenticate(response.getHeaders());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GitHub doesn't show it, but there is a tab between OAuth2Error oauth2Error. Please change this to a space to be consistent formatting.

if (oauth2Error == null) {
oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response);
}

throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

private OAuth2Error readErrorFromWwwAuthenticate(HttpHeaders headers) {
String wwwAuthenticateHeader = headers.getFirst(HttpHeaders.WWW_AUTHENTICATE);
if (!StringUtils.hasText(wwwAuthenticateHeader)) {
return null;
}

BearerTokenError bearerTokenError;
try {
bearerTokenError = BearerTokenError.parse(wwwAuthenticateHeader);
} catch (Exception ex) {
return null;
}

String errorCode = bearerTokenError.getCode() != null ?
bearerTokenError.getCode() : OAuth2ErrorCodes.SERVER_ERROR;
String errorDescription = bearerTokenError.getDescription();
String errorUri = bearerTokenError.getURI() != null ?
bearerTokenError.getURI().toString() : null;

return new OAuth2Error(errorCode, errorDescription, errorUri);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,16 +16,25 @@
package org.springframework.security.oauth2.client.userinfo;

import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;

import java.util.HashSet;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

Expand All @@ -34,7 +43,7 @@
* <p>
* For standard OAuth 2.0 Provider's, the attribute name used to access the user's name
* from the UserInfo response is required and therefore must be available via
* {@link org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}.
* {@link ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}.
* <p>
* <b>NOTE:</b> Attribute names are <b>not</b> standardized between providers and therefore will vary.
* Please consult the provider's API documentation for the set of supported user attribute names.
Expand All @@ -48,8 +57,23 @@
*/
public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserRequest, OAuth2User> {
private static final String MISSING_USER_INFO_URI_ERROR_CODE = "missing_user_info_uri";

private static final String MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE = "missing_user_name_attribute";
private NimbusUserInfoResponseClient userInfoResponseClient = new NimbusUserInfoResponseClient();

private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response";

private static final ParameterizedTypeReference<Map<String, Object>> PARAMETERIZED_RESPONSE_TYPE =
new ParameterizedTypeReference<Map<String, Object>>() {};

private Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = new OAuth2UserRequestEntityConverter();

private RestOperations restOperations;

public DefaultOAuth2UserService() {
RestTemplate restTemplate = new RestTemplate();
restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler());
this.restOperations = restTemplate;
}

@Override
public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {
Expand All @@ -64,7 +88,8 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName();
String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails()
.getUserInfoEndpoint().getUserNameAttributeName();
if (!StringUtils.hasText(userNameAttributeName)) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE,
Expand All @@ -75,13 +100,63 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

ParameterizedTypeReference<Map<String, Object>> typeReference =
new ParameterizedTypeReference<Map<String, Object>>() {};
Map<String, Object> userAttributes = this.userInfoResponseClient.getUserInfoResponse(userRequest, typeReference);
GrantedAuthority authority = new OAuth2UserAuthority(userAttributes);
Set<GrantedAuthority> authorities = new HashSet<>();
authorities.add(authority);
RequestEntity<?> request = this.requestEntityConverter.convert(userRequest);

ResponseEntity<Map<String, Object>> response;
try {
response = this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE);
} catch (OAuth2AuthenticationException ex) {
OAuth2Error oauth2Error = ex.getError();
StringBuilder errorDetails = new StringBuilder();
errorDetails.append("Error details: [");
errorDetails.append("UserInfo Uri: ").append(
userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri());
errorDetails.append(", Error Code: ").append(oauth2Error.getErrorCode());
if (oauth2Error.getDescription() != null) {
errorDetails.append(", Error Description: ").append(oauth2Error.getDescription());
}
errorDetails.append("]");
oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the UserInfo Resource: " + errorDetails.toString(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
} catch (RestClientException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
}

Map<String, Object> userAttributes = response.getBody();
Set<GrantedAuthority> authorities = Collections.singleton(new OAuth2UserAuthority(userAttributes));

return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName);
}

/**
* Sets the {@link Converter} used for converting the {@link OAuth2UserRequest}
* to a {@link RequestEntity} representation of the UserInfo Request.
*
* @since 5.1
* @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the UserInfo Request
*/
public final void setRequestEntityConverter(Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter) {
Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null");
this.requestEntityConverter = requestEntityConverter;
}

/**
* Sets the {@link RestOperations} used when requesting the UserInfo resource.
*
* <p>
* <b>NOTE:</b> At a minimum, the supplied {@code restOperations} must be configured with the following:
* <ol>
* <li>{@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}</li>
* </ol>
*
* @since 5.1
* @param restOperations the {@link RestOperations} used when requesting the UserInfo resource
*/
public final void setRestOperations(RestOperations restOperations) {
Assert.notNull(restOperations, "restOperations cannot be null");
this.restOperations = restOperations;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.userinfo;

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.util.UriComponentsBuilder;

import java.net.URI;
import java.util.Collections;

import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;

/**
* A {@link Converter} that converts the provided {@link OAuth2UserRequest}
* to a {@link RequestEntity} representation of a request for the UserInfo Endpoint.
*
* @author Joe Grandja
* @since 5.1
* @see Converter
* @see OAuth2UserRequest
* @see RequestEntity
*/
public class OAuth2UserRequestEntityConverter implements Converter<OAuth2UserRequest, RequestEntity<?>> {
private static final MediaType DEFAULT_CONTENT_TYPE = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8");

/**
* Returns the {@link RequestEntity} used for the UserInfo Request.
*
* @param userRequest the user request
* @return the {@link RequestEntity} used for the UserInfo Request
*/
@Override
public RequestEntity<?> convert(OAuth2UserRequest userRequest) {
ClientRegistration clientRegistration = userRequest.getClientRegistration();

HttpMethod httpMethod = HttpMethod.GET;
if (AuthenticationMethod.FORM.equals(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod())) {
httpMethod = HttpMethod.POST;
}
HttpHeaders headers = new HttpHeaders();
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri())
.build()
.toUri();

RequestEntity<?> request;
if (HttpMethod.POST.equals(httpMethod)) {
headers.setContentType(DEFAULT_CONTENT_TYPE);
MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
formParameters.add(OAuth2ParameterNames.ACCESS_TOKEN, userRequest.getAccessToken().getTokenValue());
request = new RequestEntity<>(formParameters, headers, httpMethod, uri);
} else {
headers.setBearerAuth(userRequest.getAccessToken().getTokenValue());
request = new RequestEntity<>(headers, httpMethod, uri);
}

return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.http;

import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
Expand All @@ -31,7 +32,7 @@ public class OAuth2ErrorResponseErrorHandlerTests {
private OAuth2ErrorResponseErrorHandler errorHandler = new OAuth2ErrorResponseErrorHandler();

@Test
public void handleErrorWhenStatusBadRequestThenHandled() {
public void handleErrorWhenErrorResponseBodyThenHandled() {
String errorResponse = "{\n" +
" \"error\": \"unauthorized_client\",\n" +
" \"error_description\": \"The client is not authorized\"\n" +
Expand All @@ -44,4 +45,17 @@ public void handleErrorWhenStatusBadRequestThenHandled() {
.isInstanceOf(OAuth2AuthenticationException.class)
.hasMessage("[unauthorized_client] The client is not authorized");
}

@Test
public void handleErrorWhenErrorResponseWwwAuthenticateHeaderThenHandled() {
String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\"";

MockClientHttpResponse response = new MockClientHttpResponse(
new byte[0], HttpStatus.BAD_REQUEST);
response.getHeaders().add(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader);

assertThatThrownBy(() -> this.errorHandler.handleError(response))
.isInstanceOf(OAuth2AuthenticationException.class)
.hasMessage("[insufficient_scope] The access token expired");
}
}
Loading