Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature | Introduced support for Azure SQL DNS Caching #1357

Merged
merged 8 commits into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ final class TDS {
static final int AE_METADATA = 0x08;

static final byte TDS_FEATURE_EXT_UTF8SUPPORT = 0x0A;
static final byte TDS_FEATURE_EXT_AZURESQLDNSCACHING = 0x0B;

static final int TDS_TVP = 0xF3;
static final int TVP_ROW = 0x01;
Expand Down Expand Up @@ -195,6 +196,8 @@ static final String getTokenName(int tdsTokenType) {
return "TDS_FEATURE_EXT_DATACLASSIFICATION (0x09)";
case TDS_FEATURE_EXT_UTF8SUPPORT:
return "TDS_FEATURE_EXT_UTF8SUPPORT (0x0A)";
case TDS_FEATURE_EXT_AZURESQLDNSCACHING:
return "TDS_FEATURE_EXT_AZURESQLDNSCACHING (0x0B)";
default:
return "unknown token (0x" + Integer.toHexString(tdsTokenType).toUpperCase() + ")";
}
Expand Down Expand Up @@ -652,16 +655,17 @@ void resetPooledConnection() {

/**
* Opens the physical communications channel (TCP/IP socket and I/O streams) to the SQL Server.
*
* @return InetSocketAddress of the connection socket.
*/
final void open(String host, int port, int timeoutMillis, boolean useParallel, boolean useTnir,
final InetSocketAddress open(String host, int port, int timeoutMillis, boolean useParallel, boolean useTnir,
boolean isTnirFirstAttempt, int timeoutMillisForFullTimeout) throws SQLServerException {
if (logger.isLoggable(Level.FINER))
logger.finer(this.toString() + ": Opening TCP socket...");

SocketFinder socketFinder = new SocketFinder(traceID, con);
channelSocket = tcpSocket = socketFinder.findSocket(host, port, timeoutMillis, useParallel, useTnir,
isTnirFirstAttempt, timeoutMillisForFullTimeout);

try {

// Set socket options
Expand All @@ -677,6 +681,7 @@ final void open(String host, int port, int timeoutMillis, boolean useParallel, b
} catch (IOException ex) {
SQLServerException.ConvertConnectExceptionToSQLServerException(host, port, con, ex);
}
return (InetSocketAddress) channelSocket.getRemoteSocketAddress();
}

/**
Expand Down Expand Up @@ -2333,6 +2338,16 @@ Socket findSocket(String hostName, int portNumber, int timeoutInMilliSeconds, bo
try {
InetAddress[] inetAddrs = null;

if (!useParallel) {
// MSF is false. TNIR could be true or false. DBMirroring could be true or false.
// For TNIR first attempt, we should do existing behavior including how host name is resolved.
if (useTnir && isTnirFirstAttempt) {
return getDefaultSocket(hostName, portNumber, SQLServerConnection.TnirFirstAttemptTimeoutMs);
} else if (!useTnir) {
return getDefaultSocket(hostName, portNumber, timeoutInMilliSeconds);
}
}

// inetAddrs is only used if useParallel is true or TNIR is true. Skip resolving address if that's not the
// case.
if (useParallel || useTnir) {
Expand All @@ -2345,16 +2360,6 @@ Socket findSocket(String hostName, int portNumber, int timeoutInMilliSeconds, bo
}
}

if (!useParallel) {
// MSF is false. TNIR could be true or false. DBMirroring could be true or false.
// For TNIR first attempt, we should do existing behavior including how host name is resolved.
if (useTnir && isTnirFirstAttempt) {
return getDefaultSocket(hostName, portNumber, SQLServerConnection.TnirFirstAttemptTimeoutMs);
} else if (!useTnir) {
return getDefaultSocket(hostName, portNumber, timeoutInMilliSeconds);
}
}

// Code reaches here only if MSF = true or (TNIR = true and not TNIR first attempt)

if (logger.isLoggable(Level.FINER)) {
Expand Down Expand Up @@ -2645,6 +2650,14 @@ private Socket getDefaultSocket(String hostName, int portNumber, int timeoutInMi
// cannot be resolved, but that InetSocketAddress(host, port) does not - it sets
// the returned InetSocketAddress as unresolved.
InetSocketAddress addr = new InetSocketAddress(hostName, portNumber);
if (addr.isUnresolved()) {
if (logger.isLoggable(Level.FINER)) {
logger.finer(this.toString() + "Failed to resolve host name: " + hostName
+ ". Using IP address from DNS cache.");
}
InetSocketAddress cacheEntry = SQLServerConnection.getDNSEntry(hostName);
addr = (null != cacheEntry) ? cacheEntry : addr;
}
return getConnectedSocket(addr, timeoutInMilliSeconds);
}

Expand Down
76 changes: 64 additions & 12 deletions src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.sql.CallableStatement;
Expand All @@ -37,6 +38,7 @@
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -686,10 +688,17 @@ boolean getServerSupportsDataClassification() {
return serverSupportsDataClassification;
}

private boolean serverSupportsDNSCaching = false;
private static ConcurrentHashMap<String, InetSocketAddress> dnsCache = null;

static InetSocketAddress getDNSEntry(String key) {
return (null != dnsCache) ? dnsCache.get(key) : null;
}

byte getServerSupportedDataClassificationVersion() {
return serverSupportedDataClassificationVersion;
}

// Boolean that indicates whether LOB objects created by this connection should be loaded into memory
private boolean delayLoadingLobs = SQLServerDriverBooleanProperty.DELAY_LOADING_LOBS.getDefaultValue();

Expand Down Expand Up @@ -2344,9 +2353,15 @@ private void login(String primary, String primaryInstanceName, int primaryPortNu
}

// Attempt login. Use Place holder to make sure that the failoverdemand is done.
connectHelper(currentConnectPlaceHolder, timerRemaining(intervalExpire), timeout, useParallel, useTnir,
(0 == attemptNumber), // is this the TNIR first attempt
InetSocketAddress inetSocketAddress = connectHelper(currentConnectPlaceHolder,
timerRemaining(intervalExpire), timeout, useParallel, useTnir, (0 == attemptNumber), // TNIR
// first
// attempt
timerRemaining(intervalExpireFullTimeout)); // Only used when host resolves to >64 IPs
// Successful connection, cache the IP address and port if server supports DNS Cache.
if (serverSupportsDNSCaching) {
dnsCache.put(currentConnectPlaceHolder.getServerName(), inetSocketAddress);
}

if (isRoutedInCurrentAttempt) {
// we ignore the failoverpartner ENVCHANGE if we got routed so no error needs to be thrown
Expand Down Expand Up @@ -2647,10 +2662,11 @@ static int timerRemaining(long timerExpire) {
* @param useTnir
* @param isTnirFirstAttempt
* @param timeOutsliceInMillisForFullTimeout
* @return InetSocketAddress of the connected socket.
* @throws SQLServerException
*/
private void connectHelper(ServerPortPlaceHolder serverInfo, int timeOutSliceInMillis, int timeOutFullInSeconds,
boolean useParallel, boolean useTnir, boolean isTnirFirstAttempt,
private InetSocketAddress connectHelper(ServerPortPlaceHolder serverInfo, int timeOutSliceInMillis,
int timeOutFullInSeconds, boolean useParallel, boolean useTnir, boolean isTnirFirstAttempt,
int timeOutsliceInMillisForFullTimeout) throws SQLServerException {
// Make the initial tcp-ip connection.

Expand All @@ -2670,12 +2686,9 @@ private void connectHelper(ServerPortPlaceHolder serverInfo, int timeOutSliceInM

// if the timeout is infinite slices are infinite too.
tdsChannel = new TDSChannel(this);
if (0 == timeOutFullInSeconds)
tdsChannel.open(serverInfo.getServerName(), serverInfo.getPortNumber(), 0, useParallel, useTnir,
isTnirFirstAttempt, timeOutsliceInMillisForFullTimeout);
else
tdsChannel.open(serverInfo.getServerName(), serverInfo.getPortNumber(), timeOutSliceInMillis, useParallel,
useTnir, isTnirFirstAttempt, timeOutsliceInMillisForFullTimeout);
InetSocketAddress inetSocketAddress = tdsChannel.open(serverInfo.getServerName(), serverInfo.getPortNumber(),
(0 == timeOutFullInSeconds) ? 0 : timeOutSliceInMillis, useParallel, useTnir, isTnirFirstAttempt,
timeOutsliceInMillisForFullTimeout);

setState(State.Connected);

Expand All @@ -2692,6 +2705,7 @@ private void connectHelper(ServerPortPlaceHolder serverInfo, int timeOutSliceInM

// We have successfully connected, now do the login. logon takes seconds timeout
executeCommand(new LogonCommand());
return inetSocketAddress;
}

/**
Expand Down Expand Up @@ -3894,6 +3908,16 @@ int writeUTF8SupportFeatureRequest(boolean write, /* if false just calculates th
return len;
}

int writeDNSCacheFeatureRequest(boolean write, /* if false just calculates the length */
TDSWriter tdsWriter) throws SQLServerException {
int len = 5; // 1byte = featureID, 4bytes = featureData length
if (write) {
tdsWriter.writeByte(TDS.TDS_FEATURE_EXT_AZURESQLDNSCACHING);
tdsWriter.writeInt(0);
}
return len;
}

private final class LogonCommand extends UninterruptableTDSCommand {
// Always update serialVersionUID when prompted.
private static final long serialVersionUID = 1L;
Expand Down Expand Up @@ -4580,7 +4604,8 @@ final void processFeatureExtAck(TDSReader tdsReader) throws SQLServerException {
}

private void onFeatureExtAck(byte featureId, byte[] data) throws SQLServerException {
if (null != routingInfo) {
// To be able to cache both control and tenant ring IPs, need to parse AZURESQLDNSCACHING.
if (null != routingInfo && TDS.TDS_FEATURE_EXT_AZURESQLDNSCACHING != featureId) {
return;
}

Expand Down Expand Up @@ -4695,6 +4720,30 @@ private void onFeatureExtAck(byte featureId, byte[] data) throws SQLServerExcept
}
break;
}
case TDS.TDS_FEATURE_EXT_AZURESQLDNSCACHING: {
if (connectionlogger.isLoggable(Level.FINER)) {
connectionlogger.fine(
toString() + " Received feature extension acknowledgement for Azure SQL DNS Caching.");
}

if (1 > data.length) {
throw new SQLServerException(SQLServerException.getErrString("R_unknownAzureSQLDNSCachingValue"),
null);
}

if (1 == data[0]) {
serverSupportsDNSCaching = true;
if (null == dnsCache) {
dnsCache = new ConcurrentHashMap<String, InetSocketAddress>();
}
} else {
serverSupportsDNSCaching = false;
if (null != dnsCache) {
dnsCache.remove(currentConnectPlaceHolder.getServerName());
}
}
break;
}
default: {
// Unknown feature ack
throw new SQLServerException(SQLServerException.getErrString("R_UnknownFeatureAck"), null);
Expand Down Expand Up @@ -4979,6 +5028,8 @@ final boolean complete(LogonCommand logonCommand, TDSReader tdsReader) throws SQ

len = len + writeUTF8SupportFeatureRequest(false, tdsWriter);

len = len + writeDNSCacheFeatureRequest(false, tdsWriter);

len = len + 1; // add 1 to length because of FeatureEx terminator

// Length of entire Login 7 packet
Expand Down Expand Up @@ -5168,6 +5219,7 @@ final boolean complete(LogonCommand logonCommand, TDSReader tdsReader) throws SQ

writeDataClassificationFeatureRequest(true, tdsWriter);
writeUTF8SupportFeatureRequest(true, tdsWriter);
writeDNSCacheFeatureRequest(true, tdsWriter);

tdsWriter.writeByte((byte) TDS.FEATURE_EXT_TERMINATOR);
tdsWriter.setDataLoggable(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ protected Object[][] getContents() {
{"R_UnknownDataClsTokenNumber", "Unknown token for Data Classification."}, // From Server
{"R_InvalidDataClsVersionNumber", "Invalid version number {0} for Data Classification."}, // From Server
{"R_unknownUTF8SupportValue", "Unknown value for UTF8 support."},
{"R_unknownAzureSQLDNSCachingValue", "Unknown value for Azure SQL DNS Caching."},
{"R_illegalWKT", "Illegal Well-Known text. Please make sure Well-Known text is valid."},
{"R_illegalTypeForGeometry", "{0} is not supported for Geometry."},
{"R_illegalWKTposition", "Illegal character in Well-Known text at position {0}."},
Expand Down