diff --git a/config/src/main/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParser.java index 78fb3543a7c..a1c67d0c176 100644 --- a/config/src/main/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParser.java @@ -42,6 +42,7 @@ /** * @author Ruby Hartono + * @author Evgeniy Cheban * @since 5.3 */ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefinitionParser { @@ -87,7 +88,7 @@ public BeanDefinition parse(Element element, ParserContext parserContext) { CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(), parserContext.extractSource(element)); parserContext.pushContainingComponent(compositeDef); - Map> providers = getProviders(element); + Map> providers = getProviders(parserContext, element); List clientRegistrations = getClientRegistrations(element, parserContext, providers); BeanDefinition clientRegistrationRepositoryBean = BeanDefinitionBuilder .rootBeanDefinition(InMemoryClientRegistrationRepository.class) @@ -107,9 +108,10 @@ private List getClientRegistrations(Element element, ParserC for (Element clientRegistrationElt : clientRegistrationElts) { String registrationId = clientRegistrationElt.getAttribute(ATT_REGISTRATION_ID); String providerId = clientRegistrationElt.getAttribute(ATT_PROVIDER_ID); - ClientRegistration.Builder builder = getBuilderFromIssuerIfPossible(registrationId, providerId, providers); + ClientRegistration.Builder builder = getBuilderFromIssuerIfPossible(parserContext, registrationId, + providerId, providers); if (builder == null) { - builder = getBuilder(registrationId, providerId, providers); + builder = getBuilder(parserContext, registrationId, providerId, providers); if (builder == null) { Object source = parserContext.extractSource(element); parserContext.getReaderContext().error(getErrorMessage(providerId, registrationId), source); @@ -117,50 +119,53 @@ private List getClientRegistrations(Element element, ParserC continue; } } - getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_ID)).ifPresent(builder::clientId); - getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_SECRET)) + getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_ID)) + .ifPresent(builder::clientId); + getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_SECRET)) .ifPresent(builder::clientSecret); - getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_AUTHENTICATION_METHOD)) + getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_AUTHENTICATION_METHOD)) .map(ClientAuthenticationMethod::new).ifPresent(builder::clientAuthenticationMethod); - getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_AUTHORIZATION_GRANT_TYPE)) + getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_AUTHORIZATION_GRANT_TYPE)) .map(AuthorizationGrantType::new).ifPresent(builder::authorizationGrantType); - getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_REDIRECT_URI)).ifPresent(builder::redirectUri); - getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_SCOPE)) + getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_REDIRECT_URI)) + .ifPresent(builder::redirectUri); + getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_SCOPE)) .map(StringUtils::commaDelimitedListToSet).ifPresent(builder::scope); - getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_NAME)).ifPresent(builder::clientName); + getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_NAME)) + .ifPresent(builder::clientName); clientRegistrations.add(builder.build()); } return clientRegistrations; } - private Map> getProviders(Element element) { + private Map> getProviders(ParserContext parserContext, Element element) { List providerElts = DomUtils.getChildElementsByTagName(element, ELT_PROVIDER); Map> providers = new HashMap<>(); for (Element providerElt : providerElts) { Map provider = new HashMap<>(); String providerId = providerElt.getAttribute(ATT_PROVIDER_ID); provider.put(ATT_PROVIDER_ID, providerId); - getOptionalIfNotEmpty(providerElt.getAttribute(ATT_AUTHORIZATION_URI)) + getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_AUTHORIZATION_URI)) .ifPresent((value) -> provider.put(ATT_AUTHORIZATION_URI, value)); - getOptionalIfNotEmpty(providerElt.getAttribute(ATT_TOKEN_URI)) + getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_TOKEN_URI)) .ifPresent((value) -> provider.put(ATT_TOKEN_URI, value)); - getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_URI)) + getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_USER_INFO_URI)) .ifPresent((value) -> provider.put(ATT_USER_INFO_URI, value)); - getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_AUTHENTICATION_METHOD)) + getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_USER_INFO_AUTHENTICATION_METHOD)) .ifPresent((value) -> provider.put(ATT_USER_INFO_AUTHENTICATION_METHOD, value)); - getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_USER_NAME_ATTRIBUTE)) + getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_USER_INFO_USER_NAME_ATTRIBUTE)) .ifPresent((value) -> provider.put(ATT_USER_INFO_USER_NAME_ATTRIBUTE, value)); - getOptionalIfNotEmpty(providerElt.getAttribute(ATT_JWK_SET_URI)) + getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_JWK_SET_URI)) .ifPresent((value) -> provider.put(ATT_JWK_SET_URI, value)); - getOptionalIfNotEmpty(providerElt.getAttribute(ATT_ISSUER_URI)) + getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_ISSUER_URI)) .ifPresent((value) -> provider.put(ATT_ISSUER_URI, value)); providers.put(providerId, provider); } return providers; } - private static ClientRegistration.Builder getBuilderFromIssuerIfPossible(String registrationId, - String configuredProviderId, Map> providers) { + private static ClientRegistration.Builder getBuilderFromIssuerIfPossible(ParserContext parserContext, + String registrationId, String configuredProviderId, Map> providers) { String providerId = (configuredProviderId != null) ? configuredProviderId : registrationId; if (providers.containsKey(providerId)) { Map provider = providers.get(providerId); @@ -168,14 +173,14 @@ private static ClientRegistration.Builder getBuilderFromIssuerIfPossible(String if (!StringUtils.isEmpty(issuer)) { ClientRegistration.Builder builder = ClientRegistrations.fromIssuerLocation(issuer) .registrationId(registrationId); - return getBuilder(builder, provider); + return getBuilder(parserContext, builder, provider); } } return null; } - private static ClientRegistration.Builder getBuilder(String registrationId, String configuredProviderId, - Map> providers) { + private static ClientRegistration.Builder getBuilder(ParserContext parserContext, String registrationId, + String configuredProviderId, Map> providers) { String providerId = (configuredProviderId != null) ? configuredProviderId : registrationId; CommonOAuth2Provider provider = getCommonProvider(providerId); if (provider == null && !providers.containsKey(providerId)) { @@ -184,26 +189,27 @@ private static ClientRegistration.Builder getBuilder(String registrationId, Stri ClientRegistration.Builder builder = (provider != null) ? provider.getBuilder(registrationId) : ClientRegistration.withRegistrationId(registrationId); if (providers.containsKey(providerId)) { - return getBuilder(builder, providers.get(providerId)); + return getBuilder(parserContext, builder, providers.get(providerId)); } return builder; } - private static ClientRegistration.Builder getBuilder(ClientRegistration.Builder builder, - Map provider) { - getOptionalIfNotEmpty(provider.get(ATT_AUTHORIZATION_URI)).ifPresent(builder::authorizationUri); - getOptionalIfNotEmpty(provider.get(ATT_TOKEN_URI)).ifPresent(builder::tokenUri); - getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_URI)).ifPresent(builder::userInfoUri); - getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_AUTHENTICATION_METHOD)).map(AuthenticationMethod::new) - .ifPresent(builder::userInfoAuthenticationMethod); - getOptionalIfNotEmpty(provider.get(ATT_JWK_SET_URI)).ifPresent(builder::jwkSetUri); - getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_USER_NAME_ATTRIBUTE)) + private static ClientRegistration.Builder getBuilder(ParserContext parserContext, + ClientRegistration.Builder builder, Map provider) { + getOptionalIfNotEmpty(parserContext, provider.get(ATT_AUTHORIZATION_URI)).ifPresent(builder::authorizationUri); + getOptionalIfNotEmpty(parserContext, provider.get(ATT_TOKEN_URI)).ifPresent(builder::tokenUri); + getOptionalIfNotEmpty(parserContext, provider.get(ATT_USER_INFO_URI)).ifPresent(builder::userInfoUri); + getOptionalIfNotEmpty(parserContext, provider.get(ATT_USER_INFO_AUTHENTICATION_METHOD)) + .map(AuthenticationMethod::new).ifPresent(builder::userInfoAuthenticationMethod); + getOptionalIfNotEmpty(parserContext, provider.get(ATT_JWK_SET_URI)).ifPresent(builder::jwkSetUri); + getOptionalIfNotEmpty(parserContext, provider.get(ATT_USER_INFO_USER_NAME_ATTRIBUTE)) .ifPresent(builder::userNameAttributeName); return builder; } - private static Optional getOptionalIfNotEmpty(String str) { - return Optional.ofNullable(str).filter((s) -> !s.isEmpty()); + private static Optional getOptionalIfNotEmpty(ParserContext parserContext, String str) { + return Optional.ofNullable(str).filter((s) -> !s.isEmpty()) + .map(parserContext.getReaderContext().getEnvironment()::resolvePlaceholders); } private static CommonOAuth2Provider getCommonProvider(String providerId) { diff --git a/config/src/test/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests.java index 74cfdb56728..94b7e093d59 100644 --- a/config/src/test/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests.java @@ -41,6 +41,7 @@ * Tests for {@link ClientRegistrationsBeanDefinitionParser}. * * @author Ruby Hartono + * @author Evgeniy Cheban */ public class ClientRegistrationsBeanDefinitionParserTests { @@ -218,6 +219,20 @@ public void parseWhenMultipleClientsConfiguredThenAvailableInRepository() { assertThat(githubProviderDetails.getUserInfoEndpoint().getUserNameAttributeName()).isEqualTo("id"); } + @Test + public void parseWhenClientPlaceholdersThenResolvePlaceholders() { + System.setProperty("oauth2.client.id", "github-client-id"); + System.setProperty("oauth2.client.secret", "github-client-secret"); + + this.spring.configLocations(xml("ClientPlaceholders")).autowire(); + + assertThat(this.clientRegistrationRepository).isInstanceOf(InMemoryClientRegistrationRepository.class); + + ClientRegistration githubRegistration = this.clientRegistrationRepository.findByRegistrationId("github"); + assertThat(githubRegistration.getClientId()).isEqualTo("github-client-id"); + assertThat(githubRegistration.getClientSecret()).isEqualTo("github-client-secret"); + } + private static MockResponse jsonResponse(String json) { return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); } diff --git a/config/src/test/resources/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests-ClientPlaceholders.xml b/config/src/test/resources/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests-ClientPlaceholders.xml new file mode 100644 index 00000000000..b929c3eee8f --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests-ClientPlaceholders.xml @@ -0,0 +1,32 @@ + + + + + + + +