From 8a0c6868cd7f61b8e64c815910a12ed0d485ef8e Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 13 Aug 2018 07:51:06 -0400 Subject: [PATCH] Add additional parameters to OAuth2UserRequest Fixes gh-5368 --- .../OAuth2LoginAuthenticationProvider.java | 6 +- ...th2LoginReactiveAuthenticationManager.java | 5 +- ...thorizationCodeAuthenticationProvider.java | 9 ++- ...tionCodeReactiveAuthenticationManager.java | 6 +- .../client/oidc/userinfo/OidcUserRequest.java | 22 +++++- .../client/userinfo/OAuth2UserRequest.java | 34 +++++++++- ...Auth2LoginAuthenticationProviderTests.java | 58 +++++++++++++--- ...ginReactiveAuthenticationManagerTests.java | 26 ++++++- ...zationCodeAuthenticationProviderTests.java | 67 +++++++++++++++---- ...odeReactiveAuthenticationManagerTests.java | 34 ++++++++++ .../oidc/userinfo/OidcUserRequestTests.java | 64 +++++++++++++----- .../userinfo/OAuth2UserRequestTests.java | 51 ++++++++++---- 12 files changed, 311 insertions(+), 71 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java index 843424df63e..6c032fb073a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java @@ -30,6 +30,7 @@ import org.springframework.util.Assert; import java.util.Collection; +import java.util.Map; /** * An implementation of an {@link AuthenticationProvider} for OAuth 2.0 Login, @@ -101,9 +102,10 @@ public Authentication authenticate(Authentication authentication) throws Authent authorizationCodeAuthentication.getAuthorizationExchange())); OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); + Map additionalParameters = accessTokenResponse.getAdditionalParameters(); - OAuth2User oauth2User = this.userService.loadUser( - new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken)); + OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest( + authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters)); Collection mappedAuthorities = this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java index 03aeca3397c..eb2f161c77d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.client.authentication; import java.util.Collection; +import java.util.Map; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.core.Authentication; @@ -109,7 +110,9 @@ public Mono authenticate(Authentication authentication) { private Mono authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); - OAuth2UserRequest userRequest = new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken); + Map additionalParameters = accessTokenResponse.getAdditionalParameters(); + OAuth2UserRequest userRequest = new OAuth2UserRequest( + authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters); return this.userService.loadUser(userRequest) .flatMap(oauth2User -> { Collection mappedAuthorities = diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java index ff1361f4d04..87a64227ead 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java @@ -139,19 +139,18 @@ public Authentication authenticate(Authentication authentication) throws Authent ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration(); - if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) { + Map additionalParameters = accessTokenResponse.getAdditionalParameters(); + if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) { OAuth2Error invalidIdTokenError = new OAuth2Error( INVALID_ID_TOKEN_ERROR_CODE, "Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(), null); throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()); } - OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse); - OidcUser oidcUser = this.userService.loadUser( - new OidcUserRequest(clientRegistration, accessTokenResponse.getAccessToken(), idToken)); - + OidcUser oidcUser = this.userService.loadUser(new OidcUserRequest( + clientRegistration, accessTokenResponse.getAccessToken(), idToken, additionalParameters)); Collection mappedAuthorities = this.authoritiesMapper.mapAuthorities(oidcUser.getAuthorities()); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java index 877e60e8676..8f9e4bcbbae 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java @@ -159,10 +159,10 @@ void setDecoderFactory( private Mono authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) { OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); - ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration(); + Map additionalParameters = accessTokenResponse.getAdditionalParameters(); - if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) { + if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) { OAuth2Error invalidIdTokenError = new OAuth2Error( INVALID_ID_TOKEN_ERROR_CODE, "Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(), @@ -171,7 +171,7 @@ private Mono authenticationResult(OAuth2LoginAuthenti } return createOidcToken(clientRegistration, accessTokenResponse) - .map(idToken -> new OidcUserRequest(clientRegistration, accessToken, idToken)) + .map(idToken -> new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters)) .flatMap(this.userService::loadUser) .flatMap(oauth2User -> { Collection mappedAuthorities = diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java index 201ba577e2d..92158890b28 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,9 @@ import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.util.Assert; +import java.util.Collections; +import java.util.Map; + /** * Represents a request the {@link OidcUserService} uses * when initiating a request to the UserInfo Endpoint. @@ -45,7 +48,22 @@ public class OidcUserRequest extends OAuth2UserRequest { public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, OidcIdToken idToken) { - super(clientRegistration, accessToken); + this(clientRegistration, accessToken, idToken, Collections.emptyMap()); + } + + /** + * Constructs an {@code OidcUserRequest} using the provided parameters. + * + * @since 5.1 + * @param clientRegistration the client registration + * @param accessToken the access token credential + * @param idToken the ID Token + * @param additionalParameters the additional parameters, may be empty + */ + public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, + OidcIdToken idToken, Map additionalParameters) { + + super(clientRegistration, accessToken, additionalParameters); Assert.notNull(idToken, "idToken cannot be null"); this.idToken = idToken; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequest.java index 949fb7699b9..b887c7aa440 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,11 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; /** * Represents a request the {@link OAuth2UserService} uses @@ -32,6 +37,7 @@ public class OAuth2UserRequest { private final ClientRegistration clientRegistration; private final OAuth2AccessToken accessToken; + private final Map additionalParameters; /** * Constructs an {@code OAuth2UserRequest} using the provided parameters. @@ -40,10 +46,26 @@ public class OAuth2UserRequest { * @param accessToken the access token */ public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken) { + this(clientRegistration, accessToken, Collections.emptyMap()); + } + + /** + * Constructs an {@code OAuth2UserRequest} using the provided parameters. + * + * @since 5.1 + * @param clientRegistration the client registration + * @param accessToken the access token + * @param additionalParameters the additional parameters, may be empty + */ + public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, + Map additionalParameters) { Assert.notNull(clientRegistration, "clientRegistration cannot be null"); Assert.notNull(accessToken, "accessToken cannot be null"); this.clientRegistration = clientRegistration; this.accessToken = accessToken; + this.additionalParameters = Collections.unmodifiableMap( + CollectionUtils.isEmpty(additionalParameters) ? + Collections.emptyMap() : new LinkedHashMap<>(additionalParameters)); } /** @@ -63,4 +85,14 @@ public ClientRegistration getClientRegistration() { public OAuth2AccessToken getAccessToken() { return this.accessToken; } + + /** + * Returns the additional parameters that may be used in the request. + * + * @since 5.1 + * @return a {@code Map} of the additional parameters, may be empty. + */ + public Map getAdditionalParameters() { + return this.additionalParameters; + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java index 668f007f156..69949edb672 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java @@ -20,6 +20,7 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; @@ -35,17 +36,20 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.user.OAuth2User; +import java.time.Instant; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; +import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; @@ -164,11 +168,7 @@ public void authenticateWhenAuthorizationResponseRedirectUriNotEqualAuthorizatio @Test public void authenticateWhenLoginSuccessThenReturnAuthentication() { - OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class); - OAuth2RefreshToken refreshToken = mock(OAuth2RefreshToken.class); - OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class); - when(accessTokenResponse.getAccessToken()).thenReturn(accessToken); - when(accessTokenResponse.getRefreshToken()).thenReturn(refreshToken); + OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2User principal = mock(OAuth2User.class); @@ -187,15 +187,13 @@ public void authenticateWhenLoginSuccessThenReturnAuthentication() { assertThat(authentication.getAuthorities()).isEqualTo(authorities); assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange); - assertThat(authentication.getAccessToken()).isEqualTo(accessToken); - assertThat(authentication.getRefreshToken()).isEqualTo(refreshToken); + assertThat(authentication.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(authentication.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); } @Test public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() { - OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class); - OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class); - when(accessTokenResponse.getAccessToken()).thenReturn(accessToken); + OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2User principal = mock(OAuth2User.class); @@ -216,4 +214,42 @@ public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() { assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities); } + + // gh-5368 + @Test + public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() { + OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + OAuth2User principal = mock(OAuth2User.class); + List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + when(principal.getAuthorities()).thenAnswer( + (Answer>) invocation -> authorities); + ArgumentCaptor userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class); + when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal); + + this.authenticationProvider.authenticate( + new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + + assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf( + accessTokenResponse.getAdditionalParameters()); + } + + private OAuth2AccessTokenResponse accessTokenSuccessResponse() { + Instant expiresAt = Instant.now().plusSeconds(5); + Set scopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2")); + Map additionalParameters = new HashMap<>(); + additionalParameters.put("param1", "value1"); + additionalParameters.put("param2", "value2"); + + return OAuth2AccessTokenResponse + .withToken("access-token-1234") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(expiresAt.getEpochSecond()) + .scopes(scopes) + .refreshToken("refresh-token-1234") + .additionalParameters(additionalParameters) + .build(); + + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java index cbc24b4606a..073d34230d7 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java @@ -23,11 +23,14 @@ import static org.mockito.Mockito.when; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -164,7 +167,7 @@ public void authenticationWhenOAuth2UserNotFoundThenEmpty() { } @Test - public void authenticationWhenOAuth2UserNotFoundThenSuccess() { + public void authenticationWhenOAuth2UserFoundThenSuccess() { OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") .tokenType(OAuth2AccessToken.TokenType.BEARER) .build(); @@ -179,6 +182,27 @@ public void authenticationWhenOAuth2UserNotFoundThenSuccess() { assertThat(result.isAuthenticated()).isTrue(); } + // gh-5368 + @Test + public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() { + Map additionalParameters = new HashMap<>(); + additionalParameters.put("param1", "value1"); + additionalParameters.put("param2", "value2"); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .additionalParameters(additionalParameters) + .build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user"); + ArgumentCaptor userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class); + when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user)); + + this.manager.authenticate(loginToken()).block(); + + assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()) + .containsAllEntriesOf(accessTokenResponse.getAdditionalParameters()); + } + private OAuth2LoginAuthenticationToken loginToken() { ClientRegistration clientRegistration = this.registration.build(); OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java index 6be263b1bd4..cc86577ce05 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java @@ -20,6 +20,7 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; @@ -37,7 +38,6 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -55,6 +55,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; @@ -78,8 +79,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { private OAuth2AuthorizationExchange authorizationExchange; private OAuth2AccessTokenResponseClient accessTokenResponseClient; private OAuth2AccessTokenResponse accessTokenResponse; - private OAuth2AccessToken accessToken; - private OAuth2RefreshToken refreshToken; private OAuth2UserService userService; private OidcAuthorizationCodeAuthenticationProvider authenticationProvider; @@ -95,9 +94,7 @@ public void setUp() throws Exception { this.authorizationResponse = mock(OAuth2AuthorizationResponse.class); this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); - this.accessTokenResponse = mock(OAuth2AccessTokenResponse.class); - this.accessToken = mock(OAuth2AccessToken.class); - this.refreshToken = mock(OAuth2RefreshToken.class); + this.accessTokenResponse = this.accessTokenSuccessResponse(); this.userService = mock(OAuth2UserService.class); this.authenticationProvider = PowerMockito.spy( new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService)); @@ -111,11 +108,6 @@ public void setUp() throws Exception { when(this.authorizationResponse.getState()).thenReturn("12345"); when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com"); when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com"); - when(this.accessTokenResponse.getAccessToken()).thenReturn(this.accessToken); - when(this.accessTokenResponse.getRefreshToken()).thenReturn(this.refreshToken); - Map additionalParameters = new HashMap<>(); - additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); - when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(additionalParameters); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse); } @@ -194,7 +186,11 @@ public void authenticateWhenTokenResponseDoesNotContainIdTokenThenThrowOAuth2Aut this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_id_token")); - when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(Collections.emptyMap()); + OAuth2AccessTokenResponse accessTokenResponse = + OAuth2AccessTokenResponse.withResponse(this.accessTokenSuccessResponse()) + .additionalParameters(Collections.emptyMap()) + .build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); this.authenticationProvider.authenticate( new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); @@ -368,8 +364,8 @@ public void authenticateWhenLoginSuccessThenReturnAuthentication() throws Except assertThat(authentication.getAuthorities()).isEqualTo(authorities); assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange); - assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken); - assertThat(authentication.getRefreshToken()).isEqualTo(this.refreshToken); + assertThat(authentication.getAccessToken()).isEqualTo(this.accessTokenResponse.getAccessToken()); + assertThat(authentication.getRefreshToken()).isEqualTo(this.accessTokenResponse.getRefreshToken()); } @Test @@ -400,6 +396,30 @@ public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() th assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities); } + // gh-5368 + @Test + public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() throws Exception { + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.ISS, "https://provider.com"); + claims.put(IdTokenClaimNames.SUB, "subject1"); + claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2")); + claims.put(IdTokenClaimNames.AZP, "client1"); + this.setUpIdToken(claims); + + OidcUser principal = mock(OidcUser.class); + List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + when(principal.getAuthorities()).thenAnswer( + (Answer>) invocation -> authorities); + ArgumentCaptor userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class); + when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal); + + this.authenticationProvider.authenticate(new OAuth2LoginAuthenticationToken( + this.clientRegistration, this.authorizationExchange)); + + assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf( + this.accessTokenResponse.getAdditionalParameters()); + } + private void setUpIdToken(Map claims) throws Exception { Instant issuedAt = Instant.now(); Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600); @@ -416,4 +436,23 @@ private void setUpIdToken(Map claims, Instant issuedAt, Instant when(jwtDecoder.decode(anyString())).thenReturn(idToken); PowerMockito.doReturn(jwtDecoder).when(this.authenticationProvider, "getJwtDecoder", any(ClientRegistration.class)); } + + private OAuth2AccessTokenResponse accessTokenSuccessResponse() { + Instant expiresAt = Instant.now().plusSeconds(5); + Set scopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email")); + Map additionalParameters = new HashMap<>(); + additionalParameters.put("param1", "value1"); + additionalParameters.put("param2", "value2"); + additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); + + return OAuth2AccessTokenResponse + .withToken("access-token-1234") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(expiresAt.getEpochSecond()) + .scopes(scopes) + .refreshToken("refresh-token-1234") + .additionalParameters(additionalParameters) + .build(); + + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java index 362beed4da6..2b9bc3061e3 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java @@ -19,6 +19,7 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -217,6 +218,39 @@ public void authenticationWhenOAuth2UserFoundThenSuccess() { assertThat(result.isAuthenticated()).isTrue(); } + // gh-5368 + @Test + public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() { + Map additionalParameters = new HashMap<>(); + additionalParameters.put(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue()); + additionalParameters.put("param1", "value1"); + additionalParameters.put("param2", "value2"); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .additionalParameters(additionalParameters) + .build(); + + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); + claims.put(IdTokenClaimNames.SUB, "rob"); + claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId")); + Instant issuedAt = Instant.now(); + Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600); + Jwt idToken = new Jwt("id-token", issuedAt, expiresAt, claims, claims); + + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse)); + DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken); + ArgumentCaptor userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class); + when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user)); + when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken)); + this.manager.setDecoderFactory(c -> this.jwtDecoder); + + this.manager.authenticate(loginToken()).block(); + + assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()) + .containsAllEntriesOf(accessTokenResponse.getAdditionalParameters()); + } + private OAuth2LoginAuthenticationToken loginToken() { ClientRegistration clientRegistration = this.registration.build(); OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java index ee43d5b1671..afb5a412955 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,57 +17,87 @@ import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; + import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Tests for {@link OidcUserRequest}. * * @author Joe Grandja */ -@RunWith(PowerMockRunner.class) -@PrepareForTest(ClientRegistration.class) public class OidcUserRequestTests { private ClientRegistration clientRegistration; private OAuth2AccessToken accessToken; private OidcIdToken idToken; + private Map additionalParameters; @Before public void setUp() { - this.clientRegistration = mock(ClientRegistration.class); - this.accessToken = mock(OAuth2AccessToken.class); - this.idToken = mock(OidcIdToken.class); + this.clientRegistration = ClientRegistration.withRegistrationId("registration-1") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate("https://client.com") + .scope(new LinkedHashSet<>(Arrays.asList("openid", "profile"))) + .authorizationUri("https://provider.com/oauth2/authorization") + .tokenUri("https://provider.com/oauth2/token") + .jwkSetUri("https://provider.com/keys") + .clientName("Client 1") + .build(); + this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", Instant.now(), Instant.now().plusSeconds(60), + new LinkedHashSet<>(Arrays.asList("scope1", "scope2"))); + Map claims = new HashMap<>(); + claims.put(IdTokenClaimNames.ISS, "https://provider.com"); + claims.put(IdTokenClaimNames.SUB, "subject1"); + claims.put(IdTokenClaimNames.AZP, "client-1"); + this.idToken = new OidcIdToken("id-token-1234", Instant.now(), + Instant.now().plusSeconds(3600), claims); + this.additionalParameters = new HashMap<>(); + this.additionalParameters.put("param1", "value1"); + this.additionalParameters.put("param2", "value2"); } - @Test(expected = IllegalArgumentException.class) + @Test public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - new OidcUserRequest(null, this.accessToken, this.idToken); + assertThatThrownBy(() -> new OidcUserRequest(null, this.accessToken, this.idToken)) + .isInstanceOf(IllegalArgumentException.class); } - @Test(expected = IllegalArgumentException.class) + @Test public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() { - new OidcUserRequest(this.clientRegistration, null, this.idToken); + assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, null, this.idToken)) + .isInstanceOf(IllegalArgumentException.class); } - @Test(expected = IllegalArgumentException.class) + @Test public void constructorWhenIdTokenIsNullThenThrowIllegalArgumentException() { - new OidcUserRequest(this.clientRegistration, this.accessToken, null); + assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, this.accessToken, null)) + .isInstanceOf(IllegalArgumentException.class); } @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { OidcUserRequest userRequest = new OidcUserRequest( - this.clientRegistration, this.accessToken, this.idToken); + this.clientRegistration, this.accessToken, this.idToken, this.additionalParameters); assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken); assertThat(userRequest.getIdToken()).isEqualTo(this.idToken); + assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java index 59a054aaa76..6415721604d 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,47 +17,70 @@ import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; + import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Tests for {@link OAuth2UserRequest}. * * @author Joe Grandja */ -@RunWith(PowerMockRunner.class) -@PrepareForTest(ClientRegistration.class) public class OAuth2UserRequestTests { private ClientRegistration clientRegistration; private OAuth2AccessToken accessToken; + private Map additionalParameters; @Before public void setUp() { - this.clientRegistration = mock(ClientRegistration.class); - this.accessToken = mock(OAuth2AccessToken.class); + this.clientRegistration = ClientRegistration.withRegistrationId("registration-1") + .clientId("client-1") + .clientSecret("secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate("https://client.com") + .scope(new LinkedHashSet<>(Arrays.asList("scope1", "scope2"))) + .authorizationUri("https://provider.com/oauth2/authorization") + .tokenUri("https://provider.com/oauth2/token") + .clientName("Client 1") + .build(); + this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", Instant.now(), Instant.now().plusSeconds(60), + new LinkedHashSet<>(Arrays.asList("scope1", "scope2"))); + this.additionalParameters = new HashMap<>(); + this.additionalParameters.put("param1", "value1"); + this.additionalParameters.put("param2", "value2"); } - @Test(expected = IllegalArgumentException.class) + @Test public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - new OAuth2UserRequest(null, this.accessToken); + assertThatThrownBy(() -> new OAuth2UserRequest(null, this.accessToken)) + .isInstanceOf(IllegalArgumentException.class); } - @Test(expected = IllegalArgumentException.class) + @Test public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() { - new OAuth2UserRequest(this.clientRegistration, null); + assertThatThrownBy(() -> new OAuth2UserRequest(this.clientRegistration, null)) + .isInstanceOf(IllegalArgumentException.class); } @Test public void constructorWhenAllParametersProvidedAndValidThenCreated() { - OAuth2UserRequest userRequest = new OAuth2UserRequest(this.clientRegistration, this.accessToken); + OAuth2UserRequest userRequest = new OAuth2UserRequest( + this.clientRegistration, this.accessToken, this.additionalParameters); assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration); assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken); + assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters); } }