Skip to content

Commit

Permalink
Fix #31233: NimbusJwtDecoder still uses RestTemplate() instead RestTe…
Browse files Browse the repository at this point in the history
…mplateBuilder (#31521)
  • Loading branch information
rujche authored Oct 19, 2022
1 parent 6869ddd commit 359bfba
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 1 deletion.
1 change: 1 addition & 0 deletions sdk/spring/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Upgrade Spring Boot dependencies version to 2.7.4 and Spring Cloud dependencies
- Fix bug: RestTemplate used to get access token should only contain 2 converters. [#31482](https://github.com/Azure/azure-sdk-for-java/issues/31482).
- Fix bug: RestOperations is not well configured when jwkResolver is null. [#31218](https://github.com/Azure/azure-sdk-for-java/issues/31218).
- Fix bug: Duplicated "scope" parameter. [#31191](https://github.com/Azure/azure-sdk-for-java/issues/31191).
- Fix bug: NimbusJwtDecoder still uses `RestTemplate()` instead `RestTemplateBuilder` [#31233](https://github.com/Azure/azure-sdk-for-java/issues/31233)

## 4.4.0 (2022-09-26)
Upgrade Spring Boot dependencies version to 2.7.3 and Spring Cloud dependencies version to 2021.0.3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
import com.azure.spring.cloud.autoconfigure.aad.implementation.oauth2.OAuth2ClientAuthenticationJwkResolver;
import com.azure.spring.cloud.autoconfigure.aad.implementation.webapi.AadJwtBearerGrantRequestEntityConverter;
import com.azure.spring.cloud.autoconfigure.aad.implementation.webapp.AadAzureDelegatedOAuth2AuthorizedClientProvider;
import com.azure.spring.cloud.autoconfigure.aad.implementation.webapp.AadOidcIdTokenDecoderFactory;
import com.azure.spring.cloud.autoconfigure.aad.properties.AadAuthenticationProperties;
import com.azure.spring.cloud.autoconfigure.aad.properties.AadAuthorizationServerEndpoints;
import com.azure.spring.cloud.autoconfigure.aad.properties.AadProfileProperties;
import com.nimbusds.jose.jwk.JWK;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Conditional;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.web.configurers.oauth2.client.OAuth2LoginConfigurer;
import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
Expand All @@ -36,8 +40,12 @@
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.web.client.RestTemplate;

import static com.azure.spring.cloud.autoconfigure.aad.implementation.AadRestTemplateCreator.createOAuth2AccessTokenResponseClientRestTemplate;
import static com.azure.spring.cloud.autoconfigure.aad.implementation.AadRestTemplateCreator.createRestTemplate;

/**
* <p>
Expand Down Expand Up @@ -170,6 +178,22 @@ RefreshTokenOAuth2AuthorizedClientProvider azureRefreshTokenProvider(
return provider;
}

/**
* Provide {@link JwtDecoderFactory} used in {@link OAuth2LoginConfigurer#init}. The {@link JwtDecoder} created by
* current {@link JwtDecoderFactory} will use {@link RestTemplate} created by {@link RestTemplateBuilder} bean.
*
* @param properties the AadAuthenticationProperties
* @return JwtDecoderFactory
*/
@Bean
@ConditionalOnMissingBean
JwtDecoderFactory<ClientRegistration> azureAdJwtDecoderFactory(AadAuthenticationProperties properties) {
AadProfileProperties profile = properties.getProfile();
AadAuthorizationServerEndpoints endpoints = new AadAuthorizationServerEndpoints(
profile.getEnvironment().getActiveDirectoryEndpoint(), profile.getTenantId());
return new AadOidcIdTokenDecoderFactory(endpoints.getJwkSetEndpoint(), createRestTemplate(restTemplateBuilder));
}

private void passwordGrantBuilderAccessTokenResponseClientCustomizer(
OAuth2AuthorizedClientProviderBuilder.PasswordGrantBuilder builder,
OAuth2ClientAuthenticationJwkResolver resolver) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public JwtDecoder jwtDecoder(AadAuthenticationProperties aadAuthenticationProper
AadAuthorizationServerEndpoints identityEndpoints = new AadAuthorizationServerEndpoints(
aadAuthenticationProperties.getProfile().getEnvironment().getActiveDirectoryEndpoint(), aadAuthenticationProperties.getProfile().getTenantId());
NimbusJwtDecoder nimbusJwtDecoder = NimbusJwtDecoder
.withJwkSetUri(identityEndpoints.getJwkSetEndpoint())
.withJwkSetUri(identityEndpoints.getJwkSetEndpoint())
.restOperations(createRestTemplate(restTemplateBuilder))
.build();
List<OAuth2TokenValidator<Jwt>> validators = createDefaultValidator(aadAuthenticationProperties);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.spring.cloud.autoconfigure.aad.implementation.webapp;

import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.web.client.RestOperations;

/**
* A factory that provides a {@link JwtDecoder} used for {@link OidcIdToken} signature verification.
*
* @see <a href="https://learn.microsoft.com/azure/active-directory/develop/id-tokens">azure-active-directory id-tokens</a>
*/
public class AadOidcIdTokenDecoderFactory implements JwtDecoderFactory<ClientRegistration> {

private final JwtDecoder jwtDecoder;

/**
*
* @param jwkSetUri The uri of the jwk set. For example:
* <a href="https://login.microsoftonline.com/common/discovery/v2.0/keys">
* https://login.microsoftonline.com/common/discovery/v2.0/keys</a>
* @param restOperations The RestOperations used to retrieve jwk from jwkSetUri.
*/
public AadOidcIdTokenDecoderFactory(String jwkSetUri, RestOperations restOperations) {
this.jwtDecoder = NimbusJwtDecoder
.withJwkSetUri(jwkSetUri)
.jwsAlgorithm(SignatureAlgorithm.RS256)
.restOperations(restOperations)
.build();
}

@Override
public JwtDecoder createDecoder(ClientRegistration context) {
return jwtDecoder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import com.azure.spring.cloud.autoconfigure.aad.properties.AadAuthenticationProperties;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.http.HttpMessageConvertersAutoConfiguration;
Expand All @@ -34,6 +38,9 @@
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
Expand All @@ -52,6 +59,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -265,6 +273,24 @@ void restTemplateWellConfiguredForAllOAuth2AuthorizedClientProvidersWhenUsingPri
});
}

@Test
void restTemplateWellConfiguredForJwtDecoderCreatedByJwtDecoderFactory() {
webApplicationContextRunner()
.withUserConfiguration(AadOAuth2ClientConfiguration.class, RestTemplateProxyCustomizerConfiguration.class)
.run(context -> {
assertThat(context).hasSingleBean(JwtDecoderFactory.class);
JwtDecoderFactory<?> factory = context.getBean(JwtDecoderFactory.class);
JwtDecoder jwtDecoder = factory.createDecoder(null);
assertTrue(jwtDecoder instanceof NimbusJwtDecoder);
DefaultJWTProcessor<?> processor = (DefaultJWTProcessor<?>) getField(NimbusJwtDecoder.class, "jwtProcessor", jwtDecoder);
JWSVerificationKeySelector<?> selector = (JWSVerificationKeySelector<?>) processor.getJWSKeySelector();
RemoteJWKSet<?> source = (RemoteJWKSet<?>) selector.getJWKSource();
ResourceRetriever retriever = source.getResourceRetriever();
RestTemplate restTemplate = (RestTemplate) getField(retriever.getClass(), "restOperations", retriever);
assertEquals(FACTORY, restTemplate.getRequestFactory());
});
}

private static void assertRestTemplateWellConfiguredForAllOAuth2AuthorizedClientProviders(ApplicationContext context) {
List<OAuth2AuthorizedClientProvider> providers = getAllOAuth2AuthorizedClientProviderThatShouldConfiguredRestTemplate(context);
assertEquals(4, providers.size());
Expand Down

0 comments on commit 359bfba

Please sign in to comment.