Skip to content

Commit

Permalink
Fix Azure#31233: NimbusJwtDecoder still uses RestTemplate() instead R…
Browse files Browse the repository at this point in the history
…estTemplateBuilder
  • Loading branch information
rujche committed Oct 17, 2022
1 parent bcee590 commit 7a47e1b
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
import com.azure.spring.cloud.autoconfigure.aad.implementation.oauth2.OAuth2ClientAuthenticationJwkResolver;
import com.azure.spring.cloud.autoconfigure.aad.implementation.webapi.AadJwtBearerGrantRequestEntityConverter;
import com.azure.spring.cloud.autoconfigure.aad.implementation.webapp.AadAzureDelegatedOAuth2AuthorizedClientProvider;
import com.azure.spring.cloud.autoconfigure.aad.implementation.webapp.AadOidcIdTokenDecoderFactory;
import com.azure.spring.cloud.autoconfigure.aad.properties.AadAuthenticationProperties;
import com.azure.spring.cloud.autoconfigure.aad.properties.AadAuthorizationServerEndpoints;
import com.azure.spring.cloud.autoconfigure.aad.properties.AadProfileProperties;
import com.nimbusds.jose.jwk.JWK;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Conditional;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.web.configurers.oauth2.client.OAuth2LoginConfigurer;
import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
Expand All @@ -36,8 +40,12 @@
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.web.client.RestTemplate;

import static com.azure.spring.cloud.autoconfigure.aad.implementation.AadRestTemplateCreator.createOAuth2AccessTokenResponseClientRestTemplate;
import static com.azure.spring.cloud.autoconfigure.aad.implementation.AadRestTemplateCreator.createRestTemplate;

/**
* <p>
Expand Down Expand Up @@ -170,6 +178,22 @@ RefreshTokenOAuth2AuthorizedClientProvider azureRefreshTokenProvider(
return provider;
}

/**
* Provide {@link JwtDecoderFactory} used in {@link OAuth2LoginConfigurer#init}. The {@link JwtDecoder} created by
* current {@link JwtDecoderFactory} will use {@link RestTemplate} created by {@link RestTemplateBuilder} bean.
*
* @param properties the AadAuthenticationProperties
* @return JwtDecoderFactory
*/
@Bean
@ConditionalOnMissingBean
public JwtDecoderFactory<ClientRegistration> azureAdJwtDecoderFactory(AadAuthenticationProperties properties) {
AadProfileProperties profile = properties.getProfile();
AadAuthorizationServerEndpoints endpoints = new AadAuthorizationServerEndpoints(
profile.getEnvironment().getActiveDirectoryEndpoint(), profile.getTenantId());
return new AadOidcIdTokenDecoderFactory(endpoints.getJwkSetEndpoint(), createRestTemplate(restTemplateBuilder));
}

private void passwordGrantBuilderAccessTokenResponseClientCustomizer(
OAuth2AuthorizedClientProviderBuilder.PasswordGrantBuilder builder,
OAuth2ClientAuthenticationJwkResolver resolver) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.spring.cloud.autoconfigure.aad.implementation.webapp;

import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.web.client.RestOperations;

/**
* A factory that provides a {@link JwtDecoder} used for {@link OidcIdToken} signature verification.
*
* @see <a href="https://learn.microsoft.com/azure/active-directory/develop/id-tokens">azure-active-directory id-tokens</a>
*/
public class AadOidcIdTokenDecoderFactory implements JwtDecoderFactory<ClientRegistration> {

private final JwtDecoder jwtDecoder;

/**
*
* @param jwkSetUri The uri of the jwk set. For example:
* <a href="https://login.microsoftonline.com/common/discovery/v2.0/keys">
* https://login.microsoftonline.com/common/discovery/v2.0/keys</a>
* @param restOperations The RestOperations used to retrieve jwk from jwkSetUri.
*/
public AadOidcIdTokenDecoderFactory(String jwkSetUri, RestOperations restOperations) {
this.jwtDecoder = NimbusJwtDecoder
.withJwkSetUri(jwkSetUri)
.jwsAlgorithm(SignatureAlgorithm.RS256)
.restOperations(restOperations)
.build();
}

@Override
public JwtDecoder createDecoder(ClientRegistration context) {
return jwtDecoder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import com.azure.spring.cloud.autoconfigure.aad.properties.AadAuthenticationProperties;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.http.HttpMessageConvertersAutoConfiguration;
Expand All @@ -34,6 +38,9 @@
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
Expand All @@ -52,6 +59,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -265,6 +273,24 @@ void restTemplateWellConfiguredForAllOAuth2AuthorizedClientProvidersWhenUsingPri
});
}

@Test
void restTemplateWellConfiguredForJwtDecoderCreatedByJwtDecoderFactory() {
webApplicationContextRunner()
.withUserConfiguration(AadOAuth2ClientConfiguration.class, RestTemplateProxyCustomizerConfiguration.class)
.run(context -> {
assertThat(context).hasSingleBean(JwtDecoderFactory.class);
JwtDecoderFactory<?> factory = context.getBean(JwtDecoderFactory.class);
JwtDecoder jwtDecoder = factory.createDecoder(null);
assertTrue(jwtDecoder instanceof NimbusJwtDecoder);
DefaultJWTProcessor<?> processor = (DefaultJWTProcessor<?>) getField(NimbusJwtDecoder.class, "jwtProcessor", jwtDecoder);
JWSVerificationKeySelector<?> selector = (JWSVerificationKeySelector<?>) processor.getJWSKeySelector();
RemoteJWKSet<?> source = (RemoteJWKSet<?>)selector.getJWKSource();
ResourceRetriever retriever = source.getResourceRetriever();
RestTemplate restTemplate = (RestTemplate) getField(retriever.getClass(), "restOperations", retriever);
assertEquals(FACTORY, restTemplate.getRequestFactory());
});
}

private static void assertRestTemplateWellConfiguredForAllOAuth2AuthorizedClientProviders(ApplicationContext context) {
List<OAuth2AuthorizedClientProvider> providers = getAllOAuth2AuthorizedClientProviderThatShouldConfiguredRestTemplate(context);
assertEquals(4, providers.size());
Expand Down

0 comments on commit 7a47e1b

Please sign in to comment.