Skip to content

Commit

Permalink
fix: copy could return wrong error message (#252)
Browse files Browse the repository at this point in the history
* fix: copy could return wrong error message

COPY statements could return wrong error messages or wrongfully detect
an invalid message stream because of a race condition in the error
handling of the COPY protocol.

* chore: cleanup and add comments

* chore: re-active commented code

* fix: stop batch processing after error in flush
  • Loading branch information
olavloite authored Jul 4, 2022
1 parent cd34476 commit 6ad4aa2
Show file tree
Hide file tree
Showing 25 changed files with 556 additions and 248 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage;
import com.google.spanner.admin.database.v1.InstanceName;
import com.google.spanner.v1.DatabaseName;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
Expand Down Expand Up @@ -79,7 +76,6 @@
@InternalApi
public class ConnectionHandler extends Thread {
private static final Logger logger = Logger.getLogger(ConnectionHandler.class.getName());
private static final int SOCKET_BUFFER_SIZE = 1 << 16;
private static final AtomicLong CONNECTION_HANDLER_ID_GENERATOR = new AtomicLong(0L);
private static final String CHANNEL_PROVIDER_PROPERTY = "CHANNEL_PROVIDER";

Expand Down Expand Up @@ -221,22 +217,17 @@ public void run() {
"Connection handler with ID %s starting for client %s",
getName(), socket.getInetAddress().getHostAddress()));

try (DataInputStream input =
new DataInputStream(
new BufferedInputStream(this.socket.getInputStream(), SOCKET_BUFFER_SIZE));
DataOutputStream output =
new DataOutputStream(
new BufferedOutputStream(this.socket.getOutputStream(), SOCKET_BUFFER_SIZE))) {
try (ConnectionMetadata connectionMetadata =
new ConnectionMetadata(this.socket.getInputStream(), this.socket.getOutputStream())) {
this.connectionMetadata = connectionMetadata;
if (!server.getOptions().disableLocalhostCheck()
&& !this.socket.getInetAddress().isAnyLocalAddress()
&& !this.socket.getInetAddress().isLoopbackAddress()) {
handleError(
output, new IllegalAccessException("This proxy may only be accessed from localhost."));
handleError(new IllegalAccessException("This proxy may only be accessed from localhost."));
return;
}

try {
this.connectionMetadata = new ConnectionMetadata(input, output);
this.message = this.server.recordMessage(BootstrapMessage.create(this));
this.message.send();
while (this.status == ConnectionStatus.UNAUTHENTICATED) {
Expand All @@ -255,7 +246,7 @@ public void run() {
handleMessages();
}
} catch (Exception e) {
this.handleError(output, e);
this.handleError(e);
}
} catch (Exception e) {
logger.log(
Expand Down Expand Up @@ -296,14 +287,14 @@ public void handleMessages() throws Exception {
message.nextHandler();
message.send();
} catch (IllegalArgumentException | IllegalStateException | EOFException fatalException) {
this.handleError(getConnectionMetadata().getOutputStream(), fatalException);
if (this.status == ConnectionStatus.COPY_IN) {
this.status = ConnectionStatus.COPY_FAILED;
} else {
this.handleError(fatalException);
// Only terminate the connection if we are not in COPY_IN mode. In COPY_IN mode the mode will
// switch to normal mode in these cases.
if (this.status != ConnectionStatus.COPY_IN) {
terminate();
}
} catch (Exception e) {
this.handleError(getConnectionMetadata().getOutputStream(), e);
this.handleError(e);
}
}

Expand Down Expand Up @@ -342,12 +333,13 @@ void terminate() {
* @param exception The exception to be related.
* @throws IOException if there is some issue in the sending of the error messages.
*/
private void handleError(DataOutputStream output, Exception exception) throws Exception {
private void handleError(Exception exception) throws Exception {
logger.log(
Level.WARNING,
exception,
() ->
String.format("Exception on connection handler with ID %s: %s", getName(), exception));
DataOutputStream output = getConnectionMetadata().peekOutputStream();
if (this.status == ConnectionStatus.TERMINATED) {
new ErrorResponse(output, exception, ErrorResponse.State.InternalError, Severity.FATAL)
.send();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,71 @@
package com.google.cloud.spanner.pgadapter.metadata;

import com.google.api.core.InternalApi;
import com.google.cloud.Tuple;
import com.google.common.base.Preconditions;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Stack;

@InternalApi
public class ConnectionMetadata {
public class ConnectionMetadata implements AutoCloseable {
private static final int SOCKET_BUFFER_SIZE = 1 << 16;

private final DataInputStream inputStream;
private final DataOutputStream outputStream;
private final InputStream rawInputStream;
private final OutputStream rawOutputStream;
private final Stack<DataInputStream> inputStream = new Stack<>();
private final Stack<DataOutputStream> outputStream = new Stack<>();

public ConnectionMetadata(DataInputStream input, DataOutputStream output) {
this.inputStream = input;
this.outputStream = output;
/**
* Creates a {@link DataInputStream} and a {@link DataOutputStream} from the given raw streams and
* pushes these as the current streams to use for communication for a connection.
*/
public ConnectionMetadata(InputStream rawInputStream, OutputStream rawOutputStream) {
this.rawInputStream = Preconditions.checkNotNull(rawInputStream);
this.rawOutputStream = Preconditions.checkNotNull(rawOutputStream);
pushNewStreams();
}

public DataInputStream getInputStream() {
return inputStream;
@Override
public void close() throws Exception {
this.rawInputStream.close();
this.rawOutputStream.close();
}

public DataOutputStream getOutputStream() {
return outputStream;
/**
* Creates a new buffered {@link DataInputStream} and {@link DataOutputStream} tuple to use for
* the connection. This is used for the COPY sub-protocol to prevent mixing the buffers used for
* the normal protocol with the COPY protocol, as that would cause multithreaded access to those
* buffers.
*/
public Tuple<DataInputStream, DataOutputStream> pushNewStreams() {
return Tuple.of(
this.inputStream.push(
new DataInputStream(new BufferedInputStream(this.rawInputStream, SOCKET_BUFFER_SIZE))),
this.outputStream.push(
new DataOutputStream(
new BufferedOutputStream(this.rawOutputStream, SOCKET_BUFFER_SIZE))));
}

/**
* Pops the current {@link DataInputStream} and {@link DataOutputStream} from the connection. This
* is done when the COPY sub-protocol has finished.
*/
public Tuple<DataInputStream, DataOutputStream> popStreams() {
return Tuple.of(this.inputStream.pop(), this.outputStream.pop());
}

/** Returns the current {@link DataInputStream} for the connection. */
public DataInputStream peekInputStream() {
return inputStream.peek();
}

/** Returns the current {@link DataOutputStream} for the connection. */
public DataOutputStream peekOutputStream() {
return outputStream.peek();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ void execute() {
ListenableFuture<StatementResult> statementResultFuture = executor.submit(mutationWriter);
ListenableFuture<Void> copyDataReceiverFuture = executor.submit(copyDataReceiver);
this.result.setFuture(statementResultFuture);

// Make sure both the front-end CopyDataReceiver and the backend MutationWriter processes
// have finished before we proceed.
//noinspection UnstableApiUsage
Futures.successfulAsList(copyDataReceiverFuture, statementResultFuture).get();
//noinspection UnstableApiUsage
Futures.allAsList(copyDataReceiverFuture, statementResultFuture).get();
} catch (ExecutionException executionException) {
result.setException(executionException.getCause());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,6 @@ public IntermediatePortalStatement bind(
return this;
}

@Override
public void handleExecutionException(SpannerException exception) {
executor.shutdownNow();
super.handleExecutionException(exception);
}

@Override
public void executeAsync(BackendConnection backendConnection) {
this.executed = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ void setException(SpannerException exception) {
*
* @param exception The exception to store.
*/
protected void handleExecutionException(SpannerException exception) {
public void handleExecutionException(SpannerException exception) {
setException(exception);
this.hasMoreData = false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import com.google.cloud.spanner.pgadapter.commands.Command;
import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata;
import com.google.cloud.spanner.pgadapter.statements.BackendConnection.ConnectionState;
import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient;
import com.google.cloud.spanner.pgadapter.utils.StatementParser;
import com.google.cloud.spanner.pgadapter.wireprotocol.BindMessage;
Expand Down Expand Up @@ -66,32 +67,41 @@ public void execute() throws Exception {
// Do a Parse-Describe-Bind-Execute round-trip for each statement in the query string.
// Finish with a Sync to close any implicit transaction and to return the results.
for (Statement originalStatement : this.statements) {
ParsedStatement originalParsedStatement = PARSER.parse(originalStatement);
ParsedStatement parsedStatement = originalParsedStatement;
if (options.requiresMatcher()
|| connectionHandler.getWellKnownClient() == WellKnownClient.PSQL) {
parsedStatement = translatePotentialMetadataCommand(parsedStatement, connectionHandler);
}
parsedStatement =
replaceKnownUnsupportedQueries(
this.connectionHandler.getWellKnownClient(), this.options, parsedStatement);
if (parsedStatement != originalParsedStatement) {
// The original statement was replaced.
originalStatement = Statement.of(parsedStatement.getSqlWithoutComments());
}
// We need to flush the entire pipeline if we encounter a COPY statement, as COPY statements
// require additional messages to be sent back and forth, and this ensures that we get
// everything in the correct order.
boolean isCopy = StatementParser.isCommand(COPY, parsedStatement.getSqlWithoutComments());
if (isCopy) {
new FlushMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN).send();
}
new ParseMessage(connectionHandler, parsedStatement, originalStatement).send();
new BindMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN).send();
new DescribeMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN).send();
new ExecuteMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN).send();
if (isCopy) {
new FlushMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN).send();
try {
ParsedStatement originalParsedStatement = PARSER.parse(originalStatement);
ParsedStatement parsedStatement = originalParsedStatement;
if (options.requiresMatcher()
|| connectionHandler.getWellKnownClient() == WellKnownClient.PSQL) {
parsedStatement = translatePotentialMetadataCommand(parsedStatement, connectionHandler);
}
parsedStatement =
replaceKnownUnsupportedQueries(
this.connectionHandler.getWellKnownClient(), this.options, parsedStatement);
if (parsedStatement != originalParsedStatement) {
// The original statement was replaced.
originalStatement = Statement.of(parsedStatement.getSqlWithoutComments());
}
// We need to flush the entire pipeline if we encounter a COPY statement, as COPY statements
// require additional messages to be sent back and forth, and this ensures that we get
// everything in the correct order.
boolean isCopy = StatementParser.isCommand(COPY, parsedStatement.getSqlWithoutComments());
if (isCopy) {
new FlushMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN).send();
if (connectionHandler
.getExtendedQueryProtocolHandler()
.getBackendConnection()
.getConnectionState()
== ConnectionState.ABORTED) {
break;
}
}
new ParseMessage(connectionHandler, parsedStatement, originalStatement).send();
new BindMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN).send();
new DescribeMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN).send();
new ExecuteMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN).send();
} catch (Exception ignore) {
// Stop further processing if an exception occurs.
break;
}
}
new SyncMessage(connectionHandler, ManuallyCreatedToken.MANUALLY_CREATED_TOKEN).send();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus;
import com.google.cloud.spanner.pgadapter.statements.CopyStatement;
import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement.ResultNotReadyBehavior;
import com.google.cloud.spanner.pgadapter.wireoutput.CommandCompleteResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.CopyInResponse;
import com.google.common.annotations.VisibleForTesting;
import java.util.concurrent.Callable;
Expand Down Expand Up @@ -60,32 +61,52 @@ void handleCopy() throws Exception {
if (copyStatement.hasException()) {
throw copyStatement.getException();
} else {
this.connectionHandler.addActiveStatement(copyStatement);
new CopyInResponse(
this.connectionHandler.getConnectionMetadata().getOutputStream(),
copyStatement.getTableColumns().size(),
copyStatement.getFormatCode())
.send();
ConnectionStatus initialConnectionStatus = this.connectionHandler.getStatus();
try {
this.connectionHandler.setStatus(ConnectionStatus.COPY_IN);
// Loop here until COPY_IN mode has finished.
while (this.connectionHandler.getStatus() == ConnectionStatus.COPY_IN) {
this.connectionHandler.handleMessages();
}
if (copyStatement.hasException(ResultNotReadyBehavior.BLOCK)
|| this.connectionHandler.getStatus() == ConnectionStatus.COPY_FAILED) {
if (copyStatement.hasException(ResultNotReadyBehavior.BLOCK)) {
throw copyStatement.getException();
} else {
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.INTERNAL, "Copy failed with unknown reason");
// Push a new set of input/output streams for the connection. This ensures that the COPY
// process does not use the buffers of the normal protocol.
this.connectionHandler.getConnectionMetadata().pushNewStreams();
this.connectionHandler.addActiveStatement(copyStatement);
new CopyInResponse(
this.connectionHandler.getConnectionMetadata().peekOutputStream(),
copyStatement.getTableColumns().size(),
copyStatement.getFormatCode())
.send();
ConnectionStatus initialConnectionStatus = this.connectionHandler.getStatus();
try {
this.connectionHandler.setStatus(ConnectionStatus.COPY_IN);
// Loop here until COPY_IN mode has finished.
while (this.connectionHandler.getStatus() == ConnectionStatus.COPY_IN) {
this.connectionHandler.handleMessages();
}
// Return CommandComplete if the COPY succeeded. This should not be cached until a flush.
// Note that if an error occurred during the COPY, the message handler will automatically
// respond with an ErrorResponse. That is why we do not check for COPY_FAILED here, and do
// not return an ErrorResponse.
if (connectionHandler.getStatus() == ConnectionStatus.COPY_DONE) {
new CommandCompleteResponse(
this.connectionHandler.getConnectionMetadata().peekOutputStream(),
"COPY " + copyStatement.getUpdateCount(ResultNotReadyBehavior.BLOCK))
.send();
}
// Throw an exception if the COPY failed. This ensures that the BackendConnection receives
// an error and marks the current (implicit) transaction as aborted.
if (copyStatement.hasException(ResultNotReadyBehavior.BLOCK)
|| this.connectionHandler.getStatus() == ConnectionStatus.COPY_FAILED) {
if (copyStatement.hasException(ResultNotReadyBehavior.BLOCK)) {
throw copyStatement.getException();
} else {
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.INTERNAL, "Copy failed with unknown reason");
}
}
} finally {
this.connectionHandler.removeActiveStatement(copyStatement);
this.copyStatement.close();
this.connectionHandler.setStatus(initialConnectionStatus);
}
} finally {
this.connectionHandler.removeActiveStatement(copyStatement);
this.copyStatement.close();
this.connectionHandler.setStatus(initialConnectionStatus);
// Pop the COPY input/output streams to return to the normal protocol and streams.
this.connectionHandler.getConnectionMetadata().popStreams();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ public BootstrapMessage(ConnectionHandler connection, int length) {
* @throws Exception If construction or reading fails.
*/
public static BootstrapMessage create(ConnectionHandler connection) throws Exception {
int length = connection.getConnectionMetadata().getInputStream().readInt();
int length = connection.getConnectionMetadata().peekInputStream().readInt();
if (length > MAX_BOOTSTRAP_MESSAGE_LENGTH) {
throw new IllegalArgumentException("Invalid bootstrap message length: " + length);
}
int protocol = connection.getConnectionMetadata().getInputStream().readInt();
int protocol = connection.getConnectionMetadata().peekInputStream().readInt();
switch (protocol) {
case SSLMessage.IDENTIFIER:
return new SSLMessage(connection);
Expand Down
Loading

0 comments on commit 6ad4aa2

Please sign in to comment.