Skip to content

Commit

Permalink
chore: add execute method with allowed return types (#2592)
Browse files Browse the repository at this point in the history
* chore: add execute method with allowed return types

Adds an execute method to the Connection API that allows the caller to
specify the allowed result types. This can be used by driver
implementations, such as JDBC, to use a single execute method in the
Connection API, while still making sure that the method only executes
statements that it should.
The current executeUpdate method in the Connection API does not overlap
completely with semantics of executeUpdate in JDBC, as JDBC allows any
statement type that does not return a ResultSet to be executed with that
method. The Connection API only allows statements that return an update
count.
Instead of modifying the executeUpdate method in the Connection API to
match the semantics of JDBC (which would be a breaking change), this
method can be used generically for all execute*** methods in the JDBC
driver, which again can be used to clean up some code there.

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* fix: bad merge for clirr

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
olavloite and gcf-owl-bot[bot] authored Sep 6, 2023
1 parent 1f850e9 commit 2851ce8
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 1 deletion.
12 changes: 12 additions & 0 deletions google-cloud-spanner/clirr-ignored-differences.xml
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,18 @@
<className>com/google/cloud/spanner/connection/Connection</className>
<method>void setMaxPartitions(int)</method>
</difference>
<!-- Add an execute method that allows the driver to state what types should be allowed or not.
This fixes the gap between what JDBC allows, and what is currently allowed in the Connection
API:
1. JDBC allows executeUpdate to be used for everything that does not return a ResultSet.
2. Connection API requires executeUpdate to be used with something that returns an update
count (i.e. no DDL and no client-side statements. -->
<difference>
<differenceType>7012</differenceType>
<className>com/google/cloud/spanner/connection/Connection</className>
<method>com.google.cloud.spanner.connection.StatementResult execute(com.google.cloud.spanner.Statement, java.util.Set)</method>
</difference>

<!-- (Internal change, use stream timeout) -->
<difference>
<differenceType>7012</differenceType>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import com.google.spanner.v1.ExecuteBatchDmlRequest;
import com.google.spanner.v1.ResultSetStats;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -949,6 +950,35 @@ default boolean isDelayTransactionStartUntilFirstWrite() {
*/
StatementResult execute(Statement statement);

/**
* Executes the given statement if allowed in the current {@link TransactionMode} and connection
* state, and if the result that would be returned is in the set of allowed result types. The
* statement will not be sent to Cloud Spanner if the result type would not be allowed. This
* method can be used by drivers that must limit the type of statements that are allowed for a
* given method, e.g. for the {@link java.sql.Statement#executeQuery(String)} and {@link
* java.sql.Statement#executeUpdate(String)} methods.
*
* <p>The returned value depends on the type of statement:
*
* <ul>
* <li>Queries and DML statements with returning clause will return a {@link ResultSet}.
* <li>Simple DML statements will return an update count
* <li>DDL statements will return a {@link ResultType#NO_RESULT}
* <li>Connection and transaction statements (SET AUTOCOMMIT=TRUE|FALSE, SHOW AUTOCOMMIT, SET
* TRANSACTION READ ONLY, etc) will return either a {@link ResultSet} or {@link
* ResultType#NO_RESULT}, depending on the type of statement (SHOW or SET)
* </ul>
*
* @param statement The statement to execute
* @param allowedResultTypes The result types that this method may return. The statement will not
* be sent to Cloud Spanner if the statement would return a result that is not one of the
* types in this set.
* @return the result of the statement
*/
default StatementResult execute(Statement statement, Set<ResultType> allowedResultTypes) {
throw new UnsupportedOperationException("Not implemented");
}

/**
* Executes the given statement if allowed in the current {@link TransactionMode} and connection
* state asynchronously. The returned value depends on the type of statement:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType;
import com.google.cloud.spanner.connection.StatementExecutor.StatementTimeout;
import com.google.cloud.spanner.connection.StatementResult.ResultType;
import com.google.cloud.spanner.connection.UnitOfWork.CallType;
import com.google.cloud.spanner.connection.UnitOfWork.UnitOfWorkState;
import com.google.common.annotations.VisibleForTesting;
Expand All @@ -60,12 +61,15 @@
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.threeten.bp.Instant;

/** Implementation for {@link Connection}, the generic Spanner connection API (not JDBC). */
Expand Down Expand Up @@ -940,9 +944,20 @@ public void rollbackToSavepoint(String name) {

@Override
public StatementResult execute(Statement statement) {
Preconditions.checkNotNull(statement);
return internalExecute(Preconditions.checkNotNull(statement), null);
}

@Override
public StatementResult execute(Statement statement, Set<ResultType> allowedResultTypes) {
return internalExecute(
Preconditions.checkNotNull(statement), Preconditions.checkNotNull(allowedResultTypes));
}

private StatementResult internalExecute(
Statement statement, @Nullable Set<ResultType> allowedResultTypes) {
ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG);
ParsedStatement parsedStatement = getStatementParser().parse(statement, this.queryOptions);
checkResultTypeAllowed(parsedStatement, allowedResultTypes);
switch (parsedStatement.getType()) {
case CLIENT_SIDE:
return parsedStatement
Expand All @@ -969,6 +984,53 @@ public StatementResult execute(Statement statement) {
"Unknown statement: " + parsedStatement.getSqlWithoutComments());
}

@VisibleForTesting
static void checkResultTypeAllowed(
ParsedStatement parsedStatement, @Nullable Set<ResultType> allowedResultTypes) {
if (allowedResultTypes == null) {
return;
}
ResultType resultType = getResultType(parsedStatement);
if (!allowedResultTypes.contains(resultType)) {
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.INVALID_ARGUMENT,
"This statement returns a result of type "
+ resultType
+ ". Only statements that return a result of one of the following types are allowed: "
+ allowedResultTypes.stream()
.map(ResultType::toString)
.collect(Collectors.joining(", ")));
}
}

private static ResultType getResultType(ParsedStatement parsedStatement) {
switch (parsedStatement.getType()) {
case CLIENT_SIDE:
if (parsedStatement.getClientSideStatement().isQuery()) {
return ResultType.RESULT_SET;
} else if (parsedStatement.getClientSideStatement().isUpdate()) {
return ResultType.UPDATE_COUNT;
} else {
return ResultType.NO_RESULT;
}
case QUERY:
return ResultType.RESULT_SET;
case UPDATE:
if (parsedStatement.hasReturningClause()) {
return ResultType.RESULT_SET;
} else {
return ResultType.UPDATE_COUNT;
}
case DDL:
return ResultType.NO_RESULT;
case UNKNOWN:
default:
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.INVALID_ARGUMENT,
"Unknown statement: " + parsedStatement.getSqlWithoutComments());
}
}

@Override
public AsyncStatementResult executeAsync(Statement statement) {
Preconditions.checkNotNull(statement);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static com.google.cloud.spanner.connection.AbstractConnectionImplTest.SELECT;
import static com.google.cloud.spanner.connection.AbstractConnectionImplTest.UPDATE;
import static com.google.cloud.spanner.connection.AbstractConnectionImplTest.expectSpannerException;
import static com.google.cloud.spanner.connection.ConnectionImpl.checkResultTypeAllowed;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.notNullValue;
Expand All @@ -28,6 +29,7 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
Expand Down Expand Up @@ -73,6 +75,7 @@
import com.google.cloud.spanner.connection.StatementResult.ResultType;
import com.google.cloud.spanner.connection.UnitOfWork.CallType;
import com.google.cloud.spanner.connection.UnitOfWork.UnitOfWorkState;
import com.google.common.collect.ImmutableSet;
import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata;
import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions;
import com.google.spanner.v1.ResultSetStats;
Expand Down Expand Up @@ -1624,4 +1627,115 @@ UnitOfWork createNewUnitOfWork(boolean isInternalMetadataQuery) {
assertNull(connection.getTransactionTag());
}
}

@Test
public void testCheckResultTypeAllowed() {
AbstractStatementParser parser =
AbstractStatementParser.getInstance(Dialect.GOOGLE_STANDARD_SQL);
String query = "select * from foo";
String dml = "update foo set bar=1 where true";
String dmlReturning = "insert into foo (id, value) values (1, 'One') then return id";
String ddl = "create table foo";
String set = "set readonly=true";
String show = "show variable readonly";
String start = "start batch dml";

// null means all statements should be allowed.
ImmutableSet<ResultType> allowedResultTypes = null;
checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes);

allowedResultTypes = ImmutableSet.of();
assertThrowResultNotAllowed(parser, query, allowedResultTypes);
assertThrowResultNotAllowed(parser, dml, allowedResultTypes);
assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes);
assertThrowResultNotAllowed(parser, ddl, allowedResultTypes);
assertThrowResultNotAllowed(parser, set, allowedResultTypes);
assertThrowResultNotAllowed(parser, show, allowedResultTypes);
assertThrowResultNotAllowed(parser, start, allowedResultTypes);

allowedResultTypes = ImmutableSet.of(ResultType.RESULT_SET);
checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes);
assertThrowResultNotAllowed(parser, dml, allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes);
assertThrowResultNotAllowed(parser, ddl, allowedResultTypes);
assertThrowResultNotAllowed(parser, set, allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes);
assertThrowResultNotAllowed(parser, start, allowedResultTypes);

allowedResultTypes = ImmutableSet.of(ResultType.UPDATE_COUNT);
assertThrowResultNotAllowed(parser, query, allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes);
assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes);
assertThrowResultNotAllowed(parser, ddl, allowedResultTypes);
assertThrowResultNotAllowed(parser, set, allowedResultTypes);
assertThrowResultNotAllowed(parser, show, allowedResultTypes);
assertThrowResultNotAllowed(parser, start, allowedResultTypes);

allowedResultTypes = ImmutableSet.of(ResultType.NO_RESULT);
assertThrowResultNotAllowed(parser, query, allowedResultTypes);
assertThrowResultNotAllowed(parser, dml, allowedResultTypes);
assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes);
assertThrowResultNotAllowed(parser, show, allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes);

allowedResultTypes = ImmutableSet.of(ResultType.RESULT_SET, ResultType.UPDATE_COUNT);
checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes);
assertThrowResultNotAllowed(parser, ddl, allowedResultTypes);
assertThrowResultNotAllowed(parser, set, allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes);
assertThrowResultNotAllowed(parser, start, allowedResultTypes);

allowedResultTypes = ImmutableSet.of(ResultType.RESULT_SET, ResultType.NO_RESULT);
checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes);
assertThrowResultNotAllowed(parser, dml, allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes);

allowedResultTypes = ImmutableSet.of(ResultType.UPDATE_COUNT, ResultType.NO_RESULT);
assertThrowResultNotAllowed(parser, query, allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes);
assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes);
assertThrowResultNotAllowed(parser, show, allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes);

allowedResultTypes =
ImmutableSet.of(ResultType.RESULT_SET, ResultType.UPDATE_COUNT, ResultType.NO_RESULT);
checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes);
checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes);
}

private void assertThrowResultNotAllowed(
AbstractStatementParser parser, String sql, ImmutableSet<ResultType> allowedResultTypes) {
SpannerException exception =
assertThrows(
SpannerException.class,
() -> checkResultTypeAllowed(parser.parse(Statement.of(sql)), allowedResultTypes));
assertEquals(ErrorCode.INVALID_ARGUMENT, exception.getErrorCode());
assertTrue(
exception.getMessage(),
exception
.getMessage()
.contains(
"Only statements that return a result of one of the following types are allowed"));
}
}
Loading

0 comments on commit 2851ce8

Please sign in to comment.