Skip to content
This repository has been archived by the owner on Aug 28, 2024. It is now read-only.

Added support for providing a JWKSCache implementation #804

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,29 @@
*/
package com.microsoft.azure.spring.autoconfigure.aad;

import com.microsoft.aad.adal4j.ClientCredential;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.util.ResourceRetriever;
import java.io.IOException;
import java.net.MalformedURLException;
import java.text.ParseException;
import java.util.concurrent.ExecutionException;

import javax.naming.ServiceUnavailableException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.naming.ServiceUnavailableException;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.net.MalformedURLException;
import java.text.ParseException;
import java.util.concurrent.ExecutionException;
import com.microsoft.aad.adal4j.ClientCredential;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.util.ResourceRetriever;

public class AADAuthenticationFilter extends OncePerRequestFilter {
private static final Logger log = LoggerFactory.getLogger(AADAuthenticationFilter.class);
Expand All @@ -48,6 +51,15 @@ public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps,
this.principalManager = new UserPrincipalManager(serviceEndpointsProps, aadAuthProps, resourceRetriever, false);
}

public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps,
ServiceEndpointsProperties serviceEndpointsProps,
ResourceRetriever resourceRetriever,
JWKSetCache jwkSetCache) {
this.aadAuthProps = aadAuthProps;
this.serviceEndpointsProps = serviceEndpointsProps;
this.principalManager = new UserPrincipalManager(serviceEndpointsProps, aadAuthProps, resourceRetriever, false, jwkSetCache);
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
import static com.microsoft.azure.telemetry.TelemetryData.SERVICE_NAME;
import static com.microsoft.azure.telemetry.TelemetryData.getClassPackageSimpleName;

import com.microsoft.azure.telemetry.TelemetrySender;
import com.nimbusds.jose.util.DefaultResourceRetriever;
import com.nimbusds.jose.util.ResourceRetriever;
import java.util.HashMap;
import java.util.Map;

import javax.annotation.PostConstruct;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression;
Expand All @@ -27,6 +26,11 @@
import org.springframework.context.annotation.PropertySource;
import org.springframework.util.ClassUtils;

import com.microsoft.azure.telemetry.TelemetrySender;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.util.DefaultResourceRetriever;
import com.nimbusds.jose.util.ResourceRetriever;

@Configuration
@ConditionalOnWebApplication
@ConditionalOnResource(resources = "classpath:aad.enable.config")
Expand Down Expand Up @@ -60,7 +64,12 @@ public AADAuthenticationFilterAutoConfiguration(AADAuthenticationProperties aadA
@ConditionalOnExpression("${azure.activedirectory.session-stateless:false} == false")
public AADAuthenticationFilter azureADJwtTokenFilter() {
LOG.info("AzureADJwtTokenFilter Constructor.");
return new AADAuthenticationFilter(aadAuthProps, serviceEndpointsProps, getJWTResourceRetriever());
if (getJWKSetCache() != null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method will always return null. How can you get the specific JWKSetCache customized by yourself?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also wondering about this. Trying to figure this part out

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with the below changes I suggested we might not need the null check. We would always add it as the 4th parameter. And we could default the lifespan to 5 minutes. That's what it gets defaulted to if you don't pass it in as per the documentation so that would be the same behavior unless you specify in the properties file

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

LOG.info("Initializing with JWKS cache");
return new AADAuthenticationFilter(aadAuthProps, serviceEndpointsProps, getJWTResourceRetriever(), getJWKSetCache());
} else {
return new AADAuthenticationFilter(aadAuthProps, serviceEndpointsProps, getJWTResourceRetriever());
}
}

@Bean
Expand All @@ -80,6 +89,12 @@ public ResourceRetriever getJWTResourceRetriever() {
aadAuthProps.getJwtSizeLimit());
}

@Bean
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@Bean
@Bean
@ConditionalOnMissingBean(JWKSetCache.class)
public JWKSetCache getJWKSetCache() {
return new DefaultJWKSetCache(properties.lifespan, TimeUnit.MILLISECONDS)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be something like what we would want to do? We would then define the lifespan within a properties file

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sound good! It is a cool way to customize the cache by properties. If users don't set related configuration, it will use 5 min by default.

@ConditionalOnMissingBean(JWKSetCache.class)
public JWKSetCache getJWKSetCache() {
return null;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method will always return null. It makes no sense.

Copy link
Contributor

@bcannariato bcannariato Feb 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well with my understanding that @ConditionalOnMissingBean annotation makes it so this method will only get called when JWKSetCache is missing? So in that case just return null. So the check above will create the AADAuthenticationFilter without the cache object.

That's what the library currently does, it doesn't include that last argument, it has 2 constructers the old one without jwkSetCache and a new one with it.

But I do not understand where the cache is configured if that bean is present...

}

@PostConstruct
private void sendTelemetry() {
if (aadAuthProps.isAllowTelemetry()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@
*/
package com.microsoft.azure.spring.autoconfigure.aad;

import java.net.MalformedURLException;
import java.net.URL;
import java.text.ParseException;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.BadJOSEException;
Expand All @@ -16,15 +24,13 @@
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.*;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTClaimsSetVerifier;

import java.net.MalformedURLException;
import java.net.URL;
import java.text.ParseException;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class UserPrincipalManager {
Expand Down Expand Up @@ -79,6 +85,37 @@ public UserPrincipalManager(ServiceEndpointsProperties serviceEndpointsProps,
}
}

/**
* Create a new {@link UserPrincipalManager} based of the {@link ServiceEndpoints#getAadKeyDiscoveryUri()} and
* {@link AADAuthenticationProperties#getEnvironment()}.
*
* @param serviceEndpointsProps - used to retrieve the JWKS URL
* @param aadAuthProps - used to retrieve the environment.
* @param resourceRetriever - configures the {@link RemoteJWKSet} call.
* @param jwkSetCache - used to cache the JWK set for a finite time
*/
public UserPrincipalManager(ServiceEndpointsProperties serviceEndpointsProps,
AADAuthenticationProperties aadAuthProps,
ResourceRetriever resourceRetriever,
boolean explicitAudienceCheck,
JWKSetCache jwkSetCache) {
this.aadAuthProps = aadAuthProps;
this.explicitAudienceCheck = explicitAudienceCheck;
if (explicitAudienceCheck) {
// client-id for "normal" check
this.validAudiences.add(this.aadAuthProps.getClientId());
// app id uri for client credentials flow (server to server communication)
this.validAudiences.add(this.aadAuthProps.getAppIdUri());
}
try {
keySource = new RemoteJWKSet<>(new URL(serviceEndpointsProps
.getServiceEndpoints(aadAuthProps.getEnvironment()).getAadKeyDiscoveryUri()), resourceRetriever, jwkSetCache);
} catch (MalformedURLException e) {
log.error("Failed to parse active directory key discovery uri.", e);
throw new IllegalStateException("Failed to parse active directory key discovery uri.", e);
}
}

public UserPrincipal buildUserPrincipal(String idToken) throws ParseException, JOSEException, BadJOSEException {
final JWSObject jwsObject = JWSObject.parse(idToken);
final ConfigurableJWTProcessor<SecurityContext> validator =
Expand Down