From 0373b505608703fdc65d6ec1e61fbfca33804115 Mon Sep 17 00:00:00 2001 From: Xiaobing Zhu <71206407+ZhuXiaoBing-cn@users.noreply.github.com> Date: Fri, 5 Mar 2021 10:41:46 +0800 Subject: [PATCH] For AAD resource-server, create grantedAuthority by both "roles" and "claims" by default. (#19412) --- ...JwtBearerTokenAuthenticationConverter.java | 7 +- .../AADJwtGrantedAuthoritiesConverter.java | 58 ++++++++++++++++ ...earerTokenAuthenticationConverterTest.java | 68 +++++++++++++++---- 3 files changed, 118 insertions(+), 15 deletions(-) create mode 100644 sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/webapi/AADJwtGrantedAuthoritiesConverter.java diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/webapi/AADJwtBearerTokenAuthenticationConverter.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/webapi/AADJwtBearerTokenAuthenticationConverter.java index feebcac92b04..0744ec42f96c 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/webapi/AADJwtBearerTokenAuthenticationConverter.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/webapi/AADJwtBearerTokenAuthenticationConverter.java @@ -2,7 +2,6 @@ // Licensed under the MIT License. package com.azure.spring.aad.webapi; -import java.util.Collection; import org.springframework.core.convert.converter.Converter; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.core.GrantedAuthority; @@ -12,6 +11,8 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtGrantedAuthoritiesConverter; import org.springframework.util.Assert; +import java.util.Collection; + /** * A {@link Converter} that takes a {@link Jwt} and converts it into a {@link BearerTokenAuthentication}. */ @@ -19,10 +20,10 @@ public class AADJwtBearerTokenAuthenticationConverter implements Converter> converter - = new JwtGrantedAuthoritiesConverter(); + private Converter> converter; public AADJwtBearerTokenAuthenticationConverter() { + this.converter = new AADJwtGrantedAuthoritiesConverter(); } public AADJwtBearerTokenAuthenticationConverter(String authoritiesClaimName) { diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/webapi/AADJwtGrantedAuthoritiesConverter.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/webapi/AADJwtGrantedAuthoritiesConverter.java new file mode 100644 index 000000000000..8871e88e6360 --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/webapi/AADJwtGrantedAuthoritiesConverter.java @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.spring.aad.webapi; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.util.StringUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.stream.Collectors; + +/** + * Extracts the {@link GrantedAuthority}s from scope attributes typically found in a {@link Jwt}. + */ +public class AADJwtGrantedAuthoritiesConverter implements Converter> { + + private static final String DEFAULT_SCP_AUTHORITY_PREFIX = "SCOPE_"; + + private static final String DEFAULT_ROLES_AUTHORITY_PREFIX = "APPROLE_"; + + private static final Collection WELL_KNOWN_AUTHORITIES_CLAIM_NAMES = Arrays.asList("scp", "roles"); + + @Override + public Collection convert(Jwt jwt) { + Collection grantedAuthorities = new ArrayList<>(); + for (String authority : getAuthorities(jwt)) { + grantedAuthorities.add(new SimpleGrantedAuthority(authority)); + } + return grantedAuthorities; + } + + private Collection getAuthorities(Jwt jwt) { + Collection authoritiesList = new ArrayList(); + for (String claimName : WELL_KNOWN_AUTHORITIES_CLAIM_NAMES) { + if (jwt.containsClaim(claimName)) { + if (jwt.getClaim(claimName) instanceof String) { + if (StringUtils.hasText(jwt.getClaim(claimName))) { + authoritiesList.addAll(Arrays.asList(((String) jwt.getClaim(claimName)).split(" ")) + .stream() + .map(s -> DEFAULT_SCP_AUTHORITY_PREFIX + s) + .collect(Collectors.toList())); + } + } else if (jwt.getClaim(claimName) instanceof Collection) { + authoritiesList.addAll(((Collection) jwt.getClaim(claimName)) + .stream() + .filter(s -> StringUtils.hasText((String) s)) + .map(s -> DEFAULT_ROLES_AUTHORITY_PREFIX + s) + .collect(Collectors.toList())); + } + } + } + return authoritiesList; + } +} diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/aad/webapi/AADJwtBearerTokenAuthenticationConverterTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/aad/webapi/AADJwtBearerTokenAuthenticationConverterTest.java index 00be2accc482..54583e376311 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/aad/webapi/AADJwtBearerTokenAuthenticationConverterTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/aad/webapi/AADJwtBearerTokenAuthenticationConverterTest.java @@ -2,23 +2,26 @@ // Licensed under the MIT License. package com.azure.spring.aad.webapi; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import java.time.Instant; -import java.util.HashMap; -import java.util.Map; +import net.minidev.json.JSONArray; import org.junit.Before; import org.junit.Test; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.oauth2.jwt.Jwt; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class AADJwtBearerTokenAuthenticationConverterTest { private Jwt jwt = mock(Jwt.class); private Map claims = new HashMap<>(); private Map headers = new HashMap<>(); + private JSONArray jsonArray = new JSONArray().appendElement("User.read").appendElement("User.write"); @Before public void init() { @@ -26,13 +29,12 @@ public void init() { claims.put("tid", "fake-tid"); headers.put("kid", "kg2LYs2T0CTjIfj4rt6JIynen38"); when(jwt.getClaim("scp")).thenReturn("Order.read Order.write"); - when(jwt.getClaim("roles")).thenReturn("User.read User.write"); + when(jwt.getClaim("roles")).thenReturn(jsonArray); when(jwt.getTokenValue()).thenReturn("fake-token-value"); when(jwt.getIssuedAt()).thenReturn(Instant.now()); when(jwt.getHeaders()).thenReturn(headers); when(jwt.getExpiresAt()).thenReturn(Instant.MAX); when(jwt.getClaims()).thenReturn(claims); - when(jwt.containsClaim("scp")).thenReturn(true); } @Test @@ -48,7 +50,22 @@ public void testCreateUserPrincipal() { } @Test - public void testExtractDefaultScopeAuthorities() { + public void testNoArgumentsConstructorDefaultScopeAndRoleAuthorities() { + when(jwt.containsClaim("scp")).thenReturn(true); + when(jwt.containsClaim("roles")).thenReturn(true); + AADJwtBearerTokenAuthenticationConverter converter = new AADJwtBearerTokenAuthenticationConverter(); + AbstractAuthenticationToken authenticationToken = converter.convert(jwt); + assertThat(authenticationToken.getPrincipal()).isExactlyInstanceOf(AADOAuth2AuthenticatedPrincipal.class); + AADOAuth2AuthenticatedPrincipal principal = (AADOAuth2AuthenticatedPrincipal) authenticationToken + .getPrincipal(); + assertThat(principal.getAttributes()).isNotEmpty(); + assertThat(principal.getAttributes()).hasSize(2); + assertThat(principal.getAuthorities()).hasSize(4); + } + + @Test + public void testNoArgumentsConstructorExtractScopeAuthorities() { + when(jwt.containsClaim("scp")).thenReturn(true); AADJwtBearerTokenAuthenticationConverter converter = new AADJwtBearerTokenAuthenticationConverter(); AbstractAuthenticationToken authenticationToken = converter.convert(jwt); assertThat(authenticationToken.getPrincipal()).isExactlyInstanceOf(AADOAuth2AuthenticatedPrincipal.class); @@ -56,19 +73,46 @@ public void testExtractDefaultScopeAuthorities() { .getPrincipal(); assertThat(principal.getAttributes()).isNotEmpty(); assertThat(principal.getAttributes()).hasSize(2); + assertThat(principal.getAuthorities()).hasSize(2); } @Test - public void testExtractCustomScopeAuthorities() { + public void testNoArgumentsConstructorExtractRoleAuthorities() { when(jwt.containsClaim("roles")).thenReturn(true); - AADJwtBearerTokenAuthenticationConverter converter = new AADJwtBearerTokenAuthenticationConverter("roles", "ROLE_"); + AADJwtBearerTokenAuthenticationConverter converter = new AADJwtBearerTokenAuthenticationConverter(); AbstractAuthenticationToken authenticationToken = converter.convert(jwt); assertThat(authenticationToken.getPrincipal()).isExactlyInstanceOf(AADOAuth2AuthenticatedPrincipal.class); AADOAuth2AuthenticatedPrincipal principal = (AADOAuth2AuthenticatedPrincipal) authenticationToken .getPrincipal(); assertThat(principal.getAttributes()).isNotEmpty(); assertThat(principal.getAttributes()).hasSize(2); + assertThat(principal.getAuthorities()).hasSize(2); } + @Test + public void testParameterConstructorExtractScopeAuthorities() { + when(jwt.containsClaim("scp")).thenReturn(true); + AADJwtBearerTokenAuthenticationConverter converter = new AADJwtBearerTokenAuthenticationConverter("scp"); + AbstractAuthenticationToken authenticationToken = converter.convert(jwt); + assertThat(authenticationToken.getPrincipal()).isExactlyInstanceOf(AADOAuth2AuthenticatedPrincipal.class); + AADOAuth2AuthenticatedPrincipal principal = (AADOAuth2AuthenticatedPrincipal) authenticationToken + .getPrincipal(); + assertThat(principal.getAttributes()).isNotEmpty(); + assertThat(principal.getAttributes()).hasSize(2); + assertThat(principal.getAuthorities()).hasSize(2); + } + @Test + public void testParameterConstructorExtractRoleAuthorities() { + when(jwt.containsClaim("roles")).thenReturn(true); + AADJwtBearerTokenAuthenticationConverter converter = new AADJwtBearerTokenAuthenticationConverter("roles", + "APPROLE_"); + AbstractAuthenticationToken authenticationToken = converter.convert(jwt); + assertThat(authenticationToken.getPrincipal()).isExactlyInstanceOf(AADOAuth2AuthenticatedPrincipal.class); + AADOAuth2AuthenticatedPrincipal principal = (AADOAuth2AuthenticatedPrincipal) authenticationToken + .getPrincipal(); + assertThat(principal.getAttributes()).isNotEmpty(); + assertThat(principal.getAttributes()).hasSize(2); + assertThat(principal.getAuthorities()).hasSize(2); + } }