From b85a42f94008d7108c857310bb70d00d2a31f17e Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 28 Aug 2018 07:53:49 -0400 Subject: [PATCH 1/3] Provide RestOperations in DefaultOAuth2UserService Fixes gh-5600 --- .../http/OAuth2ErrorResponseErrorHandler.java | 41 ++++- .../userinfo/DefaultOAuth2UserService.java | 97 ++++++++-- .../OAuth2UserRequestEntityConverter.java | 81 ++++++++ .../OAuth2ErrorResponseErrorHandlerTests.java | 16 +- .../oidc/userinfo/OidcUserServiceTests.java | 166 ++++++----------- .../DefaultOAuth2UserServiceTests.java | 173 ++++++++++-------- ...OAuth2UserRequestEntityConverterTests.java | 125 +++++++++++++ 7 files changed, 497 insertions(+), 202 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java index dca68594025..a6ec8d13124 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java @@ -15,11 +15,15 @@ */ package org.springframework.security.oauth2.client.http; +import com.nimbusds.oauth2.sdk.token.BearerTokenError; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.client.ClientHttpResponse; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.util.StringUtils; import org.springframework.web.client.DefaultResponseErrorHandler; import org.springframework.web.client.ResponseErrorHandler; @@ -44,10 +48,39 @@ public boolean hasError(ClientHttpResponse response) throws IOException { @Override public void handleError(ClientHttpResponse response) throws IOException { - if (HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) { - OAuth2Error oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + if (!HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) { + this.defaultErrorHandler.handleError(response); } - this.defaultErrorHandler.handleError(response); + + // A Bearer Token Error may be in the WWW-Authenticate response header + // See https://tools.ietf.org/html/rfc6750#section-3 + OAuth2Error oauth2Error = this.readErrorFromWwwAuthenticate(response.getHeaders()); + if (oauth2Error == null) { + oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response); + } + + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + + private OAuth2Error readErrorFromWwwAuthenticate(HttpHeaders headers) { + String wwwAuthenticateHeader = headers.getFirst(HttpHeaders.WWW_AUTHENTICATE); + if (!StringUtils.hasText(wwwAuthenticateHeader)) { + return null; + } + + BearerTokenError bearerTokenError; + try { + bearerTokenError = BearerTokenError.parse(wwwAuthenticateHeader); + } catch (Exception ex) { + return null; + } + + String errorCode = bearerTokenError.getCode() != null ? + bearerTokenError.getCode() : OAuth2ErrorCodes.SERVER_ERROR; + String errorDescription = bearerTokenError.getDescription(); + String errorUri = bearerTokenError.getURI() != null ? + bearerTokenError.getURI().toString() : null; + + return new OAuth2Error(errorCode, errorDescription, errorUri); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java index 8d211c000db..6a97b7f449d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,12 @@ package org.springframework.security.oauth2.client.userinfo; import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; @@ -24,8 +29,12 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; -import java.util.HashSet; +import java.util.Collections; import java.util.Map; import java.util.Set; @@ -34,7 +43,7 @@ *

* For standard OAuth 2.0 Provider's, the attribute name used to access the user's name * from the UserInfo response is required and therefore must be available via - * {@link org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}. + * {@link ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}. *

* NOTE: Attribute names are not standardized between providers and therefore will vary. * Please consult the provider's API documentation for the set of supported user attribute names. @@ -48,8 +57,23 @@ */ public class DefaultOAuth2UserService implements OAuth2UserService { private static final String MISSING_USER_INFO_URI_ERROR_CODE = "missing_user_info_uri"; + private static final String MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE = "missing_user_name_attribute"; - private NimbusUserInfoResponseClient userInfoResponseClient = new NimbusUserInfoResponseClient(); + + private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response"; + + private static final ParameterizedTypeReference> PARAMETERIZED_RESPONSE_TYPE = + new ParameterizedTypeReference>() {}; + + private Converter> requestEntityConverter = new OAuth2UserRequestEntityConverter(); + + private RestOperations restOperations; + + public DefaultOAuth2UserService() { + RestTemplate restTemplate = new RestTemplate(); + restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); + this.restOperations = restTemplate; + } @Override public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException { @@ -64,7 +88,8 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic ); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName(); + String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails() + .getUserInfoEndpoint().getUserNameAttributeName(); if (!StringUtils.hasText(userNameAttributeName)) { OAuth2Error oauth2Error = new OAuth2Error( MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE, @@ -75,13 +100,63 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - ParameterizedTypeReference> typeReference = - new ParameterizedTypeReference>() {}; - Map userAttributes = this.userInfoResponseClient.getUserInfoResponse(userRequest, typeReference); - GrantedAuthority authority = new OAuth2UserAuthority(userAttributes); - Set authorities = new HashSet<>(); - authorities.add(authority); + RequestEntity request = this.requestEntityConverter.convert(userRequest); + + ResponseEntity> response; + try { + response = this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE); + } catch (OAuth2AuthenticationException ex) { + OAuth2Error oauth2Error = ex.getError(); + StringBuilder errorDetails = new StringBuilder(); + errorDetails.append("Error details: ["); + errorDetails.append("UserInfo Uri: ").append( + userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri()); + errorDetails.append(", Error Code: ").append(oauth2Error.getErrorCode()); + if (oauth2Error.getDescription() != null) { + errorDetails.append(", Error Description: ").append(oauth2Error.getDescription()); + } + errorDetails.append("]"); + oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the UserInfo Resource: " + errorDetails.toString(), null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); + } catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); + } + + Map userAttributes = response.getBody(); + Set authorities = Collections.singleton(new OAuth2UserAuthority(userAttributes)); return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName); } + + /** + * Sets the {@link Converter} used for converting the {@link OAuth2UserRequest} + * to a {@link RequestEntity} representation of the UserInfo Request. + * + * @since 5.1 + * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the UserInfo Request + */ + public final void setRequestEntityConverter(Converter> requestEntityConverter) { + Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); + this.requestEntityConverter = requestEntityConverter; + } + + /** + * Sets the {@link RestOperations} used when requesting the UserInfo resource. + * + *

+ * NOTE: At a minimum, the supplied {@code restOperations} must be configured with the following: + *

    + *
  1. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
  2. + *
+ * + * @since 5.1 + * @param restOperations the {@link RestOperations} used when requesting the UserInfo resource + */ + public final void setRestOperations(RestOperations restOperations) { + Assert.notNull(restOperations, "restOperations cannot be null"); + this.restOperations = restOperations; + } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java new file mode 100644 index 00000000000..777ac39004d --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.userinfo; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.util.UriComponentsBuilder; + +import java.net.URI; +import java.util.Collections; + +import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; + +/** + * A {@link Converter} that converts the provided {@link OAuth2UserRequest} + * to a {@link RequestEntity} representation of a request for the UserInfo Endpoint. + * + * @author Joe Grandja + * @since 5.1 + * @see Converter + * @see OAuth2UserRequest + * @see RequestEntity + */ +public class OAuth2UserRequestEntityConverter implements Converter> { + private static final MediaType DEFAULT_CONTENT_TYPE = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + + /** + * Returns the {@link RequestEntity} used for the UserInfo Request. + * + * @param userRequest the user request + * @return the {@link RequestEntity} used for the UserInfo Request + */ + @Override + public RequestEntity convert(OAuth2UserRequest userRequest) { + ClientRegistration clientRegistration = userRequest.getClientRegistration(); + + HttpMethod httpMethod = HttpMethod.GET; + if (AuthenticationMethod.FORM.equals(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod())) { + httpMethod = HttpMethod.POST; + } + HttpHeaders headers = new HttpHeaders(); + headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8)); + URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()) + .build() + .toUri(); + + RequestEntity request; + if (HttpMethod.POST.equals(httpMethod)) { + headers.setContentType(DEFAULT_CONTENT_TYPE); + MultiValueMap formParameters = new LinkedMultiValueMap<>(); + formParameters.add(OAuth2ParameterNames.ACCESS_TOKEN, userRequest.getAccessToken().getTokenValue()); + request = new RequestEntity<>(formParameters, headers, httpMethod, uri); + } else { + headers.setBearerAuth(userRequest.getAccessToken().getTokenValue()); + request = new RequestEntity<>(headers, httpMethod, uri); + } + + return request; + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandlerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandlerTests.java index 5dc310f3a59..122241f5bc4 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandlerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandlerTests.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.client.http; import org.junit.Test; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; @@ -31,7 +32,7 @@ public class OAuth2ErrorResponseErrorHandlerTests { private OAuth2ErrorResponseErrorHandler errorHandler = new OAuth2ErrorResponseErrorHandler(); @Test - public void handleErrorWhenStatusBadRequestThenHandled() { + public void handleErrorWhenErrorResponseBodyThenHandled() { String errorResponse = "{\n" + " \"error\": \"unauthorized_client\",\n" + " \"error_description\": \"The client is not authorized\"\n" + @@ -44,4 +45,17 @@ public void handleErrorWhenStatusBadRequestThenHandled() { .isInstanceOf(OAuth2AuthenticationException.class) .hasMessage("[unauthorized_client] The client is not authorized"); } + + @Test + public void handleErrorWhenErrorResponseWwwAuthenticateHeaderThenHandled() { + String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\""; + + MockClientHttpResponse response = new MockClientHttpResponse( + new byte[0], HttpStatus.BAD_REQUEST); + response.getHeaders().add(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader); + + assertThatThrownBy(() -> this.errorHandler.handleError(response)) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessage("[insufficient_scope] The access token expired"); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index e388a800fdf..786298d7f6a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -18,6 +18,7 @@ import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; +import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -29,7 +30,6 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; -import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; import org.springframework.security.oauth2.core.AuthenticationMethod; @@ -71,12 +71,15 @@ public class OidcUserServiceTests { private OAuth2AccessToken accessToken; private OidcIdToken idToken; private OidcUserService userService = new OidcUserService(); + private MockWebServer server; @Rule public ExpectedException exception = ExpectedException.none(); @Before - public void setUp() throws Exception { + public void setup() throws Exception { + this.server = new MockWebServer(); + this.server.start(); this.clientRegistration = mock(ClientRegistration.class); this.providerDetails = mock(ClientRegistration.ProviderDetails.class); this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class); @@ -101,6 +104,11 @@ public void setUp() throws Exception { this.userService.setOauth2UserService(new DefaultOAuth2UserService()); } + @After + public void cleanup() throws Exception { + this.server.shutdown(); + } + @Test public void setOauth2UserServiceWhenNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.userService.setOauth2UserService(null)) @@ -135,9 +143,7 @@ public void loadUserWhenAuthorizedScopesDoesNotContainUserInfoScopesThenUserInfo } @Test - public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception { - MockWebServer server = new MockWebServer(); - + public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" + @@ -146,13 +152,9 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); + this.server.enqueue(jsonResponse(userInfoResponse)); - server.start(); - - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.accessToken.getTokenValue()).thenReturn("access-token"); @@ -160,8 +162,6 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception OidcUser user = this.userService.loadUser( new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); - server.shutdown(); - assertThat(user.getIdToken()).isNotNull(); assertThat(user.getUserInfo()).isNotNull(); assertThat(user.getUserInfo().getClaims().size()).isEqualTo(6); @@ -184,69 +184,47 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception // gh-5447 @Test - public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() throws Exception { + public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_user_info_response")); - MockWebServer server = new MockWebServer(); - String userInfoResponse = "{\n" + " \"email\": \"full_name@provider.com\",\n" + " \"name\": \"full name\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); + this.server.enqueue(jsonResponse(userInfoResponse)); - server.start(); - - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL); when(this.accessToken.getTokenValue()).thenReturn("access-token"); - try { - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); - } finally { - server.shutdown(); - } + this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); } @Test - public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() throws Exception { + public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_user_info_response")); - MockWebServer server = new MockWebServer(); - String userInfoResponse = "{\n" + " \"sub\": \"other-subject\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); - - server.start(); + this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.accessToken.getTokenValue()).thenReturn("access-token"); - try { - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); - } finally { - server.shutdown(); - } + this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); } @Test - public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception { + public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_user_info_response")); - - MockWebServer server = new MockWebServer(); + this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + @@ -256,48 +234,35 @@ public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2Authenticat " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n"; // "}\n"; // Make the JSON invalid/malformed - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); - - server.start(); + this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.accessToken.getTokenValue()).thenReturn("access-token"); - try { - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); - } finally { - server.shutdown(); - } + this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); } @Test - public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception { + public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_user_info_response")); + this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); - MockWebServer server = new MockWebServer(); - server.enqueue(new MockResponse().setResponseCode(500)); - server.start(); + this.server.enqueue(new MockResponse().setResponseCode(500)); String userInfoUri = server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.accessToken.getTokenValue()).thenReturn("access-token"); - try { - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); - } finally { - server.shutdown(); - } + this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); } @Test - public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception { - this.exception.expect(AuthenticationServiceException.class); + public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { + this.exception.expect(OAuth2AuthenticationException.class); + this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoUri = "http://invalid-provider.com/user"; @@ -308,9 +273,7 @@ public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceExceptio } @Test - public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() throws Exception { - MockWebServer server = new MockWebServer(); - + public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() { String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" + @@ -319,13 +282,9 @@ public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserN " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); + this.server.enqueue(jsonResponse(userInfoResponse)); - server.start(); - - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL); @@ -334,16 +293,12 @@ public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserN OidcUser user = this.userService.loadUser( new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); - server.shutdown(); - assertThat(user.getName()).isEqualTo("user1@example.com"); } // gh-5294 @Test public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception { - MockWebServer server = new MockWebServer(); - String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" + @@ -352,28 +307,21 @@ public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exc " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); - - server.start(); + this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.accessToken.getTokenValue()).thenReturn("access-token"); this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); - server.shutdown(); - assertThat(server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT)) - .isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT)) + .isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); } // gh-5500 @Test public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception { - MockWebServer server = new MockWebServer(); - String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" + @@ -382,31 +330,24 @@ public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodG " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); + this.server.enqueue(jsonResponse(userInfoResponse)); - server.start(); - - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); when(this.accessToken.getTokenValue()).thenReturn("access-token"); this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); - server.shutdown(); - RecordedRequest request = server.takeRequest(); + RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name()); - assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue()); } // gh-5500 @Test public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception { - MockWebServer server = new MockWebServer(); - String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" + @@ -415,24 +356,25 @@ public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPos " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); + this.server.enqueue(jsonResponse(userInfoResponse)); - server.start(); - - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM); when(this.accessToken.getTokenValue()).thenReturn("access-token"); this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); - server.shutdown(); - RecordedRequest request = server.takeRequest(); + RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name()); - assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE)).contains(MediaType.APPLICATION_FORM_URLENCODED_VALUE); assertThat(request.getBody().readUtf8()).isEqualTo("access_token=" + this.accessToken.getTokenValue()); } + + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java index 99ca960c2ee..dd5a5dd8913 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java @@ -18,7 +18,7 @@ import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; - +import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -30,7 +30,6 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; -import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -59,12 +58,15 @@ public class DefaultOAuth2UserServiceTests { private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint; private OAuth2AccessToken accessToken; private DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); + private MockWebServer server; @Rule public ExpectedException exception = ExpectedException.none(); @Before - public void setUp() throws Exception { + public void setup() throws Exception { + this.server = new MockWebServer(); + this.server.start(); this.clientRegistration = mock(ClientRegistration.class); this.providerDetails = mock(ClientRegistration.ProviderDetails.class); this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class); @@ -73,6 +75,23 @@ public void setUp() throws Exception { this.accessToken = mock(OAuth2AccessToken.class); } + @After + public void cleanup() throws Exception { + this.server.shutdown(); + } + + @Test + public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() { + this.exception.expect(IllegalArgumentException.class); + this.userService.setRequestEntityConverter(null); + } + + @Test + public void setRestOperationsWhenNullThenThrowIllegalArgumentException() { + this.exception.expect(IllegalArgumentException.class); + this.userService.setRestOperations(null); + } + @Test public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { this.exception.expect(IllegalArgumentException.class); @@ -99,9 +118,7 @@ public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2Authentication } @Test - public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception { - MockWebServer server = new MockWebServer(); - + public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { String userInfoResponse = "{\n" + " \"user-name\": \"user1\",\n" + " \"first-name\": \"first\",\n" + @@ -110,13 +127,9 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); - - server.start(); + this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); @@ -125,8 +138,6 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); - server.shutdown(); - assertThat(user.getName()).isEqualTo("user1"); assertThat(user.getAttributes().size()).isEqualTo(6); assertThat(user.getAttributes().get("user-name")).isEqualTo("user1"); @@ -144,11 +155,9 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception } @Test - public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception { + public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_user_info_response")); - - MockWebServer server = new MockWebServer(); + this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoResponse = "{\n" + " \"user-name\": \"user1\",\n" + @@ -158,52 +167,83 @@ public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2Authenticat " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n"; // "}\n"; // Make the JSON invalid/malformed - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); + this.server.enqueue(jsonResponse(userInfoResponse)); + + String userInfoUri = this.server.url("/user").toString(); + + when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); + when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); + when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); + when(this.accessToken.getTokenValue()).thenReturn("access-token"); + + this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + } + + @Test + public void loadUserWhenUserInfoErrorResponseWwwAuthenticateHeaderThenThrowOAuth2AuthenticationException() { + this.exception.expect(OAuth2AuthenticationException.class); + this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); + this.exception.expectMessage(containsString("Error Code: insufficient_scope, Error Description: The access token expired")); - server.start(); + String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\""; - String userInfoUri = server.url("/user").toString(); + MockResponse response = new MockResponse(); + response.setHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader); + response.setResponseCode(400); + this.server.enqueue(response); + + String userInfoUri = this.server.url("/user").toString(); + + when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); + when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); + when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); + when(this.accessToken.getTokenValue()).thenReturn("access-token"); + + this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + } + + @Test + public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() { + this.exception.expect(OAuth2AuthenticationException.class); + this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); + this.exception.expectMessage(containsString("Error Code: invalid_token")); + + String userInfoErrorResponse = "{\n" + + " \"error\": \"invalid_token\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(userInfoErrorResponse).setResponseCode(400)); + + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); when(this.accessToken.getTokenValue()).thenReturn("access-token"); - try { - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); - } finally { - server.shutdown(); - } + this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); } @Test - public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception { + public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_user_info_response")); + this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); - MockWebServer server = new MockWebServer(); - server.enqueue(new MockResponse().setResponseCode(500)); - server.start(); + this.server.enqueue(new MockResponse().setResponseCode(500)); - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); when(this.accessToken.getTokenValue()).thenReturn("access-token"); - try { - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); - } finally { - server.shutdown(); - } + this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); } @Test - public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception { - this.exception.expect(AuthenticationServiceException.class); + public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { + this.exception.expect(OAuth2AuthenticationException.class); + this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoUri = "http://invalid-provider.com/user"; @@ -218,8 +258,6 @@ public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceExceptio // gh-5294 @Test public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception { - MockWebServer server = new MockWebServer(); - String userInfoResponse = "{\n" + " \"user-name\": \"user1\",\n" + " \"first-name\": \"first\",\n" + @@ -228,13 +266,9 @@ public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exc " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); - - server.start(); + this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); @@ -242,16 +276,13 @@ public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exc when(this.accessToken.getTokenValue()).thenReturn("access-token"); this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); - server.shutdown(); - assertThat(server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT)) - .isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT)) + .isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); } // gh-5500 @Test public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception { - MockWebServer server = new MockWebServer(); - String userInfoResponse = "{\n" + " \"user-name\": \"user1\",\n" + " \"first-name\": \"first\",\n" + @@ -260,13 +291,9 @@ public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodG " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); - - server.start(); + this.server.enqueue(jsonResponse(userInfoResponse)); - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); @@ -274,18 +301,15 @@ public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodG when(this.accessToken.getTokenValue()).thenReturn("access-token"); this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); - server.shutdown(); - RecordedRequest request = server.takeRequest(); + RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name()); - assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue()); } // gh-5500 @Test public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception { - MockWebServer server = new MockWebServer(); - String userInfoResponse = "{\n" + " \"user-name\": \"user1\",\n" + " \"first-name\": \"first\",\n" + @@ -294,13 +318,9 @@ public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPos " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); + this.server.enqueue(jsonResponse(userInfoResponse)); - server.start(); - - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM); @@ -308,11 +328,16 @@ public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPos when(this.accessToken.getTokenValue()).thenReturn("access-token"); this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); - server.shutdown(); - RecordedRequest request = server.takeRequest(); + RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name()); - assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE)).contains(MediaType.APPLICATION_FORM_URLENCODED_VALUE); assertThat(request.getBody().readUtf8()).isEqualTo("access_token=" + this.accessToken.getTokenValue()); } + + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java new file mode 100644 index 00000000000..b798c1a786e --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java @@ -0,0 +1,125 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.userinfo; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthenticationMethod; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.MultiValueMap; + +import java.time.Instant; +import java.util.Arrays; +import java.util.LinkedHashSet; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; + +/** + * Tests for {@link OAuth2UserRequestEntityConverter}. + * + * @author Joe Grandja + */ +public class OAuth2UserRequestEntityConverterTests { + private OAuth2UserRequestEntityConverter converter = new OAuth2UserRequestEntityConverter(); + private OAuth2UserRequest userRequest; + + @Before + public void setup() { + ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate("https://client.com/callback/client-1") + .scope("read", "write") + .authorizationUri("https://provider.com/oauth2/authorize") + .tokenUri("https://provider.com/oauth2/token") + .userInfoUri("https://provider.com/user") + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("id") + .build(); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "access-token-1234", Instant.now(), + Instant.now().plusSeconds(3600), new LinkedHashSet<>(Arrays.asList("read", "write"))); + this.userRequest = new OAuth2UserRequest(clientRegistration, accessToken); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenAuthenticationMethodHeaderThenGetRequest() { + RequestEntity requestEntity = this.converter.convert(this.userRequest); + + ClientRegistration clientRegistration = this.userRequest.getClientRegistration(); + + assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.GET); + assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( + clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()); + + HttpHeaders headers = requestEntity.getHeaders(); + assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); + assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo( + "Bearer " + this.userRequest.getAccessToken().getTokenValue()); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenAuthenticationMethodFormThenPostRequest() { + ClientRegistration clientRegistration = this.from(this.userRequest.getClientRegistration()) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .build(); + OAuth2UserRequest userRequest = new OAuth2UserRequest( + clientRegistration, this.userRequest.getAccessToken()); + + RequestEntity requestEntity = this.converter.convert(userRequest); + + assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); + assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( + clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()); + + HttpHeaders headers = requestEntity.getHeaders(); + assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); + assertThat(headers.getContentType()).isEqualTo( + MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.ACCESS_TOKEN)).isEqualTo( + this.userRequest.getAccessToken().getTokenValue()); + } + + private ClientRegistration.Builder from(ClientRegistration registration) { + return ClientRegistration.withRegistrationId(registration.getRegistrationId()) + .clientId(registration.getClientId()) + .clientSecret(registration.getClientSecret()) + .clientAuthenticationMethod(registration.getClientAuthenticationMethod()) + .authorizationGrantType(registration.getAuthorizationGrantType()) + .redirectUriTemplate(registration.getRedirectUriTemplate()) + .scope(registration.getScopes()) + .authorizationUri(registration.getProviderDetails().getAuthorizationUri()) + .tokenUri(registration.getProviderDetails().getTokenUri()) + .userInfoUri(registration.getProviderDetails().getUserInfoEndpoint().getUri()) + .userNameAttributeName(registration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName()) + .clientName(registration.getClientName()); + } +} From 86dc84a85973b5a5f469a8dde9b5a36c10ff31f3 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 28 Aug 2018 08:48:23 -0400 Subject: [PATCH 2/3] Provide RestOperations in CustomUserTypesOAuth2UserService Fixes gh-5602 --- .../CustomUserTypesOAuth2UserService.java | 66 ++++++- .../NimbusUserInfoResponseClient.java | 169 ------------------ ...CustomUserTypesOAuth2UserServiceTests.java | 85 ++++----- 3 files changed, 107 insertions(+), 213 deletions(-) delete mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/NimbusUserInfoResponseClient.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java index 7322fa36f4d..034ed05e4e3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,10 +15,19 @@ */ package org.springframework.security.oauth2.client.userinfo; +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; import java.util.Collections; import java.util.LinkedHashMap; @@ -39,8 +48,13 @@ * @see ClientRegistration */ public class CustomUserTypesOAuth2UserService implements OAuth2UserService { + private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response"; + private final Map> customUserTypes; - private NimbusUserInfoResponseClient userInfoResponseClient = new NimbusUserInfoResponseClient(); + + private Converter> requestEntityConverter = new OAuth2UserRequestEntityConverter(); + + private RestOperations restOperations; /** * Constructs a {@code CustomUserTypesOAuth2UserService} using the provided parameters. @@ -50,6 +64,9 @@ public class CustomUserTypesOAuth2UserService implements OAuth2UserService> customUserTypes) { Assert.notEmpty(customUserTypes, "customUserTypes cannot be empty"); this.customUserTypes = Collections.unmodifiableMap(new LinkedHashMap<>(customUserTypes)); + RestTemplate restTemplate = new RestTemplate(); + restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); + this.restOperations = restTemplate; } @Override @@ -60,6 +77,49 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic if ((customUserType = this.customUserTypes.get(registrationId)) == null) { return null; } - return this.userInfoResponseClient.getUserInfoResponse(userRequest, customUserType); + + RequestEntity request = this.requestEntityConverter.convert(userRequest); + + ResponseEntity response; + try { + response = this.restOperations.exchange(request, customUserType); + } catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); + } + + OAuth2User oauth2User = response.getBody(); + + return oauth2User; + } + + /** + * Sets the {@link Converter} used for converting the {@link OAuth2UserRequest} + * to a {@link RequestEntity} representation of the UserInfo Request. + * + * @since 5.1 + * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the UserInfo Request + */ + public final void setRequestEntityConverter(Converter> requestEntityConverter) { + Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); + this.requestEntityConverter = requestEntityConverter; + } + + /** + * Sets the {@link RestOperations} used when requesting the UserInfo resource. + * + *

+ * NOTE: At a minimum, the supplied {@code restOperations} must be configured with the following: + *

    + *
  1. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
  2. + *
+ * + * @since 5.1 + * @param restOperations the {@link RestOperations} used when requesting the UserInfo resource + */ + public final void setRestOperations(RestOperations restOperations) { + Assert.notNull(restOperations, "restOperations cannot be null"); + this.restOperations = restOperations; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/NimbusUserInfoResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/NimbusUserInfoResponseClient.java deleted file mode 100644 index cbbe7597882..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/NimbusUserInfoResponseClient.java +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Copyright 2002-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.client.userinfo; - -import com.nimbusds.oauth2.sdk.ErrorObject; -import com.nimbusds.oauth2.sdk.ParseException; -import com.nimbusds.oauth2.sdk.http.HTTPRequest; -import com.nimbusds.oauth2.sdk.http.HTTPResponse; -import com.nimbusds.oauth2.sdk.token.BearerAccessToken; -import com.nimbusds.openid.connect.sdk.UserInfoErrorResponse; -import com.nimbusds.openid.connect.sdk.UserInfoRequest; -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.HttpHeaders; -import org.springframework.http.MediaType; -import org.springframework.http.client.AbstractClientHttpResponse; -import org.springframework.http.client.ClientHttpResponse; -import org.springframework.http.converter.GenericHttpMessageConverter; -import org.springframework.http.converter.HttpMessageNotReadableException; -import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; -import org.springframework.security.authentication.AuthenticationServiceException; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.AuthenticationMethod; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.util.Assert; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.net.URI; -import java.nio.charset.Charset; - -/** - * @author Joe Grandja - * @since 5.0 - */ -final class NimbusUserInfoResponseClient { - private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response"; - private final GenericHttpMessageConverter genericHttpMessageConverter = new MappingJackson2HttpMessageConverter(); - - T getUserInfoResponse(OAuth2UserRequest userInfoRequest, Class returnType) throws OAuth2AuthenticationException { - ClientHttpResponse userInfoResponse = this.getUserInfoResponse( - userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken()); - try { - return (T) this.genericHttpMessageConverter.read(returnType, userInfoResponse); - } catch (IOException | HttpMessageNotReadableException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, - "An error occurred reading the UserInfo Success response: " + ex.getMessage(), null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); - } - } - - T getUserInfoResponse(OAuth2UserRequest userInfoRequest, ParameterizedTypeReference typeReference) throws OAuth2AuthenticationException { - ClientHttpResponse userInfoResponse = this.getUserInfoResponse( - userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken()); - try { - return (T) this.genericHttpMessageConverter.read(typeReference.getType(), null, userInfoResponse); - } catch (IOException | HttpMessageNotReadableException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, - "An error occurred reading the UserInfo Success response: " + ex.getMessage(), null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); - } - } - - private ClientHttpResponse getUserInfoResponse(ClientRegistration clientRegistration, - OAuth2AccessToken oauth2AccessToken) throws OAuth2AuthenticationException { - URI userInfoUri = URI.create(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()); - BearerAccessToken accessToken = new BearerAccessToken(oauth2AccessToken.getTokenValue()); - AuthenticationMethod authenticationMethod = clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod(); - HTTPRequest.Method httpMethod = AuthenticationMethod.FORM.equals(authenticationMethod) - ? HTTPRequest.Method.POST : HTTPRequest.Method.GET; - - UserInfoRequest userInfoRequest = new UserInfoRequest(userInfoUri, httpMethod, accessToken); - HTTPRequest httpRequest = userInfoRequest.toHTTPRequest(); - httpRequest.setAccept(MediaType.APPLICATION_JSON_VALUE); - httpRequest.setConnectTimeout(30000); - httpRequest.setReadTimeout(30000); - HTTPResponse httpResponse; - - try { - httpResponse = httpRequest.send(); - } catch (IOException ex) { - throw new AuthenticationServiceException("An error occurred while sending the UserInfo Request: " + - ex.getMessage(), ex); - } - - if (httpResponse.getStatusCode() == HTTPResponse.SC_OK) { - return new NimbusClientHttpResponse(httpResponse); - } - - UserInfoErrorResponse userInfoErrorResponse; - try { - userInfoErrorResponse = UserInfoErrorResponse.parse(httpResponse); - } catch (ParseException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, - "An error occurred parsing the UserInfo Error response: " + ex.getMessage(), null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); - } - ErrorObject errorObject = userInfoErrorResponse.getErrorObject(); - - StringBuilder errorDescription = new StringBuilder(); - errorDescription.append("An error occurred while attempting to access the UserInfo Endpoint -> "); - errorDescription.append("Error details: ["); - errorDescription.append("UserInfo Uri: ").append(userInfoUri.toString()); - errorDescription.append(", Http Status: ").append(errorObject.getHTTPStatusCode()); - if (errorObject.getCode() != null) { - errorDescription.append(", Error Code: ").append(errorObject.getCode()); - } - if (errorObject.getDescription() != null) { - errorDescription.append(", Error Description: ").append(errorObject.getDescription()); - } - errorDescription.append("]"); - - OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, errorDescription.toString(), null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - - private static class NimbusClientHttpResponse extends AbstractClientHttpResponse { - private final HTTPResponse httpResponse; - private final HttpHeaders headers; - - private NimbusClientHttpResponse(HTTPResponse httpResponse) { - Assert.notNull(httpResponse, "httpResponse cannot be null"); - this.httpResponse = httpResponse; - this.headers = new HttpHeaders(); - this.headers.setAll(httpResponse.getHeaders()); - } - - @Override - public int getRawStatusCode() throws IOException { - return this.httpResponse.getStatusCode(); - } - - @Override - public String getStatusText() throws IOException { - return String.valueOf(this.getRawStatusCode()); - } - - @Override - public void close() { - } - - @Override - public InputStream getBody() throws IOException { - InputStream inputStream = new ByteArrayInputStream( - this.httpResponse.getContent().getBytes(Charset.forName("UTF-8"))); - return inputStream; - } - - @Override - public HttpHeaders getHeaders() { - return this.headers; - } - } -} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java index 0e19da5723a..1b3a9bc9c99 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -27,7 +28,6 @@ import org.powermock.modules.junit4.PowerMockRunner; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; -import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -60,12 +60,15 @@ public class CustomUserTypesOAuth2UserServiceTests { private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint; private OAuth2AccessToken accessToken; private CustomUserTypesOAuth2UserService userService; + private MockWebServer server; @Rule public ExpectedException exception = ExpectedException.none(); @Before public void setUp() throws Exception { + this.server = new MockWebServer(); + this.server.start(); this.clientRegistration = mock(ClientRegistration.class); this.providerDetails = mock(ClientRegistration.ProviderDetails.class); this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class); @@ -80,6 +83,11 @@ public void setUp() throws Exception { this.userService = new CustomUserTypesOAuth2UserService(customUserTypes); } + @After + public void cleanup() throws Exception { + this.server.shutdown(); + } + @Test public void constructorWhenCustomUserTypesIsNullThenThrowIllegalArgumentException() { this.exception.expect(IllegalArgumentException.class); @@ -92,6 +100,18 @@ public void constructorWhenCustomUserTypesIsEmptyThenThrowIllegalArgumentExcepti new CustomUserTypesOAuth2UserService(Collections.emptyMap()); } + @Test + public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() { + this.exception.expect(IllegalArgumentException.class); + this.userService.setRequestEntityConverter(null); + } + + @Test + public void setRestOperationsWhenNullThenThrowIllegalArgumentException() { + this.exception.expect(IllegalArgumentException.class); + this.userService.setRestOperations(null); + } + @Test public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { this.exception.expect(IllegalArgumentException.class); @@ -107,30 +127,22 @@ public void loadUserWhenCustomUserTypeNotFoundThenReturnNull() { } @Test - public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception { - MockWebServer server = new MockWebServer(); - + public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { String userInfoResponse = "{\n" + " \"id\": \"12345\",\n" + " \"name\": \"first last\",\n" + " \"login\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); + this.server.enqueue(jsonResponse(userInfoResponse)); - server.start(); - - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.accessToken.getTokenValue()).thenReturn("access-token"); OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); - server.shutdown(); - assertThat(user.getName()).isEqualTo("first last"); assertThat(user.getAttributes().size()).isEqualTo(4); assertThat(user.getAttributes().get("id")).isEqualTo("12345"); @@ -143,11 +155,9 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception } @Test - public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception { + public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_user_info_response")); - - MockWebServer server = new MockWebServer(); + this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoResponse = "{\n" + " \"id\": \"12345\",\n" + @@ -155,48 +165,35 @@ public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2Authenticat " \"login\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n"; // "}\n"; // Make the JSON invalid/malformed - server.enqueue(new MockResponse() - .setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .setBody(userInfoResponse)); + this.server.enqueue(jsonResponse(userInfoResponse)); - server.start(); - - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.accessToken.getTokenValue()).thenReturn("access-token"); - try { - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); - } finally { - server.shutdown(); - } + this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); } @Test - public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception { + public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_user_info_response")); + this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); - MockWebServer server = new MockWebServer(); - server.enqueue(new MockResponse().setResponseCode(500)); - server.start(); + this.server.enqueue(new MockResponse().setResponseCode(500)); - String userInfoUri = server.url("/user").toString(); + String userInfoUri = this.server.url("/user").toString(); when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); when(this.accessToken.getTokenValue()).thenReturn("access-token"); - try { - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); - } finally { - server.shutdown(); - } + this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); } @Test - public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception { - this.exception.expect(AuthenticationServiceException.class); + public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { + this.exception.expect(OAuth2AuthenticationException.class); + this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoUri = "http://invalid-provider.com/user"; @@ -206,6 +203,12 @@ public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceExceptio this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); } + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } + public static class CustomOAuth2User implements OAuth2User { private List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); private String id; From 3b05112305f4e7217da9ae0ece90424d738d6e53 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 4 Sep 2018 15:23:11 -0400 Subject: [PATCH 3/3] Polish --- .../http/OAuth2ErrorResponseErrorHandler.java | 2 +- ...OAuth2UserRequestEntityConverterTests.java | 58 +++++-------------- 2 files changed, 15 insertions(+), 45 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java index a6ec8d13124..28f1027e638 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java @@ -54,7 +54,7 @@ public void handleError(ClientHttpResponse response) throws IOException { // A Bearer Token Error may be in the WWW-Authenticate response header // See https://tools.ietf.org/html/rfc6750#section-3 - OAuth2Error oauth2Error = this.readErrorFromWwwAuthenticate(response.getHeaders()); + OAuth2Error oauth2Error = this.readErrorFromWwwAuthenticate(response.getHeaders()); if (oauth2Error == null) { oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java index b798c1a786e..b6ca08140a7 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java @@ -15,16 +15,14 @@ */ package org.springframework.security.oauth2.client.userinfo; -import org.junit.Before; import org.junit.Test; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthenticationMethod; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.MultiValueMap; @@ -43,35 +41,15 @@ */ public class OAuth2UserRequestEntityConverterTests { private OAuth2UserRequestEntityConverter converter = new OAuth2UserRequestEntityConverter(); - private OAuth2UserRequest userRequest; - - @Before - public void setup() { - ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1") - .clientId("client-1") - .clientSecret("secret") - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUriTemplate("https://client.com/callback/client-1") - .scope("read", "write") - .authorizationUri("https://provider.com/oauth2/authorize") - .tokenUri("https://provider.com/oauth2/token") - .userInfoUri("https://provider.com/user") - .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) - .userNameAttributeName("id") - .build(); - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token-1234", Instant.now(), - Instant.now().plusSeconds(3600), new LinkedHashSet<>(Arrays.asList("read", "write"))); - this.userRequest = new OAuth2UserRequest(clientRegistration, accessToken); - } @SuppressWarnings("unchecked") @Test public void convertWhenAuthenticationMethodHeaderThenGetRequest() { - RequestEntity requestEntity = this.converter.convert(this.userRequest); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2UserRequest userRequest = new OAuth2UserRequest( + clientRegistration, this.createAccessToken()); - ClientRegistration clientRegistration = this.userRequest.getClientRegistration(); + RequestEntity requestEntity = this.converter.convert(userRequest); assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.GET); assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( @@ -80,17 +58,17 @@ public void convertWhenAuthenticationMethodHeaderThenGetRequest() { HttpHeaders headers = requestEntity.getHeaders(); assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo( - "Bearer " + this.userRequest.getAccessToken().getTokenValue()); + "Bearer " + userRequest.getAccessToken().getTokenValue()); } @SuppressWarnings("unchecked") @Test public void convertWhenAuthenticationMethodFormThenPostRequest() { - ClientRegistration clientRegistration = this.from(this.userRequest.getClientRegistration()) + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() .userInfoAuthenticationMethod(AuthenticationMethod.FORM) .build(); OAuth2UserRequest userRequest = new OAuth2UserRequest( - clientRegistration, this.userRequest.getAccessToken()); + clientRegistration, this.createAccessToken()); RequestEntity requestEntity = this.converter.convert(userRequest); @@ -105,21 +83,13 @@ public void convertWhenAuthenticationMethodFormThenPostRequest() { MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); assertThat(formParameters.getFirst(OAuth2ParameterNames.ACCESS_TOKEN)).isEqualTo( - this.userRequest.getAccessToken().getTokenValue()); + userRequest.getAccessToken().getTokenValue()); } - private ClientRegistration.Builder from(ClientRegistration registration) { - return ClientRegistration.withRegistrationId(registration.getRegistrationId()) - .clientId(registration.getClientId()) - .clientSecret(registration.getClientSecret()) - .clientAuthenticationMethod(registration.getClientAuthenticationMethod()) - .authorizationGrantType(registration.getAuthorizationGrantType()) - .redirectUriTemplate(registration.getRedirectUriTemplate()) - .scope(registration.getScopes()) - .authorizationUri(registration.getProviderDetails().getAuthorizationUri()) - .tokenUri(registration.getProviderDetails().getTokenUri()) - .userInfoUri(registration.getProviderDetails().getUserInfoEndpoint().getUri()) - .userNameAttributeName(registration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName()) - .clientName(registration.getClientName()); + private OAuth2AccessToken createAccessToken() { + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "access-token-1234", Instant.now(), + Instant.now().plusSeconds(3600), new LinkedHashSet<>(Arrays.asList("read", "write"))); + return accessToken; } }