diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index 2c00dd7092950..9372cb1ec54fc 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -9,17 +9,16 @@ import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.Page; import org.elasticsearch.nio.ReadWriteHandler; import org.elasticsearch.nio.SocketChannelContext; -import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.WriteOperation; import javax.net.ssl.SSLEngine; import java.io.IOException; -import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; +import java.util.LinkedList; import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -37,8 +36,7 @@ public final class SSLChannelContext extends SocketChannelContext { private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {}; private final SSLDriver sslDriver; - private final SSLOutboundBuffer outboundBuffer; - private FlushOperation encryptedFlush; + private final LinkedList encryptedFlushes = new LinkedList<>(); private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER; SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, @@ -51,14 +49,16 @@ public final class SSLChannelContext extends SocketChannelContext { Predicate allowChannelPredicate) { super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate); this.sslDriver = sslDriver; - // TODO: When the bytes are actually recycled, we need to test that they are released on context close - this.outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); } @Override public void register() throws IOException { super.register(); sslDriver.init(); + SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer(); + if (outboundBuffer.hasEncryptedBytesToFlush()) { + encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation()); + } } @Override @@ -98,11 +98,12 @@ public void flushChannel() throws IOException { try { // Attempt to encrypt application write data. The encrypted data ends up in the // outbound write buffer. - sslDriver.write(unencryptedFlush, outboundBuffer); + sslDriver.write(unencryptedFlush); + SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer(); if (outboundBuffer.hasEncryptedBytesToFlush() == false) { break; } - encryptedFlush = outboundBuffer.buildNetworkFlushOperation(); + encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation()); // Flush the write buffer to the channel flushEncryptedOperation(); } catch (IOException e) { @@ -115,10 +116,11 @@ public void flushChannel() throws IOException { // We are not ready for application writes, check if the driver has non-application writes. We // only want to continue producing new writes if the outbound write buffer is fully flushed. while (pendingChannelFlush() == false && sslDriver.needsNonApplicationWrite()) { - sslDriver.nonApplicationWrite(outboundBuffer); + sslDriver.nonApplicationWrite(); // If non-application writes were produced, flush the outbound write buffer. + SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer(); if (outboundBuffer.hasEncryptedBytesToFlush()) { - encryptedFlush = outboundBuffer.buildNetworkFlushOperation(); + encryptedFlushes.addFirst(outboundBuffer.buildNetworkFlushOperation()); flushEncryptedOperation(); } } @@ -127,14 +129,14 @@ public void flushChannel() throws IOException { private void flushEncryptedOperation() throws IOException { try { + FlushOperation encryptedFlush = encryptedFlushes.getFirst(); flushToChannel(encryptedFlush); if (encryptedFlush.isFullyFlushed()) { getSelector().executeListener(encryptedFlush.getListener(), null); - encryptedFlush = null; + encryptedFlushes.removeFirst(); } } catch (IOException e) { - getSelector().executeFailedListener(encryptedFlush.getListener(), e); - encryptedFlush = null; + getSelector().executeFailedListener(encryptedFlushes.removeFirst().getListener(), e); throw e; } } @@ -163,6 +165,11 @@ public int read() throws IOException { sslDriver.read(channelBuffer); handleReadBytes(); + // It is possible that a read call produced non-application bytes to flush + SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer(); + if (outboundBuffer.hasEncryptedBytesToFlush()) { + encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation()); + } return bytesRead; } @@ -190,10 +197,11 @@ public void closeFromSelector() throws IOException { getSelector().assertOnSelectorThread(); if (channel.isOpen()) { closeTimeoutCanceller.run(); - if (encryptedFlush != null) { + for (FlushOperation encryptedFlush : encryptedFlushes) { getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException()); } - IOUtils.close(super::closeFromSelector, outboundBuffer::close, sslDriver::close); + encryptedFlushes.clear(); + IOUtils.close(super::closeFromSelector, sslDriver::close); } } @@ -208,7 +216,7 @@ private void channelCloseTimeout() { } private boolean pendingChannelFlush() { - return encryptedFlush != null; + return encryptedFlushes.isEmpty() == false; } private static class CloseNotifyOperation implements WriteOperation { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java index 4dbf1d1f03fdf..bc112dd3a60ad 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java @@ -7,6 +7,7 @@ import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.Page; import org.elasticsearch.nio.utils.ExceptionsHelper; import javax.net.ssl.SSLEngine; @@ -32,14 +33,14 @@ * * Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be * called to determine if this driver needs to produce more data to advance the handshake or close process. - * If that method returns true, {@link #nonApplicationWrite(SSLOutboundBuffer)} should be called (and the + * If that method returns true, {@link #nonApplicationWrite()} should be called (and the * data produced then flushed to the channel) until no further non-application writes are needed. * * If no non-application writes are needed, {@link #readyForApplicationWrites()} can be called to determine * if the driver is ready to consume application data. (Note: It is possible that * {@link #readyForApplicationWrites()} and {@link #needsNonApplicationWrite()} can both return false if the * driver is waiting on non-application data from the peer.) If the driver indicates it is ready for - * application writes, {@link #write(FlushOperation, SSLOutboundBuffer)} can be called. This method will + * application writes, {@link #write(FlushOperation)} can be called. This method will * encrypt flush operation application data and place it in the outbound buffer for flushing to a channel. * * If you are ready to close the channel {@link #initiateClose()} should be called. After that is called, the @@ -53,6 +54,8 @@ public class SSLDriver implements AutoCloseable { private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {}); private final SSLEngine engine; + // TODO: When the bytes are actually recycled, we need to test that they are released on driver close + private final SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); private final boolean isClientMode; // This should only be accessed by the network thread associated with this channel, so nothing needs to // be volatile. @@ -107,6 +110,10 @@ public ByteBuffer getNetworkReadBuffer() { return networkReadBuffer; } + public SSLOutboundBuffer getOutboundBuffer() { + return outboundBuffer; + } + public void read(InboundChannelBuffer buffer) throws SSLException { Mode modePriorToRead; do { @@ -125,14 +132,14 @@ public boolean needsNonApplicationWrite() { return currentMode.needsNonApplicationWrite(); } - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { - return currentMode.write(applicationBytes, outboundBuffer); + public int write(FlushOperation applicationBytes) throws SSLException { + return currentMode.write(applicationBytes); } - public void nonApplicationWrite(SSLOutboundBuffer outboundBuffer) throws SSLException { + public void nonApplicationWrite() throws SSLException { assert currentMode.isApplication() == false : "Should not be called if driver is in application mode"; if (currentMode.isApplication() == false) { - currentMode.write(EMPTY_FLUSH_OPERATION, outboundBuffer); + currentMode.write(EMPTY_FLUSH_OPERATION); } else { throw new AssertionError("Attempted to non-application write from invalid mode: " + currentMode.modeName()); } @@ -148,6 +155,7 @@ public boolean isClosed() { @Override public void close() throws SSLException { + outboundBuffer.close(); ArrayList closingExceptions = new ArrayList<>(2); closingInternal(); CloseMode closeMode = (CloseMode) this.currentMode; @@ -276,7 +284,7 @@ private interface Mode { void read(InboundChannelBuffer buffer) throws SSLException; - int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException; + int write(FlushOperation applicationBytes) throws SSLException; boolean needsNonApplicationWrite(); @@ -296,10 +304,9 @@ private class HandshakeMode implements Mode { private void startHandshake() throws SSLException { handshakeStatus = engine.getHandshakeStatus(); - if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP && - handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) { + if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { try { - handshake(null); + handshake(); } catch (SSLException e) { closingInternal(); throw e; @@ -307,7 +314,7 @@ private void startHandshake() throws SSLException { } } - private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException { + private void handshake() throws SSLException { boolean continueHandshaking = true; while (continueHandshaking) { switch (handshakeStatus) { @@ -316,15 +323,7 @@ private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException { continueHandshaking = false; break; case NEED_WRAP: - if (outboundBuffer != null) { - handshakeStatus = wrap(outboundBuffer).getHandshakeStatus(); - // If we need NEED_TASK we should run the tasks immediately - if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_TASK) { - continueHandshaking = false; - } - } else { - continueHandshaking = false; - } + handshakeStatus = wrap(outboundBuffer).getHandshakeStatus(); break; case NEED_TASK: runTasks(); @@ -351,7 +350,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException { try { SSLEngineResult result = unwrap(buffer); handshakeStatus = result.getHandshakeStatus(); - handshake(null); + handshake(); // If we are done handshaking we should exit the handshake read continueUnwrap = result.bytesConsumed() > 0 && currentMode.isHandshake(); } catch (SSLException e) { @@ -362,9 +361,9 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { + public int write(FlushOperation applicationBytes) throws SSLException { try { - handshake(outboundBuffer); + handshake(); } catch (SSLException e) { closingInternal(); throw e; @@ -444,7 +443,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { + public int write(FlushOperation applicationBytes) throws SSLException { boolean continueWrap = true; int totalBytesProduced = 0; while (continueWrap && applicationBytes.isFullyFlushed() == false) { @@ -538,7 +537,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException { } @Override - public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException { + public int write(FlushOperation applicationBytes) throws SSLException { int bytesProduced = 0; if (engine.isOutboundDone() == false) { bytesProduced += wrap(outboundBuffer).bytesProduced(); @@ -549,6 +548,8 @@ public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuff closeInboundAndSwallowPeerDidNotCloseException(); } } + } else { + needToSendClose = false; } return bytesProduced; } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java index 893af2140b9b0..dcccb23f1f665 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.Page; import org.elasticsearch.nio.TaskScheduler; import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.test.ESTestCase; @@ -45,6 +46,7 @@ public class SSLChannelContextTests extends ESTestCase { private SocketChannel rawChannel; private SSLChannelContext context; private InboundChannelBuffer channelBuffer; + private SSLOutboundBuffer outboundBuffer; private NioSelector selector; private TaskScheduler nioTimer; private BiConsumer listener; @@ -67,6 +69,7 @@ public void init() { rawChannel = mock(SocketChannel.class); sslDriver = mock(SSLDriver.class); channelBuffer = InboundChannelBuffer.allocatingInstance(); + outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n), () -> {})); when(channel.getRawChannel()).thenReturn(rawChannel); exceptionHandler = mock(Consumer.class); context = new SSLChannelContext(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer); @@ -74,6 +77,7 @@ public void init() { when(selector.isOnCurrentThread()).thenReturn(true); when(selector.getTaskScheduler()).thenReturn(nioTimer); when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer); + when(sslDriver.getOutboundBuffer()).thenReturn(outboundBuffer); ByteBuffer buffer = ByteBuffer.allocate(1 << 14); when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> { buffer.clear(); @@ -183,7 +187,7 @@ public void testQueuedWritesAreIgnoredWhenNotReadyForAppWrites() { public void testPendingEncryptedFlushMeansWriteInterested() throws Exception { when(sslDriver.readyForApplicationWrites()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(true, false); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(); // Call will put bytes in buffer to flush context.flushChannel(); @@ -208,7 +212,7 @@ public void testNoNonAppWriteInterestInAppMode() { public void testFirstFlushMustFinishForWriteToContinue() throws Exception { when(sslDriver.readyForApplicationWrites()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(true); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(); // First call will put bytes in buffer to flush context.flushChannel(); @@ -217,30 +221,30 @@ public void testFirstFlushMustFinishForWriteToContinue() throws Exception { context.flushChannel(); assertTrue(context.readyForFlush()); - verify(sslDriver, times(1)).nonApplicationWrite(any(SSLOutboundBuffer.class)); + verify(sslDriver, times(1)).nonApplicationWrite(); } public void testNonAppWrites() throws Exception { when(sslDriver.needsNonApplicationWrite()).thenReturn(true, true, false); when(sslDriver.readyForApplicationWrites()).thenReturn(false); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(); when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(1); context.flushChannel(); - verify(sslDriver, times(2)).nonApplicationWrite(any(SSLOutboundBuffer.class)); + verify(sslDriver, times(2)).nonApplicationWrite(); verify(rawChannel, times(2)).write(same(selector.getIoBuffer())); } public void testNonAppWritesStopIfBufferNotFullyFlushed() throws Exception { when(sslDriver.needsNonApplicationWrite()).thenReturn(true); when(sslDriver.readyForApplicationWrites()).thenReturn(false); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(); when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(0); context.flushChannel(); - verify(sslDriver, times(1)).nonApplicationWrite(any(SSLOutboundBuffer.class)); + verify(sslDriver, times(1)).nonApplicationWrite(); verify(rawChannel, times(1)).write(same(selector.getIoBuffer())); } @@ -250,7 +254,7 @@ public void testQueuedWriteIsFlushedInFlushCall() throws Exception { context.queueWriteOperation(flushOperation); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(10, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(10, true)).when(sslDriver).write(eq(flushOperation)); when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(10); context.flushChannel(); @@ -266,7 +270,7 @@ public void testPartialFlush() throws IOException { context.queueWriteOperation(flushOperation); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation)); when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(4); context.flushChannel(); @@ -286,7 +290,7 @@ public void testMultipleWritesPartialFlushes() throws IOException { context.queueWriteOperation(flushOperation2); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(any(FlushOperation.class), any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(any(FlushOperation.class)); when(rawChannel.write(same(selector.getIoBuffer()))).thenReturn(5, 5, 2); context.flushChannel(); @@ -303,7 +307,7 @@ public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { IOException exception = new IOException(); when(sslDriver.readyForApplicationWrites()).thenReturn(true); - doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation), any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(5, true)).when(sslDriver).write(eq(flushOperation)); when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception); expectThrows(IOException.class, () -> context.flushChannel()); @@ -314,7 +318,7 @@ public void testWhenIOExceptionThrownListenerIsCalled() throws IOException { public void testWriteIOExceptionMeansChannelReadyToClose() throws Exception { when(sslDriver.readyForApplicationWrites()).thenReturn(false); when(sslDriver.needsNonApplicationWrite()).thenReturn(true); - doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(any(SSLOutboundBuffer.class)); + doAnswer(getWriteAnswer(1, false)).when(sslDriver).nonApplicationWrite(); context.flushChannel(); @@ -406,12 +410,6 @@ public void testRegisterInitiatesDriver() throws IOException { private Answer getWriteAnswer(int bytesToEncrypt, boolean isApp) { return invocationOnMock -> { - SSLOutboundBuffer outboundBuffer; - if (isApp) { - outboundBuffer = (SSLOutboundBuffer) invocationOnMock.getArguments()[1]; - } else { - outboundBuffer = (SSLOutboundBuffer) invocationOnMock.getArguments()[0]; - } ByteBuffer byteBuffer = outboundBuffer.nextWriteBuffer(bytesToEncrypt + 1); for (int i = 0; i < bytesToEncrypt; ++i) { byteBuffer.put((byte) i); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java index 4b86d3223b061..5003d029043e9 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java @@ -181,7 +181,7 @@ public void testCloseDuringHandshakeJDK11() throws Exception { clientDriver.init(); serverDriver.init(); - assertTrue(clientDriver.needsNonApplicationWrite()); + assertTrue(clientDriver.getOutboundBuffer().hasEncryptedBytesToFlush()); assertFalse(serverDriver.needsNonApplicationWrite()); sendHandshakeMessages(clientDriver, serverDriver); sendHandshakeMessages(serverDriver, clientDriver); @@ -296,12 +296,12 @@ private void normalClose(SSLDriver sendDriver, SSLDriver receiveDriver) throws I } private void sendNonApplicationWrites(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException { - SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); + SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer(); while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) { if (outboundBuffer.hasEncryptedBytesToFlush()) { sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); } else { - sendDriver.nonApplicationWrite(outboundBuffer); + sendDriver.nonApplicationWrite(); } } } @@ -316,7 +316,7 @@ private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean i serverDriver.init(); } - assertTrue(clientDriver.needsNonApplicationWrite()); + assertTrue(clientDriver.getOutboundBuffer().hasEncryptedBytesToFlush()); assertFalse(serverDriver.needsNonApplicationWrite()); sendHandshakeMessages(clientDriver, serverDriver); @@ -331,7 +331,6 @@ private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean i sendHandshakeMessages(clientDriver, serverDriver); assertTrue(clientDriver.isHandshaking()); - assertTrue(serverDriver.isHandshaking()); sendHandshakeMessages(serverDriver, clientDriver); @@ -340,20 +339,20 @@ private void handshake(SSLDriver clientDriver, SSLDriver serverDriver, boolean i } private void sendHandshakeMessages(SSLDriver sendDriver, SSLDriver receiveDriver) throws IOException { - assertTrue(sendDriver.needsNonApplicationWrite()); + assertTrue(sendDriver.needsNonApplicationWrite() || sendDriver.getOutboundBuffer().hasEncryptedBytesToFlush()); - SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); + SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer(); while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) { if (outboundBuffer.hasEncryptedBytesToFlush()) { sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); receiveDriver.read(genericBuffer); } else { - sendDriver.nonApplicationWrite(outboundBuffer); + sendDriver.nonApplicationWrite(); } } if (receiveDriver.isHandshaking()) { - assertTrue(receiveDriver.needsNonApplicationWrite()); + assertTrue(receiveDriver.needsNonApplicationWrite() || receiveDriver.getOutboundBuffer().hasEncryptedBytesToFlush()); } } @@ -361,12 +360,12 @@ private void sendAppData(SSLDriver sendDriver, SSLDriver receiveDriver, ByteBuff assertFalse(sendDriver.needsNonApplicationWrite()); int bytesToEncrypt = Arrays.stream(message).mapToInt(Buffer::remaining).sum(); - SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n))); + SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer(); FlushOperation flushOperation = new FlushOperation(message, (r, l) -> {}); int bytesEncrypted = 0; while (bytesToEncrypt > bytesEncrypted) { - bytesEncrypted += sendDriver.write(flushOperation, outboundBuffer); + bytesEncrypted += sendDriver.write(flushOperation); sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver); } }