Skip to content

Commit

Permalink
Support for OIDC verification time JWK set resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
sberyozkin committed Nov 10, 2023
1 parent e863473 commit d161990
Show file tree
Hide file tree
Showing 13 changed files with 512 additions and 228 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,81 @@ public void setCleanUpTimerInterval(Duration cleanUpTimerInterval) {
}
}

/**
* Configuration for controlling how JsonWebKeySet containing verification keys should be acquired and managed.
*/
@ConfigItem
public Jwks jwks = new Jwks();

@ConfigGroup
public static class Jwks {

/**
* If JWK verification keys should be fetched at the moment a connection to the OIDC provider
* is initialized.
* <p/>
* Disabling this property will delay the key acquisition until the moment the current token
* has to be verified. Typically it can only be necessary if the token or other telated request properties
* provide an additional context which is required to resolve the keys correctly.
*/
@ConfigItem(defaultValue = "true")
public boolean resolveEarly = true;

/**
* Maximum number of JWK keys that can be cached.
* This property will be ignored if the {@link #resolveEarly} property is set to true.
*/
@ConfigItem(defaultValue = "10")
public int cacheSize = 10;

/**
* Number of minutes a JWK key can be cached for.
* This property will be ignored if the {@link #resolveEarly} property is set to true.
*/
@ConfigItem(defaultValue = "10M")
public Duration cacheTimeToLive = Duration.ofMinutes(10);

/**
* Cache timer interval.
* If this property is set then a timer will check and remove the stale entries periodically.
* This property will be ignored if the {@link #resolveEarly} property is set to true.
*/
@ConfigItem
public Optional<Duration> cleanUpTimerInterval = Optional.empty();

public int getCacheSize() {
return cacheSize;
}

public void setCacheSize(int cacheSize) {
this.cacheSize = cacheSize;
}

public Duration getCacheTimeToLive() {
return cacheTimeToLive;
}

public void setCacheTimeToLive(Duration cacheTimeToLive) {
this.cacheTimeToLive = cacheTimeToLive;
}

public Optional<Duration> getCleanUpTimerInterval() {
return cleanUpTimerInterval;
}

public void setCleanUpTimerInterval(Duration cleanUpTimerInterval) {
this.cleanUpTimerInterval = Optional.of(cleanUpTimerInterval);
}

public boolean isResolveEarly() {
return resolveEarly;
}

public void setResolveEarly(boolean resolveEarly) {
this.resolveEarly = resolveEarly;
}
}

@ConfigGroup
public static class Frontchannel {
/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,103 +1,33 @@
package io.quarkus.oidc.runtime;

import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import jakarta.enterprise.event.Observes;

import io.quarkus.oidc.OidcTenantConfig;
import io.vertx.core.Handler;
import io.quarkus.runtime.ShutdownEvent;
import io.vertx.core.Vertx;

public class BackChannelLogoutTokenCache {
private OidcTenantConfig oidcConfig;

private Map<String, CacheEntry> cacheMap = new ConcurrentHashMap<>();;
private AtomicInteger size = new AtomicInteger();
final MemoryCache<TokenVerificationResult> cache;

public BackChannelLogoutTokenCache(OidcTenantConfig oidcTenantConfig, Vertx vertx) {
this.oidcConfig = oidcTenantConfig;
init(vertx);
}

private void init(Vertx vertx) {
cacheMap = new ConcurrentHashMap<>();
if (oidcConfig.logout.backchannel.cleanUpTimerInterval.isPresent()) {
vertx.setPeriodic(oidcConfig.logout.backchannel.cleanUpTimerInterval.get().toMillis(), new Handler<Long>() {
@Override
public void handle(Long event) {
// Remove all the entries which have expired
removeInvalidEntries();
}
});
}
cache = new MemoryCache<TokenVerificationResult>(vertx, oidcTenantConfig.logout.backchannel.cleanUpTimerInterval,
oidcTenantConfig.logout.backchannel.tokenCacheTimeToLive, oidcTenantConfig.logout.backchannel.tokenCacheSize);
}

public void addTokenVerification(String token, TokenVerificationResult result) {
if (!prepareSpaceForNewCacheEntry()) {
clearCache();
}
cacheMap.put(token, new CacheEntry(result));
cache.add(token, result);
}

public TokenVerificationResult removeTokenVerification(String token) {
CacheEntry entry = removeCacheEntry(token);
return entry == null ? null : entry.result;
return cache.remove(token);
}

public boolean containsTokenVerification(String token) {
return cacheMap.containsKey(token);
}

public void clearCache() {
cacheMap.clear();
size.set(0);
}

private void removeInvalidEntries() {
long now = now();
for (Iterator<Map.Entry<String, CacheEntry>> it = cacheMap.entrySet().iterator(); it.hasNext();) {
Map.Entry<String, CacheEntry> next = it.next();
if (isEntryExpired(next.getValue(), now)) {
it.remove();
size.decrementAndGet();
}
}
}

private boolean prepareSpaceForNewCacheEntry() {
int currentSize;
do {
currentSize = size.get();
if (currentSize == oidcConfig.logout.backchannel.tokenCacheSize) {
return false;
}
} while (!size.compareAndSet(currentSize, currentSize + 1));
return true;
return cache.containsKey(token);
}

private CacheEntry removeCacheEntry(String token) {
CacheEntry entry = cacheMap.remove(token);
if (entry != null) {
size.decrementAndGet();
}
return entry;
}

private boolean isEntryExpired(CacheEntry entry, long now) {
return entry.createdTime + oidcConfig.logout.backchannel.tokenCacheTimeToLive.toMillis() < now;
}

private static long now() {
return System.currentTimeMillis();
}

private static class CacheEntry {
volatile TokenVerificationResult result;
long createdTime = System.currentTimeMillis();

public CacheEntry(TokenVerificationResult result) {
this.result = result;
}
void shutdown(@Observes ShutdownEvent event, Vertx vertx) {
cache.stopTimer(vertx);
}
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
package io.quarkus.oidc.runtime;

import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import jakarta.enterprise.event.Observes;

import io.quarkus.oidc.OidcRequestContext;
import io.quarkus.oidc.OidcTenantConfig;
import io.quarkus.oidc.TokenIntrospection;
import io.quarkus.oidc.TokenIntrospectionCache;
import io.quarkus.oidc.UserInfo;
import io.quarkus.oidc.UserInfoCache;
import io.quarkus.oidc.runtime.OidcConfig.TokenCache;
import io.quarkus.runtime.ShutdownEvent;
import io.smallrye.mutiny.Uni;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;

/**
Expand All @@ -31,43 +26,21 @@ public class DefaultTokenIntrospectionUserInfoCache implements TokenIntrospectio
private static final Uni<TokenIntrospection> NULL_INTROSPECTION_UNI = Uni.createFrom().nullItem();
private static final Uni<UserInfo> NULL_USERINFO_UNI = Uni.createFrom().nullItem();

private TokenCache cacheConfig;

private Map<String, CacheEntry> cacheMap;
private AtomicInteger size = new AtomicInteger();
final MemoryCache<CacheEntry> cache;

public DefaultTokenIntrospectionUserInfoCache(OidcConfig oidcConfig, Vertx vertx) {
this.cacheConfig = oidcConfig.tokenCache;
init(vertx);
}

private void init(Vertx vertx) {
if (cacheConfig.maxSize > 0) {
cacheMap = new ConcurrentHashMap<>();
if (cacheConfig.cleanUpTimerInterval.isPresent()) {
vertx.setPeriodic(cacheConfig.cleanUpTimerInterval.get().toMillis(), new Handler<Long>() {
@Override
public void handle(Long event) {
// Remove all the entries which have expired
removeInvalidEntries();
}
});
}
} else {
cacheMap = Collections.emptyMap();
}
cache = new MemoryCache<CacheEntry>(vertx, oidcConfig.tokenCache.cleanUpTimerInterval,
oidcConfig.tokenCache.timeToLive, oidcConfig.tokenCache.maxSize);
}

@Override
public Uni<Void> addIntrospection(String token, TokenIntrospection introspection, OidcTenantConfig oidcTenantConfig,
OidcRequestContext<Void> requestContext) {
if (cacheConfig.maxSize > 0) {
CacheEntry entry = findValidCacheEntry(token);
if (entry != null) {
entry.introspection = introspection;
} else if (prepareSpaceForNewCacheEntry()) {
cacheMap.put(token, new CacheEntry(introspection));
}
CacheEntry entry = cache.get(token);
if (entry != null) {
entry.introspection = introspection;
} else {
cache.add(token, new CacheEntry(introspection));
}

return CodeAuthenticationMechanism.VOID_UNI;
Expand All @@ -76,20 +49,18 @@ public Uni<Void> addIntrospection(String token, TokenIntrospection introspection
@Override
public Uni<TokenIntrospection> getIntrospection(String token, OidcTenantConfig oidcConfig,
OidcRequestContext<TokenIntrospection> requestContext) {
CacheEntry entry = findValidCacheEntry(token);
CacheEntry entry = cache.get(token);
return entry == null ? NULL_INTROSPECTION_UNI : Uni.createFrom().item(entry.introspection);
}

@Override
public Uni<Void> addUserInfo(String token, UserInfo userInfo, OidcTenantConfig oidcTenantConfig,
OidcRequestContext<Void> requestContext) {
if (cacheConfig.maxSize > 0) {
CacheEntry entry = findValidCacheEntry(token);
if (entry != null) {
entry.userInfo = userInfo;
} else if (prepareSpaceForNewCacheEntry()) {
cacheMap.put(token, new CacheEntry(userInfo));
}
CacheEntry entry = cache.get(token);
if (entry != null) {
entry.userInfo = userInfo;
} else {
cache.add(token, new CacheEntry(userInfo));
}

return CodeAuthenticationMechanism.VOID_UNI;
Expand All @@ -98,67 +69,13 @@ public Uni<Void> addUserInfo(String token, UserInfo userInfo, OidcTenantConfig o
@Override
public Uni<UserInfo> getUserInfo(String token, OidcTenantConfig oidcConfig,
OidcRequestContext<UserInfo> requestContext) {
CacheEntry entry = findValidCacheEntry(token);
CacheEntry entry = cache.get(token);
return entry == null ? NULL_USERINFO_UNI : Uni.createFrom().item(entry.userInfo);
}

public int getCacheSize() {
return cacheMap.size();
}

public void clearCache() {
cacheMap.clear();
size.set(0);
}

private void removeInvalidEntries() {
long now = now();
for (Iterator<Map.Entry<String, CacheEntry>> it = cacheMap.entrySet().iterator(); it.hasNext();) {
Map.Entry<String, CacheEntry> next = it.next();
if (isEntryExpired(next.getValue(), now)) {
it.remove();
size.decrementAndGet();
}
}
}

private boolean prepareSpaceForNewCacheEntry() {
int currentSize;
do {
currentSize = size.get();
if (currentSize == cacheConfig.maxSize) {
return false;
}
} while (!size.compareAndSet(currentSize, currentSize + 1));
return true;
}

private CacheEntry findValidCacheEntry(String token) {
CacheEntry entry = cacheMap.get(token);
if (entry != null) {
long now = now();
if (isEntryExpired(entry, now)) {
// Entry has expired, remote introspection will be required
entry = null;
cacheMap.remove(token);
size.decrementAndGet();
}
}
return entry;
}

private boolean isEntryExpired(CacheEntry entry, long now) {
return entry.createdTime + cacheConfig.timeToLive.toMillis() < now;
}

private static long now() {
return System.currentTimeMillis();
}

private static class CacheEntry {
volatile TokenIntrospection introspection;
volatile UserInfo userInfo;
long createdTime = System.currentTimeMillis();

public CacheEntry(TokenIntrospection introspection) {
this.introspection = introspection;
Expand All @@ -168,4 +85,17 @@ public CacheEntry(UserInfo userInfo) {
this.userInfo = userInfo;
}
}

public void clearCache() {
cache.clearCache();
}

public int getCacheSize() {
return cache.getCacheSize();
}

void shutdown(@Observes ShutdownEvent event, Vertx vertx) {
cache.stopTimer(vertx);
}

}
Loading

0 comments on commit d161990

Please sign in to comment.