Skip to content

Commit

Permalink
Handle WRAP ops during SSL read (elastic#41611)
Browse files Browse the repository at this point in the history
It is possible that a WRAP operation can occur while decrypting
handshake data in TLS 1.3. The SSLDriver does not currently handle this
well as it does not have access to the outbound buffer during read call.
This commit moves the buffer into the Driver to fix this issue. Data
wrapped during a read call will be queued for writing after the read
call is complete.
  • Loading branch information
Tim-Brooks authored and Gurkan Kaymak committed May 27, 2019
1 parent 4483e62 commit da6db68
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.ReadWriteHandler;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.WriteOperation;

import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.LinkedList;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
Expand All @@ -37,8 +36,7 @@ public final class SSLChannelContext extends SocketChannelContext {
private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {};

private final SSLDriver sslDriver;
private final SSLOutboundBuffer outboundBuffer;
private FlushOperation encryptedFlush;
private final LinkedList<FlushOperation> encryptedFlushes = new LinkedList<>();
private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER;

SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
Expand All @@ -51,14 +49,16 @@ public final class SSLChannelContext extends SocketChannelContext {
Predicate<NioSocketChannel> allowChannelPredicate) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
this.sslDriver = sslDriver;
// TODO: When the bytes are actually recycled, we need to test that they are released on context close
this.outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
}

@Override
public void register() throws IOException {
super.register();
sslDriver.init();
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
if (outboundBuffer.hasEncryptedBytesToFlush()) {
encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation());
}
}

@Override
Expand Down Expand Up @@ -98,11 +98,12 @@ public void flushChannel() throws IOException {
try {
// Attempt to encrypt application write data. The encrypted data ends up in the
// outbound write buffer.
sslDriver.write(unencryptedFlush, outboundBuffer);
sslDriver.write(unencryptedFlush);
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
if (outboundBuffer.hasEncryptedBytesToFlush() == false) {
break;
}
encryptedFlush = outboundBuffer.buildNetworkFlushOperation();
encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation());
// Flush the write buffer to the channel
flushEncryptedOperation();
} catch (IOException e) {
Expand All @@ -115,10 +116,11 @@ public void flushChannel() throws IOException {
// We are not ready for application writes, check if the driver has non-application writes. We
// only want to continue producing new writes if the outbound write buffer is fully flushed.
while (pendingChannelFlush() == false && sslDriver.needsNonApplicationWrite()) {
sslDriver.nonApplicationWrite(outboundBuffer);
sslDriver.nonApplicationWrite();
// If non-application writes were produced, flush the outbound write buffer.
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
if (outboundBuffer.hasEncryptedBytesToFlush()) {
encryptedFlush = outboundBuffer.buildNetworkFlushOperation();
encryptedFlushes.addFirst(outboundBuffer.buildNetworkFlushOperation());
flushEncryptedOperation();
}
}
Expand All @@ -127,14 +129,14 @@ public void flushChannel() throws IOException {

private void flushEncryptedOperation() throws IOException {
try {
FlushOperation encryptedFlush = encryptedFlushes.getFirst();
flushToChannel(encryptedFlush);
if (encryptedFlush.isFullyFlushed()) {
getSelector().executeListener(encryptedFlush.getListener(), null);
encryptedFlush = null;
encryptedFlushes.removeFirst();
}
} catch (IOException e) {
getSelector().executeFailedListener(encryptedFlush.getListener(), e);
encryptedFlush = null;
getSelector().executeFailedListener(encryptedFlushes.removeFirst().getListener(), e);
throw e;
}
}
Expand Down Expand Up @@ -163,6 +165,11 @@ public int read() throws IOException {
sslDriver.read(channelBuffer);

handleReadBytes();
// It is possible that a read call produced non-application bytes to flush
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
if (outboundBuffer.hasEncryptedBytesToFlush()) {
encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation());
}

return bytesRead;
}
Expand Down Expand Up @@ -190,10 +197,11 @@ public void closeFromSelector() throws IOException {
getSelector().assertOnSelectorThread();
if (channel.isOpen()) {
closeTimeoutCanceller.run();
if (encryptedFlush != null) {
for (FlushOperation encryptedFlush : encryptedFlushes) {
getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException());
}
IOUtils.close(super::closeFromSelector, outboundBuffer::close, sslDriver::close);
encryptedFlushes.clear();
IOUtils.close(super::closeFromSelector, sslDriver::close);
}
}

Expand All @@ -208,7 +216,7 @@ private void channelCloseTimeout() {
}

private boolean pendingChannelFlush() {
return encryptedFlush != null;
return encryptedFlushes.isEmpty() == false;
}

private static class CloseNotifyOperation implements WriteOperation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.Page;
import org.elasticsearch.nio.utils.ExceptionsHelper;

import javax.net.ssl.SSLEngine;
Expand All @@ -32,14 +33,14 @@
*
* Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be
* called to determine if this driver needs to produce more data to advance the handshake or close process.
* If that method returns true, {@link #nonApplicationWrite(SSLOutboundBuffer)} should be called (and the
* If that method returns true, {@link #nonApplicationWrite()} should be called (and the
* data produced then flushed to the channel) until no further non-application writes are needed.
*
* If no non-application writes are needed, {@link #readyForApplicationWrites()} can be called to determine
* if the driver is ready to consume application data. (Note: It is possible that
* {@link #readyForApplicationWrites()} and {@link #needsNonApplicationWrite()} can both return false if the
* driver is waiting on non-application data from the peer.) If the driver indicates it is ready for
* application writes, {@link #write(FlushOperation, SSLOutboundBuffer)} can be called. This method will
* application writes, {@link #write(FlushOperation)} can be called. This method will
* encrypt flush operation application data and place it in the outbound buffer for flushing to a channel.
*
* If you are ready to close the channel {@link #initiateClose()} should be called. After that is called, the
Expand All @@ -53,6 +54,8 @@ public class SSLDriver implements AutoCloseable {
private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {});

private final SSLEngine engine;
// TODO: When the bytes are actually recycled, we need to test that they are released on driver close
private final SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
private final boolean isClientMode;
// This should only be accessed by the network thread associated with this channel, so nothing needs to
// be volatile.
Expand Down Expand Up @@ -107,6 +110,10 @@ public ByteBuffer getNetworkReadBuffer() {
return networkReadBuffer;
}

public SSLOutboundBuffer getOutboundBuffer() {
return outboundBuffer;
}

public void read(InboundChannelBuffer buffer) throws SSLException {
Mode modePriorToRead;
do {
Expand All @@ -125,14 +132,14 @@ public boolean needsNonApplicationWrite() {
return currentMode.needsNonApplicationWrite();
}

public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
return currentMode.write(applicationBytes, outboundBuffer);
public int write(FlushOperation applicationBytes) throws SSLException {
return currentMode.write(applicationBytes);
}

public void nonApplicationWrite(SSLOutboundBuffer outboundBuffer) throws SSLException {
public void nonApplicationWrite() throws SSLException {
assert currentMode.isApplication() == false : "Should not be called if driver is in application mode";
if (currentMode.isApplication() == false) {
currentMode.write(EMPTY_FLUSH_OPERATION, outboundBuffer);
currentMode.write(EMPTY_FLUSH_OPERATION);
} else {
throw new AssertionError("Attempted to non-application write from invalid mode: " + currentMode.modeName());
}
Expand All @@ -148,6 +155,7 @@ public boolean isClosed() {

@Override
public void close() throws SSLException {
outboundBuffer.close();
ArrayList<SSLException> closingExceptions = new ArrayList<>(2);
closingInternal();
CloseMode closeMode = (CloseMode) this.currentMode;
Expand Down Expand Up @@ -276,7 +284,7 @@ private interface Mode {

void read(InboundChannelBuffer buffer) throws SSLException;

int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException;
int write(FlushOperation applicationBytes) throws SSLException;

boolean needsNonApplicationWrite();

Expand All @@ -296,18 +304,17 @@ private class HandshakeMode implements Mode {

private void startHandshake() throws SSLException {
handshakeStatus = engine.getHandshakeStatus();
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) {
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
try {
handshake(null);
handshake();
} catch (SSLException e) {
closingInternal();
throw e;
}
}
}

private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException {
private void handshake() throws SSLException {
boolean continueHandshaking = true;
while (continueHandshaking) {
switch (handshakeStatus) {
Expand All @@ -316,15 +323,7 @@ private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException {
continueHandshaking = false;
break;
case NEED_WRAP:
if (outboundBuffer != null) {
handshakeStatus = wrap(outboundBuffer).getHandshakeStatus();
// If we need NEED_TASK we should run the tasks immediately
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_TASK) {
continueHandshaking = false;
}
} else {
continueHandshaking = false;
}
handshakeStatus = wrap(outboundBuffer).getHandshakeStatus();
break;
case NEED_TASK:
runTasks();
Expand All @@ -351,7 +350,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException {
try {
SSLEngineResult result = unwrap(buffer);
handshakeStatus = result.getHandshakeStatus();
handshake(null);
handshake();
// If we are done handshaking we should exit the handshake read
continueUnwrap = result.bytesConsumed() > 0 && currentMode.isHandshake();
} catch (SSLException e) {
Expand All @@ -362,9 +361,9 @@ public void read(InboundChannelBuffer buffer) throws SSLException {
}

@Override
public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
public int write(FlushOperation applicationBytes) throws SSLException {
try {
handshake(outboundBuffer);
handshake();
} catch (SSLException e) {
closingInternal();
throw e;
Expand Down Expand Up @@ -444,7 +443,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException {
}

@Override
public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
public int write(FlushOperation applicationBytes) throws SSLException {
boolean continueWrap = true;
int totalBytesProduced = 0;
while (continueWrap && applicationBytes.isFullyFlushed() == false) {
Expand Down Expand Up @@ -538,7 +537,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException {
}

@Override
public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
public int write(FlushOperation applicationBytes) throws SSLException {
int bytesProduced = 0;
if (engine.isOutboundDone() == false) {
bytesProduced += wrap(outboundBuffer).bytesProduced();
Expand All @@ -549,6 +548,8 @@ public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuff
closeInboundAndSwallowPeerDidNotCloseException();
}
}
} else {
needToSendClose = false;
}
return bytesProduced;
}
Expand Down
Loading

0 comments on commit da6db68

Please sign in to comment.