diff --git a/authorization/msi-auth-token-provider-jar/src/main/java/com/microsoft/azure/msiAuthTokenProvider/MSICredentials.java b/authorization/msi-auth-token-provider-jar/src/main/java/com/microsoft/azure/msiAuthTokenProvider/MSICredentials.java index eb32fe913e0ea..0ce890e3a531b 100644 --- a/authorization/msi-auth-token-provider-jar/src/main/java/com/microsoft/azure/msiAuthTokenProvider/MSICredentials.java +++ b/authorization/msi-auth-token-provider-jar/src/main/java/com/microsoft/azure/msiAuthTokenProvider/MSICredentials.java @@ -15,6 +15,9 @@ import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; /** * Managed Service Identity token based credentials for use with a REST Service Client. @@ -24,6 +27,9 @@ public final class MSICredentials{ // private final List retrySlots = new ArrayList<>(); // + private final Lock lock = new ReentrantLock(); + + private final ConcurrentHashMap cache = new ConcurrentHashMap<>(); private final MSIConfigurationForVirtualMachine configForVM; private final MSIConfigurationForAppService configForAppService; private final HostType hostType; @@ -157,7 +163,7 @@ public void updateObjectId(String objectId) { public MSIToken getToken(String tokenAudience) throws IOException, AzureMSICredentialException{ switch (hostType) { case VIRTUAL_MACHINE: - return this.retrieveTokenFromIDMSWithRetry(tokenAudience == null ? this.configForVM.resource() : tokenAudience); + return this.getTokenForVirtualMachineFromIMDSEndpoint(tokenAudience == null ? this.configForVM.resource() : tokenAudience); case APP_SERVICE: return this.getTokenForAppService(tokenAudience); default: @@ -217,6 +223,51 @@ private MSIToken getTokenForAppService(String tokenAudience) throws IOException, } } + private MSIToken getTokenForVirtualMachineFromIMDSEndpoint(String tokenAudience) throws AzureMSICredentialException { + String tokenIdentifier = tokenAudience; + + String extraIdentifier = null; + if (this.configForVM.objectId() != null) + { + extraIdentifier = configForVM.objectId(); + } else if (this.configForVM.clientId() != null) { + extraIdentifier = configForVM.clientId(); + } else if (this.configForVM.identityId() != null) { + extraIdentifier = configForVM.identityId(); + } + + if (extraIdentifier != null) { + tokenIdentifier = tokenIdentifier + "#" + extraIdentifier; + } + + MSIToken token = cache.get(tokenIdentifier); + if (token != null && !token.isExpired()) { + return token; + } + + lock.lock(); + + try { + token = cache.get(tokenIdentifier); + if (token != null && !token.isExpired()) { + return token; + } + + try { + token = retrieveTokenFromIDMSWithRetry(tokenAudience); + if (token != null) { + cache.put(tokenIdentifier, token); + } + } catch (IOException exception) { + throw new AzureMSICredentialException(exception); + } + + return token; + } finally { + lock.unlock(); + } + } + private MSIToken retrieveTokenFromIDMSWithRetry(String tokenAudience) throws AzureMSICredentialException, IOException { StringBuilder payload = new StringBuilder(); final int imdsUpgradeTimeInMs = 70 * 1000;