From 5b657906c5b221faca0fcca3051e389f2798d24a Mon Sep 17 00:00:00 2001 From: machavan Date: Thu, 12 Dec 2024 12:45:24 +0530 Subject: [PATCH 01/12] Introduced timeouts for MSAL calls. --- .../sqlserver/jdbc/SQLServerConnection.java | 19 ++++++------ .../sqlserver/jdbc/SQLServerMSAL4JUtils.java | 31 ++++++++++++------- .../jdbc/SQLServerSecurityUtility.java | 6 ++-- 3 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 6874deab4..578622cf8 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -6110,10 +6110,11 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw } while (true) { + int millisecondsRemaining = timerRemaining(timerExpire); if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString())) { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - authenticationString); + authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6141,12 +6142,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (aadPrincipalID != null && !aadPrincipalID.isEmpty() && aadPrincipalSecret != null && !aadPrincipalSecret.isEmpty()) { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, aadPrincipalID, - aadPrincipalSecret, authenticationString); + aadPrincipalSecret, authenticationString, millisecondsRemaining); } else { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()), activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - authenticationString); + authenticationString, millisecondsRemaining); } // Break out of the retry loop in successful case. @@ -6159,7 +6160,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()), servicePrincipalCertificate, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString); + servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6194,7 +6195,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw throw new SQLServerException(form.format(msgArgs), null); } - int millisecondsRemaining = timerRemaining(timerExpire); + millisecondsRemaining = timerRemaining(timerExpire); if (ActiveDirectoryAuthentication.GET_ACCESS_TOKEN_TRANSIENT_ERROR != errorCategory || timerHasExpired(timerExpire) || (fedauthSleepInterval >= millisecondsRemaining)) { @@ -6240,7 +6241,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw Object[] msgArgs = {SQLServerDriver.AUTH_DLL_NAME, authenticationString}; throw new SQLServerException(form.format(msgArgs), null, 0, null); } - fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString); + fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString, millisecondsRemaining); } // Break out of the retry loop in successful case. break; @@ -6248,7 +6249,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTERACTIVE.toString())) { // interactive flow fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenInteractive(fedAuthInfo, user, - authenticationString); + authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6258,12 +6259,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn, - managedIdentityClientId); + managedIdentityClientId, millisecondsRemaining); break; } fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn, - activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString())); + activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()), millisecondsRemaining); break; } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 689347db3..c7f26e50c 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -26,6 +26,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; @@ -80,7 +81,7 @@ private SQLServerMSAL4JUtils() { private static final Lock lock = new ReentrantLock(); static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { @@ -116,7 +117,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user, password.toCharArray()) .build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -132,6 +133,8 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { lock.unlock(); executorService.shutdown(); @@ -139,7 +142,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str } static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID, - String aadPrincipalSecret, String authenticationString) throws SQLServerException { + String aadPrincipalSecret, String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { @@ -181,7 +184,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth final CompletableFuture future = clientApplication .acquireToken(ClientCredentialParameters.builder(scopes).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -197,6 +200,8 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); + } catch (TimeoutException e) { + throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { lock.unlock(); executorService.shutdown(); @@ -205,7 +210,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID, String certFile, String certPassword, String certKey, String certKeyPassword, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { @@ -297,7 +302,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI final CompletableFuture future = clientApplication .acquireToken(ClientCredentialParameters.builder(scopes).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -325,7 +330,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI } static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); /* @@ -352,7 +357,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut .acquireToken(IntegratedWindowsAuthenticationParameters .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -368,6 +373,8 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut throw new SQLServerException(e.getMessage(), e); } catch (IOException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { lock.unlock(); executorService.shutdown(); @@ -375,7 +382,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut } static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo, String user, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINER)) { @@ -432,7 +439,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu } if (null != future) { - authenticationResult = future.get(); + authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); } else { // acquire token interactively with system browser if (logger.isLoggable(Level.FINEST)) { @@ -444,7 +451,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu .loginHint(user).scopes(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT)).build(); future = pca.acquireToken(parameters); - authenticationResult = future.get(); + authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); } if (logger.isLoggable(Level.FINER)) { @@ -461,6 +468,8 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | URISyntaxException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { lock.unlock(); executorService.shutdown(); diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java index 70c50ca28..a72fb8422 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java @@ -8,6 +8,8 @@ import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.text.MessageFormat; +import java.time.Duration; +import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.HashMap; import java.util.Optional; @@ -408,7 +410,7 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, * @throws SQLServerException */ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, - String managedIdentityClientId) throws SQLServerException { + String managedIdentityClientId, int millisecondsRemaining) throws SQLServerException { String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS); String[] additionallyAllowedTenants = getAdditonallyAllowedTenants(); @@ -463,7 +465,7 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, SqlAuthenticationToken sqlFedAuthToken = null; - Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional(); + Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional(Duration.of(millisecondsRemaining, ChronoUnit.MILLIS)); if (!accessTokenOptional.isPresent()) { throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"), From db6e39d09bc021cb57ba1dc4ff89c0d6bb7be4f3 Mon Sep 17 00:00:00 2001 From: machavan Date: Thu, 12 Dec 2024 12:53:29 +0530 Subject: [PATCH 02/12] Fixed indentation issues. --- .../com/microsoft/sqlserver/jdbc/SQLServerConnection.java | 2 +- .../com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 578622cf8..4739ccc9b 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -6110,7 +6110,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw } while (true) { - int millisecondsRemaining = timerRemaining(timerExpire); + int millisecondsRemaining = timerRemaining(timerExpire); if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString())) { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index c7f26e50c..2a0d62ada 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -134,7 +134,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); } catch (TimeoutException e) { - throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); + throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { lock.unlock(); executorService.shutdown(); @@ -201,7 +201,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); } catch (TimeoutException e) { - throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); + throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { lock.unlock(); executorService.shutdown(); @@ -469,7 +469,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu } catch (MalformedURLException | URISyntaxException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); } catch (TimeoutException e) { - throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); + throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { lock.unlock(); executorService.shutdown(); From 51f1100d22edd9298c159497016c03820c4281f2 Mon Sep 17 00:00:00 2001 From: machavan Date: Mon, 16 Dec 2024 17:52:08 +0530 Subject: [PATCH 03/12] Added unit tests --- .../jdbc/SQLServerConnectionTest.java | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java index 916aa419f..11031eae2 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java @@ -42,6 +42,7 @@ import com.microsoft.aad.msal4j.TokenCache; import com.microsoft.aad.msal4j.TokenCacheAccessContext; +import com.microsoft.sqlserver.jdbc.SQLServerConnection.SqlFedAuthInfo; import com.microsoft.sqlserver.testframework.AbstractSQLGenerator; import com.microsoft.sqlserver.testframework.AbstractTest; import com.microsoft.sqlserver.testframework.Constants; @@ -50,6 +51,7 @@ @RunWith(JUnitPlatform.class) public class SQLServerConnectionTest extends AbstractTest { + // If no retry is done, the function should at least exit in 5 seconds static int threshHoldForNoRetryInMilliseconds = 5000; static int loginTimeOutInSeconds = 10; @@ -1321,4 +1323,34 @@ public void testServerNameField() throws SQLException { assertTrue(e.getMessage().matches(TestUtils.formatErrorMsg("R_errorServerName"))); } } + + + @Test + public void testGetSqlFedAuthTokenFailure() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), 0); + fail("Expected the test to throw SQLServerException"); + } catch (SQLServerException e) { + //test pass + } + } + + + @Test + public void testGetSqlFedAuthTokenPrincipalFailure() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_SERVICE_PRINCIPAL.toString(), 0); + fail("Expected the test to throw SQLServerException"); + } catch (SQLServerException e) { + } + } + } From 6f28f59ec8cafa0e0ccb680fa9432606bef8b086 Mon Sep 17 00:00:00 2001 From: machavan Date: Tue, 17 Dec 2024 10:15:41 +0530 Subject: [PATCH 04/12] Added a max wait duration of 20 seconds to MSAL calls - Added more tests - Improved test to check for specific error message --- .../sqlserver/jdbc/SQLServerMSAL4JUtils.java | 14 ++++---- .../jdbc/SQLServerSecurityUtility.java | 4 ++- .../jdbc/SQLServerConnectionTest.java | 33 ++++++++++++++----- 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 2a0d62ada..2b053cb53 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -65,7 +65,7 @@ class SQLServerMSAL4JUtils { static final String REDIRECTURI = "http://localhost"; static final String SLASH_DEFAULT = "/.default"; static final String ACCESS_TOKEN_EXPIRE = "access token expires: "; - + static final long TOKEN_WAIT_DURATION_MS = 20000; private static final TokenCacheMap TOKEN_CACHE_MAP = new TokenCacheMap(); private final static String LOGCONTEXT = "MSAL version " @@ -117,7 +117,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user, password.toCharArray()) .build()); - final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -184,7 +184,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth final CompletableFuture future = clientApplication .acquireToken(ClientCredentialParameters.builder(scopes).build()); - final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -302,7 +302,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI final CompletableFuture future = clientApplication .acquireToken(ClientCredentialParameters.builder(scopes).build()); - final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -357,7 +357,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut .acquireToken(IntegratedWindowsAuthenticationParameters .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user).build()); - final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -439,7 +439,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu } if (null != future) { - authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); + authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); } else { // acquire token interactively with system browser if (logger.isLoggable(Level.FINEST)) { @@ -451,7 +451,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu .loginHint(user).scopes(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT)).build(); future = pca.acquireToken(parameters); - authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS); + authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); } if (logger.isLoggable(Level.FINER)) { diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java index a72fb8422..4fd459a34 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java @@ -58,6 +58,8 @@ class SQLServerSecurityUtility { private static final Lock CREDENTIAL_LOCK = new ReentrantLock(); + private static final int TOKEN_WAIT_DURATION_MS = 0; + private SQLServerSecurityUtility() { throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported")); } @@ -465,7 +467,7 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, SqlAuthenticationToken sqlFedAuthToken = null; - Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional(Duration.of(millisecondsRemaining, ChronoUnit.MILLIS)); + Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional(Duration.of(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), ChronoUnit.MILLIS)); if (!accessTokenOptional.isPresent()) { throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"), diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java index 11031eae2..d2f830e90 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java @@ -1332,25 +1332,42 @@ public void testGetSqlFedAuthTokenFailure() throws SQLException { fedAuthInfo.spn = "https://database.windows.net/"; fedAuthInfo.stsurl = "https://login.windows.net/xxx"; SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", - "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), 0); - fail("Expected the test to throw SQLServerException"); + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), 10); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); } catch (SQLServerException e) { //test pass + assertEquals(e.getMessage(), SQLServerException.getErrString("R_connectionTimedOut"), "Expected Timeout Exception was not thrown"); } } + @Test + public void testGetSqlFedAuthTokenFailureNoWaiting() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), 0); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + } catch (SQLServerException e) { + //test pass + assertEquals(e.getMessage(), SQLServerException.getErrString("R_connectionTimedOut"), "Expected Timeout Exception was not thrown"); + } + } @Test - public void testGetSqlFedAuthTokenPrincipalFailure() throws SQLException { + public void testGetSqlFedAuthTokenFailureNagativeWaiting() throws SQLException { try (Connection conn = getConnection()){ SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); fedAuthInfo.spn = "https://database.windows.net/"; fedAuthInfo.stsurl = "https://login.windows.net/xxx"; - SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, "xxx", - "xxx",SqlAuthentication.ACTIVE_DIRECTORY_SERVICE_PRINCIPAL.toString(), 0); - fail("Expected the test to throw SQLServerException"); + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), -1); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); } catch (SQLServerException e) { - } + //test pass + assertEquals(e.getMessage(), SQLServerException.getErrString("R_connectionTimedOut"), "Expected Timeout Exception was not thrown"); + } } - + } From 7837e0737a3200f7fab8d39d94b22ecf8a6474a4 Mon Sep 17 00:00:00 2001 From: machavan Date: Tue, 17 Dec 2024 10:18:47 +0530 Subject: [PATCH 05/12] Added Timeout Exception catch clause for one of the auth methods --- .../java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 2b053cb53..ef3e8afaf 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -320,6 +320,8 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI // this includes all certificate exceptions throw new SQLServerException(SQLServerException.getErrString("R_readCertError") + e.getMessage(), null, 0, null); + } catch (TimeoutException e) { + throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } catch (Exception e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); From e64f411f3dc1d414fce872d9ea1cd7b72a154e1a Mon Sep 17 00:00:00 2001 From: machavan Date: Tue, 17 Dec 2024 23:58:27 +0530 Subject: [PATCH 06/12] Replaced lock with tryLock. - Replaced lock with tryLock to avoid potential long waiting for other threads while one thread is taking long to complete. --- .../sqlserver/jdbc/SQLServerMSAL4JUtils.java | 55 +++++++++++++------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index ef3e8afaf..7b168742b 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -66,6 +66,7 @@ class SQLServerMSAL4JUtils { static final String SLASH_DEFAULT = "/.default"; static final String ACCESS_TOKEN_EXPIRE = "access token expires: "; static final long TOKEN_WAIT_DURATION_MS = 20000; + static final long TOKEN_LOCK_WAIT_DURATION_MS = 5000; private static final TokenCacheMap TOKEN_CACHE_MAP = new TokenCacheMap(); private final static String LOGCONTEXT = "MSAL version " @@ -88,9 +89,11 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str logger.finest(LOGCONTEXT + authenticationString + ": get FedAuth token for user: " + user); } - lock.lock(); - + boolean lockAcquired = false; try { + //Just try to acquire the lock and if can't then proceed to attempt to get the token + lockAcquired = lock.tryLock(Math.min(millisecondsRemaining, TOKEN_LOCK_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, user, password}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(user, hashedSecret); @@ -136,7 +139,9 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - lock.unlock(); + if (lockAcquired) { + lock.unlock(); + } executorService.shutdown(); } } @@ -154,10 +159,12 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth : fedAuthInfo.spn + defaultScopeSuffix; Set scopes = new HashSet<>(); scopes.add(scope); - - lock.lock(); - + + boolean lockAcquired = false; try { + //Just try to acquire the lock and if can't then proceed to attempt to get the token + lockAcquired = lock.tryLock(Math.min(millisecondsRemaining, TOKEN_LOCK_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret( new String[] {fedAuthInfo.stsurl, aadPrincipalID, aadPrincipalSecret}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(aadPrincipalID, @@ -203,7 +210,9 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - lock.unlock(); + if (lockAcquired) { + lock.unlock(); + } executorService.shutdown(); } } @@ -224,9 +233,11 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI Set scopes = new HashSet<>(); scopes.add(scope); - lock.lock(); - + boolean lockAcquired = false; try { + //Just try to acquire the lock and if can't then proceed to attempt to get the token + lockAcquired = lock.tryLock(Math.min(millisecondsRemaining, TOKEN_LOCK_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, aadPrincipalID, certFile, certPassword, certKey, certKeyPassword}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(aadPrincipalID, @@ -326,7 +337,9 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI throw getCorrectedException(e, aadPrincipalID, authenticationString); } finally { - lock.unlock(); + if (lockAcquired) { + lock.unlock(); + } executorService.shutdown(); } } @@ -347,10 +360,12 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut + "realm name:" + kerberosPrincipal.getRealm()); } - lock.lock(); - + boolean lockAcquired = false; try { - final PublicClientApplication pca = PublicClientApplication + //Just try to acquire the lock and if can't then proceed to attempt to get the token + lockAcquired = lock.tryLock(Math.min(millisecondsRemaining, TOKEN_LOCK_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + + final PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) .setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance()) .authority(fedAuthInfo.stsurl).build(); @@ -378,7 +393,9 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - lock.unlock(); + if (lockAcquired) { + lock.unlock(); + } executorService.shutdown(); } } @@ -391,9 +408,11 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu logger.finer(LOGCONTEXT + authenticationString + ": get FedAuth token interactive for user: " + user); } - lock.lock(); - + boolean lockAcquired = false; try { + //Just try to acquire the lock and if can't then proceed to attempt to get the token + lockAcquired = lock.tryLock(Math.min(millisecondsRemaining, TOKEN_LOCK_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) .setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance()) @@ -473,7 +492,9 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - lock.unlock(); + if (lockAcquired) { + lock.unlock(); + } executorService.shutdown(); } } From 90518473d409ece22764ccc2c452e476b8731ad4 Mon Sep 17 00:00:00 2001 From: machavan Date: Thu, 19 Dec 2024 10:30:27 +0530 Subject: [PATCH 07/12] Replaced lock with semaphore for beter readablility. - Added detailed comment for the usage of semaphore. --- .../sqlserver/jdbc/SQLServerMSAL4JUtils.java | 90 +++++++++++++------ 1 file changed, 63 insertions(+), 27 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 7b168742b..1ecf51bd1 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -25,6 +25,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.locks.Lock; @@ -66,7 +67,7 @@ class SQLServerMSAL4JUtils { static final String SLASH_DEFAULT = "/.default"; static final String ACCESS_TOKEN_EXPIRE = "access token expires: "; static final long TOKEN_WAIT_DURATION_MS = 20000; - static final long TOKEN_LOCK_WAIT_DURATION_MS = 5000; + static final long TOKEN_SEM_WAIT_DURATION_MS = 5000; private static final TokenCacheMap TOKEN_CACHE_MAP = new TokenCacheMap(); private final static String LOGCONTEXT = "MSAL version " @@ -79,7 +80,7 @@ private SQLServerMSAL4JUtils() { throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported")); } - private static final Lock lock = new ReentrantLock(); + private static final Semaphore sem = new Semaphore(1); static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password, String authenticationString, int millisecondsRemaining) throws SQLServerException { @@ -89,10 +90,17 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str logger.finest(LOGCONTEXT + authenticationString + ": get FedAuth token for user: " + user); } - boolean lockAcquired = false; + boolean semAcquired = false; try { - //Just try to acquire the lock and if can't then proceed to attempt to get the token - lockAcquired = lock.tryLock(Math.min(millisecondsRemaining, TOKEN_LOCK_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + semAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, user, password}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(user, @@ -139,8 +147,8 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - if (lockAcquired) { - lock.unlock(); + if (semAcquired) { + sem.release(); } executorService.shutdown(); } @@ -160,10 +168,17 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth Set scopes = new HashSet<>(); scopes.add(scope); - boolean lockAcquired = false; + boolean semAcquired = false; try { - //Just try to acquire the lock and if can't then proceed to attempt to get the token - lockAcquired = lock.tryLock(Math.min(millisecondsRemaining, TOKEN_LOCK_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + semAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); String hashedSecret = getHashedSecret( new String[] {fedAuthInfo.stsurl, aadPrincipalID, aadPrincipalSecret}); @@ -210,8 +225,8 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - if (lockAcquired) { - lock.unlock(); + if (semAcquired) { + sem.release(); } executorService.shutdown(); } @@ -233,10 +248,17 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI Set scopes = new HashSet<>(); scopes.add(scope); - boolean lockAcquired = false; + boolean semAcquired = false; try { - //Just try to acquire the lock and if can't then proceed to attempt to get the token - lockAcquired = lock.tryLock(Math.min(millisecondsRemaining, TOKEN_LOCK_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + semAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, aadPrincipalID, certFile, certPassword, certKey, certKeyPassword}); @@ -337,8 +359,8 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI throw getCorrectedException(e, aadPrincipalID, authenticationString); } finally { - if (lockAcquired) { - lock.unlock(); + if (semAcquired) { + sem.release(); } executorService.shutdown(); } @@ -360,10 +382,17 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut + "realm name:" + kerberosPrincipal.getRealm()); } - boolean lockAcquired = false; + boolean semAcquired = false; try { - //Just try to acquire the lock and if can't then proceed to attempt to get the token - lockAcquired = lock.tryLock(Math.min(millisecondsRemaining, TOKEN_LOCK_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + semAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); final PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) @@ -393,8 +422,8 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - if (lockAcquired) { - lock.unlock(); + if (semAcquired) { + sem.release(); } executorService.shutdown(); } @@ -408,10 +437,17 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu logger.finer(LOGCONTEXT + authenticationString + ": get FedAuth token interactive for user: " + user); } - boolean lockAcquired = false; + boolean semAcquired = false; try { - //Just try to acquire the lock and if can't then proceed to attempt to get the token - lockAcquired = lock.tryLock(Math.min(millisecondsRemaining, TOKEN_LOCK_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + semAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) @@ -492,8 +528,8 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - if (lockAcquired) { - lock.unlock(); + if (semAcquired) { + sem.release(); } executorService.shutdown(); } From da74b9fc5e36dc510089ce6b7c7e38af530da7e4 Mon Sep 17 00:00:00 2001 From: machavan Date: Tue, 7 Jan 2025 13:57:29 +0530 Subject: [PATCH 08/12] Renamed semAcquired to isSemAcquired --- .../sqlserver/jdbc/SQLServerMSAL4JUtils.java | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 1ecf51bd1..7660b4ccf 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -90,7 +90,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str logger.finest(LOGCONTEXT + authenticationString + ": get FedAuth token for user: " + user); } - boolean semAcquired = false; + boolean isSemAcquired = false; try { // //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. @@ -100,7 +100,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints //to get their tokens at the same time, stressing the auth endpoint. // - semAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, user, password}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(user, @@ -147,7 +147,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - if (semAcquired) { + if (isSemAcquired) { sem.release(); } executorService.shutdown(); @@ -168,7 +168,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth Set scopes = new HashSet<>(); scopes.add(scope); - boolean semAcquired = false; + boolean isSemAcquired = false; try { // //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. @@ -178,7 +178,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints //to get their tokens at the same time, stressing the auth endpoint. // - semAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); String hashedSecret = getHashedSecret( new String[] {fedAuthInfo.stsurl, aadPrincipalID, aadPrincipalSecret}); @@ -225,7 +225,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - if (semAcquired) { + if (isSemAcquired) { sem.release(); } executorService.shutdown(); @@ -248,7 +248,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI Set scopes = new HashSet<>(); scopes.add(scope); - boolean semAcquired = false; + boolean isSemAcquired = false; try { // //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. @@ -258,7 +258,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints //to get their tokens at the same time, stressing the auth endpoint. // - semAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, aadPrincipalID, certFile, certPassword, certKey, certKeyPassword}); @@ -359,7 +359,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI throw getCorrectedException(e, aadPrincipalID, authenticationString); } finally { - if (semAcquired) { + if (isSemAcquired) { sem.release(); } executorService.shutdown(); @@ -382,7 +382,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut + "realm name:" + kerberosPrincipal.getRealm()); } - boolean semAcquired = false; + boolean isSemAcquired = false; try { // //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. @@ -392,7 +392,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints //to get their tokens at the same time, stressing the auth endpoint. // - semAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); final PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) @@ -422,7 +422,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - if (semAcquired) { + if (isSemAcquired) { sem.release(); } executorService.shutdown(); @@ -437,7 +437,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu logger.finer(LOGCONTEXT + authenticationString + ": get FedAuth token interactive for user: " + user); } - boolean semAcquired = false; + boolean isSemAcquired = false; try { // //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. @@ -447,7 +447,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints //to get their tokens at the same time, stressing the auth endpoint. // - semAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) @@ -528,7 +528,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu } catch (TimeoutException e) { throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); } finally { - if (semAcquired) { + if (isSemAcquired) { sem.release(); } executorService.shutdown(); From 6c4dca3bbf36ebc647e3afb6ecd25e59133aabd2 Mon Sep 17 00:00:00 2001 From: machavan Date: Wed, 8 Jan 2025 09:38:11 +0530 Subject: [PATCH 09/12] Fixed indentation for an existing code line --- .../java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 7660b4ccf..f4c040d0a 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -394,7 +394,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut // isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); - final PublicClientApplication pca = PublicClientApplication + final PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) .setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance()) .authority(fedAuthInfo.stsurl).build(); From c77a7e022ced855cdc7a5fc472c4f958bf4860fe Mon Sep 17 00:00:00 2001 From: machavan Date: Wed, 8 Jan 2025 18:16:30 +0530 Subject: [PATCH 10/12] Change to use Mono::timeout method --- .../com/microsoft/sqlserver/jdbc/SQLServerConnection.java | 4 ++-- .../microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index a5f299320..5f076f9af 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -6126,12 +6126,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn, - managedIdentityClientId); + managedIdentityClientId, millisecondsRemaining); break; } fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn, - activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString())); + activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()), millisecondsRemaining); // Break out of the retry loop in successful case. break; diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java index 4fd459a34..666f541ed 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java @@ -344,7 +344,7 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer * @throws SQLServerException */ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, - String managedIdentityClientId) throws SQLServerException { + String managedIdentityClientId, long millisecondsRemaining) throws SQLServerException { if (logger.isLoggable(java.util.logging.Level.FINEST)) { logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId); @@ -383,7 +383,7 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, SqlAuthenticationToken sqlFedAuthToken = null; - Optional accessTokenOptional = mic.getToken(tokenRequestContext).blockOptional(); + Optional accessTokenOptional = mic.getToken(tokenRequestContext).timeout(Duration.of(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), ChronoUnit.MILLIS)).blockOptional(); if (!accessTokenOptional.isPresent()) { throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"), @@ -467,7 +467,7 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, SqlAuthenticationToken sqlFedAuthToken = null; - Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional(Duration.of(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), ChronoUnit.MILLIS)); + Optional accessTokenOptional = dac.getToken(tokenRequestContext).timeout(Duration.of(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), ChronoUnit.MILLIS)).blockOptional(); if (!accessTokenOptional.isPresent()) { throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"), From 87a67ab7d47ca46a947afd4359a56608d89ee9e8 Mon Sep 17 00:00:00 2001 From: machavan Date: Wed, 8 Jan 2025 20:20:00 +0530 Subject: [PATCH 11/12] Updated TOKEN_WAIT_DURATION_MS to correct value. --- .../com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java index 666f541ed..d4e49ccde 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java @@ -58,7 +58,7 @@ class SQLServerSecurityUtility { private static final Lock CREDENTIAL_LOCK = new ReentrantLock(); - private static final int TOKEN_WAIT_DURATION_MS = 0; + private static final int TOKEN_WAIT_DURATION_MS = 20000; private SQLServerSecurityUtility() { throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported")); From 65ee65dbd125e01fcee2305a02614dc33574d972 Mon Sep 17 00:00:00 2001 From: machavan Date: Mon, 13 Jan 2025 22:58:46 +0530 Subject: [PATCH 12/12] Improved error messages as requested by Walmart --- .../microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java | 10 +++++----- .../sqlserver/jdbc/SQLServerConnectionTest.java | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index f4c040d0a..8850e74fc 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -145,7 +145,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); } catch (TimeoutException e) { - throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { if (isSemAcquired) { sem.release(); @@ -223,7 +223,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); } catch (TimeoutException e) { - throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), aadPrincipalID, authenticationString); } finally { if (isSemAcquired) { sem.release(); @@ -354,7 +354,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI throw new SQLServerException(SQLServerException.getErrString("R_readCertError") + e.getMessage(), null, 0, null); } catch (TimeoutException e) { - throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), aadPrincipalID, authenticationString); } catch (Exception e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); @@ -420,7 +420,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut } catch (IOException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); } catch (TimeoutException e) { - throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { if (isSemAcquired) { sem.release(); @@ -526,7 +526,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu } catch (MalformedURLException | URISyntaxException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); } catch (TimeoutException e) { - throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e); + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { if (isSemAcquired) { sem.release(); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java index d2f830e90..787c8151e 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java @@ -1336,7 +1336,7 @@ public void testGetSqlFedAuthTokenFailure() throws SQLException { fail(TestResource.getResource("R_expectedExceptionNotThrown")); } catch (SQLServerException e) { //test pass - assertEquals(e.getMessage(), SQLServerException.getErrString("R_connectionTimedOut"), "Expected Timeout Exception was not thrown"); + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); } } @@ -1351,7 +1351,7 @@ public void testGetSqlFedAuthTokenFailureNoWaiting() throws SQLException { fail(TestResource.getResource("R_expectedExceptionNotThrown")); } catch (SQLServerException e) { //test pass - assertEquals(e.getMessage(), SQLServerException.getErrString("R_connectionTimedOut"), "Expected Timeout Exception was not thrown"); + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); } } @@ -1366,7 +1366,7 @@ public void testGetSqlFedAuthTokenFailureNagativeWaiting() throws SQLException { fail(TestResource.getResource("R_expectedExceptionNotThrown")); } catch (SQLServerException e) { //test pass - assertEquals(e.getMessage(), SQLServerException.getErrString("R_connectionTimedOut"), "Expected Timeout Exception was not thrown"); + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); } }