Skip to content

Commit

Permalink
Resolve oauth2 client placeholders
Browse files Browse the repository at this point in the history
  • Loading branch information
evgeniycheban committed Aug 28, 2020
1 parent 902fca6 commit 875c323
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

/**
* @author Ruby Hartono
* @author Evgeniy Cheban
* @since 5.3
*/
public final class ClientRegistrationsBeanDefinitionParser implements BeanDefinitionParser {
Expand Down Expand Up @@ -87,7 +88,7 @@ public BeanDefinition parse(Element element, ParserContext parserContext) {
CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(),
parserContext.extractSource(element));
parserContext.pushContainingComponent(compositeDef);
Map<String, Map<String, String>> providers = getProviders(element);
Map<String, Map<String, String>> providers = getProviders(parserContext, element);
List<ClientRegistration> clientRegistrations = getClientRegistrations(element, parserContext, providers);
BeanDefinition clientRegistrationRepositoryBean = BeanDefinitionBuilder
.rootBeanDefinition(InMemoryClientRegistrationRepository.class)
Expand All @@ -107,75 +108,79 @@ private List<ClientRegistration> getClientRegistrations(Element element, ParserC
for (Element clientRegistrationElt : clientRegistrationElts) {
String registrationId = clientRegistrationElt.getAttribute(ATT_REGISTRATION_ID);
String providerId = clientRegistrationElt.getAttribute(ATT_PROVIDER_ID);
ClientRegistration.Builder builder = getBuilderFromIssuerIfPossible(registrationId, providerId, providers);
ClientRegistration.Builder builder = getBuilderFromIssuerIfPossible(parserContext, registrationId,
providerId, providers);
if (builder == null) {
builder = getBuilder(registrationId, providerId, providers);
builder = getBuilder(parserContext, registrationId, providerId, providers);
if (builder == null) {
Object source = parserContext.extractSource(element);
parserContext.getReaderContext().error(getErrorMessage(providerId, registrationId), source);
// error on the config skip to next element
continue;
}
}
getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_ID)).ifPresent(builder::clientId);
getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_SECRET))
getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_ID))
.ifPresent(builder::clientId);
getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_SECRET))
.ifPresent(builder::clientSecret);
getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_AUTHENTICATION_METHOD))
getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_AUTHENTICATION_METHOD))
.map(ClientAuthenticationMethod::new).ifPresent(builder::clientAuthenticationMethod);
getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_AUTHORIZATION_GRANT_TYPE))
getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_AUTHORIZATION_GRANT_TYPE))
.map(AuthorizationGrantType::new).ifPresent(builder::authorizationGrantType);
getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_REDIRECT_URI)).ifPresent(builder::redirectUri);
getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_SCOPE))
getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_REDIRECT_URI))
.ifPresent(builder::redirectUri);
getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_SCOPE))
.map(StringUtils::commaDelimitedListToSet).ifPresent(builder::scope);
getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_NAME)).ifPresent(builder::clientName);
getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_NAME))
.ifPresent(builder::clientName);
clientRegistrations.add(builder.build());
}
return clientRegistrations;
}

private Map<String, Map<String, String>> getProviders(Element element) {
private Map<String, Map<String, String>> getProviders(ParserContext parserContext, Element element) {
List<Element> providerElts = DomUtils.getChildElementsByTagName(element, ELT_PROVIDER);
Map<String, Map<String, String>> providers = new HashMap<>();
for (Element providerElt : providerElts) {
Map<String, String> provider = new HashMap<>();
String providerId = providerElt.getAttribute(ATT_PROVIDER_ID);
provider.put(ATT_PROVIDER_ID, providerId);
getOptionalIfNotEmpty(providerElt.getAttribute(ATT_AUTHORIZATION_URI))
getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_AUTHORIZATION_URI))
.ifPresent((value) -> provider.put(ATT_AUTHORIZATION_URI, value));
getOptionalIfNotEmpty(providerElt.getAttribute(ATT_TOKEN_URI))
getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_TOKEN_URI))
.ifPresent((value) -> provider.put(ATT_TOKEN_URI, value));
getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_URI))
getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_USER_INFO_URI))
.ifPresent((value) -> provider.put(ATT_USER_INFO_URI, value));
getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_AUTHENTICATION_METHOD))
getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_USER_INFO_AUTHENTICATION_METHOD))
.ifPresent((value) -> provider.put(ATT_USER_INFO_AUTHENTICATION_METHOD, value));
getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_USER_NAME_ATTRIBUTE))
getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_USER_INFO_USER_NAME_ATTRIBUTE))
.ifPresent((value) -> provider.put(ATT_USER_INFO_USER_NAME_ATTRIBUTE, value));
getOptionalIfNotEmpty(providerElt.getAttribute(ATT_JWK_SET_URI))
getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_JWK_SET_URI))
.ifPresent((value) -> provider.put(ATT_JWK_SET_URI, value));
getOptionalIfNotEmpty(providerElt.getAttribute(ATT_ISSUER_URI))
getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_ISSUER_URI))
.ifPresent((value) -> provider.put(ATT_ISSUER_URI, value));
providers.put(providerId, provider);
}
return providers;
}

private static ClientRegistration.Builder getBuilderFromIssuerIfPossible(String registrationId,
String configuredProviderId, Map<String, Map<String, String>> providers) {
private static ClientRegistration.Builder getBuilderFromIssuerIfPossible(ParserContext parserContext,
String registrationId, String configuredProviderId, Map<String, Map<String, String>> providers) {
String providerId = (configuredProviderId != null) ? configuredProviderId : registrationId;
if (providers.containsKey(providerId)) {
Map<String, String> provider = providers.get(providerId);
String issuer = provider.get(ATT_ISSUER_URI);
if (!StringUtils.isEmpty(issuer)) {
ClientRegistration.Builder builder = ClientRegistrations.fromIssuerLocation(issuer)
.registrationId(registrationId);
return getBuilder(builder, provider);
return getBuilder(parserContext, builder, provider);
}
}
return null;
}

private static ClientRegistration.Builder getBuilder(String registrationId, String configuredProviderId,
Map<String, Map<String, String>> providers) {
private static ClientRegistration.Builder getBuilder(ParserContext parserContext, String registrationId,
String configuredProviderId, Map<String, Map<String, String>> providers) {
String providerId = (configuredProviderId != null) ? configuredProviderId : registrationId;
CommonOAuth2Provider provider = getCommonProvider(providerId);
if (provider == null && !providers.containsKey(providerId)) {
Expand All @@ -184,26 +189,27 @@ private static ClientRegistration.Builder getBuilder(String registrationId, Stri
ClientRegistration.Builder builder = (provider != null) ? provider.getBuilder(registrationId)
: ClientRegistration.withRegistrationId(registrationId);
if (providers.containsKey(providerId)) {
return getBuilder(builder, providers.get(providerId));
return getBuilder(parserContext, builder, providers.get(providerId));
}
return builder;
}

private static ClientRegistration.Builder getBuilder(ClientRegistration.Builder builder,
Map<String, String> provider) {
getOptionalIfNotEmpty(provider.get(ATT_AUTHORIZATION_URI)).ifPresent(builder::authorizationUri);
getOptionalIfNotEmpty(provider.get(ATT_TOKEN_URI)).ifPresent(builder::tokenUri);
getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_URI)).ifPresent(builder::userInfoUri);
getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_AUTHENTICATION_METHOD)).map(AuthenticationMethod::new)
.ifPresent(builder::userInfoAuthenticationMethod);
getOptionalIfNotEmpty(provider.get(ATT_JWK_SET_URI)).ifPresent(builder::jwkSetUri);
getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_USER_NAME_ATTRIBUTE))
private static ClientRegistration.Builder getBuilder(ParserContext parserContext,
ClientRegistration.Builder builder, Map<String, String> provider) {
getOptionalIfNotEmpty(parserContext, provider.get(ATT_AUTHORIZATION_URI)).ifPresent(builder::authorizationUri);
getOptionalIfNotEmpty(parserContext, provider.get(ATT_TOKEN_URI)).ifPresent(builder::tokenUri);
getOptionalIfNotEmpty(parserContext, provider.get(ATT_USER_INFO_URI)).ifPresent(builder::userInfoUri);
getOptionalIfNotEmpty(parserContext, provider.get(ATT_USER_INFO_AUTHENTICATION_METHOD))
.map(AuthenticationMethod::new).ifPresent(builder::userInfoAuthenticationMethod);
getOptionalIfNotEmpty(parserContext, provider.get(ATT_JWK_SET_URI)).ifPresent(builder::jwkSetUri);
getOptionalIfNotEmpty(parserContext, provider.get(ATT_USER_INFO_USER_NAME_ATTRIBUTE))
.ifPresent(builder::userNameAttributeName);
return builder;
}

private static Optional<String> getOptionalIfNotEmpty(String str) {
return Optional.ofNullable(str).filter((s) -> !s.isEmpty());
private static Optional<String> getOptionalIfNotEmpty(ParserContext parserContext, String str) {
return Optional.ofNullable(str).filter((s) -> !s.isEmpty())
.map(parserContext.getReaderContext().getEnvironment()::resolvePlaceholders);
}

private static CommonOAuth2Provider getCommonProvider(String providerId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
* Tests for {@link ClientRegistrationsBeanDefinitionParser}.
*
* @author Ruby Hartono
* @author Evgeniy Cheban
*/
public class ClientRegistrationsBeanDefinitionParserTests {

Expand Down Expand Up @@ -218,6 +219,20 @@ public void parseWhenMultipleClientsConfiguredThenAvailableInRepository() {
assertThat(githubProviderDetails.getUserInfoEndpoint().getUserNameAttributeName()).isEqualTo("id");
}

@Test
public void parseWhenClientPlaceholdersThenResolvePlaceholders() {
System.setProperty("oauth2.client.id", "github-client-id");
System.setProperty("oauth2.client.secret", "github-client-secret");

this.spring.configLocations(xml("ClientPlaceholders")).autowire();

assertThat(this.clientRegistrationRepository).isInstanceOf(InMemoryClientRegistrationRepository.class);

ClientRegistration githubRegistration = this.clientRegistrationRepository.findByRegistrationId("github");
assertThat(githubRegistration.getClientId()).isEqualTo("github-client-id");
assertThat(githubRegistration.getClientSecret()).isEqualTo("github-client-secret");
}

private static MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright 2002-2020 the original author or authors.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ https://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->

<b:beans xmlns:b="http://www.springframework.org/schema/beans"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://www.springframework.org/schema/security"
xsi:schemaLocation="
http://www.springframework.org/schema/security
https://www.springframework.org/schema/security/spring-security.xsd
http://www.springframework.org/schema/beans
https://www.springframework.org/schema/beans/spring-beans.xsd">
<client-registrations>
<client-registration registration-id="github"
client-id="${oauth2.client.id}"
client-secret="${oauth2.client.secret}"
provider-id="github"/>
</client-registrations>
</b:beans>

0 comments on commit 875c323

Please sign in to comment.