diff --git a/appserver/payara-appserver-modules/microprofile/jwt-auth/src/main/java/fish/payara/microprofile/jwtauth/eesecurity/JwtPublicKeyStore.java b/appserver/payara-appserver-modules/microprofile/jwt-auth/src/main/java/fish/payara/microprofile/jwtauth/eesecurity/JwtPublicKeyStore.java index e9d930311b1..7b833f7af33 100644 --- a/appserver/payara-appserver-modules/microprofile/jwt-auth/src/main/java/fish/payara/microprofile/jwtauth/eesecurity/JwtPublicKeyStore.java +++ b/appserver/payara-appserver-modules/microprofile/jwt-auth/src/main/java/fish/payara/microprofile/jwtauth/eesecurity/JwtPublicKeyStore.java @@ -59,12 +59,16 @@ import java.security.PublicKey; import java.security.spec.RSAPublicKeySpec; import java.security.spec.X509EncodedKeySpec; +import java.time.Duration; +import java.time.temporal.ChronoUnit; import java.util.Base64; +import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; import java.util.function.Supplier; import java.util.logging.Logger; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.enterprise.inject.spi.DeploymentException; import javax.json.Json; @@ -81,21 +85,19 @@ class JwtPublicKeyStore { private static final Logger LOGGER = Logger.getLogger(JwtPublicKeyStore.class.getName()); private static final String RSA_ALGORITHM = "RSA"; + private final Config config; private final Supplier> cacheSupplier; + private final Duration defaultCacheTTL; /** - * @param cacheTTL Public key cache TTL in milliseconds + * @param defaultCacheTTL Public key cache TTL */ - public JwtPublicKeyStore(Long cacheTTL) { + public JwtPublicKeyStore(Duration defaultCacheTTL) { this.config = ConfigProvider.getConfig(); - if(cacheTTL > 0) { - cacheSupplier = new PublicKeyLoadingCache(cacheTTL, this::readRawPublicKey)::get; - } - else { - cacheSupplier = this::readRawPublicKey; - } + this.defaultCacheTTL = defaultCacheTTL; + this.cacheSupplier = new PublicKeyLoadingCache(this::readRawPublicKey)::get; } /** @@ -110,8 +112,8 @@ public PublicKey getPublicKey(String keyID) { .orElseThrow(() -> new IllegalStateException("No PublicKey found")); } - private Optional readRawPublicKey() { - Optional publicKey = readDefaultPublicKey(); + private CacheableString readRawPublicKey() { + CacheableString publicKey = readDefaultPublicKey(); if (!publicKey.isPresent()) { publicKey = readMPEmbeddedPublicKey(); @@ -122,19 +124,20 @@ private Optional readRawPublicKey() { return publicKey; } - private Optional readDefaultPublicKey() { + private CacheableString readDefaultPublicKey() { return readPublicKeyFromLocation("/publicKey.pem"); } - private Optional readMPEmbeddedPublicKey() { - return config.getOptionalValue(VERIFIER_PUBLIC_KEY, String.class); + private CacheableString readMPEmbeddedPublicKey() { + String publicKey = config.getOptionalValue(VERIFIER_PUBLIC_KEY, String.class).orElse(null); + return CacheableString.from(publicKey, defaultCacheTTL); } - private Optional readMPPublicKeyFromLocation() { + private CacheableString readMPPublicKeyFromLocation() { Optional locationOpt = config.getOptionalValue(VERIFIER_PUBLIC_KEY_LOCATION, String.class); if (!locationOpt.isPresent()) { - return Optional.empty(); + return CacheableString.empty(defaultCacheTTL); } String publicKeyLocation = locationOpt.get(); @@ -142,7 +145,7 @@ private Optional readMPPublicKeyFromLocation() { return readPublicKeyFromLocation(publicKeyLocation); } - private Optional readPublicKeyFromLocation(String publicKeyLocation) { + private CacheableString readPublicKeyFromLocation(String publicKeyLocation) { URL publicKeyURL = currentThread().getContextClassLoader().getResource(publicKeyLocation); @@ -154,7 +157,7 @@ private Optional readPublicKeyFromLocation(String publicKeyLocation) { } } if (publicKeyURL == null) { - return Optional.empty(); + return CacheableString.empty(defaultCacheTTL); } try { @@ -164,7 +167,7 @@ private Optional readPublicKeyFromLocation(String publicKeyLocation) { } } - private Optional readPublicKeyFromURL(URL publicKeyURL) throws IOException { + private CacheableString readPublicKeyFromURL(URL publicKeyURL) throws IOException { URLConnection urlConnection = publicKeyURL.openConnection(); Charset charset = Charset.defaultCharset(); @@ -182,11 +185,34 @@ private Optional readPublicKeyFromURL(URL publicKeyURL) throws IOExcepti LOGGER.severe("Charset " + ex.getCharsetName() + " for remote public key not support, Cause: " + ex.getMessage()); } } + } + + + // There's no guarantee that the response will contain at most one Cache-Control header and at most one max-age directive. + // Here, we apply the smallest of all max-age directives. + Duration cacheTTL = urlConnection.getHeaderFields().entrySet().stream() + .filter(e -> e.getKey() != null && e.getKey().trim().equalsIgnoreCase("Cache-Control")) + .flatMap(headers -> headers.getValue().stream()) + .flatMap(headerValue -> Stream.of(headerValue.split(","))) + .filter(directive -> directive.trim().startsWith("max-age")) + .map(maxAgeDirective -> { + String[] keyValue = maxAgeDirective.split("=",2); + String maxAge = keyValue[keyValue.length-1]; + try { + return Duration.ofSeconds(Long.parseLong(maxAge)); + } catch(NumberFormatException e) { + return null; + } + }) + .filter(Objects::nonNull) + .min(Duration::compareTo) + .orElse(defaultCacheTTL); + try (InputStream inputStream = urlConnection.getInputStream(); BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, charset))){ String keyContents = reader.lines().collect(Collectors.joining(System.lineSeparator())); - return Optional.of(keyContents); + return CacheableString.from(keyContents, cacheTTL); } } @@ -267,25 +293,20 @@ private JsonObject findJwk(JsonArray keys, String keyID) { private static class PublicKeyLoadingCache { - private final long ttl; - private final Supplier> keySupplier; + private final Supplier keySupplier; + private Duration ttl; private long lastUpdated; private Optional publicKey; - /** - * - * @param ttl Public key cache TTL in milliseconds - * @param keySupplier A supplier to load the public key. - */ - public PublicKeyLoadingCache(long ttl, Supplier> keySupplier) { - this.ttl = ttl; + public PublicKeyLoadingCache(Supplier keySupplier) { + this.ttl = Duration.ZERO; this.keySupplier = keySupplier; } public Optional get() { long now = System.currentTimeMillis(); - if(now - lastUpdated > ttl) { + if(now - lastUpdated > ttl.toMillis()) { refresh(); } @@ -294,11 +315,42 @@ public Optional get() { private synchronized void refresh() { long now = System.currentTimeMillis(); - if(now - lastUpdated > ttl) { - publicKey = keySupplier.get(); + if(now - lastUpdated > ttl.toMillis()) { + CacheableString result = keySupplier.get(); + publicKey = result.getValue(); + ttl = result.getCacheTTL(); lastUpdated = now; } } } + + private static class CacheableString { + + public static CacheableString empty(Duration cacheTTL) { + return from(null, cacheTTL); + } + + public static CacheableString from(String value, Duration cacheTTL) { + CacheableString instance = new CacheableString(); + instance.cacheTTL = cacheTTL; + instance.value = value; + return instance; + } + + private String value; + private Duration cacheTTL; + + public Optional getValue() { + return Optional.ofNullable(value); + } + + public Duration getCacheTTL() { + return cacheTTL; + } + + public boolean isPresent() { + return value != null; + } + } } diff --git a/appserver/payara-appserver-modules/microprofile/jwt-auth/src/main/java/fish/payara/microprofile/jwtauth/eesecurity/SignedJWTIdentityStore.java b/appserver/payara-appserver-modules/microprofile/jwt-auth/src/main/java/fish/payara/microprofile/jwtauth/eesecurity/SignedJWTIdentityStore.java index 192222e7aeb..1a9d26cc1df 100644 --- a/appserver/payara-appserver-modules/microprofile/jwt-auth/src/main/java/fish/payara/microprofile/jwtauth/eesecurity/SignedJWTIdentityStore.java +++ b/appserver/payara-appserver-modules/microprofile/jwt-auth/src/main/java/fish/payara/microprofile/jwtauth/eesecurity/SignedJWTIdentityStore.java @@ -47,12 +47,12 @@ import java.io.IOException; import java.net.URL; import java.security.PublicKey; +import java.time.Duration; import java.util.Collection; import java.util.HashSet; import java.util.Optional; import java.util.Properties; import java.util.Set; -import java.util.concurrent.TimeUnit; import java.util.logging.Logger; import javax.security.enterprise.identitystore.CredentialValidationResult; @@ -152,11 +152,12 @@ private Optional readDisableTypeVerification(Optional prope return properties.isPresent() ? Optional.ofNullable(Boolean.valueOf(properties.get().getProperty("disable.type.verification", "false"))) : Optional.empty(); } - private Long readPublicKeyCacheTTL(Optional properties) { + private Duration readPublicKeyCacheTTL(Optional properties) { return properties .map(props -> props.getProperty("publicKey.cache.ttl")) .map(Long::valueOf) - .orElseGet( () -> TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES)); + .map(Duration::ofMillis) + .orElseGet( () -> Duration.ofMinutes(5)); }