Skip to content

Commit

Permalink
Add additional parameters to OAuth2UserRequest
Browse files Browse the repository at this point in the history
Fixes gh-5368
  • Loading branch information
jgrandja committed Aug 14, 2018
1 parent 950a314 commit 8a0c686
Show file tree
Hide file tree
Showing 12 changed files with 311 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.util.Assert;

import java.util.Collection;
import java.util.Map;

/**
* An implementation of an {@link AuthenticationProvider} for OAuth 2.0 Login,
Expand Down Expand Up @@ -101,9 +102,10 @@ public Authentication authenticate(Authentication authentication) throws Authent
authorizationCodeAuthentication.getAuthorizationExchange()));

OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();

OAuth2User oauth2User = this.userService.loadUser(
new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken));
OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest(
authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters));

Collection<? extends GrantedAuthority> mappedAuthorities =
this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.authentication;

import java.util.Collection;
import java.util.Map;

import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.core.Authentication;
Expand Down Expand Up @@ -109,7 +110,9 @@ public Mono<Authentication> authenticate(Authentication authentication) {

private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
OAuth2UserRequest userRequest = new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken);
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
OAuth2UserRequest userRequest = new OAuth2UserRequest(
authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters);
return this.userService.loadUser(userRequest)
.flatMap(oauth2User -> {
Collection<? extends GrantedAuthority> mappedAuthorities =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,18 @@ public Authentication authenticate(Authentication authentication) throws Authent

ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();

if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) {
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) {
OAuth2Error invalidIdTokenError = new OAuth2Error(
INVALID_ID_TOKEN_ERROR_CODE,
"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
null);
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString());
}

OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse);

OidcUser oidcUser = this.userService.loadUser(
new OidcUserRequest(clientRegistration, accessTokenResponse.getAccessToken(), idToken));

OidcUser oidcUser = this.userService.loadUser(new OidcUserRequest(
clientRegistration, accessTokenResponse.getAccessToken(), idToken, additionalParameters));
Collection<? extends GrantedAuthority> mappedAuthorities =
this.authoritiesMapper.mapAuthorities(oidcUser.getAuthorities());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ void setDecoderFactory(

private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();

ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();

if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) {
if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) {
OAuth2Error invalidIdTokenError = new OAuth2Error(
INVALID_ID_TOKEN_ERROR_CODE,
"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
Expand All @@ -171,7 +171,7 @@ private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenti
}

return createOidcToken(clientRegistration, accessTokenResponse)
.map(idToken -> new OidcUserRequest(clientRegistration, accessToken, idToken))
.map(idToken -> new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters))
.flatMap(this.userService::loadUser)
.flatMap(oauth2User -> {
Collection<? extends GrantedAuthority> mappedAuthorities =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,9 @@
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.util.Assert;

import java.util.Collections;
import java.util.Map;

/**
* Represents a request the {@link OidcUserService} uses
* when initiating a request to the UserInfo Endpoint.
Expand All @@ -45,7 +48,22 @@ public class OidcUserRequest extends OAuth2UserRequest {
public OidcUserRequest(ClientRegistration clientRegistration,
OAuth2AccessToken accessToken, OidcIdToken idToken) {

super(clientRegistration, accessToken);
this(clientRegistration, accessToken, idToken, Collections.emptyMap());
}

/**
* Constructs an {@code OidcUserRequest} using the provided parameters.
*
* @since 5.1
* @param clientRegistration the client registration
* @param accessToken the access token credential
* @param idToken the ID Token
* @param additionalParameters the additional parameters, may be empty
*/
public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken,
OidcIdToken idToken, Map<String, Object> additionalParameters) {

super(clientRegistration, accessToken, additionalParameters);
Assert.notNull(idToken, "idToken cannot be null");
this.idToken = idToken;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,11 @@
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;

/**
* Represents a request the {@link OAuth2UserService} uses
Expand All @@ -32,6 +37,7 @@
public class OAuth2UserRequest {
private final ClientRegistration clientRegistration;
private final OAuth2AccessToken accessToken;
private final Map<String, Object> additionalParameters;

/**
* Constructs an {@code OAuth2UserRequest} using the provided parameters.
Expand All @@ -40,10 +46,26 @@ public class OAuth2UserRequest {
* @param accessToken the access token
*/
public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken) {
this(clientRegistration, accessToken, Collections.emptyMap());
}

/**
* Constructs an {@code OAuth2UserRequest} using the provided parameters.
*
* @since 5.1
* @param clientRegistration the client registration
* @param accessToken the access token
* @param additionalParameters the additional parameters, may be empty
*/
public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken,
Map<String, Object> additionalParameters) {
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
Assert.notNull(accessToken, "accessToken cannot be null");
this.clientRegistration = clientRegistration;
this.accessToken = accessToken;
this.additionalParameters = Collections.unmodifiableMap(
CollectionUtils.isEmpty(additionalParameters) ?
Collections.emptyMap() : new LinkedHashMap<>(additionalParameters));
}

/**
Expand All @@ -63,4 +85,14 @@ public ClientRegistration getClientRegistration() {
public OAuth2AccessToken getAccessToken() {
return this.accessToken;
}

/**
* Returns the additional parameters that may be used in the request.
*
* @since 5.1
* @return a {@code Map} of the additional parameters, may be empty.
*/
public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
Expand All @@ -35,17 +36,20 @@
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.user.OAuth2User;

import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
Expand Down Expand Up @@ -164,11 +168,7 @@ public void authenticateWhenAuthorizationResponseRedirectUriNotEqualAuthorizatio

@Test
public void authenticateWhenLoginSuccessThenReturnAuthentication() {
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
OAuth2RefreshToken refreshToken = mock(OAuth2RefreshToken.class);
OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
when(accessTokenResponse.getRefreshToken()).thenReturn(refreshToken);
OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);

OAuth2User principal = mock(OAuth2User.class);
Expand All @@ -187,15 +187,13 @@ public void authenticateWhenLoginSuccessThenReturnAuthentication() {
assertThat(authentication.getAuthorities()).isEqualTo(authorities);
assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
assertThat(authentication.getAccessToken()).isEqualTo(accessToken);
assertThat(authentication.getRefreshToken()).isEqualTo(refreshToken);
assertThat(authentication.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
assertThat(authentication.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
}

@Test
public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);

OAuth2User principal = mock(OAuth2User.class);
Expand All @@ -216,4 +214,42 @@ public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {

assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
}

// gh-5368
@Test
public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);

OAuth2User principal = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
when(principal.getAuthorities()).thenAnswer(
(Answer<List<GrantedAuthority>>) invocation -> authorities);
ArgumentCaptor<OAuth2UserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class);
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal);

this.authenticationProvider.authenticate(
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));

assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf(
accessTokenResponse.getAdditionalParameters());
}

private OAuth2AccessTokenResponse accessTokenSuccessResponse() {
Instant expiresAt = Instant.now().plusSeconds(5);
Set<String> scopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");

return OAuth2AccessTokenResponse
.withToken("access-token-1234")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(expiresAt.getEpochSecond())
.scopes(scopes)
.refreshToken("refresh-token-1234")
.additionalParameters(additionalParameters)
.build();

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
import static org.mockito.Mockito.when;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.security.authentication.TestingAuthenticationToken;
Expand Down Expand Up @@ -164,7 +167,7 @@ public void authenticationWhenOAuth2UserNotFoundThenEmpty() {
}

@Test
public void authenticationWhenOAuth2UserNotFoundThenSuccess() {
public void authenticationWhenOAuth2UserFoundThenSuccess() {
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.build();
Expand All @@ -179,6 +182,27 @@ public void authenticationWhenOAuth2UserNotFoundThenSuccess() {
assertThat(result.isAuthenticated()).isTrue();
}

// gh-5368
@Test
public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.additionalParameters(additionalParameters)
.build();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user");
ArgumentCaptor<OAuth2UserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class);
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user));

this.manager.authenticate(loginToken()).block();

assertThat(userRequestArgCaptor.getValue().getAdditionalParameters())
.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
}

private OAuth2LoginAuthenticationToken loginToken() {
ClientRegistration clientRegistration = this.registration.build();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest
Expand Down
Loading

0 comments on commit 8a0c686

Please sign in to comment.