diff --git a/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AppAutoConfigTest.java b/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AppAutoConfigTest.java deleted file mode 100644 index 36f41f1acceb..000000000000 --- a/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AppAutoConfigTest.java +++ /dev/null @@ -1,249 +0,0 @@ -package com.azure.test.aad.auth; - -import com.azure.spring.autoconfigure.aad.AuthorizationServerEndpoints; -import com.azure.spring.autoconfigure.aad.AzureClientRegistrationRepository; -import com.azure.spring.autoconfigure.aad.DefaultClient; -import com.azure.test.utils.AppRunner; -import org.junit.jupiter.api.Test; -import org.springframework.boot.autoconfigure.EnableAutoConfiguration; -import org.springframework.context.annotation.Configuration; -import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; - -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -public class AppAutoConfigTest { - - @Test - public void clientRegistered() { - try (AppRunner appRunner = createApp()) { - appRunner.start(); - - ClientRegistrationRepository clientRegistrationRepository = - appRunner.getBean(ClientRegistrationRepository.class); - ClientRegistration azureClientRegistration = clientRegistrationRepository.findByRegistrationId("azure"); - - assertNotNull(azureClientRegistration); - assertEquals("fake-client-id", azureClientRegistration.getClientId()); - assertEquals("fake-client-secret", azureClientRegistration.getClientSecret()); - - AuthorizationServerEndpoints authorizationServerEndpoints = new AuthorizationServerEndpoints(); - assertEquals( - authorizationServerEndpoints.authorizationEndpoint("fake-tenant-id"), - azureClientRegistration.getProviderDetails().getAuthorizationUri() - ); - assertEquals( - authorizationServerEndpoints.tokenEndpoint("fake-tenant-id"), - azureClientRegistration.getProviderDetails().getTokenUri() - ); - assertEquals( - authorizationServerEndpoints.jwkSetEndpoint("fake-tenant-id"), - azureClientRegistration.getProviderDetails().getJwkSetUri() - ); - assertEquals( - "{baseUrl}/login/oauth2/code/{registrationId}", - azureClientRegistration.getRedirectUriTemplate() - ); - assertDefaultScopes(azureClientRegistration, "openid", "profile"); - } - } - - @Test - public void clientRequiresPermissionRegistered() { - try (AppRunner appRunner = createApp()) { - appRunner.property( - "azure.activedirectory.authorization.graph.scope", - "https://graph.microsoft.com/Calendars.Read" - ); - appRunner.start(); - - ClientRegistrationRepository clientRegistrationRepository = - appRunner.getBean(ClientRegistrationRepository.class); - ClientRegistration azureClientRegistration = clientRegistrationRepository.findByRegistrationId("azure"); - ClientRegistration graphClientRegistration = clientRegistrationRepository.findByRegistrationId("graph"); - - assertNotNull(azureClientRegistration); - assertDefaultScopes( - azureClientRegistration, - "openid", "profile", "offline_access", "https://graph.microsoft.com/Calendars.Read" - ); - - assertNotNull(graphClientRegistration); - assertDefaultScopes(graphClientRegistration, "https://graph.microsoft.com/Calendars.Read"); - } - } - - @Test - public void clientRequiresMultiPermissions() { - try (AppRunner appRunner = createApp()) { - appRunner.property( - "azure.activedirectory.authorization.graph.scope", - "https://graph.microsoft.com/Calendars.Read" - ); - appRunner.property( - "azure.activedirectory.authorization.arm.scope", - "https://management.core.windows.net/user_impersonation" - ); - appRunner.start(); - - ClientRegistrationRepository clientRegistrationRepository = - appRunner.getBean(ClientRegistrationRepository.class); - ClientRegistration azureClientRegistration = clientRegistrationRepository.findByRegistrationId("azure"); - ClientRegistration graphClientRegistration = clientRegistrationRepository.findByRegistrationId("graph"); - - assertNotNull(azureClientRegistration); - assertDefaultScopes( - azureClientRegistration, - "openid", - "profile", - "offline_access", - "https://graph.microsoft.com/Calendars.Read", - "https://management.core.windows.net/user_impersonation" - ); - - assertNotNull(graphClientRegistration); - assertDefaultScopes(graphClientRegistration, "https://graph.microsoft.com/Calendars.Read"); - } - } - - @Test - public void clientRequiresPermissionInDefaultClient() { - try (AppRunner appRunner = createApp()) { - appRunner.property( - "azure.activedirectory.authorization.azure.scope", - "https://graph.microsoft.com/Calendars.Read" - ); - appRunner.start(); - - ClientRegistrationRepository clientRegistrationRepository = - appRunner.getBean(ClientRegistrationRepository.class); - ClientRegistration azureClientRegistration = clientRegistrationRepository.findByRegistrationId("azure"); - - assertNotNull(azureClientRegistration); - assertDefaultScopes( - azureClientRegistration, - "openid", "profile", "offline_access", "https://graph.microsoft.com/Calendars.Read" - ); - } - } - - @Test - public void aadAwareClientRepository() { - try (AppRunner appRunner = createApp()) { - appRunner.property( - "azure.activedirectory.authorization.graph.scope", - "https://graph.microsoft.com/Calendars.Read") - ; - appRunner.start(); - - AzureClientRegistrationRepository azureClientRegistrationRepository = - (AzureClientRegistrationRepository) appRunner.getBean(ClientRegistrationRepository.class); - ClientRegistration azureClientRegistration = - azureClientRegistrationRepository.findByRegistrationId("azure"); - ClientRegistration graphClientRegistration = - azureClientRegistrationRepository.findByRegistrationId("graph"); - - assertDefaultScopes( - azureClientRegistrationRepository.defaultClient(), - "openid", "profile", "offline_access" - ); - assertEquals(azureClientRegistrationRepository.defaultClient().getClientRegistration(), azureClientRegistration); - - assertFalse(azureClientRegistrationRepository.isAuthorizedClient(azureClientRegistration)); - assertTrue(azureClientRegistrationRepository.isAuthorizedClient(graphClientRegistration)); - assertFalse(azureClientRegistrationRepository.isAuthorizedClient("azure")); - assertTrue(azureClientRegistrationRepository.isAuthorizedClient("graph")); - - List clientRegistrations = collectClients(azureClientRegistrationRepository); - assertEquals(1, clientRegistrations.size()); - assertEquals("azure", clientRegistrations.get(0).getRegistrationId()); - } - } - - @Test - public void defaultClientWithAuthzScope() { - try (AppRunner appRunner = createApp()) { - appRunner.property( - "azure.activedirectory.authorization.azure.scope", - "https://graph.microsoft.com/Calendars.Read" - ); - appRunner.start(); - - AzureClientRegistrationRepository azureClientRegistrationRepository = - appRunner.getBean(AzureClientRegistrationRepository.class); - assertDefaultScopes( - azureClientRegistrationRepository.defaultClient(), - "openid", "profile", "offline_access", "https://graph.microsoft.com/Calendars.Read" - ); - } - } - - @Test - public void customizeEnvironment() { - try (AppRunner appRunner = createApp()) { - appRunner.property("azure.activedirectory.environment", "cn-v2-graph"); - appRunner.start(); - - AzureClientRegistrationRepository azureClientRegistrationRepository = - appRunner.getBean(AzureClientRegistrationRepository.class); - ClientRegistration azureClientRegistration = - azureClientRegistrationRepository.findByRegistrationId("azure"); - - AuthorizationServerEndpoints authorizationServerEndpoints = - new AuthorizationServerEndpoints("https://login.partner.microsoftonline.cn"); - assertEquals( - authorizationServerEndpoints.authorizationEndpoint("fake-tenant-id"), - azureClientRegistration.getProviderDetails().getAuthorizationUri() - ); - assertEquals( - authorizationServerEndpoints.tokenEndpoint("fake-tenant-id"), - azureClientRegistration.getProviderDetails().getTokenUri() - ); - assertEquals( - authorizationServerEndpoints.jwkSetEndpoint("fake-tenant-id"), - azureClientRegistration.getProviderDetails().getJwkSetUri() - ); - } - } - - private AppRunner createApp() { - AppRunner result = new AppRunner(DumbApp.class); - result.property("azure.activedirectory.tenant-id", "fake-tenant-id"); - result.property("azure.activedirectory.client-id", "fake-client-id"); - result.property("azure.activedirectory.client-secret", "fake-client-secret"); - result.property("azure.activedirectory.user-group.allowed-groups", "group1"); - return result; - } - - private void assertDefaultScopes(ClientRegistration client, String ... scopes) { - assertEquals(scopes.length, client.getScopes().size()); - for (String s : scopes) { - assertTrue(client.getScopes().contains(s)); - } - } - - private void assertDefaultScopes(DefaultClient client, String ... expected) { - assertEquals(expected.length, client.getScopeList().size()); - for (String e : expected) { - assertTrue(client.getScopeList().contains(e)); - } - } - - private List collectClients(Iterable iterable) { - List result = new ArrayList<>(); - iterable.forEach(result::add); - return result; - } - - @Configuration - @EnableAutoConfiguration - @EnableWebSecurity - public static class DumbApp {} -} diff --git a/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AuthorizedClientRepoTest.java b/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AuthorizedClientRepoTest.java index e2c3fe29d060..6a7283908e39 100644 --- a/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AuthorizedClientRepoTest.java +++ b/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AuthorizedClientRepoTest.java @@ -1,11 +1,14 @@ package com.azure.test.aad.auth; -import com.azure.spring.autoconfigure.aad.AzureClientRegistrationRepository; -import com.azure.spring.autoconfigure.aad.AzureOAuth2AuthorizedClientRepository; +import java.time.Instant; + +import com.azure.spring.aad.implementation.AzureAuthorizedClientRepository; +import com.azure.spring.aad.implementation.AzureClientRegistrationRepository; import com.azure.test.utils.AppRunner; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.context.annotation.Configuration; import org.springframework.mock.web.MockHttpServletRequest; @@ -15,114 +18,104 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken; -import java.time.Instant; -import java.util.Optional; - import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; public class AuthorizedClientRepoTest { - private AppRunner appRunner; + private AppRunner runner; - private ClientRegistration azureClientRegistration; - private ClientRegistration graphClientRegistration; + private ClientRegistration azure; + private ClientRegistration graph; - private OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository; - private MockHttpServletRequest mockHttpServletRequest; - private MockHttpServletResponse mockHttpServletResponse; + private OAuth2AuthorizedClientRepository repo; + private MockHttpServletRequest request; + private MockHttpServletResponse response; @BeforeEach public void setup() { - appRunner = createApp(); - appRunner.start(); + runner = createApp(); + runner.start(); - AzureClientRegistrationRepository azureClientRegistrationRepository = - appRunner.getBean(AzureClientRegistrationRepository.class); - azureClientRegistration = azureClientRegistrationRepository.findByRegistrationId("azure"); - graphClientRegistration = azureClientRegistrationRepository.findByRegistrationId("graph"); + AzureClientRegistrationRepository clientRepo = runner.getBean(AzureClientRegistrationRepository.class); + azure = clientRepo.findByRegistrationId("azure"); + graph = clientRepo.findByRegistrationId("graph"); - oAuth2AuthorizedClientRepository = new AzureOAuth2AuthorizedClientRepository(azureClientRegistrationRepository); - mockHttpServletRequest = new MockHttpServletRequest(); - mockHttpServletResponse = new MockHttpServletResponse(); + repo = new AzureAuthorizedClientRepository(clientRepo); + request = new MockHttpServletRequest(); + response = new MockHttpServletResponse(); } private AppRunner createApp() { - AppRunner result = new AppRunner(AppAutoConfigTest.DumbApp.class); - result.property("azure.activedirectory.tenant-id", "fake-tenant-id"); - result.property("azure.activedirectory.client-id", "fake-client-id"); - result.property("azure.activedirectory.client-secret", "fake-client-secret"); - result.property("azure.activedirectory.user-group.allowed-groups", "group1"); - result.property("azure.activedirectory.authorization.graph.scope", "Calendars.Read"); + AppRunner result = new AppRunner(AzureActiveDirectoryConfigurationTest.DumbApp.class); + result.property("azure.active.directory.uri", "fake-uri"); + result.property("azure.active.directory.tenant-id", "fake-tenant-id"); + result.property("azure.active.directory.client-id", "fake-client-id"); + result.property("azure.active.directory.client-secret", "fake-client-secret"); + result.property("azure.active.directory.authorization.graph.scopes", "Calendars.Read"); return result; } @AfterEach public void tearDown() { - appRunner.stop(); + runner.stop(); } @Test public void loadInitAzureAuthzClient() { - oAuth2AuthorizedClientRepository.saveAuthorizedClient( - toOAuthAuthorizedClient(azureClientRegistration), + repo.saveAuthorizedClient( + createAuthorizedClient(azure), createAuthentication(), - mockHttpServletRequest, - mockHttpServletResponse - ); + request, + response); - OAuth2AuthorizedClient oAuth2AuthorizedClient = - oAuth2AuthorizedClientRepository.loadAuthorizedClient( - "graph", - createAuthentication(), - mockHttpServletRequest - ); + OAuth2AuthorizedClient client = repo.loadAuthorizedClient( + "graph", + createAuthentication(), + request); - assertNotNull(oAuth2AuthorizedClient); - assertNotNull(oAuth2AuthorizedClient.getAccessToken()); - assertNotNull(oAuth2AuthorizedClient.getRefreshToken()); + assertNotNull(client); + assertNotNull(client.getAccessToken()); + assertNotNull(client.getRefreshToken()); - assertTrue(isTokenExpired(oAuth2AuthorizedClient.getAccessToken())); - assertEquals("fake-refresh-token", oAuth2AuthorizedClient.getRefreshToken().getTokenValue()); + assertTrue(isTokenExpired(client.getAccessToken())); + assertEquals("fake-refresh-token", client.getRefreshToken().getTokenValue()); } @Test public void saveAndLoadAzureAuthzClient() { - oAuth2AuthorizedClientRepository.saveAuthorizedClient( - toOAuthAuthorizedClient(graphClientRegistration), + repo.saveAuthorizedClient( + createAuthorizedClient(graph), createAuthentication(), - mockHttpServletRequest, - mockHttpServletResponse - ); + request, + response); - OAuth2AuthorizedClient oAuth2AuthorizedClient = - oAuth2AuthorizedClientRepository.loadAuthorizedClient( - "graph", - createAuthentication(), - mockHttpServletRequest - ); + OAuth2AuthorizedClient client = repo.loadAuthorizedClient( + "graph", + createAuthentication(), + request); - assertNotNull(oAuth2AuthorizedClient); - assertNotNull(oAuth2AuthorizedClient.getAccessToken()); - assertNotNull(oAuth2AuthorizedClient.getRefreshToken()); + assertNotNull(client); + assertNotNull(client.getAccessToken()); + assertNotNull(client.getRefreshToken()); - assertEquals("fake-access-token", oAuth2AuthorizedClient.getAccessToken().getTokenValue()); - assertEquals("fake-refresh-token", oAuth2AuthorizedClient.getRefreshToken().getTokenValue()); + assertEquals("fake-access-token", client.getAccessToken().getTokenValue()); + assertEquals("fake-refresh-token", client.getRefreshToken().getTokenValue()); } - private OAuth2AuthorizedClient toOAuthAuthorizedClient(ClientRegistration clientRegistration) { - return new OAuth2AuthorizedClient( - clientRegistration, + private OAuth2AuthorizedClient createAuthorizedClient(ClientRegistration client) { + OAuth2AuthorizedClient result = new OAuth2AuthorizedClient( + client, "fake-principal-name", createAccessToken(), - createRefreshToken() - ); + createRefreshToken()); + + return result; } private OAuth2AccessToken createAccessToken() { @@ -142,16 +135,12 @@ private Authentication createAuthentication() { return new PreAuthenticatedAuthenticationToken("fake-user", "fake-crednetial"); } - private boolean isTokenExpired(OAuth2AccessToken oAuth2AccessToken) { - return Optional.ofNullable(oAuth2AccessToken) - .map(AbstractOAuth2Token::getExpiresAt) - .map(expiredAt -> expiredAt.isBefore(Instant.now())) - .orElse(false); + private boolean isTokenExpired(OAuth2AccessToken token) { + return token.getExpiresAt().isBefore(Instant.now()); } @Configuration @EnableAutoConfiguration @EnableWebSecurity - public static class DumbApp { - } + public static class DumbApp {} } diff --git a/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AuthzCodeGrantRequestEntityConverterTest.java b/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AuthzCodeGrantRequestEntityConverterTest.java index c7f16e2377ee..08e132c5adf7 100644 --- a/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AuthzCodeGrantRequestEntityConverterTest.java +++ b/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AuthzCodeGrantRequestEntityConverterTest.java @@ -1,14 +1,15 @@ package com.azure.test.aad.auth; -import com.azure.spring.autoconfigure.aad.AzureClientRegistrationRepository; -import com.azure.spring.autoconfigure.aad.AzureOAuth2AuthorizationCodeGrantRequestEntityConverter; +import com.azure.spring.aad.implementation.AuthzCodeGrantRequestEntityConverter; +import com.azure.spring.aad.implementation.AzureClientRegistrationRepository; import com.azure.test.utils.AppRunner; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.context.annotation.Configuration; -import org.springframework.http.HttpEntity; +import org.springframework.http.RequestEntity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -17,102 +18,89 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.util.MultiValueMap; -import java.util.Optional; - import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; public class AuthzCodeGrantRequestEntityConverterTest { - private AppRunner appRunner; - private AzureClientRegistrationRepository azureClientRegistrationRepository; - private ClientRegistration azureClientRegistration; - private ClientRegistration graphClientRegistration; + private AppRunner runner; + private AzureClientRegistrationRepository repo; + private ClientRegistration azure; + private ClientRegistration graph; @BeforeEach public void setupApp() { - appRunner = createApp(); - appRunner.start(); + runner = createApp(); + runner.start(); - azureClientRegistrationRepository = appRunner.getBean(AzureClientRegistrationRepository.class); - azureClientRegistration = azureClientRegistrationRepository.findByRegistrationId("azure"); - graphClientRegistration = azureClientRegistrationRepository.findByRegistrationId("graph"); + repo = runner.getBean(AzureClientRegistrationRepository.class); + azure = repo.findByRegistrationId("azure"); + graph = repo.findByRegistrationId("graph"); } private AppRunner createApp() { AppRunner result = new AppRunner(DumbApp.class); - result.property("azure.activedirectory.tenant-id", "fake-tenant-id"); - result.property("azure.activedirectory.client-id", "fake-client-id"); - result.property("azure.activedirectory.client-secret", "fake-client-secret"); - result.property("azure.activedirectory.user-group.allowed-groups", "group1"); - result.property("azure.activedirectory.authorization.graph.scope", "Calendars.Read"); + result.property("azure.active.directory.uri", "http://localhost"); + result.property("azure.active.directory.tenant-id", "fake-tenant-id"); + result.property("azure.active.directory.client-id", "fake-client-id"); + result.property("azure.active.directory.client-secret", "fake-client-secret"); + result.property("azure.active.directory.authorization.graph.scopes", "Calendars.Read"); return result; } @AfterEach public void tearDownApp() { - appRunner.stop(); + runner.stop(); } @Test public void addScopeForDefaultClient() { - MultiValueMap multiValueMap = toMultiValueMap(createCodeGrantRequest(azureClientRegistration)); - assertEquals("openid profile offline_access", multiValueMap.getFirst("scope")); + MultiValueMap body = convertedBodyOf(createCodeGrantRequest(azure)); + assertEquals("openid profile offline_access", body.getFirst("scope")); } @Test public void noScopeParamForOtherClient() { - MultiValueMap multiValueMap = toMultiValueMap(createCodeGrantRequest(graphClientRegistration)); - assertNull(multiValueMap.get("scope")); + MultiValueMap body = convertedBodyOf(createCodeGrantRequest(graph)); + assertNull(body.get("scope")); } - @SuppressWarnings("unchecked") - private MultiValueMap toMultiValueMap(OAuth2AuthorizationCodeGrantRequest request) { - return (MultiValueMap) - Optional.ofNullable(azureClientRegistrationRepository) - .map(AzureClientRegistrationRepository::defaultClient) - .map(AzureOAuth2AuthorizationCodeGrantRequestEntityConverter::new) - .map(converter -> converter.convert(request)) - .map(HttpEntity::getBody) - .orElse(null); + private MultiValueMap convertedBodyOf(OAuth2AuthorizationCodeGrantRequest request) { + AuthzCodeGrantRequestEntityConverter converter = new AuthzCodeGrantRequestEntityConverter(repo.defaultClient()); + RequestEntity entity = converter.convert(request); + return (MultiValueMap) entity.getBody(); } - private OAuth2AuthorizationCodeGrantRequest createCodeGrantRequest(ClientRegistration clientRegistration) { - return new OAuth2AuthorizationCodeGrantRequest( - clientRegistration, - toOAuth2AuthorizationExchange(clientRegistration) - ); + private OAuth2AuthorizationCodeGrantRequest createCodeGrantRequest(ClientRegistration client) { + return new OAuth2AuthorizationCodeGrantRequest(client, createExchange(client)); } - private OAuth2AuthorizationExchange toOAuth2AuthorizationExchange(ClientRegistration clientRegistration) { + private OAuth2AuthorizationExchange createExchange(ClientRegistration client) { return new OAuth2AuthorizationExchange( - toOAuth2AuthorizationRequest(clientRegistration), - toOAuth2AuthorizationResponse() - ); + createAuthorizationRequest(client), + createAuthorizationResponse()); } - private OAuth2AuthorizationRequest toOAuth2AuthorizationRequest(ClientRegistration clientRegistration) { - return OAuth2AuthorizationRequest.authorizationCode() - .authorizationUri( - clientRegistration.getProviderDetails().getAuthorizationUri() - ) - .clientId(clientRegistration.getClientId()) - .scopes(clientRegistration.getScopes()) - .state("fake-state") - .redirectUri("http://localhost") - .build(); + private OAuth2AuthorizationRequest createAuthorizationRequest(ClientRegistration client) { + OAuth2AuthorizationRequest.Builder builder = OAuth2AuthorizationRequest.authorizationCode(); + builder.authorizationUri(client.getProviderDetails().getAuthorizationUri()); + builder.clientId(client.getClientId()); + builder.scopes(client.getScopes()); + builder.state("fake-state"); + builder.redirectUri("http://localhost"); + return builder.build(); } - private OAuth2AuthorizationResponse toOAuth2AuthorizationResponse() { - return OAuth2AuthorizationResponse.success("fake-code") - .redirectUri("http://localhost") - .state("fake-state") - .build(); + private OAuth2AuthorizationResponse createAuthorizationResponse() { + OAuth2AuthorizationResponse.Builder builder = OAuth2AuthorizationResponse.success("fake-code"); + builder.redirectUri("http://localhost"); + builder.state("fake-state"); + return builder.build(); } @Configuration @EnableAutoConfiguration @EnableWebSecurity - public static class DumbApp { - } + public static class DumbApp {} } diff --git a/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AzureActiveDirectoryConfigurationTest.java b/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AzureActiveDirectoryConfigurationTest.java new file mode 100644 index 000000000000..99e8e56b205a --- /dev/null +++ b/sdk/spring/azure-spring-boot-test-aad/src/test/java/com/azure/test/aad/auth/AzureActiveDirectoryConfigurationTest.java @@ -0,0 +1,188 @@ +package com.azure.test.aad.auth; + +import java.util.ArrayList; +import java.util.List; + +import com.azure.spring.aad.implementation.AzureClientRegistrationRepository; +import com.azure.spring.aad.implementation.DefaultClient; +import com.azure.spring.aad.implementation.IdentityEndpoints; +import com.azure.test.utils.AppRunner; +import org.junit.jupiter.api.Test; + +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.context.annotation.Configuration; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class AzureActiveDirectoryConfigurationTest { + + @Test + public void clientRegistered() { + try (AppRunner runner = createApp()) { + runner.start(); + + ClientRegistrationRepository repo = runner.getBean(ClientRegistrationRepository.class); + ClientRegistration azure = repo.findByRegistrationId("azure"); + + assertNotNull(azure); + assertEquals("fake-client-id", azure.getClientId()); + assertEquals("fake-client-secret", azure.getClientSecret()); + + IdentityEndpoints endpoints = new IdentityEndpoints(); + assertEquals(endpoints.authorizationEndpoint("fake-tenant-id"), azure.getProviderDetails().getAuthorizationUri()); + assertEquals(endpoints.tokenEndpoint("fake-tenant-id"), azure.getProviderDetails().getTokenUri()); + assertEquals(endpoints.jwkSetEndpoint("fake-tenant-id"), azure.getProviderDetails().getJwkSetUri()); + assertEquals("{baseUrl}/login/oauth2/code/{registrationId}", azure.getRedirectUriTemplate()); + assertDefaultScopes(azure, "openid", "profile"); + } + } + + @Test + public void clientRequiresPermissionRegistered() { + try (AppRunner runner = createApp()) { + runner.property("azure.active.directory.authorization.graph.scopes", "Calendars.Read"); + runner.start(); + + ClientRegistrationRepository repo = runner.getBean(ClientRegistrationRepository.class); + ClientRegistration azure = repo.findByRegistrationId("azure"); + ClientRegistration graph = repo.findByRegistrationId("graph"); + + assertNotNull(azure); + assertDefaultScopes(azure, "openid", "profile", "offline_access", "Calendars.Read"); + + assertNotNull(graph); + assertDefaultScopes(graph, "Calendars.Read"); + } + } + + @Test + public void clientRequiresMultiPermissions() { + try (AppRunner runner = createApp()) { + runner.property("azure.active.directory.authorization.graph.scopes", "Calendars.Read"); + runner.property("azure.active.directory.authorization.arm.scopes", "https://management.core.windows.net/user_impersonation"); + runner.start(); + + ClientRegistrationRepository repo = runner.getBean(ClientRegistrationRepository.class); + ClientRegistration azure = repo.findByRegistrationId("azure"); + ClientRegistration graph = repo.findByRegistrationId("graph"); + + assertNotNull(azure); + assertDefaultScopes( + azure, + "openid", + "profile", + "offline_access", + "Calendars.Read", + "https://management.core.windows.net/user_impersonation"); + + assertNotNull(graph); + assertDefaultScopes(graph, "Calendars.Read"); + } + } + + @Test + public void clientRequiresPermissionInDefaultClient() { + try (AppRunner runner = createApp()) { + runner.property("azure.active.directory.authorization.azure.scopes", "Calendars.Read"); + runner.start(); + + ClientRegistrationRepository repo = runner.getBean(ClientRegistrationRepository.class); + ClientRegistration azure = repo.findByRegistrationId("azure"); + + assertNotNull(azure); + assertDefaultScopes(azure, "openid", "profile", "offline_access", "Calendars.Read"); + } + } + + @Test + public void aadAwareClientRepository() { + try (AppRunner runner = createApp()) { + runner.property("azure.active.directory.authorization.graph.scopes", "Calendars.Read"); + runner.start(); + + AzureClientRegistrationRepository repo = (AzureClientRegistrationRepository) runner.getBean(ClientRegistrationRepository.class); + ClientRegistration azure = repo.findByRegistrationId("azure"); + ClientRegistration graph = repo.findByRegistrationId("graph"); + + assertDefaultScopes(repo.defaultClient(), "openid", "profile", "offline_access"); + assertEquals(repo.defaultClient().client(), azure); + + assertFalse(repo.isAuthzClient(azure)); + assertTrue(repo.isAuthzClient(graph)); + assertFalse(repo.isAuthzClient("azure")); + assertTrue(repo.isAuthzClient("graph")); + + List clients = collectClients((Iterable) repo); + assertEquals(1, clients.size()); + assertEquals("azure", clients.get(0).getRegistrationId()); + } + } + + @Test + public void defaultClientWithAuthzScope() { + try (AppRunner runner = createApp()) { + runner.property("azure.active.directory.authorization.azure.scopes", "Calendars.Read"); + runner.start(); + + AzureClientRegistrationRepository repo = runner.getBean(AzureClientRegistrationRepository.class); + assertDefaultScopes(repo.defaultClient(), "openid", "profile", "offline_access", "Calendars.Read"); + } + } + + @Test + public void customizeUri() { + try (AppRunner runner = createApp()) { + runner.property("azure.active.directory.uri", "http://localhost/"); + runner.start(); + + AzureClientRegistrationRepository repo = runner.getBean(AzureClientRegistrationRepository.class); + ClientRegistration azure = repo.findByRegistrationId("azure"); + + IdentityEndpoints endpoints = new IdentityEndpoints("http://localhost/"); + assertEquals(endpoints.authorizationEndpoint("fake-tenant-id"), azure.getProviderDetails().getAuthorizationUri()); + assertEquals(endpoints.tokenEndpoint("fake-tenant-id"), azure.getProviderDetails().getTokenUri()); + assertEquals(endpoints.jwkSetEndpoint("fake-tenant-id"), azure.getProviderDetails().getJwkSetUri()); + } + } + + private AppRunner createApp() { + AppRunner result = new AppRunner(DumbApp.class); + result.property("azure.active.directory.uri", "https://login.microsoftonline.com"); + result.property("azure.active.directory.tenant-id", "fake-tenant-id"); + result.property("azure.active.directory.client-id", "fake-client-id"); + result.property("azure.active.directory.client-secret", "fake-client-secret"); + result.property("azure.active.directory.user-group.allowed-groups", "groupA, groupB"); + return result; + } + + private void assertDefaultScopes(ClientRegistration client, String ... scopes) { + assertEquals(scopes.length, client.getScopes().size()); + for (String s : scopes) { + assertTrue(client.getScopes().contains(s)); + } + } + + private void assertDefaultScopes(DefaultClient client, String ... expected) { + assertEquals(expected.length, client.scopes().size()); + for (String e : expected) { + assertTrue(client.scopes().contains(e)); + } + } + + private List collectClients(Iterable itr) { + List result = new ArrayList<>(); + itr.forEach(c -> result.add(c)); + return result; + } + + @Configuration + @EnableAutoConfiguration + @EnableWebSecurity + public static class DumbApp {} +} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AuthorizationProperties.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AuthorizationProperties.java new file mode 100644 index 000000000000..0e30f2be233c --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AuthorizationProperties.java @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.aad.implementation; + +import java.util.List; + +public class AuthorizationProperties { + + private List scopes; + + public void setScopes(List scopes) { + this.scopes = scopes; + } + + public List getScopes() { + return scopes; + } +} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureOAuth2AuthorizationCodeGrantRequestEntityConverter.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AuthzCodeGrantRequestEntityConverter.java similarity index 51% rename from sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureOAuth2AuthorizationCodeGrantRequestEntityConverter.java rename to sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AuthzCodeGrantRequestEntityConverter.java index 15f35fc00aeb..769fb62119e5 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureOAuth2AuthorizationCodeGrantRequestEntityConverter.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AuthzCodeGrantRequestEntityConverter.java @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package com.azure.spring.autoconfigure.aad; +package com.azure.spring.aad.implementation; import org.springframework.http.HttpEntity; import org.springframework.http.RequestEntity; @@ -11,26 +11,12 @@ import java.util.Optional; -/** - * This converter is to add 'scope' parameter when request for access token. - * - * In default oidc flow, when request for access token by authorization code, 'scope' parameter is not necessary. - * Because one consent operation only create one authorizedClient. - * - * But for Microsoft Authorization Server, one consent can created multiple authorizedClient. - * So scope parameter is necessary when request for access token. - * - * Refs: - * 1. https://tools.ietf.org/html/rfc6749#section-4.1.3 - * 2. https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow#request-an-access-token - */ -public class AzureOAuth2AuthorizationCodeGrantRequestEntityConverter - extends OAuth2AuthorizationCodeGrantRequestEntityConverter { +public class AuthzCodeGrantRequestEntityConverter extends OAuth2AuthorizationCodeGrantRequestEntityConverter { private final DefaultClient defaultClient; - public AzureOAuth2AuthorizationCodeGrantRequestEntityConverter(DefaultClient defaultClient) { - this.defaultClient = defaultClient; + public AuthzCodeGrantRequestEntityConverter(DefaultClient client) { + defaultClient = client; } @Override @@ -41,16 +27,16 @@ public RequestEntity convert(OAuth2AuthorizationCodeGrantRequest request) { Optional.ofNullable(result) .map(HttpEntity::getBody) .map(b -> (MultiValueMap) b) - .ifPresent(map -> map.add("scope", scopeValue())); + .ifPresent(body -> body.add("scope", scopeValue())); } return result; } private boolean isRequestForDefaultClient(OAuth2AuthorizationCodeGrantRequest request) { - return request.getClientRegistration().equals(defaultClient.getClientRegistration()); + return request.getClientRegistration().equals(defaultClient.client()); } private String scopeValue() { - return String.join(" ", defaultClient.getScope()); + return String.join(" ", defaultClient.scope()); } } diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureActiveDirectoryConfiguration.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureActiveDirectoryConfiguration.java new file mode 100644 index 000000000000..6468cc53df3c --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureActiveDirectoryConfiguration.java @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.aad.implementation; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; + +import java.util.ArrayList; +import java.util.List; + +@Configuration +@ConditionalOnClass(ClientRegistrationRepository.class) +@EnableConfigurationProperties(AzureActiveDirectoryProperties.class) +@ConditionalOnExpression("#{'${azure.active.directory.uri:notExist}' != 'notExist'}") +public class AzureActiveDirectoryConfiguration { + + private static final String DEFAULT_CLIENT = "azure"; + + @Autowired + private AzureActiveDirectoryProperties config; + + @Bean + @ConditionalOnMissingBean({ ClientRegistrationRepository.class, AzureClientRegistrationRepository.class }) + public AzureClientRegistrationRepository clientRegistrationRepository() { + return new AzureClientRegistrationRepository( + createDefaultClient(), + createAuthzClients()); + } + + private DefaultClient createDefaultClient() { + ClientRegistration.Builder builder = createClientBuilder(DEFAULT_CLIENT); + builder.scope(allScopes()); + ClientRegistration client = builder.build(); + + return new DefaultClient(client, defaultScopes()); + } + + private String[] allScopes() { + List result = openidScopes(); + for (AuthorizationProperties authz : config.getAuthorization().values()) { + result.addAll(authz.getScopes()); + } + return result.toArray(new String[0]); + } + + private List defaultScopes() { + List result = openidScopes(); + addAuthzDefaultScope(result); + return result; + } + + private void addAuthzDefaultScope(List result) { + AuthorizationProperties authz = config.getAuthorization().get(DEFAULT_CLIENT); + if (authz != null) { + result.addAll(authz.getScopes()); + } + } + + private List openidScopes() { + List result = new ArrayList<>(); + result.add("openid"); + result.add("profile"); + + if (!config.getAuthorization().isEmpty()) { + result.add("offline_access"); + } + return result; + } + + private List createAuthzClients() { + List result = new ArrayList<>(); + for (String name : config.getAuthorization().keySet()) { + if (DEFAULT_CLIENT.equals(name)) { + continue; + } + + AuthorizationProperties authz = config.getAuthorization().get(name); + result.add(createClientBuilder(name, authz)); + } + return result; + } + + private ClientRegistration createClientBuilder(String id, AuthorizationProperties authz) { + ClientRegistration.Builder result = createClientBuilder(id); + result.scope(authz.getScopes()); + return result.build(); + } + + private ClientRegistration.Builder createClientBuilder(String id) { + ClientRegistration.Builder result = ClientRegistration.withRegistrationId(id); + result.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE); + result.redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}"); + + result.clientId(config.getClientId()); + result.clientSecret(config.getClientSecret()); + + IdentityEndpoints endpoints = new IdentityEndpoints(config.getUri()); + result.authorizationUri(endpoints.authorizationEndpoint(config.getTenantId())); + result.tokenUri(endpoints.tokenEndpoint(config.getTenantId())); + result.jwkSetUri(endpoints.jwkSetEndpoint(config.getTenantId())); + + return result; + } + + @Bean + @ConditionalOnMissingBean + public OAuth2AuthorizedClientRepository authorizedClientRepository(AzureClientRegistrationRepository repo) { + return new AzureAuthorizedClientRepository(repo); + } + + @Configuration + @ConditionalOnMissingBean(WebSecurityConfigurerAdapter.class) + public static class DefaultAzureOAuth2Configuration extends AzureOAuth2Configuration { + + @Override + protected void configure(HttpSecurity http) throws Exception { + super.configure(http); + http.authorizeRequests().anyRequest().authenticated(); + } + } +} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureActiveDirectoryProperties.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureActiveDirectoryProperties.java new file mode 100644 index 000000000000..757308c344dd --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureActiveDirectoryProperties.java @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.aad.implementation; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +import java.util.HashMap; +import java.util.Map; + +@ConfigurationProperties("azure.active.directory") +public class AzureActiveDirectoryProperties { + + private String uri; + private String tenantId; + private String clientId; + private String clientSecret; + + private Map authorization = new HashMap<>(); + + public void setUri(String uri) { + this.uri = uri; + } + + public String getUri() { + return uri; + } + + public void setTenantId(String tenantId) { + this.tenantId = tenantId; + } + + public String getTenantId() { + return tenantId; + } + + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public String getClientId() { + return clientId; + } + + public void setClientSecret(String clientSecret) { + this.clientSecret = clientSecret; + } + + public String getClientSecret() { + return clientSecret; + } + + public void setAuthorization(Map authorization) { + this.authorization = authorization; + } + + public Map getAuthorization() { + return authorization; + } +} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureAuthorizedClientRepository.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureAuthorizedClientRepository.java new file mode 100644 index 000000000000..137d1595ef8b --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureAuthorizedClientRepository.java @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.aad.implementation; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +public class AzureAuthorizedClientRepository implements OAuth2AuthorizedClientRepository { + + private AzureClientRegistrationRepository repo; + private OAuth2AuthorizedClientRepository delegate; + + private static OAuth2AuthorizedClientRepository createDefaultDelegate(ClientRegistrationRepository repo) { + return new HttpSessionOAuth2AuthorizedClientRepository(); + } + + public AzureAuthorizedClientRepository(AzureClientRegistrationRepository repo) { + this(repo, createDefaultDelegate(repo)); + } + + public AzureAuthorizedClientRepository(AzureClientRegistrationRepository repo, + OAuth2AuthorizedClientRepository delegate) { + this.repo = repo; + this.delegate = delegate; + } + + @Override + public void saveAuthorizedClient(OAuth2AuthorizedClient client, + Authentication principal, + HttpServletRequest request, + HttpServletResponse response) { + delegate.saveAuthorizedClient(client, principal, request, response); + } + + @Override + @SuppressWarnings("unchecked") + public T loadAuthorizedClient(String id, + Authentication principal, + HttpServletRequest request) { + OAuth2AuthorizedClient result = delegate.loadAuthorizedClient(id, principal, request); + if (result != null) { + return (T) result; + } + + if (repo.isAuthzClient(id)) { + OAuth2AuthorizedClient client = loadAuthorizedClient(defaultClientRegistrationId(), principal, request); + return (T) createInitAuthzClient(client, id, principal); + } + return null; + } + + private String defaultClientRegistrationId() { + return repo.defaultClient().client().getRegistrationId(); + } + + private OAuth2AuthorizedClient createInitAuthzClient(OAuth2AuthorizedClient client, + String id, + Authentication principal) { + if (client == null || client.getRefreshToken() == null) { + return null; + } + + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, + "non-access-token", + Instant.MIN, + Instant.now().minus(100, ChronoUnit.DAYS)); + + return new OAuth2AuthorizedClient( + repo.findByRegistrationId(id), + principal.getName(), + accessToken, + client.getRefreshToken() + ); + } + + @Override + public void removeAuthorizedClient(String id, + Authentication principal, + HttpServletRequest request, + HttpServletResponse response) { + delegate.removeAuthorizedClient(id, principal, request, response); + } +} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureClientRegistrationRepository.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureClientRegistrationRepository.java new file mode 100644 index 000000000000..987430fb8753 --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureClientRegistrationRepository.java @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.aad.implementation; + +import org.jetbrains.annotations.NotNull; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +public class AzureClientRegistrationRepository implements ClientRegistrationRepository, Iterable { + + private DefaultClient defaultClient; + private List authzClients; + + private Map clients; + + public AzureClientRegistrationRepository(DefaultClient defaultClient, List authzClients) { + this.defaultClient = defaultClient; + this.authzClients = new ArrayList<>(authzClients); + + clients = new HashMap<>(); + addClientRegistration(defaultClient.client()); + for (ClientRegistration c : authzClients) { + addClientRegistration(c); + } + } + + private void addClientRegistration(ClientRegistration client) { + clients.put(client.getRegistrationId(), client); + } + + @Override + public ClientRegistration findByRegistrationId(String registrationId) { + return clients.get(registrationId); + } + + @NotNull + @Override + public Iterator iterator() { + return Collections.singleton(defaultClient.client()).iterator(); + } + + public DefaultClient defaultClient() { + return defaultClient; + } + + public boolean isAuthzClient(ClientRegistration client) { + return authzClients.contains(client); + } + + public boolean isAuthzClient(String id) { + ClientRegistration client = findByRegistrationId(id); + return client != null && isAuthzClient(client); + } +} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureOAuth2WebSecurityConfigurerAdapter.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureOAuth2Configuration.java similarity index 62% rename from sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureOAuth2WebSecurityConfigurerAdapter.java rename to sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureOAuth2Configuration.java index c2a016e8b072..99d7954b74e7 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureOAuth2WebSecurityConfigurerAdapter.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/AzureOAuth2Configuration.java @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package com.azure.spring.autoconfigure.aad; +package com.azure.spring.aad.implementation; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.config.annotation.web.builders.HttpSecurity; @@ -10,15 +10,10 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; -import java.util.Optional; - -/** - * The main purpose of this class is to make AzureOAuth2AuthorizationCodeGrantRequestEntityConverter take effect. - */ -public abstract class AzureOAuth2WebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter { +public abstract class AzureOAuth2Configuration extends WebSecurityConfigurerAdapter { @Autowired - private AzureClientRegistrationRepository azureClientRegistrationRepository; + private AzureClientRegistrationRepository repo; @Override protected void configure(HttpSecurity http) throws Exception { @@ -27,10 +22,7 @@ protected void configure(HttpSecurity http) throws Exception { protected OAuth2AccessTokenResponseClient accessTokenResponseClient() { DefaultAuthorizationCodeTokenResponseClient result = new DefaultAuthorizationCodeTokenResponseClient(); - Optional.ofNullable(azureClientRegistrationRepository) - .map(AzureClientRegistrationRepository::defaultClient) - .map(AzureOAuth2AuthorizationCodeGrantRequestEntityConverter::new) - .ifPresent(result::setRequestEntityConverter); + result.setRequestEntityConverter(new AuthzCodeGrantRequestEntityConverter(repo.defaultClient())); return result; } } diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/DefaultClient.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/DefaultClient.java new file mode 100644 index 000000000000..4aedcf167b31 --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/DefaultClient.java @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.aad.implementation; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; + +import java.util.List; + +public class DefaultClient { + + private final ClientRegistration client; + private final List scopes; + + public DefaultClient(ClientRegistration client, List scopes) { + this.client = client; + this.scopes = scopes; + } + + public ClientRegistration client() { + return client; + } + + public List scope() { + return scopes; + } + + public List scopes() { + return scopes; + } +} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AuthorizationServerEndpoints.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/IdentityEndpoints.java similarity index 56% rename from sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AuthorizationServerEndpoints.java rename to sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/IdentityEndpoints.java index 65981c98bab2..ae02c2a78b82 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AuthorizationServerEndpoints.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/aad/implementation/IdentityEndpoints.java @@ -1,31 +1,29 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package com.azure.spring.autoconfigure.aad; +package com.azure.spring.aad.implementation; import com.nimbusds.oauth2.sdk.util.StringUtils; -/** - * Util class used to create authorization server endpoints. - */ -public class AuthorizationServerEndpoints { +public class IdentityEndpoints { + + private static final String IDENTITY_PLATFORM = "https://login.microsoftonline.com/"; - private static final String DEFAULT_AUTHORIZATION_SERVER_URI = "https://login.microsoftonline.com/"; private static final String AUTHORIZATION_ENDPOINT = "/oauth2/v2.0/authorize"; private static final String TOKEN_ENDPOINT = "/oauth2/v2.0/token"; private static final String JWK_SET_ENDPOINT = "/discovery/v2.0/keys"; - private final String baseUri; + private String baseUri; - public AuthorizationServerEndpoints() { - this(DEFAULT_AUTHORIZATION_SERVER_URI); + public IdentityEndpoints() { + this(IDENTITY_PLATFORM); } - public AuthorizationServerEndpoints(String authorizationServerUri) { - if (StringUtils.isBlank(authorizationServerUri)) { - authorizationServerUri = DEFAULT_AUTHORIZATION_SERVER_URI; + public IdentityEndpoints(String baseUri) { + if (StringUtils.isBlank(baseUri)) { + baseUri = IDENTITY_PLATFORM; } - this.baseUri = addSlash(authorizationServerUri); + this.baseUri = addSlash(baseUri); } private String addSlash(String uri) { diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFailureHandler.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFailureHandler.java new file mode 100644 index 000000000000..6bab923cedb8 --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFailureHandler.java @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.autoconfigure.aad; + +import com.microsoft.aad.msal4j.MsalServiceException; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler; +import org.springframework.security.web.savedrequest.DefaultSavedRequest; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.Optional; + +/** + * Strategy used to handle a failed authentication attempt. + *

+ * To redirect the user to the authentication page to allow them to try again when conditional access policy is + * configured on Azure Active Directory. + */ +public class AADAuthenticationFailureHandler implements AuthenticationFailureHandler { + private static final String DEFAULT_FAILURE_URL = "/login?error"; + private final AuthenticationFailureHandler defaultHandler; + + public AADAuthenticationFailureHandler() { + this.defaultHandler = new SimpleUrlAuthenticationFailureHandler(DEFAULT_FAILURE_URL); + } + + @Override + public void onAuthenticationFailure(HttpServletRequest request, + HttpServletResponse response, + AuthenticationException exception) throws IOException, ServletException { + // Handle conditional access policy, step 3. + MsalServiceException msalServiceException = (MsalServiceException) + Optional.of(exception) + .filter(e -> e instanceof OAuth2AuthenticationException) + .map(e -> (OAuth2AuthenticationException) e) + .filter(e -> AADOAuth2ErrorCode.CONDITIONAL_ACCESS_POLICY.equals((e.getError().getErrorCode()))) + .map(Throwable::getCause) + .filter(cause -> cause instanceof MsalServiceException) + .orElse(null); + if (msalServiceException == null) { + // Default handle logic + defaultHandler.onAuthenticationFailure(request, response, exception); + } else { + // Put claims into session + Optional.of(msalServiceException) + .map(MsalServiceException::claims) + .ifPresent(claims -> request.getSession() + .setAttribute(Constants.CONDITIONAL_ACCESS_POLICY_CLAIMS, claims)); + // Redirect + response.setStatus(302); + String redirectUrl = Optional.of(request) + .map(HttpServletRequest::getSession) + .map(s -> s.getAttribute(Constants.SAVED_REQUEST)) + .map(r -> (DefaultSavedRequest) r) + .map(DefaultSavedRequest::getRedirectUrl) + .orElse(null); + response.sendRedirect(redirectUrl); + } + } +} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java index b110c366cf10..cb4733b03911 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java @@ -42,7 +42,7 @@ public class AADAuthenticationFilter extends OncePerRequestFilter { private static final String CURRENT_USER_PRINCIPAL = "CURRENT_USER_PRINCIPAL"; private final UserPrincipalManager userPrincipalManager; - private final GraphOboClient graphOboClient; + private final AzureADGraphClient azureADGraphClient; public AADAuthenticationFilter(AADAuthenticationProperties aadAuthenticationProperties, ServiceEndpointsProperties serviceEndpointsProperties, @@ -80,18 +80,16 @@ public AADAuthenticationFilter(AADAuthenticationProperties aadAuthenticationProp ServiceEndpointsProperties serviceEndpointsProperties, UserPrincipalManager userPrincipalManager) { this.userPrincipalManager = userPrincipalManager; - this.graphOboClient = new GraphOboClient( + this.azureADGraphClient = new AzureADGraphClient( aadAuthenticationProperties, serviceEndpointsProperties ); } @Override - protected void doFilterInternal( - HttpServletRequest httpServletRequest, - HttpServletResponse httpServletResponse, - FilterChain filterChain - ) throws ServletException, IOException { + protected void doFilterInternal(HttpServletRequest httpServletRequest, + HttpServletResponse httpServletResponse, + FilterChain filterChain) throws ServletException, IOException { String aadIssuedBearerToken = Optional.of(httpServletRequest) .map(r -> r.getHeader(HttpHeaders.AUTHORIZATION)) .map(String::trim) @@ -108,19 +106,21 @@ protected void doFilterInternal( UserPrincipal userPrincipal = (UserPrincipal) httpSession.getAttribute(CURRENT_USER_PRINCIPAL); if (userPrincipal == null || !userPrincipal.getAadIssuedBearerToken().equals(aadIssuedBearerToken) + || userPrincipal.getAccessTokenForGraphApi() == null ) { userPrincipal = userPrincipalManager.buildUserPrincipal(aadIssuedBearerToken); String tenantId = userPrincipal.getClaim(AADTokenClaim.TID).toString(); - String accessTokenForGraphApi = graphOboClient + String accessTokenForGraphApi = azureADGraphClient .acquireTokenForGraphApi(aadIssuedBearerToken, tenantId) .accessToken(); - userPrincipal.setGroups(graphOboClient.getGroups(accessTokenForGraphApi)); + userPrincipal.setAccessTokenForGraphApi(accessTokenForGraphApi); + userPrincipal.setGroups(azureADGraphClient.getGroups(accessTokenForGraphApi)); httpSession.setAttribute(CURRENT_USER_PRINCIPAL, userPrincipal); } final Authentication authentication = new PreAuthenticatedAuthenticationToken( userPrincipal, null, - graphOboClient.toGrantedAuthoritySet(userPrincipal.getGroups()) + azureADGraphClient.toGrantedAuthoritySet(userPrincipal.getGroups()) ); LOGGER.info("Request token verification success. {}", authentication); SecurityContextHolder.getContext().setAuthentication(authentication); @@ -137,6 +137,7 @@ protected void doFilterInternal( } catch (MsalServiceException ex) { // Handle conditional access policy, step 2. // No step 3 any more, because ServletException will not be caught. + // TODO: Do we need to return 401 instead of 500? if (ex.claims() != null && !ex.claims().isEmpty()) { throw new ServletException("Handle conditional access policy", ex); } else { diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterAutoConfiguration.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterAutoConfiguration.java index 05489f4fd1f9..1ce1a4718abf 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterAutoConfiguration.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterAutoConfiguration.java @@ -35,9 +35,8 @@ *

* The configuration will not be activated if no {@literal azure.activedirectory.client-id} property provided. *

- * A stateless filter {@link AADAppRoleStatelessAuthenticationFilter} will be auto-configured by specifying - * {@literal azure.activedirectory.session-stateless=true}. Otherwise, {@link AADAuthenticationFilter} will be - * configured. + * A stateless filter {@link AADAppRoleStatelessAuthenticationFilter} will be auto-configured by specifying {@literal + * azure.activedirectory.session-stateless=true}. Otherwise, {@link AADAuthenticationFilter} will be configured. */ @Configuration @ConditionalOnWebApplication @@ -45,6 +44,7 @@ @ConditionalOnProperty(prefix = AADAuthenticationFilterAutoConfiguration.PROPERTY_PREFIX, value = { "client-id" }) @EnableConfigurationProperties({ AADAuthenticationProperties.class, ServiceEndpointsProperties.class }) @PropertySource(value = "classpath:service-endpoints.properties") +@ConditionalOnExpression("#{'${azure.active.directory.uri:notExist}' == 'notExist'}") public class AADAuthenticationFilterAutoConfiguration { public static final String PROPERTY_PREFIX = "azure.activedirectory"; private static final Logger LOG = LoggerFactory.getLogger(AADAuthenticationProperties.class); @@ -67,7 +67,7 @@ public AADAuthenticationFilterAutoConfiguration(AADAuthenticationProperties aadA @ConditionalOnMissingBean(AADAuthenticationFilter.class) @ConditionalOnExpression("${azure.activedirectory.session-stateless:false} == false") // client-id and client-secret used to: get graphApiToken -> groups - @ConditionalOnProperty(prefix = PROPERTY_PREFIX, value = {"client-id", "client-secret"}) + @ConditionalOnProperty(prefix = PROPERTY_PREFIX, value = { "client-id", "client-secret" }) public AADAuthenticationFilter azureADJwtTokenFilter() { LOG.info("AzureADJwtTokenFilter Constructor."); return new AADAuthenticationFilter( @@ -82,7 +82,7 @@ public AADAuthenticationFilter azureADJwtTokenFilter() { @ConditionalOnMissingBean(AADAppRoleStatelessAuthenticationFilter.class) @ConditionalOnExpression("${azure.activedirectory.session-stateless:false} == true") // client-id used to: userPrincipalManager.getValidator - @ConditionalOnProperty(prefix = PROPERTY_PREFIX, value = {"client-id"}) + @ConditionalOnProperty(prefix = PROPERTY_PREFIX, value = { "client-id" }) public AADAppRoleStatelessAuthenticationFilter azureADStatelessAuthFilter(ResourceRetriever resourceRetriever) { LOG.info("Creating AzureADStatelessAuthFilter bean."); return new AADAppRoleStatelessAuthenticationFilter( diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationProperties.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationProperties.java index 3ff646b13c51..1925493d60c6 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationProperties.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADAuthenticationProperties.java @@ -13,10 +13,9 @@ import javax.annotation.PostConstruct; import javax.validation.constraints.NotEmpty; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; @@ -35,8 +34,6 @@ public class AADAuthenticationProperties { private static final String GROUP_RELATIONSHIP_DIRECT = "direct"; private static final String GROUP_RELATIONSHIP_TRANSITIVE = "transitive"; - private Map authorization = new HashMap<>(); - /** * Default UserGroup configuration. */ @@ -48,17 +45,27 @@ public class AADAuthenticationProperties { private String environment = DEFAULT_SERVICE_ENVIRONMENT; /** - * Registered application ID in Azure AD. - * Must be configured when OAuth2 authentication is done in front end + * Registered application ID in Azure AD. Must be configured when OAuth2 authentication is done in front end */ private String clientId; /** - * API Access Key of the registered application. - * Must be configured when OAuth2 authentication is done in front end + * API Access Key of the registered application. Must be configured when OAuth2 authentication is done in front end */ private String clientSecret; + /** + * Redirection Endpoint: Used by the authorization server to return responses containing authorization credentials + * to the client via the resource owner user-agent. + */ + private String redirectUriTemplate; + + /** + * Optional. scope doc: https://docs.microsoft + * .com/en-us/azure/active-directory/develop/v2-permissions-and-consent#scopes-and-permissions + */ + private List scope = Arrays.asList("openid", "https://graph.microsoft.com/user.read", "profile"); + /** * App ID URI which might be used in the "aud" claim of an id_token. */ @@ -146,10 +153,9 @@ public static class UserGroupProperties { /** - * The way to obtain group relationship.
- * direct: the default value, get groups that the user is a direct member of;
- * transitive: Get groups that the user is a member of, and will also return all - * groups the user is a nested member of; + * The way to obtain group relationship.
direct: the default value, get groups that the user is a direct + * member of;
transitive: Get groups that the user is a member of, and will also return all groups the user + * is a nested member of; */ @NotEmpty private String groupRelationship = GROUP_RELATIONSHIP_DIRECT; @@ -197,12 +203,12 @@ public void setGroupRelationship(String groupRelationship) { @Override public String toString() { return "UserGroupProperties{" - + "allowedGroups=" + allowedGroups - + ", key='" + key + '\'' - + ", value='" + value + '\'' - + ", objectIDKey='" + objectIDKey + '\'' - + ", groupRelationship='" + groupRelationship + '\'' - + '}'; + + "allowedGroups=" + allowedGroups + + ", key='" + key + '\'' + + ", value='" + value + '\'' + + ", objectIDKey='" + objectIDKey + '\'' + + ", groupRelationship='" + groupRelationship + '\'' + + '}'; } @Override @@ -257,14 +263,6 @@ public void validateUserGroupProperties() { } } - public void setAuthorization(Map authorization) { - this.authorization = authorization; - } - - public Map getAuthorization() { - return authorization; - } - public UserGroupProperties getUserGroup() { return userGroup; } @@ -297,6 +295,22 @@ public void setClientSecret(String clientSecret) { this.clientSecret = clientSecret; } + public String getRedirectUriTemplate() { + return redirectUriTemplate; + } + + public void setRedirectUriTemplate(String redirectUriTemplate) { + this.redirectUriTemplate = redirectUriTemplate; + } + + public void setScope(List scope) { + this.scope = scope; + } + + public List getScope() { + return scope; + } + @Deprecated public void setActiveDirectoryGroups(List activeDirectoryGroups) { this.userGroup.setAllowedGroups(activeDirectoryGroups); diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADOAuth2AutoConfiguration.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADOAuth2AutoConfiguration.java new file mode 100644 index 000000000000..1e540b631ea6 --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADOAuth2AutoConfiguration.java @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.autoconfigure.aad; + +import com.azure.spring.telemetry.TelemetrySender; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.condition.ConditionalOnResource; +import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.PropertySource; +import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; + +import javax.annotation.PostConstruct; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.azure.spring.telemetry.TelemetryData.SERVICE_NAME; +import static com.azure.spring.telemetry.TelemetryData.getClassPackageSimpleName; + +/** + * {@link EnableAutoConfiguration Auto-configuration} for Azure Active Authentication OAuth 2.0. + *

+ * The configuration will be activated when configured: + * 1. {@literal azure.activedirectory.client-id} + * 2. {@literal azure.activedirectory.client-secret} + * 3. {@literal azure.activedirectory.tenant-id} + * client-id, client-secret, tenant-id used in ClientRegistration. + * client-id, client-secret also used to get graphApiToken, then get groups. + *

+ * A OAuth2 user service {@link AADOAuth2UserService} will be auto-configured by specifying {@literal + * azure.activedirectory.user-group.allowed-groups} property. + */ +@Configuration +@ConditionalOnResource(resources = "classpath:aad.enable.config") +@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET) +@ConditionalOnProperty(prefix = "azure.activedirectory", value = { "client-id", "client-secret", "tenant-id" }) +@PropertySource(value = "classpath:service-endpoints.properties") +@EnableConfigurationProperties({ AADAuthenticationProperties.class, ServiceEndpointsProperties.class }) +@ConditionalOnExpression("#{'${azure.active.directory.uri:notExist}' == 'notExist'}") +public class AADOAuth2AutoConfiguration { + + private static final Logger LOGGER = LoggerFactory.getLogger(AADOAuth2AutoConfiguration.class); + private final AADAuthenticationProperties aadAuthenticationProperties; + private final ServiceEndpointsProperties serviceEndpointsProperties; + + public AADOAuth2AutoConfiguration(AADAuthenticationProperties aadAuthProperties, + ServiceEndpointsProperties serviceEndpointsProperties) { + this.aadAuthenticationProperties = aadAuthProperties; + this.serviceEndpointsProperties = serviceEndpointsProperties; + } + + @Bean + @ConditionalOnProperty(prefix = "azure.activedirectory.user-group", value = "allowed-groups") + public OAuth2UserService oidcUserService() { + return new AADOAuth2UserService(aadAuthenticationProperties, serviceEndpointsProperties); + } + + @Bean + public ClientRegistrationRepository clientRegistrationRepository() { + return new InMemoryClientRegistrationRepository(azureClientRegistration()); + } + + private ClientRegistration azureClientRegistration() { + String tenantId = aadAuthenticationProperties.getTenantId().trim(); + Assert.hasText(tenantId, "azure.activedirectory.tenant-id should have text."); + Assert.doesNotContain(tenantId, " ", "azure.activedirectory.tenant-id should not contain ' '."); + Assert.doesNotContain(tenantId, "/", "azure.activedirectory.tenant-id should not contain '/'."); + + String redirectUriTemplate = Optional.of(aadAuthenticationProperties) + .map(AADAuthenticationProperties::getRedirectUriTemplate) + .orElse("{baseUrl}/login/oauth2/code/{registrationId}"); + + List scope = aadAuthenticationProperties.getScope(); + if (!scope.toString().contains(".default")) { + if (aadAuthenticationProperties.allowedGroupsConfigured() + && !scope.contains("https://graph.microsoft.com/user.read") + ) { + scope.add("https://graph.microsoft.com/user.read"); + LOGGER.warn("scope 'https://graph.microsoft.com/user.read' has been added."); + } + if (!scope.contains("openid")) { + scope.add("openid"); + LOGGER.warn("scope 'openid' has been added."); + } + if (!scope.contains("profile")) { + scope.add("profile"); + LOGGER.warn("scope 'profile' has been added."); + } + } + + return ClientRegistration.withRegistrationId("azure") + .clientId(aadAuthenticationProperties.getClientId()) + .clientSecret(aadAuthenticationProperties.getClientSecret()) + .clientAuthenticationMethod(ClientAuthenticationMethod.POST) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUriTemplate(redirectUriTemplate) + .scope(scope) + .authorizationUri( + String.format( + "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", + tenantId + ) + ) + .tokenUri( + String.format( + "https://login.microsoftonline.com/%s/oauth2/v2.0/token", + tenantId + ) + ) + .userInfoUri("https://graph.microsoft.com/oidc/userinfo") + .userNameAttributeName(AADTokenClaim.NAME) + .jwkSetUri( + String.format( + "https://login.microsoftonline.com/%s/discovery/v2.0/keys", + tenantId + ) + ) + .clientName("Azure") + .build(); + } + + @PostConstruct + private void sendTelemetry() { + if (aadAuthenticationProperties.isAllowTelemetry()) { + final Map events = new HashMap<>(); + final TelemetrySender sender = new TelemetrySender(); + events.put(SERVICE_NAME, getClassPackageSimpleName(AADOAuth2AutoConfiguration.class)); + sender.send(ClassUtils.getUserClass(getClass()).getSimpleName(), events); + } + } +} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADOAuth2ErrorCode.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADOAuth2ErrorCode.java new file mode 100644 index 000000000000..aaa9b557acf9 --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADOAuth2ErrorCode.java @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.autoconfigure.aad; + +/** + * entity class of AADOAuth2ErrorCode + */ +public class AADOAuth2ErrorCode { + public static final String CONDITIONAL_ACCESS_POLICY = "conditional_access_policy"; + public static final String INVALID_REQUEST = "invalid_request"; + public static final String SERVER_SERVER = "server_error"; +} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADOAuth2UserService.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADOAuth2UserService.java index 70edf71aa5fe..33bdc74fa093 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADOAuth2UserService.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AADOAuth2UserService.java @@ -3,6 +3,7 @@ package com.azure.spring.autoconfigure.aad; +import com.microsoft.aad.msal4j.MsalServiceException; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; @@ -11,24 +12,32 @@ import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import javax.naming.ServiceUnavailableException; +import java.io.IOException; +import java.net.MalformedURLException; import java.util.Optional; import java.util.Set; +import static com.azure.spring.autoconfigure.aad.AADOAuth2ErrorCode.CONDITIONAL_ACCESS_POLICY; +import static com.azure.spring.autoconfigure.aad.AADOAuth2ErrorCode.INVALID_REQUEST; +import static com.azure.spring.autoconfigure.aad.AADOAuth2ErrorCode.SERVER_SERVER; + /** - * This implementation will retrieve group info of user from Microsoft Graph and map groups to {@link - * GrantedAuthority}. + * This implementation will retrieve group info of user from Microsoft Graph and map groups to {@link GrantedAuthority}. */ public class AADOAuth2UserService implements OAuth2UserService { + private final AADAuthenticationProperties aadAuthenticationProperties; + private final ServiceEndpointsProperties serviceEndpointsProperties; private final OidcUserService oidcUserService; - private final GraphWebClient graphWebClient; - public AADOAuth2UserService( - GraphWebClient graphWebClient - ) { - this.graphWebClient = graphWebClient; + public AADOAuth2UserService(AADAuthenticationProperties aadAuthenticationProperties, + ServiceEndpointsProperties serviceEndpointsProperties) { + this.aadAuthenticationProperties = aadAuthenticationProperties; + this.serviceEndpointsProperties = serviceEndpointsProperties; this.oidcUserService = new OidcUserService(); } @@ -36,7 +45,36 @@ public AADOAuth2UserService( public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException { // Delegate to the default implementation for loading a user OidcUser oidcUser = oidcUserService.loadUser(userRequest); - final Set mappedAuthorities = graphWebClient.getGrantedAuthorities(); + final Set mappedAuthorities; + try { + // https://github.com/MicrosoftDocs/azure-docs/issues/8121#issuecomment-387090099 + // In AAD App Registration configure oauth2AllowImplicitFlow to true + final AzureADGraphClient azureADGraphClient = new AzureADGraphClient( + aadAuthenticationProperties, + serviceEndpointsProperties + ); + String graphApiToken = azureADGraphClient + .acquireTokenForGraphApi( + userRequest.getIdToken().getTokenValue(), + aadAuthenticationProperties.getTenantId() + ) + .accessToken(); + mappedAuthorities = azureADGraphClient.getGrantedAuthorities(graphApiToken); + } catch (MalformedURLException e) { + throw toOAuth2AuthenticationException(INVALID_REQUEST, "Failed to acquire token for Graph API.", e); + } catch (ServiceUnavailableException e) { + throw toOAuth2AuthenticationException(SERVER_SERVER, "Failed to acquire token for Graph API.", e); + } catch (IOException e) { + throw toOAuth2AuthenticationException(SERVER_SERVER, "Failed to map group to authorities.", e); + } catch (MsalServiceException e) { + // Handle conditional access policy, step 2. + // OAuth2AuthenticationException will be caught by AADAuthenticationFailureHandler. + if (e.claims() != null && !e.claims().isEmpty()) { + throw toOAuth2AuthenticationException(CONDITIONAL_ACCESS_POLICY, "Handle conditional access policy", e); + } else { + throw e; + } + } String nameAttributeKey = Optional.of(userRequest) .map(OAuth2UserRequest::getClientRegistration) @@ -48,4 +86,11 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio // Create a copy of oidcUser but use the mappedAuthorities instead return new DefaultOidcUser(mappedAuthorities, oidcUser.getIdToken(), nameAttributeKey); } + + private OAuth2AuthenticationException toOAuth2AuthenticationException(String errorCode, + String description, + Exception cause) { + OAuth2Error oAuth2Error = new OAuth2Error(errorCode, description, null); + return new OAuth2AuthenticationException(oAuth2Error, cause); + } } diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AuthorizationProperties.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AuthorizationProperties.java deleted file mode 100644 index d16dbf9be56b..000000000000 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AuthorizationProperties.java +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.spring.autoconfigure.aad; - -import java.util.Arrays; -import java.util.List; - -/** - * Properties for a authorized client. - */ -public class AuthorizationProperties { - - private String[] scope = new String[0]; - - public void setScope(String[] scope) { - this.scope = scope.clone(); - } - - public String[] getScope() { - return scope.clone(); - } - - public List scopes() { - return Arrays.asList(scope); - } -} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/GraphOboClient.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureADGraphClient.java similarity index 90% rename from sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/GraphOboClient.java rename to sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureADGraphClient.java index f36a1293286a..9e9274ad5bfd 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/GraphOboClient.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureADGraphClient.java @@ -42,9 +42,9 @@ /** * Microsoft Graph client encapsulation. */ -public class GraphOboClient { +public class AzureADGraphClient { - private static final Logger LOGGER = LoggerFactory.getLogger(GraphOboClient.class); + private static final Logger LOGGER = LoggerFactory.getLogger(AzureADGraphClient.class); private static final String MICROSOFT_GRAPH_SCOPE = "https://graph.microsoft.com/user.read"; private static final String AAD_GRAPH_API_SCOPE = "https://graph.windows.net/user.read"; // We use "aadfeed5" as suffix when client library is ADAL, upgrade to "aadfeed6" for MSAL @@ -55,14 +55,14 @@ public class GraphOboClient { private final AADAuthenticationProperties aadAuthenticationProperties; private final boolean graphApiVersionIsV2; - public GraphOboClient(AADAuthenticationProperties aadAuthenticationProperties, - ServiceEndpointsProperties serviceEndpointsProps) { + public AzureADGraphClient(AADAuthenticationProperties aadAuthenticationProperties, + ServiceEndpointsProperties serviceEndpointsProps) { this.aadAuthenticationProperties = aadAuthenticationProperties; this.serviceEndpoints = serviceEndpointsProps.getServiceEndpoints(aadAuthenticationProperties.getEnvironment()); this.graphApiVersionIsV2 = Optional.of(aadAuthenticationProperties) - .map(AADAuthenticationProperties::getEnvironment) - .map(environment -> environment.contains(V2_VERSION_ENV_FLAG)) - .orElse(false); + .map(AADAuthenticationProperties::getEnvironment) + .map(environment -> environment.contains(V2_VERSION_ENV_FLAG)) + .orElse(false); } private String getUserMemberships(String accessToken, String urlString) throws IOException { @@ -119,7 +119,7 @@ private static String getResponseString(HttpURLConnection connection) throws IOE public Set getGroups(String graphApiToken) throws IOException { final Set groups = new LinkedHashSet<>(); final ObjectMapper objectMapper = JacksonObjectMapperFactory.getInstance(); - String aadMembershipRestUri = getAadMembershipRestUri(); + String aadMembershipRestUri = serviceEndpoints.getAadMembershipRestUri(); while (aadMembershipRestUri != null) { String membershipsJson = getUserMemberships(graphApiToken, aadMembershipRestUri); Memberships memberships = objectMapper.readValue(membershipsJson, Memberships.class); @@ -136,20 +136,6 @@ public Set getGroups(String graphApiToken) throws IOException { return groups; } - /** - * Get the rest url to get the groups that the user is a member of. - * @return rest url - */ - private String getAadMembershipRestUri() { - if (AADAuthenticationProperties.getDirectGroupRelationship() - .equalsIgnoreCase(aadAuthenticationProperties - .getUserGroup().getGroupRelationship())) { - return serviceEndpoints.getAadMembershipRestUri(); - } else { - return serviceEndpoints.getAadTransitiveMemberRestUri(); - } - } - private boolean isGroupObject(final Membership membership) { return membership.getObjectType().equals(aadAuthenticationProperties.getUserGroup().getValue()); } @@ -176,6 +162,7 @@ public Set toGrantedAuthoritySet(final Set group /** * Acquire access token for calling Graph API. + * * @param idToken The token used to perform an OBO request. * @param tenantId The tenant id. * @return The access token for Graph service. diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureActiveDirectoryAutoConfiguration.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureActiveDirectoryAutoConfiguration.java deleted file mode 100644 index 6b158bafbd0a..000000000000 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureActiveDirectoryAutoConfiguration.java +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.spring.autoconfigure.aad; - -import com.azure.spring.telemetry.TelemetrySender; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; -import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; -import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.boot.autoconfigure.condition.ConditionalOnResource; -import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; -import org.springframework.boot.context.properties.EnableConfigurationProperties; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.context.annotation.PropertySource; -import org.springframework.security.config.annotation.web.builders.HttpSecurity; -import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; -import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; -import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.oidc.user.OidcUser; -import org.springframework.util.ClassUtils; -import org.springframework.web.reactive.function.client.WebClient; - -import javax.annotation.PostConstruct; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static com.azure.spring.telemetry.TelemetryData.SERVICE_NAME; -import static com.azure.spring.telemetry.TelemetryData.getClassPackageSimpleName; - -/** - * Provide necessary beans used for AAD authentication and authorization. - */ -@Configuration -@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET) -@ConditionalOnResource(resources = "classpath:aad.enable.config") -@ConditionalOnClass(ClientRegistrationRepository.class) -@ConditionalOnProperty(prefix = "azure.activedirectory", value = {"client-id", "client-secret", "tenant-id"}) -@PropertySource(value = "classpath:service-endpoints.properties") -@EnableConfigurationProperties({ AADAuthenticationProperties.class, ServiceEndpointsProperties.class }) -public class AzureActiveDirectoryAutoConfiguration { - - private static final String DEFAULT_CLIENT = "azure"; - - @Autowired - private AADAuthenticationProperties aadAuthenticationProperties; - @Autowired - private ServiceEndpointsProperties serviceEndpointsProperties; - - @Bean - @ConditionalOnMissingBean({ ClientRegistrationRepository.class, AzureClientRegistrationRepository.class }) - public AzureClientRegistrationRepository clientRegistrationRepository() { - return new AzureClientRegistrationRepository( - createDefaultClient(), - createClientRegistrations() - ); - } - - @Bean - @ConditionalOnMissingBean - public OAuth2AuthorizedClientRepository authorizedClientRepository(AzureClientRegistrationRepository repo) { - return new AzureOAuth2AuthorizedClientRepository(repo); - } - - @Bean - @ConditionalOnMissingBean - WebClient webClient( - ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository - ) { - OAuth2AuthorizedClientManager oAuth2AuthorizedClientManager = new DefaultOAuth2AuthorizedClientManager( - clientRegistrationRepository, - oAuth2AuthorizedClientRepository - ); - ServletOAuth2AuthorizedClientExchangeFilterFunction servletOAuth2AuthorizedClientExchangeFilterFunction = - new ServletOAuth2AuthorizedClientExchangeFilterFunction(oAuth2AuthorizedClientManager); - return WebClient.builder() - .apply(servletOAuth2AuthorizedClientExchangeFilterFunction.oauth2Configuration()) - .build(); - } - - @Bean - @ConditionalOnMissingBean - GraphWebClient graphWebClient(WebClient webClient) { - return new GraphWebClient( - aadAuthenticationProperties, - serviceEndpointsProperties, - webClient - ); - } - - @Bean - @ConditionalOnProperty(prefix = "azure.activedirectory.user-group", value = "allowed-groups") - public OAuth2UserService oidcUserService(GraphWebClient graphWebClient) { - return new AADOAuth2UserService(graphWebClient); - } - - private DefaultClient createDefaultClient() { - ClientRegistration clientRegistration = toClientRegistrationBuilder(DEFAULT_CLIENT) - .scope(allScopes()) - .build(); - return new DefaultClient(clientRegistration, defaultScopes()); - } - - private String[] allScopes() { - List result = openidScopes(); - for (AuthorizationProperties properties : aadAuthenticationProperties.getAuthorization().values()) { - result.addAll(properties.scopes()); - } - return result.toArray(new String[0]); - } - - private String[] defaultScopes() { - List result = openidScopes(); - AuthorizationProperties authorizationProperties = - aadAuthenticationProperties.getAuthorization().get(DEFAULT_CLIENT); - if (authorizationProperties != null) { - result.addAll(authorizationProperties.scopes()); - } - return result.toArray(new String[0]); - } - - private List openidScopes() { - List result = new ArrayList<>(); - result.add("openid"); - result.add("profile"); - if (!aadAuthenticationProperties.getAuthorization().isEmpty()) { - result.add("offline_access"); - } - return result; - } - - private List createClientRegistrations() { - List result = new ArrayList<>(); - for (String name : aadAuthenticationProperties.getAuthorization().keySet()) { - if (DEFAULT_CLIENT.equals(name)) { - continue; - } - AuthorizationProperties authorizationProperties = - aadAuthenticationProperties.getAuthorization().get(name); - result.add(toClientRegistration(name, authorizationProperties)); - } - return result; - } - - private ClientRegistration toClientRegistration(String id, AuthorizationProperties authorizationProperties) { - return toClientRegistrationBuilder(id) - .scope(authorizationProperties.getScope()) - .build(); - } - - private ClientRegistration.Builder toClientRegistrationBuilder(String registrationId) { - String authorizationServerUri = - serviceEndpointsProperties.getServiceEndpoints(aadAuthenticationProperties.getEnvironment()) - .getAadSigninUri(); - AuthorizationServerEndpoints endpoints = new AuthorizationServerEndpoints(authorizationServerUri); - String tenantId = aadAuthenticationProperties.getTenantId(); - return ClientRegistration.withRegistrationId(registrationId) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}") - .clientId(aadAuthenticationProperties.getClientId()) - .clientSecret(aadAuthenticationProperties.getClientSecret()) - .authorizationUri(endpoints.authorizationEndpoint(tenantId)) - .tokenUri(endpoints.tokenEndpoint(tenantId)) - .jwkSetUri(endpoints.jwkSetEndpoint(tenantId)); - } - - /** - * Default configuration class for using AAD authentication and authorization. - * - * User can write another configuration bean to override it. - * If user write another configuration bean, to make sure `AzureOAuth2AuthorizationCodeGrantRequestEntityConverter` - * take effect, please: - * 1. Extends AzureOAuth2WebSecurityConfigurerAdapter instead of WebSecurityConfigurerAdapter. - * 2. Call `super.configure(http)` in your configure() - */ - @Configuration - @ConditionalOnMissingBean(WebSecurityConfigurerAdapter.class) - @EnableWebSecurity - public static class DefaultAzureOAuth2WebSecurityConfigurerAdapter extends AzureOAuth2WebSecurityConfigurerAdapter { - - @Override - protected void configure(HttpSecurity http) throws Exception { - super.configure(http); - http.authorizeRequests().anyRequest().authenticated(); - } - } - - @PostConstruct - private void sendTelemetry() { - if (aadAuthenticationProperties.isAllowTelemetry()) { - final Map events = new HashMap<>(); - final TelemetrySender sender = new TelemetrySender(); - events.put(SERVICE_NAME, getClassPackageSimpleName(AzureActiveDirectoryAutoConfiguration.class)); - sender.send(ClassUtils.getUserClass(getClass()).getSimpleName(), events); - } - } -} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureClientRegistrationRepository.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureClientRegistrationRepository.java deleted file mode 100644 index 47b2dec6b166..000000000000 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureClientRegistrationRepository.java +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.spring.autoconfigure.aad; - -import org.jetbrains.annotations.NotNull; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -/** - * A ClientRegistrationRepository that manage all AAD's ClientRegistrations. - */ -public class AzureClientRegistrationRepository implements ClientRegistrationRepository, Iterable { - - private final DefaultClient defaultClient; - private final List authorizedClientRegistrations; - - private final Map clientRegistrations; - - public AzureClientRegistrationRepository(DefaultClient defaultClient, - List authorizedClientRegistrations) { - this.defaultClient = defaultClient; - this.authorizedClientRegistrations = new ArrayList<>(authorizedClientRegistrations); - clientRegistrations = new HashMap<>(); - addClientRegistration(defaultClient.getClientRegistration()); - for (ClientRegistration clientRegistration : authorizedClientRegistrations) { - addClientRegistration(clientRegistration); - } - } - - private void addClientRegistration(ClientRegistration clientRegistration) { - clientRegistrations.put(clientRegistration.getRegistrationId(), clientRegistration); - } - - @Override - public ClientRegistration findByRegistrationId(String registrationId) { - return clientRegistrations.get(registrationId); - } - - @NotNull - @Override - public Iterator iterator() { - return Collections.singleton(defaultClient.getClientRegistration()).iterator(); - } - - public DefaultClient defaultClient() { - return defaultClient; - } - - public boolean isAuthorizedClient(ClientRegistration clientRegistration) { - return authorizedClientRegistrations.contains(clientRegistration); - } - - public boolean isAuthorizedClient(String id) { - return Optional.of(id) - .map(this::findByRegistrationId) - .map(this::isAuthorizedClient) - .orElse(false); - } -} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureOAuth2AuthorizedClientRepository.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureOAuth2AuthorizedClientRepository.java deleted file mode 100644 index c850e1e7a156..000000000000 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/AzureOAuth2AuthorizedClientRepository.java +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.spring.autoconfigure.aad; - -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Optional; - -/** - * An OAuth2AuthorizedClientRepository that manage all AAD authorizedClients. - */ -public class AzureOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository { - - private final AzureClientRegistrationRepository azureClientRegistrationRepository; - private final OAuth2AuthorizedClientRepository delegatedOAuth2AuthorizedClientRepository; - - private static OAuth2AuthorizedClientRepository createDefaultDelegate() { - return new HttpSessionOAuth2AuthorizedClientRepository(); - } - - public AzureOAuth2AuthorizedClientRepository(AzureClientRegistrationRepository azureClientRegistrationRepository) { - this(azureClientRegistrationRepository, createDefaultDelegate()); - } - - public AzureOAuth2AuthorizedClientRepository( - AzureClientRegistrationRepository azureClientRegistrationRepository, - OAuth2AuthorizedClientRepository delegatedOAuth2AuthorizedClientRepository - ) { - this.azureClientRegistrationRepository = azureClientRegistrationRepository; - this.delegatedOAuth2AuthorizedClientRepository = delegatedOAuth2AuthorizedClientRepository; - } - - @Override - public void saveAuthorizedClient( - OAuth2AuthorizedClient oAuth2AuthorizedClient, - Authentication principal, - HttpServletRequest request, - HttpServletResponse response - ) { - delegatedOAuth2AuthorizedClientRepository.saveAuthorizedClient( - oAuth2AuthorizedClient, principal, request, response); - } - - @Override - @SuppressWarnings("unchecked") - public T loadAuthorizedClient( - String id, - Authentication principal, - HttpServletRequest request - ) { - OAuth2AuthorizedClient result = - delegatedOAuth2AuthorizedClientRepository.loadAuthorizedClient(id, principal, request); - if (result != null) { - return (T) result; - } - if (azureClientRegistrationRepository.isAuthorizedClient(id)) { - OAuth2AuthorizedClient defaultOAuth2AuthorizedClient = - loadAuthorizedClient(defaultClientRegistrationId(), principal, request); - return (T) toOauth2AuthorizedClient(defaultOAuth2AuthorizedClient, id, principal); - } - return null; - } - - private String defaultClientRegistrationId() { - return azureClientRegistrationRepository.defaultClient().getClientRegistration().getRegistrationId(); - } - - private OAuth2AuthorizedClient toOauth2AuthorizedClient( - OAuth2AuthorizedClient oAuth2AuthorizedClient, - String id, - Authentication principal - ) { - OAuth2RefreshToken oAuth2RefreshToken = Optional.ofNullable(oAuth2AuthorizedClient) - .map(OAuth2AuthorizedClient::getRefreshToken) - .orElse(null); - if (oAuth2RefreshToken == null) { - return null; - } - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, - "non-access-token", - Instant.MIN, - Instant.now().minus(100, ChronoUnit.DAYS) - ); - return new OAuth2AuthorizedClient( - azureClientRegistrationRepository.findByRegistrationId(id), - principal.getName(), - accessToken, - oAuth2RefreshToken - ); - } - - @Override - public void removeAuthorizedClient( - String id, - Authentication principal, - HttpServletRequest request, - HttpServletResponse response - ) { - delegatedOAuth2AuthorizedClientRepository.removeAuthorizedClient(id, principal, request, response); - } -} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Constants.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Constants.java index 289b3ebd679f..87be1cd2f4e2 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Constants.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Constants.java @@ -18,6 +18,7 @@ public class Constants { public static final String CLAIMS = "claims"; public static final Set DEFAULT_AUTHORITY_SET; public static final String ROLE_PREFIX = "ROLE_"; + public static final String SAVED_REQUEST = "SPRING_SECURITY_SAVED_REQUEST"; static { Set authoritySet = new HashSet<>(); diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/DefaultClient.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/DefaultClient.java deleted file mode 100644 index 084c83dda4a6..000000000000 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/DefaultClient.java +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.spring.autoconfigure.aad; - -import org.springframework.security.oauth2.client.registration.ClientRegistration; - -import java.util.Arrays; -import java.util.List; - -/** - * The default client after user consent on microsoft login page. - * DefaultClient.clientRegistration.scopes contains all scopes consented in the login page. - * DefaultClient.scope only contains the scopes for defaultClientRegistration. - */ -public class DefaultClient { - - private final ClientRegistration clientRegistration; - private final String[] scope; - - public DefaultClient(ClientRegistration clientRegistration, String[] scope) { - this.clientRegistration = clientRegistration; - this.scope = scope.clone(); - } - - public ClientRegistration getClientRegistration() { - return clientRegistration; - } - - public String[] getScope() { - return scope.clone(); - } - - public List getScopeList() { - return Arrays.asList(scope); - } -} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/GraphWebClient.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/GraphWebClient.java deleted file mode 100644 index 5dc64db581aa..000000000000 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/GraphWebClient.java +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.spring.autoconfigure.aad; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.nimbusds.oauth2.sdk.http.HTTPResponse; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.http.HttpHeaders; -import org.springframework.http.MediaType; -import org.springframework.security.core.authority.SimpleGrantedAuthority; -import org.springframework.web.reactive.function.client.WebClient; - -import java.util.Collections; -import java.util.LinkedHashSet; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; - -import static com.azure.spring.autoconfigure.aad.Constants.DEFAULT_AUTHORITY_SET; -import static com.azure.spring.autoconfigure.aad.Constants.ROLE_PREFIX; -import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; - - -/** - * Microsoft Graph web client implemented by OAuth2 WebClient. - */ -public class GraphWebClient { - private static final Logger LOGGER = LoggerFactory.getLogger(GraphWebClient.class); - - private final ServiceEndpoints serviceEndpoints; - private final AADAuthenticationProperties aadAuthenticationProperties; - private final boolean graphApiVersionIsV2; - private final WebClient webClient; - - public GraphWebClient( - AADAuthenticationProperties aadAuthenticationProperties, - ServiceEndpointsProperties serviceEndpointsProps, - WebClient webClient - ) { - this.aadAuthenticationProperties = aadAuthenticationProperties; - this.serviceEndpoints = serviceEndpointsProps.getServiceEndpoints(aadAuthenticationProperties.getEnvironment()); - this.webClient = webClient; - this.graphApiVersionIsV2 = Optional.of(aadAuthenticationProperties) - .map(AADAuthenticationProperties::getEnvironment) - .map(environment -> environment.contains("v2-graph")) - .orElse(false); - } - - public Set getGrantedAuthorities() { - return toGrantedAuthoritySet(getGroupsFromGraphApi()); - } - - public Set getGroupsFromGraphApi() { - final ObjectMapper objectMapper = JacksonObjectMapperFactory.getInstance(); - Set groups = new LinkedHashSet<>(); - String aadMembershipRestUri = getAadMembershipRestUri(); - while (aadMembershipRestUri != null) { - String membershipsJson = getUserMembershipsJson(aadMembershipRestUri); - Memberships memberships; - try { - memberships = objectMapper.readValue(membershipsJson, Memberships.class); - } catch (JsonProcessingException e) { - LOGGER.error("Can not get groups.", e); - return Collections.emptySet(); - } - groups = memberships.getValue() - .stream() - .filter(this::isGroupObject) - .map(Membership::getDisplayName) - .collect(Collectors.toSet()); - aadMembershipRestUri = Optional.of(memberships) - .map(Memberships::getOdataNextLink) - .map(this::getUrlStringFromODataNextLink) - .orElse(null); - } - return groups; - } - - private String getUserMembershipsJson(String urlString) { - String responseInJson; - if (graphApiVersionIsV2) { - responseInJson = webClient - .get() - .uri(urlString) - .attributes(clientRegistrationId("graph")) - .accept(MediaType.APPLICATION_JSON) - .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) - .retrieve() - .bodyToMono(String.class) - .block(); - } else { - responseInJson = webClient - .get() - .uri(urlString) - .attributes(clientRegistrationId("graph")) - .header(HttpHeaders.ACCEPT, "application/json;odata=minimalmetadata") - .header("api-version", "1.6") - .retrieve() - .bodyToMono(String.class) - .block(); - } - if (responseInJson == null || responseInJson.isEmpty()) { - throw new IllegalStateException( - "Response is not " + HTTPResponse.SC_OK + ", response json: " + responseInJson); - } - return responseInJson; - } - - private String getUrlStringFromODataNextLink(String odataNextLink) { - if (this.graphApiVersionIsV2) { - return odataNextLink; - } else { - String skipToken = odataNextLink.split("/memberOf\\?")[1]; - return serviceEndpoints.getAadMembershipRestUri() + "&" + skipToken; - } - } - - private String getAadMembershipRestUri() { - if (AADAuthenticationProperties.getDirectGroupRelationship() - .equalsIgnoreCase( - aadAuthenticationProperties.getUserGroup().getGroupRelationship() - ) - ) { - return serviceEndpoints.getAadMembershipRestUri(); - } else { - return serviceEndpoints.getAadTransitiveMemberRestUri(); - } - } - - private boolean isGroupObject(final Membership membership) { - return membership.getObjectType().equals(aadAuthenticationProperties.getUserGroup().getValue()); - } - - public Set toGrantedAuthoritySet(final Set groups) { - Set grantedAuthoritySet = - groups.stream() - .filter(aadAuthenticationProperties::isAllowedGroup) - .map(group -> new SimpleGrantedAuthority(ROLE_PREFIX + group)) - .collect(Collectors.toSet()); - return Optional.of(grantedAuthoritySet) - .filter(g -> !g.isEmpty()) - .orElse(DEFAULT_AUTHORITY_SET); - } -} diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/JacksonObjectMapperFactory.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/JacksonObjectMapperFactory.java index bee1a10331ed..059d3fcf4968 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/JacksonObjectMapperFactory.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/JacksonObjectMapperFactory.java @@ -6,7 +6,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; /** - * Factory class of JacksonObjectMapper + * factoty class of JacksonObjectMapper */ public final class JacksonObjectMapperFactory { diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Membership.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Membership.java index 9911094a1ab8..f37c0714cddc 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Membership.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Membership.java @@ -56,8 +56,8 @@ public boolean equals(Object o) { } final Membership group = (Membership) o; return this.getDisplayName().equals(group.getDisplayName()) - && this.getObjectID().equals(group.getObjectID()) - && this.getObjectType().equals(group.getObjectType()); + && this.getObjectID().equals(group.getObjectID()) + && this.getObjectType().equals(group.getObjectType()); } @Override diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Memberships.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Memberships.java index 813914d711ac..0b0597871189 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Memberships.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/Memberships.java @@ -24,7 +24,7 @@ public class Memberships { @JsonCreator public Memberships( @JsonAlias("odata.nextLink") - @JsonProperty("@odata.nextLink") String odataNextLink, + @JsonProperty("odata.nextLink") String odataNextLink, @JsonProperty("value") List value) { this.odataNextLink = odataNextLink; this.value = value; @@ -48,7 +48,7 @@ public boolean equals(Object o) { } final Memberships groups = (Memberships) o; return this.getOdataNextLink().equals(groups.getOdataNextLink()) - && this.getValue().equals(groups.getValue()); + && this.getValue().equals(groups.getValue()); } @Override diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/UserPrincipal.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/UserPrincipal.java index f15bf599eb0f..fea3c0d87d12 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/UserPrincipal.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/UserPrincipal.java @@ -66,6 +66,14 @@ public void setRoles(Set roles) { this.roles = roles; } + public String getAccessTokenForGraphApi() { + return accessTokenForGraphApi; + } + + public void setAccessTokenForGraphApi(String accessTokenForGraphApi) { + this.accessTokenForGraphApi = accessTokenForGraphApi; + } + public boolean isMemberOf(AADAuthenticationProperties aadAuthenticationProperties, String group) { return aadAuthenticationProperties.isAllowedGroup(group) && Optional.of(groups) diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/UserPrincipalManager.java b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/UserPrincipalManager.java index ab2120024ce0..1b43952b74a3 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/UserPrincipalManager.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/azure/spring/autoconfigure/aad/UserPrincipalManager.java @@ -50,8 +50,8 @@ public class UserPrincipalManager { private final Boolean explicitAudienceCheck; private final Set validAudiences = new HashSet<>(); - /**ø - * Creates a new {@link UserPrincipalManager} with a predefined {@link JWKSource}. + /** + * ø Creates a new {@link UserPrincipalManager} with a predefined {@link JWKSource}. *

* This is helpful in cases the JWK is not a remote JWKSet or for unit testing. * @@ -138,6 +138,7 @@ public UserPrincipalManager(ServiceEndpointsProperties serviceEndpointsProps, /** * Parse the id token to {@link UserPrincipal}. + * * @param aadIssuedBearerToken The token issued by AAD. * @return The parsed {@link UserPrincipal}. * @throws ParseException If the token couldn't be parsed to a valid JWS object. @@ -163,13 +164,6 @@ public UserPrincipal buildUserPrincipal(String aadIssuedBearerToken) throws Pars } public boolean isTokenIssuedByAAD(String token) { - return staticIsTokenIssuedByAAD(token); - } - - public static boolean staticIsTokenIssuedByAAD(String token) { - if (token == null) { - return false; - } try { final JWT jwt = JWTParser.parse(token); return isAADIssuer(jwt.getJWTClaimsSet().getIssuer()); diff --git a/sdk/spring/azure-spring-boot/src/main/resources/META-INF/spring.factories b/sdk/spring/azure-spring-boot/src/main/resources/META-INF/spring.factories index 6aca6335a723..a852e020f0b2 100644 --- a/sdk/spring/azure-spring-boot/src/main/resources/META-INF/spring.factories +++ b/sdk/spring/azure-spring-boot/src/main/resources/META-INF/spring.factories @@ -1,7 +1,8 @@ org.springframework.boot.env.EnvironmentPostProcessor=com.azure.spring.cloudfoundry.environment.VcapProcessor org.springframework.boot.autoconfigure.EnableAutoConfiguration=\ -com.azure.spring.autoconfigure.aad.AADAuthenticationFilterAutoConfiguration, \ -com.azure.spring.autoconfigure.aad.AzureActiveDirectoryAutoConfiguration, \ +com.azure.spring.autoconfigure.aad.AADAuthenticationFilterAutoConfiguration,\ +com.azure.spring.autoconfigure.aad.AADOAuth2AutoConfiguration, \ +com.azure.spring.aad.implementation.AzureActiveDirectoryConfiguration,\ com.azure.spring.autoconfigure.b2c.AADB2CAutoConfiguration,\ com.azure.spring.autoconfigure.cosmos.CosmosAutoConfiguration,\ com.azure.spring.autoconfigure.cosmos.CosmosHealthConfiguration,\ diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationAutoConfigurationTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationAutoConfigurationTest.java index 7b428e0c6060..9a7bffa20688 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationAutoConfigurationTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationAutoConfigurationTest.java @@ -16,13 +16,11 @@ public class AADAuthenticationAutoConfigurationTest { private final WebApplicationContextRunner contextRunner = new WebApplicationContextRunner() .withConfiguration(AutoConfigurations.of(AADAuthenticationFilterAutoConfiguration.class)) - .withPropertyValues( - "azure.activedirectory.client-id=fake-client-id", + .withPropertyValues("azure.activedirectory.client-id=fake-client-id", "azure.activedirectory.client-secret=fake-client-secret", "azure.activedirectory.user-group.allowed-groups=fake-group", "azure.service.endpoints.global.aadKeyDiscoveryUri=http://fake.aad.discovery.uri", - TestConstants.ALLOW_TELEMETRY_PROPERTY + "=false" - ); + TestConstants.ALLOW_TELEMETRY_PROPERTY + "=false"); @Test public void createAADAuthenticationFilter() { diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterPropertiesTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterPropertiesTest.java index 6c6d92229d48..3f7e0ee8673b 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterPropertiesTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterPropertiesTest.java @@ -57,7 +57,7 @@ public void defaultEnvironmentIsGlobal() { final AADAuthenticationProperties properties = context.getBean(AADAuthenticationProperties.class); - assertThat(properties.getEnvironment()).isEqualTo("global"); + assertThat(properties.getEnvironment()).isEqualTo(TestConstants.DEFAULT_ENVIRONMENT); } } @@ -92,7 +92,7 @@ public void emptySettingsNotAllowed() { final BindValidationException bindException = (BindValidationException) exception.getCause().getCause(); final List errors = bindException.getValidationErrors().getAllErrors(); - final List errorStrings = errors.stream().map(ObjectError::toString).collect(Collectors.toList()); + final List errorStrings = errors.stream().map(e -> e.toString()).collect(Collectors.toList()); final List errorStringsExpected = Arrays.asList( "Field error in object 'azure.activedirectory' on field 'activeDirectoryGroups': " diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java index 3251c92b510a..0424b8835c5d 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java @@ -5,6 +5,7 @@ import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.proc.BadJOSEException; +import org.junit.Assume; import org.junit.Ignore; import org.junit.Test; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -29,7 +30,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.springframework.http.HttpHeaders.AUTHORIZATION; public class AADAuthenticationFilterTest { private static final String TOKEN = "dummy-token"; @@ -51,6 +51,12 @@ public AADAuthenticationFilterTest() { ); } + @Ignore + public void beforeEveryMethod() { + Assume.assumeTrue(!TestConstants.CLIENT_ID.contains("real_client_id")); + Assume.assumeTrue(!TestConstants.CLIENT_SECRET.contains("real_client_secret")); + Assume.assumeTrue(!TestConstants.BEARER_TOKEN.contains("real_jtw_bearer_token")); + } //TODO (Zhou Liu): current test case is out of date, a new test case need to cover here, do it later. @Test @@ -64,7 +70,7 @@ public void doFilterInternal() { this.contextRunner.run(context -> { final HttpServletRequest request = mock(HttpServletRequest.class); - when(request.getHeader(AUTHORIZATION)).thenReturn(TestConstants.BEARER_TOKEN); + when(request.getHeader(TestConstants.TOKEN_HEADER)).thenReturn(TestConstants.BEARER_TOKEN); final HttpServletResponse response = mock(HttpServletResponse.class); final FilterChain filterChain = mock(FilterChain.class); diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADOAuth2ConfigTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADOAuth2ConfigTest.java index d31a731963bf..ce59323e9596 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADOAuth2ConfigTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AADOAuth2ConfigTest.java @@ -16,18 +16,18 @@ import org.springframework.core.io.support.ResourcePropertySource; import org.springframework.mock.env.MockPropertySource; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.test.context.support.TestPropertySourceUtils; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import java.util.Arrays; -import java.util.Collections; import java.util.HashSet; import java.util.Map; import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class AADOAuth2ConfigTest { private static final String AAD_OAUTH2_MINIMUM_PROPS = "aad-backend-oauth2-minimum.properties"; @@ -55,7 +55,6 @@ public void clear() { @Test public void noOAuth2UserServiceBeanCreatedIfPropsNotConfigured() { final AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext(); - context.register(AzureActiveDirectoryAutoConfiguration.class); context.refresh(); exception.expect(NoSuchBeanDefinitionException.class); @@ -70,7 +69,7 @@ public void testOAuth2UserServiceBeanCreatedIfPropsConfigured() { @Test public void noOAuth2UserServiceBeanCreatedIfTenantIdNotConfigured() { - testPropResource.getSource().remove("azure.activedirectory.tenant-id"); + testPropResource.getSource().remove(TestConstants.TENANT_ID_PROPERTY); testContext = initTestContext(); exception.expect(NoSuchBeanDefinitionException.class); @@ -122,38 +121,37 @@ public void testEndpointsPropertiesLoadAndOverridable() { @Test public void testScopePropertyConfiguredWithDynamicPermissions() { - testContext = initTestContext("azure.activedirectory.authorization.graph.scope=email"); + testContext = initTestContext("azure.activedirectory.scope=email"); + final Environment environment = testContext.getEnvironment(); - assertThat(environment.getProperty("azure.activedirectory.authorization.graph.scope")).isEqualTo("email"); - - final AzureClientRegistrationRepository azureClientRegistrationRepository = - testContext.getBean(AzureClientRegistrationRepository.class); - final ClientRegistration clientRegistration = azureClientRegistrationRepository.findByRegistrationId("azure"); - final Set actualScopes = clientRegistration.getScopes(); - final Set expectedScopes = new HashSet<>(Arrays.asList("openid", "profile", "offline_access", "email")); - assertEquals(expectedScopes, actualScopes); + assertThat(environment.getProperty("azure.activedirectory.scope")) + .isEqualTo("email"); + + final ClientRegistrationRepository clientRegistrationRepository = + testContext.getBean(ClientRegistrationRepository.class); + final ClientRegistration clientRegistration = clientRegistrationRepository.findByRegistrationId("azure"); + final Set createdScopes = clientRegistration.getScopes(); + final Set expectedScopes = new HashSet<>(Arrays.asList("email", "openid", "profile", + "https://graph.microsoft.com/user.read")); + assertTrue(createdScopes.equals(expectedScopes)); + } @Test public void testScopePropertyConfiguredWithStaticPermissions() { - testContext = initTestContext("azure.activedirectory.authorization.graph.scope=1111/.default"); + testContext = initTestContext("azure.activedirectory.scope=1111/.default"); + final Environment environment = testContext.getEnvironment(); - assertThat(environment.getProperty("azure.activedirectory.authorization.graph.scope")).isEqualTo("1111/" - + ".default"); - - final AzureClientRegistrationRepository azureClientRegistrationRepository = - testContext.getBean(AzureClientRegistrationRepository.class); - final ClientRegistration clientRegistration = azureClientRegistrationRepository.findByRegistrationId("azure"); - final Set actualScopes = clientRegistration.getScopes(); - final Set expectedScopes = - new HashSet<>(Arrays.asList("openid", "profile", "offline_access", "1111/.default")); - assertEquals(expectedScopes, actualScopes); - - final ClientRegistration graphClientRegistration = - azureClientRegistrationRepository.findByRegistrationId("graph"); - final Set graphActualScopes = graphClientRegistration.getScopes(); - final Set graphExpectedScopes = new HashSet<>(Collections.singletonList("1111/.default")); - assertEquals(graphExpectedScopes, graphActualScopes); + assertThat(environment.getProperty("azure.activedirectory.scope")) + .isEqualTo("1111/.default"); + + final ClientRegistrationRepository clientRegistrationRepository = + testContext.getBean(ClientRegistrationRepository.class); + final ClientRegistration clientRegistration = clientRegistrationRepository.findByRegistrationId("azure"); + final Set createdScopes = clientRegistration.getScopes(); + final Set expectedScopes = new HashSet<>(Arrays.asList("1111/.default")); + assertTrue(createdScopes.equals(expectedScopes)); + } private AnnotationConfigWebApplicationContext initTestContext(String... environment) { @@ -166,7 +164,7 @@ private AnnotationConfigWebApplicationContext initTestContext(String... environm TestPropertySourceUtils.addInlinedPropertiesToEnvironment(context, environment); } - context.register(AzureActiveDirectoryAutoConfiguration.class); + context.register(AADOAuth2AutoConfiguration.class); context.refresh(); return context; diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AzureADGraphClientTest.java similarity index 84% rename from sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientTest.java rename to sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AzureADGraphClientTest.java index ec581b33908e..bcd6de70161e 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/AzureADGraphClientTest.java @@ -19,9 +19,11 @@ import static org.assertj.core.api.Assertions.assertThat; @RunWith(MockitoJUnitRunner.class) -public class GraphWebClientTest { +public class AzureADGraphClientTest { - private GraphWebClient adGraphClient; + private AzureADGraphClient adGraphClient; + + private AADAuthenticationProperties aadAuthenticationProperties; @Mock private ServiceEndpointsProperties endpointsProps; @@ -30,9 +32,9 @@ public class GraphWebClientTest { public void setup() { final List activeDirectoryGroups = new ArrayList<>(); activeDirectoryGroups.add("Test_Group"); - AADAuthenticationProperties aadAuthenticationProperties = new AADAuthenticationProperties(); + aadAuthenticationProperties = new AADAuthenticationProperties(); aadAuthenticationProperties.getUserGroup().setAllowedGroups(activeDirectoryGroups); - adGraphClient = new GraphWebClient(aadAuthenticationProperties, endpointsProps, null); + adGraphClient = new AzureADGraphClient(aadAuthenticationProperties, endpointsProps); } @Test diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientAzureADGraphTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientAzureADGraphTest.java deleted file mode 100644 index d235f4329f39..000000000000 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientAzureADGraphTest.java +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.spring.autoconfigure.aad; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.tomakehurst.wiremock.junit.WireMockRule; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.springframework.security.core.GrantedAuthority; - -import java.io.IOException; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; -import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; -import static com.github.tomakehurst.wiremock.client.WireMock.get; -import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; -import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; -import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; -import static com.github.tomakehurst.wiremock.client.WireMock.urlMatching; -import static com.github.tomakehurst.wiremock.client.WireMock.verify; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.http.HttpHeaders.ACCEPT; -import static org.springframework.http.HttpHeaders.AUTHORIZATION; -import static org.springframework.http.HttpHeaders.CONTENT_TYPE; -import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE; - -public class GraphWebClientAzureADGraphTest { - @Rule - public WireMockRule wireMockRule = new WireMockRule(9519); - - private GraphWebClient graphWebClient; - private AADAuthenticationProperties aadAuthenticationProperties; - private static String userGroupsJson; - - static { - try { - final ObjectMapper objectMapper = new ObjectMapper(); - final Map json = objectMapper.readValue( - GraphWebClientAzureADGraphTest.class - .getClassLoader() - .getResourceAsStream("aad/azure-ad-graph-user-groups.json"), - new TypeReference>() { - } - ); - userGroupsJson = objectMapper.writeValueAsString(json); - } catch (IOException e) { - e.printStackTrace(); - userGroupsJson = null; - } - Assert.assertNotNull(userGroupsJson); - } - - @Before - public void setup() { - aadAuthenticationProperties = new AADAuthenticationProperties(); - ServiceEndpointsProperties serviceEndpointsProperties = new ServiceEndpointsProperties(); - final ServiceEndpoints serviceEndpoints = new ServiceEndpoints(); - serviceEndpoints.setAadMembershipRestUri("http://localhost:9519/memberOf"); - serviceEndpointsProperties.getEndpoints().put("global", serviceEndpoints); - this.graphWebClient = new GraphWebClient( - aadAuthenticationProperties, - serviceEndpointsProperties, - GraphWebClientTestUtil.createWebClientForTest() - ); - } - - @Test - public void getAuthoritiesByUserGroups() { - aadAuthenticationProperties.getUserGroup().setAllowedGroups(Collections.singletonList("group1")); - stubFor(get(urlEqualTo("/memberOf")) - .withHeader(ACCEPT, equalTo("application/json;odata=minimalmetadata")) - .willReturn(aResponse() - .withStatus(200) - .withHeader(CONTENT_TYPE, APPLICATION_JSON_VALUE) - .withBody(userGroupsJson))); - - assertThat(graphWebClient.getGrantedAuthorities()) - .isNotEmpty() - .extracting(GrantedAuthority::getAuthority) - .containsExactly("ROLE_group1"); - - verify(getRequestedFor(urlMatching("/memberOf")) - .withHeader(AUTHORIZATION, equalTo(String.format("Bearer %s", TestConstants.ACCESS_TOKEN))) - .withHeader(ACCEPT, equalTo("application/json;odata=minimalmetadata")) - .withHeader("api-version", equalTo("1.6"))); - } - - @Test - public void getGroups() { - aadAuthenticationProperties.getUserGroup().setAllowedGroups(Arrays.asList("group1", "group2", "group3")); - stubFor(get(urlEqualTo("/memberOf")) - .withHeader(ACCEPT, equalTo("application/json;odata=minimalmetadata")) - .willReturn(aResponse() - .withStatus(200) - .withHeader(CONTENT_TYPE, APPLICATION_JSON_VALUE) - .withBody(userGroupsJson))); - - final Collection authorities = graphWebClient.getGrantedAuthorities(); - - assertThat(authorities) - .isNotEmpty() - .extracting(GrantedAuthority::getAuthority) - .containsExactlyInAnyOrder("ROLE_group1", "ROLE_group2", "ROLE_group3"); - - verify(getRequestedFor(urlMatching("/memberOf")) - .withHeader(AUTHORIZATION, equalTo(String.format("Bearer %s", TestConstants.ACCESS_TOKEN))) - .withHeader(ACCEPT, equalTo("application/json;odata=minimalmetadata")) - .withHeader("api-version", equalTo("1.6"))); - } -} diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientMicrosoftGraphTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientMicrosoftGraphTest.java deleted file mode 100644 index 53f9ca652f28..000000000000 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientMicrosoftGraphTest.java +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.spring.autoconfigure.aad; - -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.tomakehurst.wiremock.junit.WireMockRule; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.springframework.security.core.GrantedAuthority; - -import java.io.IOException; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; -import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; -import static com.github.tomakehurst.wiremock.client.WireMock.get; -import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; -import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; -import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; -import static com.github.tomakehurst.wiremock.client.WireMock.urlMatching; -import static com.github.tomakehurst.wiremock.client.WireMock.verify; -import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.http.HttpHeaders.ACCEPT; -import static org.springframework.http.HttpHeaders.AUTHORIZATION; -import static org.springframework.http.HttpHeaders.CONTENT_TYPE; -import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE; - -public class GraphWebClientMicrosoftGraphTest { - @Rule - public WireMockRule wireMockRule = new WireMockRule(9519); - - private GraphWebClient graphWebClient; - private AADAuthenticationProperties aadAuthenticationProperties; - private ServiceEndpointsProperties serviceEndpointsProperties; - private static String userGroupsJson; - - static { - try { - final ObjectMapper objectMapper = new ObjectMapper(); - final Map json = objectMapper.readValue( - GraphWebClientMicrosoftGraphTest.class - .getClassLoader() - .getResourceAsStream("aad/microsoft-graph-user-groups.json"), - new TypeReference>() { - } - ); - userGroupsJson = objectMapper.writeValueAsString(json); - } catch (IOException e) { - e.printStackTrace(); - userGroupsJson = null; - } - Assert.assertNotNull(userGroupsJson); - } - - @Before - public void setup() { - aadAuthenticationProperties = new AADAuthenticationProperties(); - aadAuthenticationProperties.setEnvironment("global-v2-graph"); - aadAuthenticationProperties.getUserGroup().setKey("@odata.type"); - aadAuthenticationProperties.getUserGroup().setValue("#microsoft.graph.group"); - aadAuthenticationProperties.getUserGroup().setObjectIDKey("id"); - serviceEndpointsProperties = new ServiceEndpointsProperties(); - final ServiceEndpoints serviceEndpoints = new ServiceEndpoints(); - serviceEndpoints.setAadMembershipRestUri("http://localhost:9519/memberOf"); - serviceEndpoints.setAadTransitiveMemberRestUri("http://localhost:9519/transitiveMemberOf"); - serviceEndpointsProperties.getEndpoints().put("global-v2-graph", serviceEndpoints); - } - - @Test - public void getAuthoritiesByUserGroups() { - aadAuthenticationProperties.getUserGroup().setGroupRelationship("direct"); - aadAuthenticationProperties.getUserGroup().setAllowedGroups(Collections.singletonList("group1")); - serviceEndpointsProperties.getServiceEndpoints("global-v2-graph") - .setAadMembershipRestUri("http://localhost:9519/memberOf"); - this.graphWebClient = new GraphWebClient( - aadAuthenticationProperties, - serviceEndpointsProperties, - GraphWebClientTestUtil.createWebClientForTest() - ); - - stubFor(get(urlEqualTo("/memberOf")) - .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE)) - .willReturn(aResponse() - .withStatus(200) - .withHeader(CONTENT_TYPE, APPLICATION_JSON_VALUE) - .withBody(userGroupsJson))); - - assertThat(graphWebClient.getGrantedAuthorities()) - .isNotEmpty() - .extracting(GrantedAuthority::getAuthority) - .containsExactly("ROLE_group1"); - - verify(getRequestedFor(urlMatching("/memberOf")) - .withHeader(AUTHORIZATION, equalTo(String.format("Bearer %s", TestConstants.ACCESS_TOKEN))) - .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE))); - } - - @Test - public void getDirectGroups() { - aadAuthenticationProperties.getUserGroup().setGroupRelationship("direct"); - AADAuthenticationProperties.UserGroupProperties userGroupProperties = aadAuthenticationProperties.getUserGroup(); - userGroupProperties.setAllowedGroups(Arrays.asList("group1", "group2", "group3")); - aadAuthenticationProperties.setUserGroup(userGroupProperties); - this.graphWebClient = new GraphWebClient( - aadAuthenticationProperties, - serviceEndpointsProperties, - GraphWebClientTestUtil.createWebClientForTest() - ); - - stubFor(get(urlEqualTo("/memberOf")) - .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE)) - .willReturn(aResponse() - .withStatus(200) - .withHeader(CONTENT_TYPE, APPLICATION_JSON_VALUE) - .withBody(userGroupsJson))); - - final Collection authorities = graphWebClient.getGrantedAuthorities(); - - assertThat(authorities) - .isNotEmpty() - .extracting(GrantedAuthority::getAuthority) - .containsExactlyInAnyOrder("ROLE_group1", "ROLE_group2", "ROLE_group3"); - - verify(getRequestedFor(urlMatching("/memberOf")) - .withHeader(AUTHORIZATION, equalTo(String.format("Bearer %s", TestConstants.ACCESS_TOKEN))) - .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE))); - } - - @Test - public void getTransitiveGroups() { - aadAuthenticationProperties.getUserGroup().setGroupRelationship("transitive"); - AADAuthenticationProperties.UserGroupProperties userGroupProperties = aadAuthenticationProperties.getUserGroup(); - userGroupProperties.setAllowedGroups(Arrays.asList("group1", "group2", "group3")); - aadAuthenticationProperties.setUserGroup(userGroupProperties); - this.graphWebClient = new GraphWebClient( - aadAuthenticationProperties, - serviceEndpointsProperties, - GraphWebClientTestUtil.createWebClientForTest() - ); - - stubFor(get(urlEqualTo("/transitiveMemberOf")) - .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE)) - .willReturn(aResponse() - .withStatus(200) - .withHeader(CONTENT_TYPE, APPLICATION_JSON_VALUE) - .withBody(userGroupsJson))); - - final Collection authorities = graphWebClient.getGrantedAuthorities(); - - assertThat(authorities) - .isNotEmpty() - .extracting(GrantedAuthority::getAuthority) - .containsExactlyInAnyOrder("ROLE_group1", "ROLE_group2", "ROLE_group3"); - - verify(getRequestedFor(urlMatching("/transitiveMemberOf")) - .withHeader(AUTHORIZATION, equalTo(String.format("Bearer %s", TestConstants.ACCESS_TOKEN))) - .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE))); - } -} diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientTestUtil.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientTestUtil.java deleted file mode 100644 index 1ba4fd72690b..000000000000 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphWebClientTestUtil.java +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package com.azure.spring.autoconfigure.aad; - -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.web.reactive.function.client.WebClient; - -import java.time.Instant; -import java.time.temporal.ChronoUnit; - -public class GraphWebClientTestUtil { - - public static WebClient createWebClientForTest() { - ClientRegistration clientRegistration = - ClientRegistration.withRegistrationId("graph") - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}") - .clientId("test") - .clientSecret("test") - .authorizationUri("test") - .tokenUri("test") - .jwkSetUri("test") - .build(); - OAuth2AuthorizedClientManager oAuth2AuthorizedClientManager = authorizeRequest -> new OAuth2AuthorizedClient( - clientRegistration, - "principalName", - new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, - TestConstants.ACCESS_TOKEN, - Instant.now().minus(10, ChronoUnit.MINUTES), - Instant.now().plus(10, ChronoUnit.MINUTES) - ) - ); - ServletOAuth2AuthorizedClientExchangeFilterFunction servletOAuth2AuthorizedClientExchangeFilterFunction = - new ServletOAuth2AuthorizedClientExchangeFilterFunction(oAuth2AuthorizedClientManager); - return WebClient.builder() - .apply(servletOAuth2AuthorizedClientExchangeFilterFunction.oauth2Configuration()) - .build(); - } -} diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/MicrosoftGraphConstants.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/MicrosoftGraphConstants.java new file mode 100644 index 000000000000..8529dfcd3660 --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/MicrosoftGraphConstants.java @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.autoconfigure.aad; + +import java.util.Arrays; +import java.util.List; + +public class MicrosoftGraphConstants { + + public static final String SERVICE_ENVIRONMENT_PROPERTY = "azure.activedirectory.environment"; + public static final String CLIENT_ID_PROPERTY = "azure.activedirectory.client-id"; + public static final String CLIENT_SECRET_PROPERTY = "azure.activedirectory.client-secret"; + public static final String TARGETED_GROUPS_PROPERTY = "azure.activedirectory.user-group.allowed-groups"; + public static final String TENANT_ID_PROPERTY = "azure.activedirectory.tenant-id"; + + public static final String DEFAULT_ENVIRONMENT = "global"; + public static final String CLIENT_ID = "real_client_id"; + public static final String CLIENT_SECRET = "real_client_secret"; + public static final List TARGETED_GROUPS = Arrays.asList("group1", "group2", "group3"); + + public static final String TOKEN_HEADER = "Authorization"; + public static final String BEARER_TOKEN = "Bearer real_jtw_bearer_token"; + + /** Token from https://docs.microsoft.com/azure/active-directory/develop/v2-id-and-access-tokens */ + public static final String JWT_TOKEN = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik1uQ19WWmNBVGZNNXBPWWlKSE1" + + "iYTlnb0VLWSJ9.eyJhdWQiOiI2NzMxZGU3Ni0xNGE2LTQ5YWUtOTdiYy02ZWJhNjkxNDM5MWUiLCJpc3MiOiJodHRwczovL2xvZ2lu" + + "Lm1pY3Jvc29mdG9ubGluZS5jb20vYjk0MTk4MTgtMDlhZi00OWMyLWIwYzMtNjUzYWRjMWYzNzZlL3YyLjAiLCJpYXQiOjE0NTIyOD" + + "UzMzEsIm5iZiI6MTQ1MjI4NTMzMSwiZXhwIjoxNDUyMjg5MjMxLCJuYW1lIjoiQmFiZSBSdXRoIiwibm9uY2UiOiIxMjM0NSIsIm9p" + + "ZCI6ImExZGJkZGU4LWU0ZjktNDU3MS1hZDkzLTMwNTllMzc1MGQyMyIsInByZWZlcnJlZF91c2VybmFtZSI6InRoZWdyZWF0YmFtYm" + + "lub0BueXkub25taWNyb3NvZnQuY29tIiwic3ViIjoiTUY0Zi1nZ1dNRWppMTJLeW5KVU5RWnBoYVVUdkxjUXVnNWpkRjJubDAxUSIs" + + "InRpZCI6ImI5NDE5ODE4LTA5YWYtNDljMi1iMGMzLTY1M2FkYzFmMzc2ZSIsInZlciI6IjIuMCJ9.p_rYdrtJ1oCmgDBggNHB9O38K" + + "TnLCMGbMDODdirdmZbmJcTHiZDdtTc-hguu3krhbtOsoYM2HJeZM3Wsbp_YcfSKDY--X_NobMNsxbT7bqZHxDnA2jTMyrmt5v2EKUn" + + "EeVtSiJXyO3JWUq9R0dO-m4o9_8jGP6zHtR62zLaotTBYHmgeKpZgTFB9WtUq8DVdyMn_HSvQEfz-LWqckbcTwM_9RNKoGRVk38KCh" + + "VJo4z5LkksYRarDo8QgQ7xEKmYmPvRr_I7gvM2bmlZQds2OeqWLB1NSNbFZqyFOCgYn3bAQ-nEQSKwBaA36jYGPOVG2r2Qv1uKcpSO" + + "xzxaQybzYpQ"; +} diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/ResourceRetrieverTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/ResourceRetrieverTest.java index 2cd00720ca9d..46b75b523b09 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/ResourceRetrieverTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/ResourceRetrieverTest.java @@ -14,13 +14,11 @@ public class ResourceRetrieverTest { private final WebApplicationContextRunner contextRunner = new WebApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(AADAuthenticationFilterAutoConfiguration.class)) - .withPropertyValues( - "azure.activedirectory.client-id=fake-client-id", - "azure.activedirectory.client-secret=fake-client-secret", - "azure.activedirectory.user-group.allowed-groups=fake-group", - "azure.service.endpoints.global.aadKeyDiscoveryUri=http://fake.aad.discovery.uri" - ); + .withConfiguration(AutoConfigurations.of(AADAuthenticationFilterAutoConfiguration.class)) + .withPropertyValues("azure.activedirectory.client-id=fake-client-id", + "azure.activedirectory.client-secret=fake-client-secret", + "azure.activedirectory.user-group.allowed-groups=fake-group", + "azure.service.endpoints.global.aadKeyDiscoveryUri=http://fake.aad.discovery.uri"); @Test public void resourceRetrieverDefaultConfig() { diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/TestConstants.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/TestConstants.java index ec08008cfd5a..00b37e65d235 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/TestConstants.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/TestConstants.java @@ -11,12 +11,28 @@ public class TestConstants { public static final String CLIENT_ID_PROPERTY = "azure.activedirectory.client-id"; public static final String CLIENT_SECRET_PROPERTY = "azure.activedirectory.client-secret"; public static final String TARGETED_GROUPS_PROPERTY = "azure.activedirectory.user-group.allowed-groups"; + public static final String TENANT_ID_PROPERTY = "azure.activedirectory.tenant-id"; public static final String ALLOW_TELEMETRY_PROPERTY = "azure.activedirectory.allow-telemetry"; + public static final String DEFAULT_ENVIRONMENT = "global"; public static final String CLIENT_ID = "real_client_id"; public static final String CLIENT_SECRET = "real_client_secret"; public static final List TARGETED_GROUPS = Arrays.asList("group1", "group2", "group3"); + public static final String TOKEN_HEADER = "Authorization"; public static final String ACCESS_TOKEN = "real_jwt_access_token"; - public static final String BEARER_TOKEN = "Bearer " + ACCESS_TOKEN; + public static final String BEARER_TOKEN = "Bearer real_jwt_bearer_token"; + + /** Token from https://docs.microsoft.com/azure/active-directory/develop/v2-id-and-access-tokens */ + public static final String JWT_TOKEN = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik1uQ19WWmNBVGZNNXBPWWlKSE1" + + "iYTlnb0VLWSJ9.eyJhdWQiOiI2NzMxZGU3Ni0xNGE2LTQ5YWUtOTdiYy02ZWJhNjkxNDM5MWUiLCJpc3MiOiJodHRwczovL2xvZ2lu" + + "Lm1pY3Jvc29mdG9ubGluZS5jb20vYjk0MTk4MTgtMDlhZi00OWMyLWIwYzMtNjUzYWRjMWYzNzZlL3YyLjAiLCJpYXQiOjE0NTIyOD" + + "UzMzEsIm5iZiI6MTQ1MjI4NTMzMSwiZXhwIjoxNDUyMjg5MjMxLCJuYW1lIjoiQmFiZSBSdXRoIiwibm9uY2UiOiIxMjM0NSIsIm9p" + + "ZCI6ImExZGJkZGU4LWU0ZjktNDU3MS1hZDkzLTMwNTllMzc1MGQyMyIsInByZWZlcnJlZF91c2VybmFtZSI6InRoZWdyZWF0YmFtYm" + + "lub0BueXkub25taWNyb3NvZnQuY29tIiwic3ViIjoiTUY0Zi1nZ1dNRWppMTJLeW5KVU5RWnBoYVVUdkxjUXVnNWpkRjJubDAxUSIs" + + "InRpZCI6ImI5NDE5ODE4LTA5YWYtNDljMi1iMGMzLTY1M2FkYzFmMzc2ZSIsInZlciI6IjIuMCJ9.p_rYdrtJ1oCmgDBggNHB9O38K" + + "TnLCMGbMDODdirdmZbmJcTHiZDdtTc-hguu3krhbtOsoYM2HJeZM3Wsbp_YcfSKDY--X_NobMNsxbT7bqZHxDnA2jTMyrmt5v2EKUn" + + "EeVtSiJXyO3JWUq9R0dO-m4o9_8jGP6zHtR62zLaotTBYHmgeKpZgTFB9WtUq8DVdyMn_HSvQEfz-LWqckbcTwM_9RNKoGRVk38KCh" + + "VJo4z5LkksYRarDo8QgQ7xEKmYmPvRr_I7gvM2bmlZQds2OeqWLB1NSNbFZqyFOCgYn3bAQ-nEQSKwBaA36jYGPOVG2r2Qv1uKcpSO" + + "xzxaQybzYpQ"; } diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphOboClientAzureADGraphTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/UserPrincipalAzureADGraphTest.java similarity index 54% rename from sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphOboClientAzureADGraphTest.java rename to sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/UserPrincipalAzureADGraphTest.java index 4a74a2bfc3e1..513f00cafa8c 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphOboClientAzureADGraphTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/UserPrincipalAzureADGraphTest.java @@ -6,13 +6,24 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.github.tomakehurst.wiremock.junit.WireMockRule; +import com.nimbusds.jose.JWSObject; +import com.nimbusds.jwt.JWTClaimsSet; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import org.springframework.http.HttpHeaders; import org.springframework.security.core.GrantedAuthority; +import org.springframework.util.StringUtils; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.file.Files; +import java.text.ParseException; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -29,29 +40,26 @@ import static com.github.tomakehurst.wiremock.client.WireMock.verify; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.http.HttpHeaders.ACCEPT; -import static org.springframework.http.HttpHeaders.AUTHORIZATION; import static org.springframework.http.HttpHeaders.CONTENT_TYPE; import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE; -public class GraphOboClientAzureADGraphTest { + +public class UserPrincipalAzureADGraphTest { @Rule public WireMockRule wireMockRule = new WireMockRule(9519); - private GraphOboClient graphOboClient; + private AzureADGraphClient graphClientMock; private AADAuthenticationProperties aadAuthenticationProperties; private ServiceEndpointsProperties serviceEndpointsProperties; + private String accessToken; private static String userGroupsJson; static { try { final ObjectMapper objectMapper = new ObjectMapper(); - final Map json = objectMapper.readValue( - GraphOboClientAzureADGraphTest.class - .getClassLoader() - .getResourceAsStream("aad/azure-ad-graph-user-groups.json"), - new TypeReference>() { - } - ); + final Map json = objectMapper.readValue(UserPrincipalAzureADGraphTest.class.getClassLoader() + .getResourceAsStream("aad/azure-ad-graph-user-groups.json"), + new TypeReference>() { }); userGroupsJson = objectMapper.writeValueAsString(json); } catch (IOException e) { e.printStackTrace(); @@ -62,6 +70,7 @@ public class GraphOboClientAzureADGraphTest { @Before public void setup() { + accessToken = TestConstants.ACCESS_TOKEN; aadAuthenticationProperties = new AADAuthenticationProperties(); serviceEndpointsProperties = new ServiceEndpointsProperties(); final ServiceEndpoints serviceEndpoints = new ServiceEndpoints(); @@ -72,7 +81,7 @@ public void setup() { @Test public void getAuthoritiesByUserGroups() throws Exception { aadAuthenticationProperties.getUserGroup().setAllowedGroups(Collections.singletonList("group1")); - this.graphOboClient = new GraphOboClient(aadAuthenticationProperties, serviceEndpointsProperties); + this.graphClientMock = new AzureADGraphClient(aadAuthenticationProperties, serviceEndpointsProperties); stubFor(get(urlEqualTo("/memberOf")) .withHeader(ACCEPT, equalTo("application/json;odata=minimalmetadata")) @@ -81,21 +90,19 @@ public void getAuthoritiesByUserGroups() throws Exception { .withHeader(CONTENT_TYPE, APPLICATION_JSON_VALUE) .withBody(userGroupsJson))); - assertThat(graphOboClient.getGrantedAuthorities(TestConstants.ACCESS_TOKEN)) - .isNotEmpty() - .extracting(GrantedAuthority::getAuthority) - .containsExactly("ROLE_group1"); + assertThat(graphClientMock.getGrantedAuthorities(TestConstants.ACCESS_TOKEN)).isNotEmpty() + .extracting(GrantedAuthority::getAuthority).containsExactly("ROLE_group1"); verify(getRequestedFor(urlMatching("/memberOf")) - .withHeader(AUTHORIZATION, equalTo(TestConstants.BEARER_TOKEN)) + .withHeader(HttpHeaders.AUTHORIZATION, equalTo(String.format("Bearer %s", accessToken))) .withHeader(ACCEPT, equalTo("application/json;odata=minimalmetadata")) .withHeader("api-version", equalTo("1.6"))); } @Test public void getGroups() throws Exception { - aadAuthenticationProperties.getUserGroup().setAllowedGroups(Arrays.asList("group1", "group2", "group3")); - this.graphOboClient = new GraphOboClient(aadAuthenticationProperties, serviceEndpointsProperties); + aadAuthenticationProperties.setActiveDirectoryGroups(Arrays.asList("group1", "group2", "group3")); + this.graphClientMock = new AzureADGraphClient(aadAuthenticationProperties, serviceEndpointsProperties); stubFor(get(urlEqualTo("/memberOf")) .withHeader(ACCEPT, equalTo("application/json;odata=minimalmetadata")) @@ -104,17 +111,43 @@ public void getGroups() throws Exception { .withHeader(CONTENT_TYPE, APPLICATION_JSON_VALUE) .withBody(userGroupsJson))); - final Collection authorities = graphOboClient + final Collection authorities = graphClientMock .getGrantedAuthorities(TestConstants.ACCESS_TOKEN); - assertThat(authorities) - .isNotEmpty() - .extracting(GrantedAuthority::getAuthority) - .containsExactlyInAnyOrder("ROLE_group1", "ROLE_group2", "ROLE_group3"); + assertThat(authorities).isNotEmpty().extracting(GrantedAuthority::getAuthority) + .containsExactlyInAnyOrder("ROLE_group1", "ROLE_group2", "ROLE_group3"); verify(getRequestedFor(urlMatching("/memberOf")) - .withHeader(AUTHORIZATION, equalTo(TestConstants.BEARER_TOKEN)) + .withHeader(HttpHeaders.AUTHORIZATION, equalTo(String.format("Bearer %s", accessToken))) .withHeader(ACCEPT, equalTo("application/json;odata=minimalmetadata")) .withHeader("api-version", equalTo("1.6"))); } + + @Test + public void userPrincipalIsSerializable() throws ParseException, IOException, ClassNotFoundException { + final File tmpOutputFile = File.createTempFile("test-user-principal", "txt"); + + try (FileOutputStream fileOutputStream = new FileOutputStream(tmpOutputFile); + ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream); + FileInputStream fileInputStream = new FileInputStream(tmpOutputFile); + ObjectInputStream objectInputStream = new ObjectInputStream(fileInputStream)) { + + final JWSObject jwsObject = JWSObject.parse(TestConstants.JWT_TOKEN); + final JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().subject("fake-subject").build(); + final UserPrincipal principal = new UserPrincipal("", jwsObject, jwtClaimsSet); + + objectOutputStream.writeObject(principal); + + final UserPrincipal serializedPrincipal = (UserPrincipal) objectInputStream.readObject(); + + Assert.assertNotNull("Serialized UserPrincipal not null", serializedPrincipal); + Assert.assertFalse("Serialized UserPrincipal kid not empty", + StringUtils.isEmpty(serializedPrincipal.getKid())); + Assert.assertNotNull("Serialized UserPrincipal claims not null.", serializedPrincipal.getClaims()); + Assert.assertTrue("Serialized UserPrincipal claims not empty.", + serializedPrincipal.getClaims().size() > 0); + } finally { + Files.deleteIfExists(tmpOutputFile.toPath()); + } + } } diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/UserPrincipalManagerTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/UserPrincipalManagerTest.java index 5d867887a3f5..d11d8cd3ab7a 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/UserPrincipalManagerTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/UserPrincipalManagerTest.java @@ -3,61 +3,37 @@ package com.azure.spring.autoconfigure.aad; -import com.nimbusds.jose.JWSObject; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.ImmutableJWKSet; import com.nimbusds.jose.proc.SecurityContext; -import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.proc.BadJWTException; import junitparams.FileParameters; import junitparams.JUnitParamsRunner; -import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; -import org.springframework.util.StringUtils; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; -import java.text.ParseException; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; -import static org.junit.Assert.assertTrue; @RunWith(JUnitParamsRunner.class) public class UserPrincipalManagerTest { - /** Token from https://docs.microsoft.com/azure/active-directory/develop/v2-id-and-access-tokens */ - private static final String JWT_TOKEN = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik1uQ19WWmNBVGZNNXBPWWlKSE1" - + "iYTlnb0VLWSJ9.eyJhdWQiOiI2NzMxZGU3Ni0xNGE2LTQ5YWUtOTdiYy02ZWJhNjkxNDM5MWUiLCJpc3MiOiJodHRwczovL2xvZ2lu" - + "Lm1pY3Jvc29mdG9ubGluZS5jb20vYjk0MTk4MTgtMDlhZi00OWMyLWIwYzMtNjUzYWRjMWYzNzZlL3YyLjAiLCJpYXQiOjE0NTIyOD" - + "UzMzEsIm5iZiI6MTQ1MjI4NTMzMSwiZXhwIjoxNDUyMjg5MjMxLCJuYW1lIjoiQmFiZSBSdXRoIiwibm9uY2UiOiIxMjM0NSIsIm9p" - + "ZCI6ImExZGJkZGU4LWU0ZjktNDU3MS1hZDkzLTMwNTllMzc1MGQyMyIsInByZWZlcnJlZF91c2VybmFtZSI6InRoZWdyZWF0YmFtYm" - + "lub0BueXkub25taWNyb3NvZnQuY29tIiwic3ViIjoiTUY0Zi1nZ1dNRWppMTJLeW5KVU5RWnBoYVVUdkxjUXVnNWpkRjJubDAxUSIs" - + "InRpZCI6ImI5NDE5ODE4LTA5YWYtNDljMi1iMGMzLTY1M2FkYzFmMzc2ZSIsInZlciI6IjIuMCJ9.p_rYdrtJ1oCmgDBggNHB9O38K" - + "TnLCMGbMDODdirdmZbmJcTHiZDdtTc-hguu3krhbtOsoYM2HJeZM3Wsbp_YcfSKDY--X_NobMNsxbT7bqZHxDnA2jTMyrmt5v2EKUn" - + "EeVtSiJXyO3JWUq9R0dO-m4o9_8jGP6zHtR62zLaotTBYHmgeKpZgTFB9WtUq8DVdyMn_HSvQEfz-LWqckbcTwM_9RNKoGRVk38KCh" - + "VJo4z5LkksYRarDo8QgQ7xEKmYmPvRr_I7gvM2bmlZQds2OeqWLB1NSNbFZqyFOCgYn3bAQ-nEQSKwBaA36jYGPOVG2r2Qv1uKcpSO" - + "xzxaQybzYpQ"; - private static ImmutableJWKSet immutableJWKSet; @BeforeClass public static void setupClass() throws Exception { final X509Certificate cert = (X509Certificate) CertificateFactory.getInstance("X.509") .generateCertificate(Files.newInputStream(Paths.get("src/test/resources/test-public-key.txt"))); - immutableJWKSet = new ImmutableJWKSet<>(new JWKSet(JWK.parse(cert))); + immutableJWKSet = new ImmutableJWKSet<>(new JWKSet(JWK.parse( + cert))); } private UserPrincipalManager userPrincipalManager; @@ -67,28 +43,19 @@ public static void setupClass() throws Exception { public void testAlgIsTakenFromJWT() throws Exception { userPrincipalManager = new UserPrincipalManager(immutableJWKSet); final UserPrincipal userPrincipal = userPrincipalManager.buildUserPrincipal( - new String(Files.readAllBytes( - Paths.get("src/test/resources/jwt-signed.txt")), - StandardCharsets.UTF_8 - ) - ); - assertThat(userPrincipal) - .isNotNull() - .extracting(UserPrincipal::getIssuer, UserPrincipal::getSubject) - .containsExactly("https://sts.windows.net/test", "test@example.com"); + new String(Files.readAllBytes( + Paths.get("src/test/resources/jwt-signed.txt")), StandardCharsets.UTF_8)); + assertThat(userPrincipal).isNotNull().extracting(UserPrincipal::getIssuer, UserPrincipal::getSubject) + .containsExactly("https://sts.windows.net/test", "test@example.com"); } @Test public void invalidIssuer() { userPrincipalManager = new UserPrincipalManager(immutableJWKSet); - assertThatCode(() -> - userPrincipalManager.buildUserPrincipal( + assertThatCode(() -> userPrincipalManager.buildUserPrincipal( new String(Files.readAllBytes( - Paths.get("src/test/resources/jwt-bad-issuer.txt")), - StandardCharsets.UTF_8 - ) - ) - ).isInstanceOf(BadJWTException.class); + Paths.get("src/test/resources/jwt-bad-issuer.txt")), StandardCharsets.UTF_8))) + .isInstanceOf(BadJWTException.class); } @Test @@ -103,47 +70,9 @@ public void validIssuer(final String token) { @Test public void nullIssuer() { userPrincipalManager = new UserPrincipalManager(immutableJWKSet); - assertThatCode(() -> - userPrincipalManager.buildUserPrincipal( + assertThatCode(() -> userPrincipalManager.buildUserPrincipal( new String(Files.readAllBytes( - Paths.get("src/test/resources/jwt-null-issuer.txt")), - StandardCharsets.UTF_8 - ) - ) - ).isInstanceOf(BadJWTException.class); - } - - - - @Test - public void userPrincipalIsSerializable() throws ParseException, IOException, ClassNotFoundException { - final File tmpOutputFile = File.createTempFile("test-user-principal", "txt"); - - try (FileOutputStream fileOutputStream = new FileOutputStream(tmpOutputFile); - ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream); - FileInputStream fileInputStream = new FileInputStream(tmpOutputFile); - ObjectInputStream objectInputStream = new ObjectInputStream(fileInputStream)) { - - final JWSObject jwsObject = JWSObject.parse(JWT_TOKEN); - final JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().subject("fake-subject").build(); - final UserPrincipal principal = new UserPrincipal("", jwsObject, jwtClaimsSet); - - objectOutputStream.writeObject(principal); - - final UserPrincipal serializedPrincipal = (UserPrincipal) objectInputStream.readObject(); - - Assert.assertNotNull("Serialized UserPrincipal not null", serializedPrincipal); - Assert.assertFalse( - "Serialized UserPrincipal kid not empty", - StringUtils.isEmpty(serializedPrincipal.getKid()) - ); - Assert.assertNotNull("Serialized UserPrincipal claims not null.", serializedPrincipal.getClaims()); - assertTrue( - "Serialized UserPrincipal claims not empty.", - serializedPrincipal.getClaims().size() > 0 - ); - } finally { - Files.deleteIfExists(tmpOutputFile.toPath()); - } + Paths.get("src/test/resources/jwt-null-issuer.txt")), StandardCharsets.UTF_8))) + .isInstanceOf(BadJWTException.class); } } diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphOboClientMicrosoftGraphTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/UserPrincipalMicrosoftGraphTest.java similarity index 58% rename from sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphOboClientMicrosoftGraphTest.java rename to sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/UserPrincipalMicrosoftGraphTest.java index 402953a21b37..839d39d60905 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/GraphOboClientMicrosoftGraphTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/azure/spring/autoconfigure/aad/UserPrincipalMicrosoftGraphTest.java @@ -6,13 +6,23 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.github.tomakehurst.wiremock.junit.WireMockRule; +import com.nimbusds.jose.JWSObject; +import com.nimbusds.jwt.JWTClaimsSet; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.springframework.security.core.GrantedAuthority; +import org.springframework.util.StringUtils; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.file.Files; +import java.text.ParseException; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -33,25 +43,22 @@ import static org.springframework.http.HttpHeaders.CONTENT_TYPE; import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE; -public class GraphOboClientMicrosoftGraphTest { +public class UserPrincipalMicrosoftGraphTest { @Rule public WireMockRule wireMockRule = new WireMockRule(9519); - private GraphOboClient graphOboClient; + private AzureADGraphClient graphClientMock; private AADAuthenticationProperties aadAuthenticationProperties; private ServiceEndpointsProperties serviceEndpointsProperties; + private String accessToken; private static String userGroupsJson; static { try { final ObjectMapper objectMapper = new ObjectMapper(); - final Map json = objectMapper.readValue( - GraphOboClientMicrosoftGraphTest.class - .getClassLoader() - .getResourceAsStream("aad/microsoft-graph-user-groups.json"), - new TypeReference>() { - } - ); + final Map json = objectMapper.readValue(UserPrincipalMicrosoftGraphTest.class + .getClassLoader().getResourceAsStream("aad/microsoft-graph-user-groups.json"), + new TypeReference>() { }); userGroupsJson = objectMapper.writeValueAsString(json); } catch (IOException e) { e.printStackTrace(); @@ -62,6 +69,7 @@ public class GraphOboClientMicrosoftGraphTest { @Before public void setup() { + accessToken = MicrosoftGraphConstants.BEARER_TOKEN; aadAuthenticationProperties = new AADAuthenticationProperties(); aadAuthenticationProperties.setEnvironment("global-v2-graph"); aadAuthenticationProperties.getUserGroup().setKey("@odata.type"); @@ -70,17 +78,13 @@ public void setup() { serviceEndpointsProperties = new ServiceEndpointsProperties(); final ServiceEndpoints serviceEndpoints = new ServiceEndpoints(); serviceEndpoints.setAadMembershipRestUri("http://localhost:9519/memberOf"); - serviceEndpoints.setAadTransitiveMemberRestUri("http://localhost:9519/transitiveMemberOf"); serviceEndpointsProperties.getEndpoints().put("global-v2-graph", serviceEndpoints); } @Test public void getAuthoritiesByUserGroups() throws Exception { - aadAuthenticationProperties.getUserGroup().setGroupRelationship("direct"); aadAuthenticationProperties.getUserGroup().setAllowedGroups(Collections.singletonList("group1")); - serviceEndpointsProperties.getServiceEndpoints("global-v2-graph") - .setAadMembershipRestUri("http://localhost:9519/memberOf"); - this.graphOboClient = new GraphOboClient(aadAuthenticationProperties, serviceEndpointsProperties); + this.graphClientMock = new AzureADGraphClient(aadAuthenticationProperties, serviceEndpointsProperties); stubFor(get(urlEqualTo("/memberOf")) .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE)) @@ -89,24 +93,20 @@ public void getAuthoritiesByUserGroups() throws Exception { .withHeader(CONTENT_TYPE, APPLICATION_JSON_VALUE) .withBody(userGroupsJson))); - assertThat(graphOboClient.getGrantedAuthorities(TestConstants.ACCESS_TOKEN)) + assertThat(graphClientMock.getGrantedAuthorities(MicrosoftGraphConstants.BEARER_TOKEN)) .isNotEmpty() .extracting(GrantedAuthority::getAuthority) .containsExactly("ROLE_group1"); verify(getRequestedFor(urlMatching("/memberOf")) - .withHeader(AUTHORIZATION, equalTo(TestConstants.BEARER_TOKEN)) + .withHeader(AUTHORIZATION, equalTo(String.format("Bearer %s", accessToken))) .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE))); } @Test - public void getDirectGroups() throws Exception { - aadAuthenticationProperties.getUserGroup().setGroupRelationship("direct"); - AADAuthenticationProperties.UserGroupProperties userGroupProperties = - aadAuthenticationProperties.getUserGroup(); - userGroupProperties.setAllowedGroups(Arrays.asList("group1", "group2", "group3")); - aadAuthenticationProperties.setUserGroup(userGroupProperties); - this.graphOboClient = new GraphOboClient(aadAuthenticationProperties, serviceEndpointsProperties); + public void getGroups() throws Exception { + aadAuthenticationProperties.setActiveDirectoryGroups(Arrays.asList("group1", "group2", "group3")); + this.graphClientMock = new AzureADGraphClient(aadAuthenticationProperties, serviceEndpointsProperties); stubFor(get(urlEqualTo("/memberOf")) .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE)) @@ -115,8 +115,8 @@ public void getDirectGroups() throws Exception { .withHeader(CONTENT_TYPE, APPLICATION_JSON_VALUE) .withBody(userGroupsJson))); - final Collection authorities = graphOboClient - .getGrantedAuthorities(TestConstants.ACCESS_TOKEN); + final Collection authorities = graphClientMock + .getGrantedAuthorities(MicrosoftGraphConstants.BEARER_TOKEN); assertThat(authorities) .isNotEmpty() @@ -124,36 +124,35 @@ public void getDirectGroups() throws Exception { .containsExactlyInAnyOrder("ROLE_group1", "ROLE_group2", "ROLE_group3"); verify(getRequestedFor(urlMatching("/memberOf")) - .withHeader(AUTHORIZATION, equalTo(TestConstants.BEARER_TOKEN)) + .withHeader(AUTHORIZATION, equalTo(String.format("Bearer %s", accessToken))) .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE))); } @Test - public void getTransitiveGroups() throws Exception { - aadAuthenticationProperties.getUserGroup().setGroupRelationship("transitive"); - AADAuthenticationProperties.UserGroupProperties userGroupProperties = - aadAuthenticationProperties.getUserGroup(); - userGroupProperties.setAllowedGroups(Arrays.asList("group1", "group2", "group3")); - aadAuthenticationProperties.setUserGroup(userGroupProperties); - this.graphOboClient = new GraphOboClient(aadAuthenticationProperties, serviceEndpointsProperties); - - stubFor(get(urlEqualTo("/transitiveMemberOf")) - .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE)) - .willReturn(aResponse() - .withStatus(200) - .withHeader(CONTENT_TYPE, APPLICATION_JSON_VALUE) - .withBody(userGroupsJson))); - - final Collection authorities = graphOboClient - .getGrantedAuthorities(TestConstants.ACCESS_TOKEN); - - assertThat(authorities) - .isNotEmpty() - .extracting(GrantedAuthority::getAuthority) - .containsExactlyInAnyOrder("ROLE_group1", "ROLE_group2", "ROLE_group3"); - - verify(getRequestedFor(urlMatching("/transitiveMemberOf")) - .withHeader(AUTHORIZATION, equalTo(TestConstants.BEARER_TOKEN)) - .withHeader(ACCEPT, equalTo(APPLICATION_JSON_VALUE))); + public void userPrincipalIsSerializable() throws ParseException, IOException, ClassNotFoundException { + final File tmpOutputFile = File.createTempFile("test-user-principal", "txt"); + + try (FileOutputStream fileOutputStream = new FileOutputStream(tmpOutputFile); + ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream); + FileInputStream fileInputStream = new FileInputStream(tmpOutputFile); + ObjectInputStream objectInputStream = new ObjectInputStream(fileInputStream)) { + + final JWSObject jwsObject = JWSObject.parse(MicrosoftGraphConstants.JWT_TOKEN); + final JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().subject("fake-subject").build(); + final UserPrincipal principal = new UserPrincipal("", jwsObject, jwtClaimsSet); + + objectOutputStream.writeObject(principal); + + final UserPrincipal serializedPrincipal = (UserPrincipal) objectInputStream.readObject(); + + Assert.assertNotNull("Serialized UserPrincipal not null", serializedPrincipal); + Assert.assertFalse("Serialized UserPrincipal kid not empty", + StringUtils.isEmpty(serializedPrincipal.getKid())); + Assert.assertNotNull("Serialized UserPrincipal claims not null.", serializedPrincipal.getClaims()); + Assert.assertTrue("Serialized UserPrincipal claims not empty.", + serializedPrincipal.getClaims().size() > 0); + } finally { + Files.deleteIfExists(tmpOutputFile.toPath()); + } } }