Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove token refresh options and default to 5 minutes #13148

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
* A token cache that supports caching a token and refreshing it.
*/
public class SimpleTokenCache {
private static final Duration REFRESH_TIMEOUT = Duration.ofSeconds(30);
// The delay after a refresh to attempt another token refresh
private static final Duration REFRESH_DELAY = Duration.ofSeconds(30);
// the offset before token expiry to attempt proactive token refresh
private static final Duration REFRESH_OFFSET = Duration.ofMinutes(5);
private final AtomicReference<MonoProcessor<AccessToken>> wip;
private volatile AccessToken cache;
private volatile OffsetDateTime nextTokenRefresh = OffsetDateTime.now();
Expand All @@ -31,20 +34,10 @@ public class SimpleTokenCache {
* @param tokenSupplier a method to get a new token
*/
public SimpleTokenCache(Supplier<Mono<AccessToken>> tokenSupplier) {
this(tokenSupplier, new TokenRefreshOptions());
}

/**
* Creates an instance of RefreshableTokenCredential with default scheme "Bearer".
*
* @param tokenSupplier a method to get a new token
* @param tokenRefreshOptions the options to configure the token refresh behavior
*/
public SimpleTokenCache(Supplier<Mono<AccessToken>> tokenSupplier, TokenRefreshOptions tokenRefreshOptions) {
this.wip = new AtomicReference<>();
this.tokenSupplier = tokenSupplier;
this.shouldRefresh = accessToken -> OffsetDateTime.now().isAfter(accessToken.getExpiresAt()
.minus(tokenRefreshOptions.getOffset()));
this.shouldRefresh = accessToken -> OffsetDateTime.now()
.isAfter(accessToken.getExpiresAt().minus(REFRESH_OFFSET));
}

/**
Expand Down Expand Up @@ -97,11 +90,11 @@ public Mono<AccessToken> getToken() {
cache = accessToken;
monoProcessor.onNext(accessToken);
monoProcessor.onComplete();
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_TIMEOUT);
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_DELAY);
return Mono.just(accessToken);
} else if (signal.isOnError() && error != null) { // ERROR
logger.error(refreshLog(cache, now, "Failed to acquire a new access token"));
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_TIMEOUT);
nextTokenRefresh = OffsetDateTime.now().plus(REFRESH_DELAY);
return fallback.switchIfEmpty(Mono.error(error));
} else { // NO REFRESH
monoProcessor.onComplete();
Expand Down Expand Up @@ -138,7 +131,7 @@ private String refreshLog(AccessToken cache, OffsetDateTime now, String log) {
Duration tte = Duration.between(now, cache.getExpiresAt());
info.append(" at ").append(tte.abs().getSeconds()).append(" seconds ")
.append(tte.isNegative() ? "after" : "before").append(" expiry. ")
.append("Retry may be attempted after ").append(REFRESH_TIMEOUT.getSeconds()).append(" seconds.");
.append("Retry may be attempted after ").append(REFRESH_DELAY.getSeconds()).append(" seconds.");
if (!tte.isNegative()) {
info.append(" The token currently cached will be used.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,4 @@ public interface TokenCredential {
* @return a Publisher that emits a single access token
*/
Mono<AccessToken> getToken(TokenRequestContext request);

/**
* The options to configure the token refresh behavior.
*
* @return the current offset for token refresh
*/
default TokenRefreshOptions getTokenRefreshOptions() {
return new TokenRefreshOptions();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ public BearerTokenAuthenticationPolicy(TokenCredential credential, String... sco
assert scopes.length > 0;
this.credential = credential;
this.scopes = scopes;
this.cache = new SimpleTokenCache(() -> credential.getToken(new TokenRequestContext().addScopes(scopes)),
credential.getTokenRefreshOptions());
this.cache = new SimpleTokenCache(() -> credential.getToken(new TokenRequestContext().addScopes(scopes)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void testOnlyOneThreadRefreshesToken() throws Exception {
}

@Test
public void testMultipleThreadsWaitForTimeout() throws Exception {
public void testLongRunningWontOverflow() throws Exception {
AtomicLong refreshes = new AtomicLong(0);

// token expires on creation. Run this 100 times to simulate running the application a long time
Expand All @@ -68,134 +68,15 @@ public void testMultipleThreadsWaitForTimeout() throws Exception {
.flatMap(start -> cache.getToken()
.map(t -> Duration.between(start, OffsetDateTime.now()).toMillis())
.doOnNext(millis -> {
// System.out.format("Thread: %s\tDuration: %smillis%n",
// Thread.currentThread().getName(), Duration.between(start, OffsetDateTime.now()).toMillis());
})))
.doOnComplete(latch::countDown)
.subscribe();

latch.await();
Assertions.assertEquals(2, refreshes.get());
}

@Test
public void testProactiveRefreshBeforeExpiry() throws Exception {
AtomicInteger latency = new AtomicInteger(1);
SimpleTokenCache cache = new SimpleTokenCache(
() -> remoteGetTokenThatExpiresSoonAsync(1000 * latency.getAndIncrement(), 60 * 1000),
new TestTokenRefreshOptions(Duration.ofSeconds(28))); // refresh at second 32, just past REFRESH_TIMEOUT

CountDownLatch latch = new CountDownLatch(1);
AtomicLong maxMillis = new AtomicLong(0);

Flux.interval(Duration.ofSeconds(2))
.take(25) // 48 seconds after first token, making sure of a refresh
.flatMap(i -> {
OffsetDateTime start = OffsetDateTime.now();
return cache.getToken()
.map(t -> Duration.between(start, OffsetDateTime.now()).toMillis())
.doOnNext(millis -> {
if (millis > maxMillis.get()) {
maxMillis.set(millis);
}
});
}).doOnComplete(latch::countDown)
.subscribe();

latch.await();
Assertions.assertTrue(maxMillis.get() >= 2000);
Assertions.assertTrue(maxMillis.get() < 3000); // Big enough for any latency, small enough to make sure no get token is called twice
}

@Test
public void testRefreshAfterExpiry() throws Exception {
AtomicInteger latency = new AtomicInteger(1);
SimpleTokenCache cache = new SimpleTokenCache(
() -> remoteGetTokenThatExpiresSoonAsync(1000 * latency.getAndIncrement(), 15 * 1000),
new TestTokenRefreshOptions(Duration.ZERO)); // refresh at second 30 because of REFRESH_TIMEOUT

CountDownLatch latch = new CountDownLatch(1);
AtomicLong maxMillis = new AtomicLong(0);

Flux.interval(Duration.ofSeconds(2))
.take(10) // 38 seconds after first token, making sure of a refresh
.flatMap(i -> {
OffsetDateTime start = OffsetDateTime.now();
return cache.getToken()
.map(t -> Duration.between(start, OffsetDateTime.now()).toMillis())
.doOnNext(millis -> {
if (millis > maxMillis.get()) {
maxMillis.set(millis);
}
});
}).doOnComplete(latch::countDown)
.subscribe();

latch.await();
Assertions.assertTrue(maxMillis.get() >= 15000);
}

@Test
public void testProactiveRefreshError() throws Exception {
AtomicInteger latency = new AtomicInteger(1);
AtomicInteger tryCount = new AtomicInteger(0);
SimpleTokenCache cache = new SimpleTokenCache(
() -> remoteGetTokenWithPersistentError(1000 * latency.getAndIncrement(), 60 * 1000, 2, tryCount),
new TestTokenRefreshOptions(Duration.ofSeconds(28))); // refresh at second 32, just past REFRESH_TIMEOUT

CountDownLatch latch = new CountDownLatch(1);
AtomicLong maxMillis = new AtomicLong(0);
AtomicInteger errorCount = new AtomicInteger(0);

Flux.interval(Duration.ofSeconds(2))
.take(32) // 64 seconds after first token, making sure of a refresh
.flatMap(i -> {
OffsetDateTime start = OffsetDateTime.now();
return cache.getToken()
.map(t -> Duration.between(start, OffsetDateTime.now()).toMillis())
.doOnNext(millis -> {
if (millis > maxMillis.get()) {
maxMillis.set(millis);
}
})
.doOnError(t -> errorCount.incrementAndGet());
}).doOnTerminate(latch::countDown)
.subscribe();

latch.await();
Assertions.assertTrue(maxMillis.get() >= 1000);
Assertions.assertTrue(maxMillis.get() < 2000); // Big enough for any latency, small enough to make sure no get token is called twice
Assertions.assertEquals(1, errorCount.get()); // Only the error after expiresAt will be propagated
}

@Test
public void testProactiveRefreshErrorTimeout() throws Exception {
AtomicInteger latency = new AtomicInteger(1);
AtomicInteger tryCount = new AtomicInteger(0);
SimpleTokenCache cache = new SimpleTokenCache(
() -> remoteGetTokenWithTemporaryError(1000 * latency.getAndIncrement(), 60 * 1000, 2, tryCount),
new TestTokenRefreshOptions(Duration.ofSeconds(28))); // refresh at second 32, just past REFRESH_TIMEOUT

CountDownLatch latch = new CountDownLatch(1);
AtomicLong maxMillis = new AtomicLong(0);
AtomicInteger errorCount = new AtomicInteger(0);

Flux.interval(Duration.ofSeconds(2))
.take(32) // 64 seconds after first token, making sure of a refresh
.flatMap(i -> {
OffsetDateTime start = OffsetDateTime.now();
return cache.getToken()
.map(t -> Duration.between(start, OffsetDateTime.now()).toMillis())
.doOnNext(millis -> {
if (millis > maxMillis.get()) {
maxMillis.set(millis);
}
})
.doOnError(t -> errorCount.incrementAndGet());
}).doOnTerminate(latch::countDown)
.subscribe();

latch.await();
Assertions.assertTrue(maxMillis.get() >= 3000);
Assertions.assertEquals(0, errorCount.get()); // Only the error after expiresAt will be propagated
// At most 10 requests should do actual token acquisition, use 11 for safe
Assertions.assertTrue(refreshes.get() <= 11);
}

private Mono<AccessToken> remoteGetTokenAsync(long delayInMillis) {
Expand All @@ -214,24 +95,6 @@ private Mono<AccessToken> incrementalRemoteGetTokenAsync(AtomicInteger latency)
.map(l -> new Token(Integer.toString(RANDOM.nextInt(100))));
}

private Mono<AccessToken> remoteGetTokenWithTemporaryError(long delayInMillis, long validityInMillis, int errorAt, AtomicInteger tryCount) {
if (tryCount.incrementAndGet() == errorAt) {
return Mono.error(new RuntimeException("Expected error"));
} else {
return Mono.delay(Duration.ofMillis(delayInMillis))
.map(l -> new Token(Integer.toString(RANDOM.nextInt(100)), validityInMillis));
}
}

private Mono<AccessToken> remoteGetTokenWithPersistentError(long delayInMillis, long validityInMillis, int errorAfter, AtomicInteger tryCount) {
if (tryCount.incrementAndGet() >= errorAfter) {
return Mono.error(new RuntimeException("Expected error"));
} else {
return Mono.delay(Duration.ofMillis(delayInMillis))
.map(l -> new Token(Integer.toString(RANDOM.nextInt(100)), validityInMillis));
}
}

private static class Token extends AccessToken {
private String token;
private OffsetDateTime expiry;
Expand Down Expand Up @@ -261,17 +124,4 @@ public boolean isExpired() {
return OffsetDateTime.now().isAfter(expiry);
}
}

private static final class TestTokenRefreshOptions extends TokenRefreshOptions {
private final Duration offset;

private TestTokenRefreshOptions(Duration offset) {
this.offset = offset;
}

@Override
public Duration getOffset() {
return offset;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
import com.azure.perf.test.core.PerfStressOptions;
import reactor.core.publisher.Mono;

import java.time.Duration;

public class WriteCache extends ServiceTest<PerfStressOptions> {
private final SharedTokenCacheCredential credential;

public WriteCache(PerfStressOptions options) {
super(options);
credential = new SharedTokenCacheCredentialBuilder()
.clientId(CLI_CLIENT_ID)
.tokenRefreshOffset(Duration.ofMinutes(60))
.build();
}

Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-core</artifactId>
<version>1.7.0-beta.1</version> <!-- {x-version-update;beta_com.azure:azure-core;dependency} -->
<version>1.7.0-beta.2</version> <!-- {x-version-update;com.azure:azure-core;current} -->
</dependency>
<dependency>
<groupId>com.microsoft.azure</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import com.azure.core.annotation.Immutable;
import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRefreshOptions;
import com.azure.core.credential.TokenRequestContext;
import com.azure.core.util.logging.ClientLogger;
import com.azure.identity.implementation.IdentityClient;
Expand All @@ -28,7 +27,6 @@ public class AuthorizationCodeCredential implements TokenCredential {
private final String authCode;
private final URI redirectUri;
private final IdentityClient identityClient;
private final IdentityClientOptions identityClientOptions;
private final AtomicReference<MsalAuthenticationAccount> cachedToken;
private final ClientLogger logger = new ClientLogger(AuthorizationCodeCredential.class);

Expand All @@ -48,7 +46,6 @@ public class AuthorizationCodeCredential implements TokenCredential {
.clientId(clientId)
.identityClientOptions(identityClientOptions)
.build();
this.identityClientOptions = identityClientOptions;
this.cachedToken = new AtomicReference<>();
this.authCode = authCode;
this.redirectUri = redirectUri;
Expand All @@ -74,9 +71,4 @@ public Mono<AccessToken> getToken(TokenRequestContext request) {
.doOnNext(token -> LoggingUtil.logTokenSuccess(logger, request))
.doOnError(error -> LoggingUtil.logTokenError(logger, request, error));
}

@Override
public TokenRefreshOptions getTokenRefreshOptions() {
return identityClientOptions.getTokenRefreshOptions();
}
}
Loading