Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduced timeouts for MSAL calls. #2562

Merged
merged 13 commits into from
Jan 15, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -6110,10 +6110,11 @@
}

while (true) {
int millisecondsRemaining = timerRemaining(timerExpire);
machavan marked this conversation as resolved.
Show resolved Hide resolved
if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString())) {
fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user,
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
authenticationString);
authenticationString, millisecondsRemaining);

// Break out of the retry loop in successful case.
break;
Expand Down Expand Up @@ -6141,12 +6142,12 @@
if (aadPrincipalID != null && !aadPrincipalID.isEmpty() && aadPrincipalSecret != null
&& !aadPrincipalSecret.isEmpty()) {
fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, aadPrincipalID,
aadPrincipalSecret, authenticationString);
aadPrincipalSecret, authenticationString, millisecondsRemaining);
} else {
fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo,
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()),
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
authenticationString);
authenticationString, millisecondsRemaining);
}

// Break out of the retry loop in successful case.
Expand All @@ -6159,7 +6160,7 @@
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()),
servicePrincipalCertificate,
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()),
servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString);
servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString, millisecondsRemaining);

// Break out of the retry loop in successful case.
break;
Expand Down Expand Up @@ -6194,7 +6195,7 @@
throw new SQLServerException(form.format(msgArgs), null);
}

int millisecondsRemaining = timerRemaining(timerExpire);
millisecondsRemaining = timerRemaining(timerExpire);

Check warning on line 6198 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java#L6198

Added line #L6198 was not covered by tests
if (ActiveDirectoryAuthentication.GET_ACCESS_TOKEN_TRANSIENT_ERROR != errorCategory
|| timerHasExpired(timerExpire) || (fedauthSleepInterval >= millisecondsRemaining)) {

Expand Down Expand Up @@ -6240,15 +6241,15 @@
Object[] msgArgs = {SQLServerDriver.AUTH_DLL_NAME, authenticationString};
throw new SQLServerException(form.format(msgArgs), null, 0, null);
}
fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString);
fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString, millisecondsRemaining);

Check warning on line 6244 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java#L6244

Added line #L6244 was not covered by tests
}
// Break out of the retry loop in successful case.
break;
} else if (authenticationString
.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTERACTIVE.toString())) {
// interactive flow
fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenInteractive(fedAuthInfo, user,
authenticationString);
authenticationString, millisecondsRemaining);

// Break out of the retry loop in successful case.
break;
Expand All @@ -6258,12 +6259,12 @@

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
managedIdentityClientId);
managedIdentityClientId, millisecondsRemaining);
break;
}

fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn,
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()));
activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()), millisecondsRemaining);

Check warning on line 6267 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java#L6267

Added line #L6267 was not covered by tests

break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
Expand Down Expand Up @@ -80,7 +81,7 @@
private static final Lock lock = new ReentrantLock();

static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password,
String authenticationString) throws SQLServerException {
String authenticationString, int millisecondsRemaining) throws SQLServerException {
ExecutorService executorService = Executors.newSingleThreadExecutor();

if (logger.isLoggable(Level.FINEST)) {
Expand Down Expand Up @@ -116,7 +117,7 @@
.builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user, password.toCharArray())
.build());

final IAuthenticationResult authenticationResult = future.get();
final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS);

Check warning on line 120 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java#L120

Added line #L120 was not covered by tests

if (logger.isLoggable(Level.FINER)) {
logger.finer(
Expand All @@ -132,14 +133,16 @@
throw new SQLServerException(e.getMessage(), e);
} catch (MalformedURLException | ExecutionException e) {
throw getCorrectedException(e, user, authenticationString);
} catch (TimeoutException e) {
throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e);
} finally {
lock.unlock();
executorService.shutdown();
}
}

static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID,
String aadPrincipalSecret, String authenticationString) throws SQLServerException {
String aadPrincipalSecret, String authenticationString, int millisecondsRemaining) throws SQLServerException {
ExecutorService executorService = Executors.newSingleThreadExecutor();

if (logger.isLoggable(Level.FINEST)) {
Expand Down Expand Up @@ -181,7 +184,7 @@

final CompletableFuture<IAuthenticationResult> future = clientApplication
.acquireToken(ClientCredentialParameters.builder(scopes).build());
final IAuthenticationResult authenticationResult = future.get();
final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS);

Check warning on line 187 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java#L187

Added line #L187 was not covered by tests

if (logger.isLoggable(Level.FINER)) {
logger.finer(
Expand All @@ -197,6 +200,8 @@
throw new SQLServerException(e.getMessage(), e);
} catch (MalformedURLException | ExecutionException e) {
throw getCorrectedException(e, aadPrincipalID, authenticationString);
} catch (TimeoutException e) {
throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e);
} finally {
lock.unlock();
executorService.shutdown();
Expand All @@ -205,7 +210,7 @@

static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthInfo fedAuthInfo,
String aadPrincipalID, String certFile, String certPassword, String certKey, String certKeyPassword,
String authenticationString) throws SQLServerException {
String authenticationString, int millisecondsRemaining) throws SQLServerException {
ExecutorService executorService = Executors.newSingleThreadExecutor();

if (logger.isLoggable(Level.FINEST)) {
Expand Down Expand Up @@ -297,7 +302,7 @@

final CompletableFuture<IAuthenticationResult> future = clientApplication
.acquireToken(ClientCredentialParameters.builder(scopes).build());
final IAuthenticationResult authenticationResult = future.get();
final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS);

Check warning on line 305 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java#L305

Added line #L305 was not covered by tests

machavan marked this conversation as resolved.
Show resolved Hide resolved
if (logger.isLoggable(Level.FINER)) {
logger.finer(
Expand Down Expand Up @@ -325,7 +330,7 @@
}

static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo,
String authenticationString) throws SQLServerException {
String authenticationString, int millisecondsRemaining) throws SQLServerException {
ExecutorService executorService = Executors.newSingleThreadExecutor();

/*
Expand All @@ -352,7 +357,7 @@
.acquireToken(IntegratedWindowsAuthenticationParameters
.builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user).build());

final IAuthenticationResult authenticationResult = future.get();
final IAuthenticationResult authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS);

Check warning on line 360 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java#L360

Added line #L360 was not covered by tests

if (logger.isLoggable(Level.FINER)) {
logger.finer(
Expand All @@ -368,14 +373,16 @@
throw new SQLServerException(e.getMessage(), e);
} catch (IOException | ExecutionException e) {
throw getCorrectedException(e, user, authenticationString);
} catch (TimeoutException e) {
throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e);

Check warning on line 377 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java#L376-L377

Added lines #L376 - L377 were not covered by tests
} finally {
lock.unlock();
executorService.shutdown();
}
}

static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo, String user,
String authenticationString) throws SQLServerException {
String authenticationString, int millisecondsRemaining) throws SQLServerException {
ExecutorService executorService = Executors.newSingleThreadExecutor();

if (logger.isLoggable(Level.FINER)) {
Expand Down Expand Up @@ -432,7 +439,7 @@
}

if (null != future) {
authenticationResult = future.get();
authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS);

Check warning on line 442 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java#L442

Added line #L442 was not covered by tests
} else {
// acquire token interactively with system browser
if (logger.isLoggable(Level.FINEST)) {
Expand All @@ -444,7 +451,7 @@
.loginHint(user).scopes(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT)).build();

future = pca.acquireToken(parameters);
authenticationResult = future.get();
authenticationResult = future.get(millisecondsRemaining, TimeUnit.MILLISECONDS);

Check warning on line 454 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java#L454

Added line #L454 was not covered by tests
}

if (logger.isLoggable(Level.FINER)) {
Expand All @@ -461,6 +468,8 @@
throw new SQLServerException(e.getMessage(), e);
} catch (MalformedURLException | URISyntaxException | ExecutionException e) {
throw getCorrectedException(e, user, authenticationString);
} catch (TimeoutException e) {
throw new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e);

Check warning on line 472 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java#L471-L472

Added lines #L471 - L472 were not covered by tests
} finally {
lock.unlock();
executorService.shutdown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.text.MessageFormat;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Optional;
Expand Down Expand Up @@ -408,7 +410,7 @@
* @throws SQLServerException
*/
static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource,
String managedIdentityClientId) throws SQLServerException {
String managedIdentityClientId, int millisecondsRemaining) throws SQLServerException {
String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS);
String[] additionallyAllowedTenants = getAdditonallyAllowedTenants();

Expand Down Expand Up @@ -463,7 +465,7 @@

SqlAuthenticationToken sqlFedAuthToken = null;

Optional<AccessToken> accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional();
Optional<AccessToken> accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional(Duration.of(millisecondsRemaining, ChronoUnit.MILLIS));

Check warning on line 468 in src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java#L468

Added line #L468 was not covered by tests

if (!accessTokenOptional.isPresent()) {
throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import com.microsoft.aad.msal4j.TokenCache;
import com.microsoft.aad.msal4j.TokenCacheAccessContext;
import com.microsoft.sqlserver.jdbc.SQLServerConnection.SqlFedAuthInfo;
import com.microsoft.sqlserver.testframework.AbstractSQLGenerator;
import com.microsoft.sqlserver.testframework.AbstractTest;
import com.microsoft.sqlserver.testframework.Constants;
Expand All @@ -50,6 +51,7 @@

@RunWith(JUnitPlatform.class)
public class SQLServerConnectionTest extends AbstractTest {

// If no retry is done, the function should at least exit in 5 seconds
static int threshHoldForNoRetryInMilliseconds = 5000;
static int loginTimeOutInSeconds = 10;
Expand Down Expand Up @@ -1321,4 +1323,34 @@ public void testServerNameField() throws SQLException {
assertTrue(e.getMessage().matches(TestUtils.formatErrorMsg("R_errorServerName")));
}
}


@Test
public void testGetSqlFedAuthTokenFailure() throws SQLException {
try (Connection conn = getConnection()){
SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo();
fedAuthInfo.spn = "https://database.windows.net/";
fedAuthInfo.stsurl = "https://login.windows.net/xxx";
SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx",
"xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), 0);
machavan marked this conversation as resolved.
Show resolved Hide resolved
fail("Expected the test to throw SQLServerException");
machavan marked this conversation as resolved.
Show resolved Hide resolved
} catch (SQLServerException e) {
//test pass
machavan marked this conversation as resolved.
Show resolved Hide resolved
}
}


@Test
public void testGetSqlFedAuthTokenPrincipalFailure() throws SQLException {
try (Connection conn = getConnection()){
SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo();
fedAuthInfo.spn = "https://database.windows.net/";
fedAuthInfo.stsurl = "https://login.windows.net/xxx";
SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, "xxx",
"xxx",SqlAuthentication.ACTIVE_DIRECTORY_SERVICE_PRINCIPAL.toString(), 0);
fail("Expected the test to throw SQLServerException");
machavan marked this conversation as resolved.
Show resolved Hide resolved
} catch (SQLServerException e) {
}
machavan marked this conversation as resolved.
Show resolved Hide resolved
}

}
Loading