Skip to content

Commit

Permalink
Fix creating lots of Sessions (#1208)
Browse files Browse the repository at this point in the history
* trigger CI

* Fix bigInt issue

* Add new line

* fix calling sdpe

* move to util class

* Change atomic int to atomic long

* Remove methods that are incompatible with JDK 8

* Fix issues when using JDK 8, use BouncyCastle as a workaround

* Remove swap file

* Editing comment for jdk8

Co-authored-by: ulvii <[email protected]>
  • Loading branch information
rene-ye and ulvii authored Jan 9, 2020
1 parent ee2708e commit b3ae62f
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 50 deletions.
7 changes: 7 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@
<version>${google.gson.version}</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk15on</artifactId>
<version>1.64</version>
<optional>true</optional>
</dependency>


<!-- dependencies provided by an OSGi-Framework -->
<dependency>
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import java.util.Hashtable;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

import javax.crypto.KeyAgreement;

Expand All @@ -49,40 +49,41 @@
*
*/
public interface ISQLServerEnclaveProvider {
static final String proc = "EXEC sp_describe_parameter_encryption ?,?,?";
static final String SDPE1 = "EXEC sp_describe_parameter_encryption ?,?";
static final String SDPE2 = "EXEC sp_describe_parameter_encryption ?,?,?";

default byte[] getEnclavePackage(String userSQL, ArrayList<byte[]> enclaveCEKs) throws SQLServerException {
EnclaveSession enclaveSession = getEnclaveSession();
if (null != enclaveSession) {
try {
ByteArrayOutputStream enclavePackage = new ByteArrayOutputStream();
enclavePackage.writeBytes(enclaveSession.getSessionID());
enclavePackage.write(enclaveSession.getSessionID());
ByteArrayOutputStream keys = new ByteArrayOutputStream();
byte[] randomGUID = new byte[16];
SecureRandom.getInstanceStrong().nextBytes(randomGUID);
keys.writeBytes(randomGUID);
keys.writeBytes(ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN)
.putLong(enclaveSession.getCounter()).array());
keys.writeBytes(MessageDigest.getInstance("SHA-256").digest((userSQL).getBytes(UTF_16LE)));
keys.write(randomGUID);
keys.write(ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(enclaveSession.getCounter())
.array());
keys.write(MessageDigest.getInstance("SHA-256").digest((userSQL).getBytes(UTF_16LE)));
for (byte[] b : enclaveCEKs) {
keys.writeBytes(b);
keys.write(b);
}
enclaveCEKs.clear();
SQLServerAeadAes256CbcHmac256EncryptionKey encryptedKey = new SQLServerAeadAes256CbcHmac256EncryptionKey(
enclaveSession.getSessionSecret(), SQLServerAeadAes256CbcHmac256Algorithm.algorithmName);
SQLServerAeadAes256CbcHmac256Algorithm algo = new SQLServerAeadAes256CbcHmac256Algorithm(encryptedKey,
SQLServerEncryptionType.Randomized, (byte) 0x1);
enclavePackage.writeBytes(algo.encryptData(keys.toByteArray()));
enclavePackage.write(algo.encryptData(keys.toByteArray()));
return enclavePackage.toByteArray();
} catch (GeneralSecurityException | SQLServerException e) {
} catch (GeneralSecurityException | SQLServerException | IOException e) {
SQLServerException.makeFromDriverError(null, this, e.getLocalizedMessage(), "0", false);
}
}
return null;
}

default ResultSet executeProc(PreparedStatement stmt, String userSql, String preparedTypeDefinitions,
BaseAttestationRequest req) throws SQLException {
default ResultSet executeSDPEv2(PreparedStatement stmt, String userSql, String preparedTypeDefinitions,
BaseAttestationRequest req) throws SQLException, IOException {
((SQLServerPreparedStatement) stmt).isInternalEncryptionQuery = true;
stmt.setNString(1, userSql);
if (preparedTypeDefinitions != null && preparedTypeDefinitions.length() != 0) {
Expand All @@ -94,7 +95,19 @@ default ResultSet executeProc(PreparedStatement stmt, String userSql, String pre
return ((SQLServerPreparedStatement) stmt).executeQueryInternal();
}

default void processAev1SPDE(String userSql, String preparedTypeDefinitions, Parameter[] params,
default ResultSet executeSDPEv1(PreparedStatement stmt, String userSql,
String preparedTypeDefinitions) throws SQLException {
((SQLServerPreparedStatement) stmt).isInternalEncryptionQuery = true;
stmt.setNString(1, userSql);
if (preparedTypeDefinitions != null && preparedTypeDefinitions.length() != 0) {
stmt.setNString(2, preparedTypeDefinitions);
} else {
stmt.setNString(2, "");
}
return ((SQLServerPreparedStatement) stmt).executeQueryInternal();
}

default void processSDPEv1(String userSql, String preparedTypeDefinitions, Parameter[] params,
ArrayList<String> parameterNames, SQLServerConnection connection, PreparedStatement stmt, ResultSet rs,
ArrayList<byte[]> enclaveRequestedCEKs) throws SQLException {
Map<Integer, CekTableEntry> cekList = new HashMap<>();
Expand Down Expand Up @@ -148,7 +161,7 @@ default void processAev1SPDE(String userSql, String preparedTypeDefinitions, Par

// Process the second resultset.
if (!stmt.getMoreResults()) {
throw new SQLServerException(this, SQLServerException.getErrString("R_UnexpectedDescribeParamFormat"), null,
throw new SQLServerException(null, SQLServerException.getErrString("R_UnexpectedDescribeParamFormat"), null,
0, false);
}

Expand All @@ -164,7 +177,7 @@ default void processAev1SPDE(String userSql, String preparedTypeDefinitions, Par
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_InvalidEncryptionKeyOrdinal"));
Object[] msgArgs = {cekOrdinal, cekEntry.getSize()};
throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
throw new SQLServerException(null, form.format(msgArgs), null, 0, false);
}
SQLServerEncryptionType encType = SQLServerEncryptionType
.of((byte) rs.getInt(DescribeParameterEncryptionResultSet2.ColumnEncrytionType.value()));
Expand All @@ -180,7 +193,7 @@ default void processAev1SPDE(String userSql, String preparedTypeDefinitions, Par
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_ForceEncryptionTrue_HonorAETrue_UnencryptedColumn"));
Object[] msgArgs = {userSql, paramIndex + 1};
SQLServerException.makeFromDriverError(connection, this, form.format(msgArgs), "0", true);
SQLServerException.makeFromDriverError(null, connection, form.format(msgArgs), "0", true);
}
}
}
Expand Down Expand Up @@ -241,7 +254,7 @@ abstract class BaseAttestationRequest {
protected byte[] x;
protected byte[] y;

byte[] getBytes() {
byte[] getBytes() throws IOException {
return null;
};

Expand Down Expand Up @@ -388,13 +401,13 @@ byte[] getSessionID() {

class EnclaveSession {
private byte[] sessionID;
private AtomicInteger counter;
private AtomicLong counter;
private byte[] sessionSecret;

EnclaveSession(byte[] cs, byte[] b) {
sessionID = cs;
sessionSecret = b;
counter = new AtomicInteger(0);
counter = new AtomicLong(0);
}

byte[] getSessionID() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,18 @@ private ArrayList<byte[]> describeParameterEncryption(SQLServerConnection connec
ArrayList<String> parameterNames) throws SQLServerException {
ArrayList<byte[]> enclaveRequestedCEKs = new ArrayList<>();
ResultSet rs = null;
try (PreparedStatement stmt = connection.prepareStatement(proc)) {
rs = executeProc(stmt, userSql, preparedTypeDefinitions, aasParams);
try (PreparedStatement stmt = connection.prepareStatement(connection.enclaveEstablished() ? SDPE1 : SDPE2)) {
if (connection.enclaveEstablished()) {
rs = executeSDPEv1(stmt, userSql, preparedTypeDefinitions);
} else {
rs = executeSDPEv2(stmt, userSql, preparedTypeDefinitions, aasParams);
}
if (null == rs) {
// No results. Meaning no parameter.
// Should never happen.
return enclaveRequestedCEKs;
}
processAev1SPDE(userSql, preparedTypeDefinitions, params, parameterNames, connection, stmt, rs,
processSDPEv1(userSql, preparedTypeDefinitions, params, parameterNames, connection, stmt, rs,
enclaveRequestedCEKs);
// Process the third resultset.
if (connection.isAEv2() && stmt.getMoreResults()) {
Expand All @@ -139,7 +143,7 @@ private ArrayList<byte[]> describeParameterEncryption(SQLServerConnection connec
}
// Null check for rs is done already.
rs.close();
} catch (SQLException e) {
} catch (SQLException | IOException e) {
if (e instanceof SQLServerException) {
throw (SQLServerException) e;
} else {
Expand All @@ -164,26 +168,26 @@ class AASAttestationParameters extends BaseAttestationRequest {
byte[] attestationUrlBytes = (attestationUrl + '\0').getBytes(UTF_16LE);

ByteArrayOutputStream os = new ByteArrayOutputStream();
os.writeBytes(ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(attestationUrlBytes.length).array());
os.writeBytes(attestationUrlBytes);
os.writeBytes(NONCE_LENGTH);
os.write(ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(attestationUrlBytes.length).array());
os.write(attestationUrlBytes);
os.write(NONCE_LENGTH);
new SecureRandom().nextBytes(nonce);
os.writeBytes(nonce);
os.write(nonce);
enclaveChallenge = os.toByteArray();

initBcryptECDH();
}

@Override
byte[] getBytes() {
byte[] getBytes() throws IOException {
ByteArrayOutputStream os = new ByteArrayOutputStream();
os.writeBytes(ENCLAVE_TYPE);
os.writeBytes(ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(enclaveChallenge.length).array());
os.writeBytes(enclaveChallenge);
os.writeBytes(ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(ENCLAVE_LENGTH).array());
os.writeBytes(ECDH_MAGIC);
os.writeBytes(x);
os.writeBytes(y);
os.write(ENCLAVE_TYPE);
os.write(ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(enclaveChallenge.length).array());
os.write(enclaveChallenge);
os.write(ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(ENCLAVE_LENGTH).array());
os.write(ECDH_MAGIC);
os.write(x);
os.write(y);
return os.toByteArray();
}

Expand Down Expand Up @@ -293,12 +297,12 @@ void validateToken(String attestationUrl, byte[] nonce) throws SQLServerExceptio
String authorityUrl = new URL(attestationUrl).getAuthority();
URL wellKnownUrl = new URL("https://" + authorityUrl + "/.well-known/openid-configuration");
URLConnection con = wellKnownUrl.openConnection();
String wellKnownUrlJson = new String(con.getInputStream().readAllBytes());
String wellKnownUrlJson = new String(Util.convertInputStreamToString(con.getInputStream()));
JsonObject attestationJson = JsonParser.parseString(wellKnownUrlJson).getAsJsonObject();
// Get our Keys
URL jwksUrl = new URL(attestationJson.get("jwks_uri").getAsString());
URLConnection jwksCon = jwksUrl.openConnection();
String jwksUrlJson = new String(jwksCon.getInputStream().readAllBytes());
String jwksUrlJson = new String(Util.convertInputStreamToString(jwksCon.getInputStream()));
JsonObject jwksJson = JsonParser.parseString(jwksUrlJson).getAsJsonObject();
keys = jwksJson.get("keys").getAsJsonArray();
certificateCache.put(attestationUrl, new JWTCertificateEntry(keys));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.GeneralSecurityException;
import java.security.NoSuchAlgorithmException;
import java.security.Security;
import java.security.Signature;
import java.security.SignatureException;
import java.security.cert.CertificateException;
Expand Down Expand Up @@ -119,7 +121,9 @@ private byte[] getAttestationCertificates() throws IOException {
if (null == certData) {
java.net.URL url = new java.net.URL(attestationUrl + "/attestationservice.svc/v2.0/signingCertificates/");
java.net.URLConnection con = url.openConnection();
String s = new String(con.getInputStream().readAllBytes());
byte[] buff = new byte[con.getInputStream().available()];
con.getInputStream().read(buff, 0, buff.length);
String s = new String(buff);
// omit the square brackets that come with the JSON
String[] bytesString = s.substring(1, s.length() - 1).split(",");
certData = new byte[bytesString.length];
Expand All @@ -136,14 +140,18 @@ private ArrayList<byte[]> describeParameterEncryption(SQLServerConnection connec
ArrayList<String> parameterNames) throws SQLServerException {
ArrayList<byte[]> enclaveRequestedCEKs = new ArrayList<>();
ResultSet rs = null;
try (PreparedStatement stmt = connection.prepareStatement(proc)) {
rs = executeProc(stmt, userSql, preparedTypeDefinitions, vsmParams);
try (PreparedStatement stmt = connection.prepareStatement(connection.enclaveEstablished() ? SDPE1 : SDPE2)) {
if (connection.enclaveEstablished()) {
rs = executeSDPEv1(stmt, userSql, preparedTypeDefinitions);
} else {
rs = executeSDPEv2(stmt, userSql, preparedTypeDefinitions, vsmParams);
}
if (null == rs) {
// No results. Meaning no parameter.
// Should never happen.
return enclaveRequestedCEKs;
}
processAev1SPDE(userSql, preparedTypeDefinitions, params, parameterNames, connection, stmt, rs,
processSDPEv1(userSql, preparedTypeDefinitions, params, parameterNames, connection, stmt, rs,
enclaveRequestedCEKs);
// Process the third resultset.
if (connection.isAEv2() && stmt.getMoreResults()) {
Expand All @@ -158,7 +166,7 @@ private ArrayList<byte[]> describeParameterEncryption(SQLServerConnection connec
}
// Null check for rs is done already.
rs.close();
} catch (SQLException e) {
} catch (SQLException | IOException e) {
if (e instanceof SQLServerException) {
throw (SQLServerException) e;
} else {
Expand All @@ -181,14 +189,14 @@ class VSMAttestationParameters extends BaseAttestationRequest {
}

@Override
byte[] getBytes() {
byte[] getBytes() throws IOException {
ByteArrayOutputStream os = new ByteArrayOutputStream();
os.writeBytes(ENCLAVE_TYPE);
os.writeBytes(enclaveChallenge);
os.writeBytes(ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(ENCLAVE_LENGTH).array());
os.writeBytes(ECDH_MAGIC);
os.writeBytes(x);
os.writeBytes(y);
os.write(ENCLAVE_TYPE);
os.write(enclaveChallenge);
os.write(ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(ENCLAVE_LENGTH).array());
os.write(ECDH_MAGIC);
os.write(x);
os.write(y);
return os.toByteArray();
}
}
Expand Down Expand Up @@ -319,7 +327,17 @@ void validateStatementSignature() throws SQLServerException, GeneralSecurityExce
SQLServerResource.getResource("R_EnclavePackageLengthError"), "0", false);
}

Signature sig = Signature.getInstance("RSASSA-PSS");
Signature sig = null;
try {
sig = Signature.getInstance("RSASSA-PSS");
} catch (NoSuchAlgorithmException e) {
/*
* RSASSA-PSS was added in JDK 11, the user might be using an older version of Java. Use BC as backup.
* Remove this logic if JDK 8 stops being supported or backports RSASSA-PSS
*/
Security.addProvider(new org.bouncycastle.jce.provider.BouncyCastleProvider());
sig = Signature.getInstance("RSASSA-PSS");
}
PSSParameterSpec pss = new PSSParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1);
sig.setParameter(pss);
sig.initVerify(healthCert);
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/com/microsoft/sqlserver/jdbc/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package com.microsoft.sqlserver.jdbc;

import java.io.IOException;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.net.InetAddress;
Expand Down Expand Up @@ -991,6 +992,16 @@ static boolean use43Wrapper() {
static String escapeSingleQuotes(String name) {
return name.replace("'", "''");
}

static String convertInputStreamToString(java.io.InputStream is) throws IOException {
java.io.ByteArrayOutputStream result = new java.io.ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int length;
while ((length = is.read(buffer)) != -1) {
result.write(buffer, 0, length);
}
return result.toString();
}
}


Expand Down

0 comments on commit b3ae62f

Please sign in to comment.