From da6db689853e2ba5781d74efd824a91ff3299c33 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Mon, 29 Apr 2019 13:08:35 -0600 Subject: [PATCH] Handle WRAP ops during SSL read (#41611) It is possible that a WRAP operation can occur while decrypting handshake data in TLS 1.3. The SSLDriver does not currently handle this well as it does not have access to the outbound buffer during read call. This commit moves the buffer into the Driver to fix this issue. Data wrapped during a read call will be queued for writing after the read call is complete. --- .../transport/nio/SSLChannelContext.java | 42 ++++++++------- .../security/transport/nio/SSLDriver.java | 51 ++++++++++--------- .../transport/nio/SSLChannelContextTests.java | 34 ++++++------- .../transport/nio/SSLDriverTests.java | 21 ++++---- 4 files changed, 77 insertions(+), 71 deletions(-) diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index 2c00dd7092950..9372cb1ec54fc 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -9,17 +9,16 @@ import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ReadWriteHandler; import org.elasticsearch.nio.SocketChannelContext; -import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.WriteOperation; import javax.net.ssl.SSLEngine; import java.io.IOException; -import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; +import java.util.LinkedList; import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -37,8 +36,7 @@ public final class SSLChannelContext extends SocketChannelContext { private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {}; private final SSLDriver sslDriver; - private final SSLOutboundBuffer outboundBuffer; - private FlushOperation encryptedFlush; + private final LinkedList encryptedFlushes = new LinkedList<>(); private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER; SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, @@ -51,14 +49,16 @@ public final class SSLChannelContext extends SocketChannelContext { Predicate allowChannelPredicate) { super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate); this.sslDriver = sslDriver; - // TODO: When the bytes are actually recycled, we need to test that they are released on context close - this.outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); } @Override public void register() throws IOException { super.register(); sslDriver.init(); + SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer(); + if (outboundBuffer.hasEncryptedBytesToFlush()) { + encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation()); + } } @Override @@ -98,11 +98,12 @@ public void flushChannel() throws IOException { try { // Attempt to encrypt application write data. The encrypted data ends up in the // outbound write buffer. - sslDriver.write(unencryptedFlush, outboundBuffer); + sslDriver.write(unencryptedFlush); + SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer(); if (outboundBuffer.hasEncryptedBytesToFlush() == false) { break; } - encryptedFlush = outboundBuffer.buildNetworkFlushOperation(); + encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation()); // Flush the write buffer to the channel flushEncryptedOperation(); } catch (IOException e) { @@ -115,10 +116,11 @@ public void flushChannel() throws IOException { // We are not ready for application writes, check if the driver has non-application writes. We // only want to continue producing new writes if the outbound write buffer is fully flushed. while (pendingChannelFlush() == false && sslDriver.needsNonApplicationWrite()) { - sslDriver.nonApplicationWrite(outboundBuffer); + sslDriver.nonApplicationWrite(); // If non-application writes were produced, flush the outbound write buffer. + SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer(); if (outboundBuffer.hasEncryptedBytesToFlush()) { - encryptedFlush = outboundBuffer.buildNetworkFlushOperation(); + encryptedFlushes.addFirst(outboundBuffer.buildNetworkFlushOperation()); flushEncryptedOperation(); } } @@ -127,14 +129,14 @@ public void flushChannel() throws IOException { private void flushEncryptedOperation() throws IOException { try { + FlushOperation encryptedFlush = encryptedFlushes.getFirst(); flushToChannel(encryptedFlush); if (encryptedFlush.isFullyFlushed()) { getSelector().executeListener(encryptedFlush.getListener(), null); - encryptedFlush = null; + encryptedFlushes.removeFirst(); } } catch (IOException e) { - getSelector().executeFailedListener(encryptedFlush.getListener(), e); - encryptedFlush = null; + getSelector().executeFailedListener(encryptedFlushes.removeFirst().getListener(), e); throw e; } } @@ -163,6 +165,11 @@ public int read() throws IOException { sslDriver.read(channelBuffer); handleReadBytes(); + // It is possible that a read call produced non-application bytes to flush + SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer(); + if (outboundBuffer.hasEncryptedBytesToFlush()) { + encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation()); + } return bytesRead; } @@ -190,10 +197,11 @@ public void closeFromSelector() throws IOException { getSelector().assertOnSelectorThread(); if (channel.isOpen()) { closeTimeoutCanceller.run(); - if (encryptedFlush != null) { + for (FlushOperation encryptedFlush : encryptedFlushes) { getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException()); } - IOUtils.close(super::closeFromSelector, outboundBuffer::close, sslDriver::close); + encryptedFlushes.clear(); + IOUtils.close(super::closeFromSelector, sslDriver::close); } } @@ -208,7 +216,7 @@ private void channelCloseTimeout() { } private boolean pendingChannelFlush() { - return encryptedFlush != null; + return encryptedFlushes.isEmpty() == false; } private static class CloseNotifyOperation implements WriteOperation { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java index 4dbf1d1f03fdf..bc112dd3a60ad 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java @@ -7,6 +7,7 @@ import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.Page; import org.elasticsearch.nio.utils.ExceptionsHelper; import javax.net.ssl.SSLEngine; @@ -32,14 +33,14 @@ * * Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be * called to determine if this driver needs to produce more data to advance the handshake or close process. - * If that method returns true, {@link #nonApplicationWrite(SSLOutboundBuffer)} should be called (and the + * If that method returns true, {@link #nonApplicationWrite()} should be called (and the * data produced then flushed to the channel) until no further non-application writes are needed. * * If no non-application writes are needed, {@link #readyForApplicationWrites()} can be called to determine * if the driver is ready to consume application data. (Note: It is possible that * {@link #readyForApplicationWrites()} and {@link #needsNonApplicationWrite()} can both return false if the * driver is waiting on non-application data from the peer.) If the driver indicates it is ready for - * application writes, {@link #write(FlushOperation, SSLOutboundBuffer)} can be called. This method will + * application writes, {@link #write(FlushOperation)} can be called. This method will * encrypt flush operation application data and place it in the outbound buffer for flushing to a channel. * * If you are ready to close the channel {@link #initiateClose()} should be called. After that is called, the @@ -53,6 +54,8 @@ public class SSLDriver implements AutoCloseable { private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {}); private final SSLEngine engine; + // TODO: When the bytes are actually recycled, we need to test that they are released on driver close + private final SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); private final boolean isClientMode; // This should only be accessed by the network thread associated with this channel, so nothing needs to // be volatile. @@ -107,6 +110,10 @@ public ByteBuffer getNetworkReadBuffer() { return networkReadBuffer; } + public SSLOutboundBuffer getOutboundBuffer() { + return outboundBuffer; + } + public void read(InboundChannelBuffer buffer) throws SSLException { Mode modePriorToRead; do { @@ -125,14 +132,14 @@ public boolean needsNonApplicationWrite() { return currentMode.needsNonApplicationWrite(); } - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { - return currentMode.write(applicationBytes, outboundBuffer); + public int write(FlushOperation applicationBytes) throws SSLException { + return currentMode.write(applicationBytes); } - public void nonApplicationWrite(SSLOutboundBuffer outboundBuffer) throws SSLException { + public void nonApplicationWrite() throws SSLException { assert currentMode.isApplication() == false : "Should not be called if driver is in application mode"; if (currentMode.isApplication() == false) { - currentMode.write(EMPTY_FLUSH_OPERATION, outboundBuffer); + currentMode.write(EMPTY_FLUSH_OPERATION); } else { throw new AssertionError("Attempted to non-application write from invalid mode: " + currentMode.modeName()); } @@ -148,6 +155,7 @@ public boolean isClosed() { @Override public void close() throws SSLException { + outboundBuffer.close(); ArrayList closingExceptions = new ArrayList<>(2); closingInternal(); CloseMode closeMode = (CloseMode) this.currentMode; @@ -276,7 +284,7 @@ private interface Mode { void read(InboundChannelBuffer buffer) throws SSLException; - int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException; + int write(FlushOperation applicationBytes) throws SSLException; boolean needsNonApplicationWrite(); @@ -296,10 +304,9 @@ private class HandshakeMode implements Mode { private void startHandshake() throws SSLException { handshakeStatus = engine.getHandshakeStatus(); - if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP && - handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) { + if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { try { - handshake(null); + handshake(); } catch (SSLException e) { closingInternal(); throw e; @@ -307,7 +314,7 @@ private void startHandshake() throws SSLException { } } - private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException { + private void handshake() throws SSLException { boolean continueHandshaking = true; while (continueHandshaking) { switch (handshakeStatus) { @@ -316,15 +323,7 @@ private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException { continueHandshaking = false; break; case NEED_WRAP: - if (outboundBuffer != null) { - handshakeStatus = wrap(outboundBuffer).getHandshakeStatus(); - // If we need NEED_TASK we should run the tasks immediately - if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_TASK) { - continueHandshaking = false; - } - } else { - continueHandshaking = false; - } + handshakeStatus = wrap(outboundBuffer).getHandshakeStatus(); break; case NEED_TASK: runTasks(); @@ -351,7 +350,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException { try { SSLEngineResult result = unwrap(buffer); handshakeStatus = result.getHandshakeStatus(); - handshake(null); + handshake(); // If we are done handshaking we should exit the handshake read continueUnwrap = result.bytesConsumed() > 0 && currentMode.isHandshake(); } catch (SSLException e) { @@ -362,9 +361,9 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { + public int write(FlushOperation applicationBytes) throws SSLException { try { - handshake(outboundBuffer); + handshake(); } catch (SSLException e) { closingInternal(); throw e; @@ -444,7 +443,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { + public int write(FlushOperation applicationBytes) throws SSLException { boolean continueWrap = true; int totalBytesProduced = 0; while (continueWrap && applicationBytes.isFullyFlushed() == false) { @@ -538,7 +537,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { + public int write(FlushOperation applicationBytes) throws SSLException { int bytesProduced = 0; if (engine.isOutboundDone() == false) { bytesProduced += wrap(outboundBuffer).bytesProduced(); @@ -549,6 +548,8 @@ public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuff closeInboundAndSwallowPeerDidNotCloseException(); } } + } else { + needToSendClose = false; } return bytesProduced; } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java index 893af2140b9b0..dcccb23f1f665 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.Page; import org.elasticsearch.nio.TaskScheduler; import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.test.ESTestCase; @@ -45,6 +46,7 @@ public class SSLChannelContextTests extends ESTestCase { private SocketChannel rawChannel; private SSLChannelContext context; private InboundChannelBuffer channelBuffer; + private SSLOutboundBuffer outboundBuffer; private NioSelector selector; private TaskScheduler nioTimer; private BiConsumer listener; @@ -67,6 +69,7 @@ public void init() { rawChannel = mock(SocketChannel.class); sslDriver = mock(SSLDriver.class); channelBuffer = InboundChannelBuffer.allocatingInstance(); + outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n), () -> {})); when(channel.getRawChannel()).thenReturn(rawChannel); exceptionHandler = mock(Consumer.class); context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); @@ -74,6 +77,7 @@ public void init() { when(selector.isOnCurrentThread()).thenReturn(true); when(selector.getTaskScheduler()).thenReturn(nioTimer); when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer); + when(sslDriver.getOutboundBuffer()).thenReturn(outboundBuffer); ByteBuffer buffer = ByteBuffer.allocate(1 << 14); when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> { buffer.clear(); @@ -183,7 +187,7 @@ public void testQueuedWritesAreIgnoredWhenNotReadyForAppWrites() { public void testPendingEncryptedFlushMeansWriteInterested() throws Exception { when(sslDriver.readyForApplicationWrites()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(true, false); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(); // Call will put bytes in buffer to flush context.flushChannel(); @@ -208,7 +212,7 @@ public void testNoNonAppWriteInterestInAppMode() { public void testFirstFlushMustFinishForWriteToContinue() throws Exception { when(sslDriver.readyForApplicationWrites()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(true); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(); // First call will put bytes in buffer to flush context.flushChannel(); @@ -217,30 +221,30 @@ public void testFirstFlushMustFinishForWriteToContinue() throws Exception { context.flushChannel(); assertTrue(context.readyForFlush()); - verify(sslDriver, times(1)).nonApplicationWrite(any(SSLOutboundBuffer.class)); + verify(sslDriver, times(1)).nonApplicationWrite(); } public void testNonAppWrites() throws Exception { when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, false); when(sslDriver.readyForApplicationWrites()).thenReturn(false); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(); when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(1); context.flushChannel(); - verify(sslDriver, times(2)).nonApplicationWrite(any(SSLOutboundBuffer.class)); + verify(sslDriver, times(2)).nonApplicationWrite(); verify(rawChannel, times(2)).write(same(selector.getIoBuffer())); } public void testNonAppWritesStopIfBufferNotFullyFlushed() throws Exception { when(sslDriver.needsNonApplicationWrite()).thenReturn(true); when(sslDriver.readyForApplicationWrites()).thenReturn(false); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(); when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(0); context.flushChannel(); - verify(sslDriver, times(1)).nonApplicationWrite(any(SSLOutboundBuffer.class)); + verify(sslDriver, times(1)).nonApplicationWrite(); verify(rawChannel, times(1)).write(same(selector.getIoBuffer())); } @@ -250,7 +254,7 @@ public void testQueuedWriteIsFlushedInFlushCall() throws Exception { context.queueWriteOperation(flushOperation); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(10, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(10, true)).when(sslDriver).write(eq(flushOperation)); when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(10); context.flushChannel(); @@ -266,7 +270,7 @@ public void testPartialFlush() throws IOException { context.queueWriteOperation(flushOperation); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation)); when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(4); context.flushChannel(); @@ -286,7 +290,7 @@ public void testMultipleWritesPartialFlushes() throws IOException { context.queueWriteOperation(flushOperation2); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(any(FlushOperation.class), any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(any(FlushOperation.class)); when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(5, 5, 2); context.flushChannel(); @@ -303,7 +307,7 @@ public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { IOException exception = new IOException(); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation)); when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception); expectThrows(IOException.class, () -> context.flushChannel()); @@ -314,7 +318,7 @@ public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { public void testWriteIOExceptionMeansChannelReadyToClose() throws Exception { when(sslDriver.readyForApplicationWrites()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(true); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(); context.flushChannel(); @@ -406,12 +410,6 @@ public void testRegisterInitiatesDriver() throws IOException { private Answer getWriteAnswer(int bytesToEncrypt, boolean isApp) { return invocationOnMock -> { - SSLOutboundBuffer outboundBuffer; - if (isApp) { - outboundBuffer = (SSLOutboundBuffer) invocationOnMock.getArguments()[1]; - } else { - outboundBuffer = (SSLOutboundBuffer) invocationOnMock.getArguments()[0]; - } ByteBuffer byteBuffer = outboundBuffer.nextWriteBuffer(bytesToEncrypt + 1); for (int i = 0; i < bytesToEncrypt; ++i) { byteBuffer.put((byte) i); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java index 4b86d3223b061..5003d029043e9 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java @@ -181,7 +181,7 @@ public void testCloseDuringHandshakeJDK11() throws Exception { clientDriver.init(); serverDriver.init(); - assertTrue(clientDriver.needsNonApplicationWrite()); + assertTrue(clientDriver.getOutboundBuffer().hasEncryptedBytesToFlush()); assertFalse(serverDriver.needsNonApplicationWrite()); sendHandshakeMessages(clientDriver, serverDriver); sendHandshakeMessages(serverDriver, clientDriver); @@ -296,12 +296,12 @@ private void normalClose(SSLDriver sendDriver, SSLDriver receiveDriver) throws I } private void sendNonApplicationWrites(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException { - SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); + SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer(); while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) { if (outboundBuffer.hasEncryptedBytesToFlush()) { sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); } else { - sendDriver.nonApplicationWrite(outboundBuffer); + sendDriver.nonApplicationWrite(); } } } @@ -316,7 +316,7 @@ private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean i serverDriver.init(); } - assertTrue(clientDriver.needsNonApplicationWrite()); + assertTrue(clientDriver.getOutboundBuffer().hasEncryptedBytesToFlush()); assertFalse(serverDriver.needsNonApplicationWrite()); sendHandshakeMessages(clientDriver, serverDriver); @@ -331,7 +331,6 @@ private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean i sendHandshakeMessages(clientDriver, serverDriver); assertTrue(clientDriver.isHandshaking()); - assertTrue(serverDriver.isHandshaking()); sendHandshakeMessages(serverDriver, clientDriver); @@ -340,20 +339,20 @@ private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean i } private void sendHandshakeMessages(SSLDriver sendDriver, SSLDriver receiveDriver) throws IOException { - assertTrue(sendDriver.needsNonApplicationWrite()); + assertTrue(sendDriver.needsNonApplicationWrite() || sendDriver.getOutboundBuffer().hasEncryptedBytesToFlush()); - SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); + SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer(); while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) { if (outboundBuffer.hasEncryptedBytesToFlush()) { sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); receiveDriver.read(genericBuffer); } else { - sendDriver.nonApplicationWrite(outboundBuffer); + sendDriver.nonApplicationWrite(); } } if (receiveDriver.isHandshaking()) { - assertTrue(receiveDriver.needsNonApplicationWrite()); + assertTrue(receiveDriver.needsNonApplicationWrite() || receiveDriver.getOutboundBuffer().hasEncryptedBytesToFlush()); } } @@ -361,12 +360,12 @@ private void sendAppData(SSLDriver sendDriver, SSLDriver receiveDriver, ByteBuff assertFalse(sendDriver.needsNonApplicationWrite()); int bytesToEncrypt = Arrays.stream(message).mapToInt(Buffer::remaining).sum(); - SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); + SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer(); FlushOperation flushOperation = new FlushOperation(message, (r, l) -> {}); int bytesEncrypted = 0; while (bytesToEncrypt > bytesEncrypted) { - bytesEncrypted += sendDriver.write(flushOperation, outboundBuffer); + bytesEncrypted += sendDriver.write(flushOperation); sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); } }