Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try to keep an SSLSocketFactory instance for an SslConfig instance ma… #251

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/main/java/com/bettercloud/vault/SslConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.SSLSocketFactory;

/**
* <p>A container for SSL-related configuration options, meant to be stored within a {@link VaultConfig} instance.</p>
Expand All @@ -46,6 +47,7 @@ public class SslConfig implements Serializable {

private boolean verify;
private transient SSLContext sslContext;
private transient SSLSocketFactory sslSocketFactory;
private transient KeyStore trustStore;
private transient KeyStore keyStore;
private String keyStorePassword;
Expand Down Expand Up @@ -469,6 +471,10 @@ public SSLContext getSslContext() {
return sslContext;
}

public SSLSocketFactory getSslSocketFactory() {
return sslSocketFactory;
}

protected String getPemUTF8() {
return pemUTF8;
}
Expand All @@ -489,8 +495,10 @@ private void buildSsl() throws VaultException {
if (verify) {
if (keyStore != null || trustStore != null) {
this.sslContext = buildSslContextFromJks();
this.sslSocketFactory = sslContext.getSocketFactory();
} else if (pemUTF8 != null || clientPemUTF8 != null || clientKeyPemUTF8 != null) {
this.sslContext = buildSslContextFromPem();
this.sslSocketFactory = sslContext.getSocketFactory();
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/com/bettercloud/vault/api/Logical.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ private LogicalResponse read(final String path, Boolean shouldRetry, final logic
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.get();

// Validate response - don't treat 4xx class errors as exceptions, we want to return an error as the response
Expand Down Expand Up @@ -160,6 +161,7 @@ public LogicalResponse read(final String path, Boolean shouldRetry, final Intege
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.get();

// Validate response - don't treat 4xx class errors as exceptions, we want to return an error as the response
Expand Down Expand Up @@ -261,6 +263,7 @@ private LogicalResponse write(final String path, final Map<String, Object> nameV
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.post();

// HTTP Status should be either 200 (with content - e.g. PKI write) or 204 (no content)
Expand Down Expand Up @@ -352,6 +355,7 @@ private LogicalResponse delete(final String path, final Logical.logicalOperation
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.delete();

// Validate response
Expand Down Expand Up @@ -412,6 +416,7 @@ public LogicalResponse delete(final String path, final int[] versions) throws Va
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.body(versionsToDelete.toString().getBytes(StandardCharsets.UTF_8))
.post();

Expand Down Expand Up @@ -483,6 +488,7 @@ public LogicalResponse unDelete(final String path, final int[] versions) throws
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.body(versionsToUnDelete.toString().getBytes(StandardCharsets.UTF_8))
.post();

Expand Down Expand Up @@ -542,6 +548,7 @@ public LogicalResponse destroy(final String path, final int[] versions) throws V
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.body(versionsToDestroy.toString().getBytes(StandardCharsets.UTF_8))
.post();

Expand Down Expand Up @@ -593,6 +600,7 @@ public LogicalResponse upgrade(final String kvPath) throws VaultException {
.readTimeoutSeconds(config.getReadTimeout())
.sslVerification(config.getSslConfig().isVerify())
.sslContext(config.getSslConfig().getSslContext())
.sslSocketFactory(config.getSslConfig().getSslSocketFactory())
.body(kvToUpgrade.toString().getBytes(StandardCharsets.UTF_8))
.post();

Expand Down
32 changes: 27 additions & 5 deletions src/main/java/com/bettercloud/vault/rest/Rest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import javax.net.ssl.SSLSocketFactory;

/**
* <p>A simple client for issuing HTTP requests. Supports the HTTP verbs:</p>
Expand Down Expand Up @@ -66,6 +67,7 @@ public class Rest {
* verification process, to always trust any certificates.
*/
private static SSLContext DISABLED_SSL_CONTEXT;
private static SSLSocketFactory DISABLED_SSL_SOCKET_FACTORY;

static {
try {
Expand All @@ -84,6 +86,7 @@ public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
}}, new java.security.SecureRandom());
DISABLED_SSL_SOCKET_FACTORY = DISABLED_SSL_CONTEXT.getSocketFactory();
} catch (NoSuchAlgorithmException | KeyManagementException e) {
e.printStackTrace();
}
Expand All @@ -98,6 +101,7 @@ public X509Certificate[] getAcceptedIssuers() {
private Integer readTimeoutSeconds;
private Boolean sslVerification;
private SSLContext sslContext;
private SSLSocketFactory sslSocketFactory;

/**
* <p>Sets the base URL to which the HTTP request will be sent. The URL may or may not include query parameters
Expand Down Expand Up @@ -248,6 +252,11 @@ public Rest sslContext(final SSLContext sslContext) {
return this;
}

public Rest sslSocketFactory(final SSLSocketFactory sslSocketFactory) {
this.sslSocketFactory = sslSocketFactory;
return this;
}

/**
* <p>Executes an HTTP GET request with the settings already configured. Parameters and headers are optional, but
* a <code>RestException</code> will be thrown if the caller has not first set a base URL with the
Expand Down Expand Up @@ -446,8 +455,11 @@ private URLConnection initURLConnection(final String urlString, final String met
final HttpsURLConnection httpsURLConnection = (HttpsURLConnection) connection;
if (sslVerification != null && !sslVerification) {
// SSL verification disabled
httpsURLConnection.setSSLSocketFactory(DISABLED_SSL_CONTEXT.getSocketFactory());
httpsURLConnection.setSSLSocketFactory(DISABLED_SSL_SOCKET_FACTORY);
httpsURLConnection.setHostnameVerifier((s, sslSession) -> true);
} else if (sslSocketFactory != null) {
// Socket factory supplied for keep-alive connections
httpsURLConnection.setSSLSocketFactory(sslSocketFactory);
} else if (sslContext != null) {
// Cert file supplied
httpsURLConnection.setSSLSocketFactory(sslContext.getSocketFactory());
Expand All @@ -463,11 +475,10 @@ private URLConnection initURLConnection(final String urlString, final String met

return connection;
} catch (Exception e) {
throw new RestException(e);
} finally {
if (connection instanceof HttpURLConnection) {
if (connection != null && connection instanceof HttpURLConnection) {
((HttpURLConnection) connection).disconnect();
}
throw new RestException(e);
}
}

Expand Down Expand Up @@ -499,8 +510,8 @@ private String parametersToQueryString() {
* @throws RestException
*/
private byte[] responseBodyBytes(final URLConnection connection) throws RestException {
InputStream inputStream = null;
try {
final InputStream inputStream;
final int responseCode = this.connectionStatus(connection);
if (200 <= responseCode && responseCode <= 299) {
inputStream = connection.getInputStream();
Expand All @@ -519,9 +530,20 @@ private byte[] responseBodyBytes(final URLConnection connection) throws RestExce
while ((bytesRead = inputStream.read(bytes, 0, bytes.length)) != -1) {
byteArrayOutputStream.write(bytes, 0, bytesRead);
}
inputStream.close();
byteArrayOutputStream.flush();
return byteArrayOutputStream.toByteArray();
} catch (IOException e) {
try {
if (inputStream == null) {
inputStream = ((HttpURLConnection) connection).getErrorStream();
}
if (inputStream != null) {
inputStream.close();
}
} catch (IOException ee) {
//do nothing
}
return new byte[0];
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.bettercloud.vault.response.LogicalResponse;
import com.bettercloud.vault.util.VaultContainer;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -443,4 +444,56 @@ public void testVaultUpgrade() throws VaultException {
Assert.assertEquals(kVOriginalVersion, "1");
Assert.assertEquals(kVUpgradedVersion, "2");
}

/**
* Verify that value can be read several times with connection re-usage
*
* @throws VaultException
*/
@Test
public void testReadSeveralTimesWithConnectionReUsage() throws VaultException {
final int readTimes = 10;
final String pathToWrite = "secret/hello";
final String pathToRead = "secret/hello";

final String value = "world";
final Map<String, Object> testMap = new HashMap<>();
testMap.put("value", value);

final Vault vault = container.getRootVault();
String hostport = String.format("%s.%s", container.getContainerIpAddress(), container.getMappedPort(8200));
Logical logical = vault.logical();

int connBefore = connStat(hostport);
logical.write(pathToWrite, testMap);
for(int i = 0; i < readTimes; i++) {
final String valueRead = logical.read(pathToRead).getData().get("value");
assertEquals(value, valueRead);
}
int connCreated = connStat(hostport) - connBefore;
assertTrue("Too many new connections to '" + hostport + "' created: " + connCreated, connCreated <= 1);
}

private int connStat(String host) {
ProcessBuilder pb = new ProcessBuilder();
pb.command("netstat");
pb.redirectErrorStream(true);

try {
Process p = pb.start();
InputStream inputStream = p.getInputStream();
String result = new String(inputStream.readAllBytes());
int conn = 0;
for (String line : result.split("\n")) {
if (line.matches(".*" + host + "\\s+ESTABLISHED")) {
conn++;
System.out.println(line);
}
}
return conn;
} catch (IOException e) {
System.err.println("Error executing netstat: " + e.getMessage());
return 0;
}
}
}