Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose RestOperations in DefaultOAuth2UserService and CustomUserTypesOAuth2UserService #5641

Merged
merged 3 commits into from
Sep 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
*/
package org.springframework.security.oauth2.client.http;

import com.nimbusds.oauth2.sdk.token.BearerTokenError;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.util.StringUtils;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.ResponseErrorHandler;

Expand All @@ -44,10 +48,39 @@ public boolean hasError(ClientHttpResponse response) throws IOException {

@Override
public void handleError(ClientHttpResponse response) throws IOException {
if (HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) {
OAuth2Error oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
if (!HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) {
this.defaultErrorHandler.handleError(response);
}
this.defaultErrorHandler.handleError(response);

// A Bearer Token Error may be in the WWW-Authenticate response header
// See https://tools.ietf.org/html/rfc6750#section-3
OAuth2Error oauth2Error = this.readErrorFromWwwAuthenticate(response.getHeaders());
if (oauth2Error == null) {
oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response);
}

throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

private OAuth2Error readErrorFromWwwAuthenticate(HttpHeaders headers) {
String wwwAuthenticateHeader = headers.getFirst(HttpHeaders.WWW_AUTHENTICATE);
if (!StringUtils.hasText(wwwAuthenticateHeader)) {
return null;
}

BearerTokenError bearerTokenError;
try {
bearerTokenError = BearerTokenError.parse(wwwAuthenticateHeader);
} catch (Exception ex) {
return null;
}

String errorCode = bearerTokenError.getCode() != null ?
bearerTokenError.getCode() : OAuth2ErrorCodes.SERVER_ERROR;
String errorDescription = bearerTokenError.getDescription();
String errorUri = bearerTokenError.getURI() != null ?
bearerTokenError.getURI().toString() : null;

return new OAuth2Error(errorCode, errorDescription, errorUri);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,10 +15,19 @@
*/
package org.springframework.security.oauth2.client.userinfo;

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.util.Assert;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;

import java.util.Collections;
import java.util.LinkedHashMap;
Expand All @@ -39,8 +48,13 @@
* @see ClientRegistration
*/
public class CustomUserTypesOAuth2UserService implements OAuth2UserService<OAuth2UserRequest, OAuth2User> {
private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response";

private final Map<String, Class<? extends OAuth2User>> customUserTypes;
private NimbusUserInfoResponseClient userInfoResponseClient = new NimbusUserInfoResponseClient();

private Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = new OAuth2UserRequestEntityConverter();

private RestOperations restOperations;

/**
* Constructs a {@code CustomUserTypesOAuth2UserService} using the provided parameters.
Expand All @@ -50,6 +64,9 @@ public class CustomUserTypesOAuth2UserService implements OAuth2UserService<OAuth
public CustomUserTypesOAuth2UserService(Map<String, Class<? extends OAuth2User>> customUserTypes) {
Assert.notEmpty(customUserTypes, "customUserTypes cannot be empty");
this.customUserTypes = Collections.unmodifiableMap(new LinkedHashMap<>(customUserTypes));
RestTemplate restTemplate = new RestTemplate();
restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler());
this.restOperations = restTemplate;
}

@Override
Expand All @@ -60,6 +77,49 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic
if ((customUserType = this.customUserTypes.get(registrationId)) == null) {
return null;
}
return this.userInfoResponseClient.getUserInfoResponse(userRequest, customUserType);

RequestEntity<?> request = this.requestEntityConverter.convert(userRequest);

ResponseEntity<? extends OAuth2User> response;
try {
response = this.restOperations.exchange(request, customUserType);
} catch (RestClientException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
}

OAuth2User oauth2User = response.getBody();

return oauth2User;
}

/**
* Sets the {@link Converter} used for converting the {@link OAuth2UserRequest}
* to a {@link RequestEntity} representation of the UserInfo Request.
*
* @since 5.1
* @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the UserInfo Request
*/
public final void setRequestEntityConverter(Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter) {
Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null");
this.requestEntityConverter = requestEntityConverter;
}

/**
* Sets the {@link RestOperations} used when requesting the UserInfo resource.
*
* <p>
* <b>NOTE:</b> At a minimum, the supplied {@code restOperations} must be configured with the following:
* <ol>
* <li>{@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}</li>
* </ol>
*
* @since 5.1
* @param restOperations the {@link RestOperations} used when requesting the UserInfo resource
*/
public final void setRestOperations(RestOperations restOperations) {
Assert.notNull(restOperations, "restOperations cannot be null");
this.restOperations = restOperations;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,16 +16,25 @@
package org.springframework.security.oauth2.client.userinfo;

import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;

import java.util.HashSet;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

Expand All @@ -34,7 +43,7 @@
* <p>
* For standard OAuth 2.0 Provider's, the attribute name used to access the user's name
* from the UserInfo response is required and therefore must be available via
* {@link org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}.
* {@link ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}.
* <p>
* <b>NOTE:</b> Attribute names are <b>not</b> standardized between providers and therefore will vary.
* Please consult the provider's API documentation for the set of supported user attribute names.
Expand All @@ -48,8 +57,23 @@
*/
public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserRequest, OAuth2User> {
private static final String MISSING_USER_INFO_URI_ERROR_CODE = "missing_user_info_uri";

private static final String MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE = "missing_user_name_attribute";
private NimbusUserInfoResponseClient userInfoResponseClient = new NimbusUserInfoResponseClient();

private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response";

private static final ParameterizedTypeReference<Map<String, Object>> PARAMETERIZED_RESPONSE_TYPE =
new ParameterizedTypeReference<Map<String, Object>>() {};

private Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = new OAuth2UserRequestEntityConverter();

private RestOperations restOperations;

public DefaultOAuth2UserService() {
RestTemplate restTemplate = new RestTemplate();
restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler());
this.restOperations = restTemplate;
}

@Override
public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {
Expand All @@ -64,7 +88,8 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName();
String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails()
.getUserInfoEndpoint().getUserNameAttributeName();
if (!StringUtils.hasText(userNameAttributeName)) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE,
Expand All @@ -75,13 +100,63 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

ParameterizedTypeReference<Map<String, Object>> typeReference =
new ParameterizedTypeReference<Map<String, Object>>() {};
Map<String, Object> userAttributes = this.userInfoResponseClient.getUserInfoResponse(userRequest, typeReference);
GrantedAuthority authority = new OAuth2UserAuthority(userAttributes);
Set<GrantedAuthority> authorities = new HashSet<>();
authorities.add(authority);
RequestEntity<?> request = this.requestEntityConverter.convert(userRequest);

ResponseEntity<Map<String, Object>> response;
try {
response = this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE);
} catch (OAuth2AuthenticationException ex) {
OAuth2Error oauth2Error = ex.getError();
StringBuilder errorDetails = new StringBuilder();
errorDetails.append("Error details: [");
errorDetails.append("UserInfo Uri: ").append(
userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri());
errorDetails.append(", Error Code: ").append(oauth2Error.getErrorCode());
if (oauth2Error.getDescription() != null) {
errorDetails.append(", Error Description: ").append(oauth2Error.getDescription());
}
errorDetails.append("]");
oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the UserInfo Resource: " + errorDetails.toString(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
} catch (RestClientException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
}

Map<String, Object> userAttributes = response.getBody();
Set<GrantedAuthority> authorities = Collections.singleton(new OAuth2UserAuthority(userAttributes));

return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName);
}

/**
* Sets the {@link Converter} used for converting the {@link OAuth2UserRequest}
* to a {@link RequestEntity} representation of the UserInfo Request.
*
* @since 5.1
* @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the UserInfo Request
*/
public final void setRequestEntityConverter(Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter) {
Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null");
this.requestEntityConverter = requestEntityConverter;
}

/**
* Sets the {@link RestOperations} used when requesting the UserInfo resource.
*
* <p>
* <b>NOTE:</b> At a minimum, the supplied {@code restOperations} must be configured with the following:
* <ol>
* <li>{@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}</li>
* </ol>
*
* @since 5.1
* @param restOperations the {@link RestOperations} used when requesting the UserInfo resource
*/
public final void setRestOperations(RestOperations restOperations) {
Assert.notNull(restOperations, "restOperations cannot be null");
this.restOperations = restOperations;
}
}
Loading