Skip to content

Commit

Permalink
FISH-868 JWT public key cache TTL can be set by HTTP header max-age.
Browse files Browse the repository at this point in the history
  • Loading branch information
ghunteranderson committed Jan 16, 2021
1 parent 2276d6a commit 6bd8fb3
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,16 @@
import java.security.PublicKey;
import java.security.spec.RSAPublicKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.enterprise.inject.spi.DeploymentException;
import javax.json.Json;
Expand All @@ -81,21 +85,19 @@ class JwtPublicKeyStore {

private static final Logger LOGGER = Logger.getLogger(JwtPublicKeyStore.class.getName());
private static final String RSA_ALGORITHM = "RSA";


private final Config config;
private final Supplier<Optional<String>> cacheSupplier;
private final Duration defaultCacheTTL;

/**
* @param cacheTTL Public key cache TTL in milliseconds
* @param defaultCacheTTL Public key cache TTL
*/
public JwtPublicKeyStore(Long cacheTTL) {
public JwtPublicKeyStore(Duration defaultCacheTTL) {
this.config = ConfigProvider.getConfig();
if(cacheTTL > 0) {
cacheSupplier = new PublicKeyLoadingCache(cacheTTL, this::readRawPublicKey)::get;
}
else {
cacheSupplier = this::readRawPublicKey;
}
this.defaultCacheTTL = defaultCacheTTL;
this.cacheSupplier = new PublicKeyLoadingCache(this::readRawPublicKey)::get;
}

/**
Expand All @@ -110,8 +112,8 @@ public PublicKey getPublicKey(String keyID) {
.orElseThrow(() -> new IllegalStateException("No PublicKey found"));
}

private Optional<String> readRawPublicKey() {
Optional<String> publicKey = readDefaultPublicKey();
private CacheableString readRawPublicKey() {
CacheableString publicKey = readDefaultPublicKey();

if (!publicKey.isPresent()) {
publicKey = readMPEmbeddedPublicKey();
Expand All @@ -122,27 +124,28 @@ private Optional<String> readRawPublicKey() {
return publicKey;
}

private Optional<String> readDefaultPublicKey() {
private CacheableString readDefaultPublicKey() {
return readPublicKeyFromLocation("/publicKey.pem");
}

private Optional<String> readMPEmbeddedPublicKey() {
return config.getOptionalValue(VERIFIER_PUBLIC_KEY, String.class);
private CacheableString readMPEmbeddedPublicKey() {
String publicKey = config.getOptionalValue(VERIFIER_PUBLIC_KEY, String.class).orElse(null);
return CacheableString.from(publicKey, defaultCacheTTL);
}

private Optional<String> readMPPublicKeyFromLocation() {
private CacheableString readMPPublicKeyFromLocation() {
Optional<String> locationOpt = config.getOptionalValue(VERIFIER_PUBLIC_KEY_LOCATION, String.class);

if (!locationOpt.isPresent()) {
return Optional.empty();
return CacheableString.empty(defaultCacheTTL);
}

String publicKeyLocation = locationOpt.get();

return readPublicKeyFromLocation(publicKeyLocation);
}

private Optional<String> readPublicKeyFromLocation(String publicKeyLocation) {
private CacheableString readPublicKeyFromLocation(String publicKeyLocation) {

URL publicKeyURL = currentThread().getContextClassLoader().getResource(publicKeyLocation);

Expand All @@ -154,7 +157,7 @@ private Optional<String> readPublicKeyFromLocation(String publicKeyLocation) {
}
}
if (publicKeyURL == null) {
return Optional.empty();
return CacheableString.empty(defaultCacheTTL);
}

try {
Expand All @@ -164,7 +167,7 @@ private Optional<String> readPublicKeyFromLocation(String publicKeyLocation) {
}
}

private Optional<String> readPublicKeyFromURL(URL publicKeyURL) throws IOException {
private CacheableString readPublicKeyFromURL(URL publicKeyURL) throws IOException {

URLConnection urlConnection = publicKeyURL.openConnection();
Charset charset = Charset.defaultCharset();
Expand All @@ -182,11 +185,34 @@ private Optional<String> readPublicKeyFromURL(URL publicKeyURL) throws IOExcepti
LOGGER.severe("Charset " + ex.getCharsetName() + " for remote public key not support, Cause: " + ex.getMessage());
}
}

}


// There's no guarantee that the response will contain at most one Cache-Control header and at most one max-age directive.
// Here, we apply the smallest of all max-age directives.
Duration cacheTTL = urlConnection.getHeaderFields().entrySet().stream()
.filter(e -> e.getKey() != null && e.getKey().trim().equalsIgnoreCase("Cache-Control"))
.flatMap(headers -> headers.getValue().stream())
.flatMap(headerValue -> Stream.of(headerValue.split(",")))
.filter(directive -> directive.trim().startsWith("max-age"))
.map(maxAgeDirective -> {
String[] keyValue = maxAgeDirective.split("=",2);
String maxAge = keyValue[keyValue.length-1];
try {
return Duration.ofSeconds(Long.parseLong(maxAge));
} catch(NumberFormatException e) {
return null;
}
})
.filter(Objects::nonNull)
.min(Duration::compareTo)
.orElse(defaultCacheTTL);

try (InputStream inputStream = urlConnection.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, charset))){
String keyContents = reader.lines().collect(Collectors.joining(System.lineSeparator()));
return Optional.of(keyContents);
return CacheableString.from(keyContents, cacheTTL);
}

}
Expand Down Expand Up @@ -267,25 +293,20 @@ private JsonObject findJwk(JsonArray keys, String keyID) {

private static class PublicKeyLoadingCache {

private final long ttl;
private final Supplier<Optional<String>> keySupplier;
private final Supplier<CacheableString> keySupplier;
private Duration ttl;
private long lastUpdated;
private Optional<String> publicKey;


/**
*
* @param ttl Public key cache TTL in milliseconds
* @param keySupplier A supplier to load the public key.
*/
public PublicKeyLoadingCache(long ttl, Supplier<Optional<String>> keySupplier) {
this.ttl = ttl;
public PublicKeyLoadingCache(Supplier<CacheableString> keySupplier) {
this.ttl = Duration.ZERO;
this.keySupplier = keySupplier;
}

public Optional<String> get() {
long now = System.currentTimeMillis();
if(now - lastUpdated > ttl) {
if(now - lastUpdated > ttl.toMillis()) {
refresh();
}

Expand All @@ -294,11 +315,42 @@ public Optional<String> get() {

private synchronized void refresh() {
long now = System.currentTimeMillis();
if(now - lastUpdated > ttl) {
publicKey = keySupplier.get();
if(now - lastUpdated > ttl.toMillis()) {
CacheableString result = keySupplier.get();
publicKey = result.getValue();
ttl = result.getCacheTTL();
lastUpdated = now;
}
}

}

private static class CacheableString {

public static CacheableString empty(Duration cacheTTL) {
return from(null, cacheTTL);
}

public static CacheableString from(String value, Duration cacheTTL) {
CacheableString instance = new CacheableString();
instance.cacheTTL = cacheTTL;
instance.value = value;
return instance;
}

private String value;
private Duration cacheTTL;

public Optional<String> getValue() {
return Optional.ofNullable(value);
}

public Duration getCacheTTL() {
return cacheTTL;
}

public boolean isPresent() {
return value != null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@
import java.io.IOException;
import java.net.URL;
import java.security.PublicKey;
import java.time.Duration;
import java.util.Collection;
import java.util.HashSet;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;

import javax.security.enterprise.identitystore.CredentialValidationResult;
Expand Down Expand Up @@ -152,11 +152,12 @@ private Optional<Boolean> readDisableTypeVerification(Optional<Properties> prope
return properties.isPresent() ? Optional.ofNullable(Boolean.valueOf(properties.get().getProperty("disable.type.verification", "false"))) : Optional.empty();
}

private Long readPublicKeyCacheTTL(Optional<Properties> properties) {
private Duration readPublicKeyCacheTTL(Optional<Properties> properties) {
return properties
.map(props -> props.getProperty("publicKey.cache.ttl"))
.map(Long::valueOf)
.orElseGet( () -> TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES));
.map(Duration::ofMillis)
.orElseGet( () -> Duration.ofMinutes(5));
}


Expand Down

0 comments on commit 6bd8fb3

Please sign in to comment.