diff --git a/google-cloud-clients/google-cloud-spanner/pom.xml b/google-cloud-clients/google-cloud-spanner/pom.xml index cecfe3530085..41d13f064747 100644 --- a/google-cloud-clients/google-cloud-spanner/pom.xml +++ b/google-cloud-clients/google-cloud-spanner/pom.xml @@ -61,6 +61,7 @@ com.google.cloud.spanner.IntegrationTest com.google.cloud.spanner.FlakyTest + 2400 diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java index 8840abc7415a..3e19ee30e026 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java @@ -46,7 +46,7 @@ public class BatchClientImpl implements BatchClient { @Override public BatchReadOnlyTransaction batchReadOnlyTransaction(TimestampBound bound) { - SessionImpl session = (SessionImpl) spanner.createSession(db); + SessionImpl session = spanner.createSession(db); return new BatchReadOnlyTransactionImpl(spanner, session, checkNotNull(bound)); } diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java index 749d01278b65..da644d5a0fc3 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java @@ -17,6 +17,9 @@ package com.google.cloud.spanner; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.SessionPool.PooledSession; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Function; import com.google.common.util.concurrent.ListenableFuture; import io.opencensus.common.Scope; import io.opencensus.trace.Span; @@ -33,17 +36,33 @@ class DatabaseClientImpl implements DatabaseClient { TraceUtil.exportSpans(READ_WRITE_TRANSACTION, READ_ONLY_TRANSACTION, PARTITION_DML_TRANSACTION); } - private final SessionPool pool; + @VisibleForTesting final SessionPool pool; DatabaseClientImpl(SessionPool pool) { this.pool = pool; } + @VisibleForTesting + PooledSession getReadSession() { + return pool.getReadSession(); + } + + @VisibleForTesting + PooledSession getReadWriteSession() { + return pool.getReadWriteSession(); + } + @Override - public Timestamp write(Iterable mutations) throws SpannerException { + public Timestamp write(final Iterable mutations) throws SpannerException { Span span = tracer.spanBuilder(READ_WRITE_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { - return pool.getReadWriteSession().write(mutations); + return runWithSessionRetry( + new Function() { + @Override + public Timestamp apply(Session session) { + return session.write(mutations); + } + }); } catch (RuntimeException e) { TraceUtil.endSpanWithFailure(span, e); throw e; @@ -53,10 +72,16 @@ public Timestamp write(Iterable mutations) throws SpannerException { } @Override - public Timestamp writeAtLeastOnce(Iterable mutations) throws SpannerException { + public Timestamp writeAtLeastOnce(final Iterable mutations) throws SpannerException { Span span = tracer.spanBuilder(READ_WRITE_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { - return pool.getReadWriteSession().writeAtLeastOnce(mutations); + return runWithSessionRetry( + new Function() { + @Override + public Timestamp apply(Session session) { + return session.writeAtLeastOnce(mutations); + } + }); } catch (RuntimeException e) { TraceUtil.endSpanWithFailure(span, e); throw e; @@ -69,7 +94,7 @@ public Timestamp writeAtLeastOnce(Iterable mutations) throws SpannerEx public ReadContext singleUse() { Span span = tracer.spanBuilder(READ_ONLY_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { - return pool.getReadSession().singleUse(); + return getReadSession().singleUse(); } catch (RuntimeException e) { TraceUtil.endSpanWithFailure(span, e); throw e; @@ -80,7 +105,7 @@ public ReadContext singleUse() { public ReadContext singleUse(TimestampBound bound) { Span span = tracer.spanBuilder(READ_ONLY_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { - return pool.getReadSession().singleUse(bound); + return getReadSession().singleUse(bound); } catch (RuntimeException e) { TraceUtil.endSpanWithFailure(span, e); throw e; @@ -91,7 +116,7 @@ public ReadContext singleUse(TimestampBound bound) { public ReadOnlyTransaction singleUseReadOnlyTransaction() { Span span = tracer.spanBuilder(READ_ONLY_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { - return pool.getReadSession().singleUseReadOnlyTransaction(); + return getReadSession().singleUseReadOnlyTransaction(); } catch (RuntimeException e) { TraceUtil.endSpanWithFailure(span, e); throw e; @@ -102,7 +127,7 @@ public ReadOnlyTransaction singleUseReadOnlyTransaction() { public ReadOnlyTransaction singleUseReadOnlyTransaction(TimestampBound bound) { Span span = tracer.spanBuilder(READ_ONLY_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { - return pool.getReadSession().singleUseReadOnlyTransaction(bound); + return getReadSession().singleUseReadOnlyTransaction(bound); } catch (RuntimeException e) { TraceUtil.endSpanWithFailure(span, e); throw e; @@ -113,7 +138,7 @@ public ReadOnlyTransaction singleUseReadOnlyTransaction(TimestampBound bound) { public ReadOnlyTransaction readOnlyTransaction() { Span span = tracer.spanBuilder(READ_ONLY_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { - return pool.getReadSession().readOnlyTransaction(); + return getReadSession().readOnlyTransaction(); } catch (RuntimeException e) { TraceUtil.endSpanWithFailure(span, e); throw e; @@ -124,7 +149,7 @@ public ReadOnlyTransaction readOnlyTransaction() { public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) { Span span = tracer.spanBuilder(READ_ONLY_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { - return pool.getReadSession().readOnlyTransaction(bound); + return getReadSession().readOnlyTransaction(bound); } catch (RuntimeException e) { TraceUtil.endSpanWithFailure(span, e); throw e; @@ -135,7 +160,7 @@ public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) { public TransactionRunner readWriteTransaction() { Span span = tracer.spanBuilder(READ_WRITE_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { - return pool.getReadWriteSession().readWriteTransaction(); + return getReadWriteSession().readWriteTransaction(); } catch (RuntimeException e) { TraceUtil.endSpanWithFailure(span, e); throw e; @@ -146,7 +171,7 @@ public TransactionRunner readWriteTransaction() { public TransactionManager transactionManager() { Span span = tracer.spanBuilder(READ_WRITE_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { - return pool.getReadWriteSession().transactionManager(); + return getReadWriteSession().transactionManager(); } catch (RuntimeException e) { TraceUtil.endSpanWithFailure(span, e); throw e; @@ -154,16 +179,33 @@ public TransactionManager transactionManager() { } @Override - public long executePartitionedUpdate(Statement stmt) { + public long executePartitionedUpdate(final Statement stmt) { Span span = tracer.spanBuilder(PARTITION_DML_TRANSACTION).startSpan(); try (Scope s = tracer.withSpan(span)) { - return pool.getReadWriteSession().executePartitionedUpdate(stmt); + return runWithSessionRetry( + new Function() { + @Override + public Long apply(Session session) { + return session.executePartitionedUpdate(stmt); + } + }); } catch (RuntimeException e) { TraceUtil.endSpanWithFailure(span, e); throw e; } } + private T runWithSessionRetry(Function callable) { + PooledSession session = getReadWriteSession(); + while (true) { + try { + return callable.apply(session); + } catch (SessionNotFoundException e) { + session = pool.replaceReadWriteSession(e, session); + } + } + } + ListenableFuture closeAsync() { return pool.closeAsync(); } diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java index 85b2f25f5d2a..753c3f6f3909 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java @@ -22,13 +22,26 @@ /** Forwarding implementation of ResultSet that forwards all calls to a delegate. */ public class ForwardingResultSet extends ForwardingStructReader implements ResultSet { - private final ResultSet delegate; + private ResultSet delegate; public ForwardingResultSet(ResultSet delegate) { super(delegate); this.delegate = Preconditions.checkNotNull(delegate); } + /** + * Replaces the underlying {@link ResultSet}. It is the responsibility of the caller to ensure + * that the new delegate has the same properties and is in the same state as the original + * delegate. This method can be used if the underlying delegate needs to be replaced after a + * session or transaction needed to be restarted after the {@link ResultSet} had already been + * returned to the user. + */ + void replaceDelegate(ResultSet newDelegate) { + Preconditions.checkNotNull(newDelegate); + super.replaceDelegate(newDelegate); + this.delegate = newDelegate; + } + @Override public boolean next() throws SpannerException { return delegate.next(); diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingStructReader.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingStructReader.java index 6ad5b9a6c940..9b30b8998522 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingStructReader.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingStructReader.java @@ -25,12 +25,23 @@ /** Forwarding implements of StructReader */ public class ForwardingStructReader implements StructReader { - private final StructReader delegate; + private StructReader delegate; public ForwardingStructReader(StructReader delegate) { this.delegate = Preconditions.checkNotNull(delegate); } + /** + * Replaces the underlying {@link StructReader}. It is the responsibility of the caller to ensure + * that the new delegate has the same properties and is in the same state as the original + * delegate. This method can be used if the underlying delegate needs to be replaced after a + * session or transaction needed to be restarted after the {@link StructReader} had already been + * returned to the user. + */ + void replaceDelegate(StructReader newDelegate) { + this.delegate = Preconditions.checkNotNull(newDelegate); + } + @Override public Type getType() { return delegate.getType(); diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionNotFoundException.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionNotFoundException.java new file mode 100644 index 000000000000..5fe18eff56ea --- /dev/null +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionNotFoundException.java @@ -0,0 +1,34 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import javax.annotation.Nullable; + +/** + * Exception thrown by Cloud Spanner when an operation detects that the session that is being used + * is no longer valid. This type of error has its own subclass as it is a condition that should + * normally be hidden from the user, and the client library should try to fix this internally. + */ +public class SessionNotFoundException extends SpannerException { + private static final long serialVersionUID = -6395746612598975751L; + + /** Private constructor. Use {@link SpannerExceptionFactory} to create instances. */ + SessionNotFoundException( + DoNotConstructDirectly token, @Nullable String message, @Nullable Throwable cause) { + super(token, ErrorCode.NOT_FOUND, false, message, cause); + } +} diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java index 695ce9685961..670608a6f551 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java @@ -24,7 +24,9 @@ import com.google.cloud.spanner.Options.QueryOption; import com.google.cloud.spanner.Options.ReadOption; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Function; import com.google.common.base.Preconditions; +import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; @@ -75,35 +77,79 @@ Instant instant() { * Wrapper around {@code ReadContext} that releases the session to the pool once the call is * finished, if it is a single use context. */ - private static class AutoClosingReadContext implements ReadContext { - private final ReadContext delegate; - private final PooledSession session; + private static class AutoClosingReadContext implements ReadContext { + private final Function readContextDelegateSupplier; + private T readContextDelegate; + private final SessionPool sessionPool; + private PooledSession session; private final boolean isSingleUse; private boolean closed; + private boolean sessionUsedForQuery = false; private AutoClosingReadContext( - ReadContext delegate, PooledSession session, boolean isSingleUse) { - this.delegate = delegate; + Function delegateSupplier, + SessionPool sessionPool, + PooledSession session, + boolean isSingleUse) { + this.readContextDelegateSupplier = delegateSupplier; + this.sessionPool = sessionPool; this.session = session; this.isSingleUse = isSingleUse; + while (true) { + try { + this.readContextDelegate = readContextDelegateSupplier.apply(this.session); + break; + } catch (SessionNotFoundException e) { + replaceSessionIfPossible(e); + } + } } - private ResultSet wrap(final ResultSet resultSet) { - session.markUsed(); - if (!isSingleUse) { - return resultSet; + T getReadContextDelegate() { + return readContextDelegate; + } + + private ResultSet wrap(final Supplier resultSetSupplier) { + ResultSet res; + while (true) { + try { + res = resultSetSupplier.get(); + break; + } catch (SessionNotFoundException e) { + replaceSessionIfPossible(e); + } } - return new ForwardingResultSet(resultSet) { + return new ForwardingResultSet(res) { + private boolean beforeFirst = true; + @Override public boolean next() throws SpannerException { + while (true) { + try { + return internalNext(); + } catch (SessionNotFoundException e) { + replaceSessionIfPossible(e); + replaceDelegate(resultSetSupplier.get()); + } + } + } + + private boolean internalNext() { try { boolean ret = super.next(); - if (!ret) { + if (beforeFirst) { + session.markUsed(); + beforeFirst = false; + sessionUsedForQuery = true; + } + if (!ret && isSingleUse) { close(); } return ret; + } catch (SessionNotFoundException e) { + throw e; } catch (SpannerException e) { - if (!closed) { + if (!closed && isSingleUse) { session.lastException = e; AutoClosingReadContext.this.close(); } @@ -114,30 +160,69 @@ public boolean next() throws SpannerException { @Override public void close() { super.close(); - AutoClosingReadContext.this.close(); + if (isSingleUse) { + AutoClosingReadContext.this.close(); + } } }; } + private void replaceSessionIfPossible(SessionNotFoundException e) { + if (isSingleUse || !sessionUsedForQuery) { + // This class is only used by read-only transactions, so we know that we only need a + // read-only session. + session = sessionPool.replaceReadSession(e, session); + readContextDelegate = readContextDelegateSupplier.apply(session); + } else { + throw e; + } + } + @Override public ResultSet read( - String table, KeySet keys, Iterable columns, ReadOption... options) { - return wrap(delegate.read(table, keys, columns, options)); + final String table, + final KeySet keys, + final Iterable columns, + final ReadOption... options) { + return wrap( + new Supplier() { + @Override + public ResultSet get() { + return readContextDelegate.read(table, keys, columns, options); + } + }); } @Override public ResultSet readUsingIndex( - String table, String index, KeySet keys, Iterable columns, ReadOption... options) { - return wrap(delegate.readUsingIndex(table, index, keys, columns, options)); + final String table, + final String index, + final KeySet keys, + final Iterable columns, + final ReadOption... options) { + return wrap( + new Supplier() { + @Override + public ResultSet get() { + return readContextDelegate.readUsingIndex(table, index, keys, columns, options); + } + }); } @Override @Nullable public Struct readRow(String table, Key key, Iterable columns) { try { - session.markUsed(); - return delegate.readRow(table, key, columns); + while (true) { + try { + session.markUsed(); + return readContextDelegate.readRow(table, key, columns); + } catch (SessionNotFoundException e) { + replaceSessionIfPossible(e); + } + } } finally { + sessionUsedForQuery = true; if (isSingleUse) { close(); } @@ -148,9 +233,16 @@ public Struct readRow(String table, Key key, Iterable columns) { @Nullable public Struct readRowUsingIndex(String table, String index, Key key, Iterable columns) { try { - session.markUsed(); - return delegate.readRowUsingIndex(table, index, key, columns); + while (true) { + try { + session.markUsed(); + return readContextDelegate.readRowUsingIndex(table, index, key, columns); + } catch (SessionNotFoundException e) { + replaceSessionIfPossible(e); + } + } } finally { + sessionUsedForQuery = true; if (isSingleUse) { close(); } @@ -158,13 +250,25 @@ public Struct readRowUsingIndex(String table, String index, Key key, Iterable() { + @Override + public ResultSet get() { + return readContextDelegate.executeQuery(statement, options); + } + }); } @Override - public ResultSet analyzeQuery(Statement statement, QueryAnalyzeMode queryMode) { - return wrap(delegate.analyzeQuery(statement, queryMode)); + public ResultSet analyzeQuery(final Statement statement, final QueryAnalyzeMode queryMode) { + return wrap( + new Supplier() { + @Override + public ResultSet get() { + return readContextDelegate.analyzeQuery(statement, queryMode); + } + }); } @Override @@ -173,46 +277,181 @@ public void close() { return; } closed = true; - delegate.close(); + readContextDelegate.close(); session.close(); } } - private static class AutoClosingReadTransaction extends AutoClosingReadContext - implements ReadOnlyTransaction { - private final ReadOnlyTransaction txn; + private static class AutoClosingReadTransaction + extends AutoClosingReadContext implements ReadOnlyTransaction { AutoClosingReadTransaction( - ReadOnlyTransaction txn, PooledSession session, boolean isSingleUse) { - super(txn, session, isSingleUse); - this.txn = txn; + Function txnSupplier, + SessionPool sessionPool, + PooledSession session, + boolean isSingleUse) { + super(txnSupplier, sessionPool, session, isSingleUse); } @Override public Timestamp getReadTimestamp() { - return txn.getReadTimestamp(); + return getReadContextDelegate().getReadTimestamp(); } } private static class AutoClosingTransactionManager implements TransactionManager { - final TransactionManager delegate; - final PooledSession session; + private class SessionPoolResultSet extends ForwardingResultSet { + private SessionPoolResultSet(ResultSet delegate) { + super(delegate); + } + + @Override + public boolean next() { + try { + return super.next(); + } catch (SessionNotFoundException e) { + throw handleSessionNotFound(e); + } + } + } + + /** + * {@link TransactionContext} that is used in combination with an {@link + * AutoClosingTransactionManager}. This {@link TransactionContext} handles {@link + * SessionNotFoundException}s by replacing the underlying session with a fresh one, and then + * throws an {@link AbortedException} to trigger the retry-loop that has been created by the + * caller. + */ + private class SessionPoolTransactionContext implements TransactionContext { + private final TransactionContext delegate; + + private SessionPoolTransactionContext(TransactionContext delegate) { + this.delegate = delegate; + } + + @Override + public ResultSet read( + String table, KeySet keys, Iterable columns, ReadOption... options) { + return new SessionPoolResultSet(delegate.read(table, keys, columns, options)); + } + + @Override + public ResultSet readUsingIndex( + String table, + String index, + KeySet keys, + Iterable columns, + ReadOption... options) { + return new SessionPoolResultSet( + delegate.readUsingIndex(table, index, keys, columns, options)); + } + + @Override + public Struct readRow(String table, Key key, Iterable columns) { + try { + return delegate.readRow(table, key, columns); + } catch (SessionNotFoundException e) { + throw handleSessionNotFound(e); + } + } + + @Override + public void buffer(Mutation mutation) { + delegate.buffer(mutation); + } + + @Override + public Struct readRowUsingIndex( + String table, String index, Key key, Iterable columns) { + try { + return delegate.readRowUsingIndex(table, index, key, columns); + } catch (SessionNotFoundException e) { + throw handleSessionNotFound(e); + } + } + + @Override + public void buffer(Iterable mutations) { + delegate.buffer(mutations); + } + + @Override + public long executeUpdate(Statement statement) { + try { + return delegate.executeUpdate(statement); + } catch (SessionNotFoundException e) { + throw handleSessionNotFound(e); + } + } + + @Override + public long[] batchUpdate(Iterable statements) { + try { + return delegate.batchUpdate(statements); + } catch (SessionNotFoundException e) { + throw handleSessionNotFound(e); + } + } + + @Override + public ResultSet executeQuery(Statement statement, QueryOption... options) { + return new SessionPoolResultSet(delegate.executeQuery(statement, options)); + } + + @Override + public ResultSet analyzeQuery(Statement statement, QueryAnalyzeMode queryMode) { + return new SessionPoolResultSet(delegate.analyzeQuery(statement, queryMode)); + } + + @Override + public void close() { + delegate.close(); + } + } + + private TransactionManager delegate; + private final SessionPool sessionPool; + private PooledSession session; private boolean closed; + private boolean restartedAfterSessionNotFound; - AutoClosingTransactionManager(TransactionManager delegate, PooledSession session) { - this.delegate = delegate; + AutoClosingTransactionManager(SessionPool sessionPool, PooledSession session) { + this.sessionPool = sessionPool; this.session = session; + this.delegate = session.delegate.transactionManager(); } @Override public TransactionContext begin() { - return delegate.begin(); + while (true) { + try { + return internalBegin(); + } catch (SessionNotFoundException e) { + session = sessionPool.replaceReadWriteSession(e, session); + delegate = session.delegate.transactionManager(); + } + } + } + + private TransactionContext internalBegin() { + TransactionContext res = new SessionPoolTransactionContext(delegate.begin()); + session.markUsed(); + return res; + } + + private SpannerException handleSessionNotFound(SessionNotFoundException e) { + session = sessionPool.replaceReadWriteSession(e, session); + delegate = session.delegate.transactionManager(); + restartedAfterSessionNotFound = true; + return SpannerExceptionFactory.newSpannerException(ErrorCode.ABORTED, e.getMessage(), e); } @Override public void commit() { try { delegate.commit(); + } catch (SessionNotFoundException e) { + throw handleSessionNotFound(e); } finally { if (getState() != TransactionState.ABORTED) { close(); @@ -231,7 +470,21 @@ public void rollback() { @Override public TransactionContext resetForRetry() { - return delegate.resetForRetry(); + while (true) { + try { + if (restartedAfterSessionNotFound) { + TransactionContext res = new SessionPoolTransactionContext(delegate.begin()); + restartedAfterSessionNotFound = false; + return res; + } else { + return new SessionPoolTransactionContext(delegate.resetForRetry()); + } + } catch (SessionNotFoundException e) { + session = sessionPool.replaceReadWriteSession(e, session); + delegate = session.delegate.transactionManager(); + restartedAfterSessionNotFound = true; + } + } } @Override @@ -254,7 +507,61 @@ public void close() { @Override public TransactionState getState() { - return delegate.getState(); + if (restartedAfterSessionNotFound) { + return TransactionState.ABORTED; + } else { + return delegate.getState(); + } + } + } + + /** + * {@link TransactionRunner} that automatically handles {@link SessionNotFoundException}s by + * replacing the underlying read/write session and then restarts the transaction. + */ + private static final class SessionPoolTransactionRunner implements TransactionRunner { + private final SessionPool sessionPool; + private PooledSession session; + private TransactionRunner runner; + + private SessionPoolTransactionRunner(SessionPool sessionPool, PooledSession session) { + this.sessionPool = sessionPool; + this.session = session; + this.runner = session.delegate.readWriteTransaction(); + } + + @Override + @Nullable + public T run(TransactionCallable callable) { + try { + T result; + while (true) { + try { + result = runner.run(callable); + break; + } catch (SessionNotFoundException e) { + session = sessionPool.replaceReadWriteSession(e, session); + runner = session.delegate.readWriteTransaction(); + } + } + session.markUsed(); + return result; + } catch (SpannerException e) { + throw session.lastException = e; + } finally { + session.close(); + } + } + + @Override + public Timestamp getCommitTimestamp() { + return runner.getCommitTimestamp(); + } + + @Override + public TransactionRunner allowNestedTransaction() { + runner.allowNestedTransaction(); + return runner; } } @@ -275,18 +582,24 @@ private enum SessionState { } final class PooledSession implements Session { - @VisibleForTesting final Session delegate; + @VisibleForTesting SessionImpl delegate; private volatile Instant lastUseTime; private volatile SpannerException lastException; private volatile LeakedSessionException leakedException; + private volatile boolean allowReplacing = true; @GuardedBy("lock") private SessionState state; - private PooledSession(Session delegate) { + private PooledSession(SessionImpl delegate) { this.delegate = delegate; this.state = SessionState.AVAILABLE; - markUsed(); + this.lastUseTime = clock.instant(); + } + + @VisibleForTesting + void setAllowReplacing(boolean allowReplacing) { + this.allowReplacing = allowReplacing; } private void markBusy() { @@ -337,7 +650,16 @@ public Timestamp writeAtLeastOnce(Iterable mutations) throws SpannerEx @Override public ReadContext singleUse() { try { - return new AutoClosingReadContext(delegate.singleUse(), this, true); + return new AutoClosingReadContext<>( + new Function() { + @Override + public ReadContext apply(PooledSession session) { + return session.delegate.singleUse(); + } + }, + SessionPool.this, + this, + true); } catch (Exception e) { close(); throw e; @@ -345,9 +667,18 @@ public ReadContext singleUse() { } @Override - public ReadContext singleUse(TimestampBound bound) { + public ReadContext singleUse(final TimestampBound bound) { try { - return new AutoClosingReadContext(delegate.singleUse(bound), this, true); + return new AutoClosingReadContext<>( + new Function() { + @Override + public ReadContext apply(PooledSession session) { + return session.delegate.singleUse(bound); + } + }, + SessionPool.this, + this, + true); } catch (Exception e) { close(); throw e; @@ -356,39 +687,57 @@ public ReadContext singleUse(TimestampBound bound) { @Override public ReadOnlyTransaction singleUseReadOnlyTransaction() { - try { - return new AutoClosingReadTransaction(delegate.singleUseReadOnlyTransaction(), this, true); - } catch (Exception e) { - close(); - throw e; - } + return internalReadOnlyTransaction( + new Function() { + @Override + public ReadOnlyTransaction apply(PooledSession session) { + return session.delegate.singleUseReadOnlyTransaction(); + } + }, + true); } @Override - public ReadOnlyTransaction singleUseReadOnlyTransaction(TimestampBound bound) { - try { - return new AutoClosingReadTransaction( - delegate.singleUseReadOnlyTransaction(bound), this, true); - } catch (Exception e) { - close(); - throw e; - } + public ReadOnlyTransaction singleUseReadOnlyTransaction(final TimestampBound bound) { + return internalReadOnlyTransaction( + new Function() { + @Override + public ReadOnlyTransaction apply(PooledSession session) { + return session.delegate.singleUseReadOnlyTransaction(bound); + } + }, + true); } @Override public ReadOnlyTransaction readOnlyTransaction() { - try { - return new AutoClosingReadTransaction(delegate.readOnlyTransaction(), this, false); - } catch (Exception e) { - close(); - throw e; - } + return internalReadOnlyTransaction( + new Function() { + @Override + public ReadOnlyTransaction apply(PooledSession session) { + return session.delegate.readOnlyTransaction(); + } + }, + false); } @Override - public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) { + public ReadOnlyTransaction readOnlyTransaction(final TimestampBound bound) { + return internalReadOnlyTransaction( + new Function() { + @Override + public ReadOnlyTransaction apply(PooledSession session) { + return session.delegate.readOnlyTransaction(bound); + } + }, + false); + } + + private ReadOnlyTransaction internalReadOnlyTransaction( + Function transactionSupplier, boolean isSingleUse) { try { - return new AutoClosingReadTransaction(delegate.readOnlyTransaction(bound), this, false); + return new AutoClosingReadTransaction( + transactionSupplier, SessionPool.this, this, isSingleUse); } catch (Exception e) { close(); throw e; @@ -397,34 +746,7 @@ public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) { @Override public TransactionRunner readWriteTransaction() { - final TransactionRunner runner = delegate.readWriteTransaction(); - return new TransactionRunner() { - - @Override - @Nullable - public T run(TransactionCallable callable) { - try { - markUsed(); - T result = runner.run(callable); - return result; - } catch (SpannerException e) { - throw lastException = e; - } finally { - close(); - } - } - - @Override - public Timestamp getCommitTimestamp() { - return runner.getCommitTimestamp(); - } - - @Override - public TransactionRunner allowNestedTransaction() { - runner.allowNestedTransaction(); - return runner; - } - }; + return new SessionPoolTransactionRunner(SessionPool.this, this); } @Override @@ -469,8 +791,7 @@ private void markUsed() { @Override public TransactionManager transactionManager() { - markUsed(); - return new AutoClosingTransactionManager(delegate.transactionManager(), this); + return new AutoClosingTransactionManager(SessionPool.this, this); } } @@ -756,6 +1077,11 @@ private SessionPool( this.poolMaintainer = new PoolMaintainer(); } + @VisibleForTesting + int getNumberOfAvailableWritePreparedSessions() { + return writePreparedSessions.size(); + } + private void initPool() { synchronized (lock) { poolMaintainer.init(); @@ -823,7 +1149,7 @@ private PooledSession findSessionToKeepAlive( * session being returned to the pool or a new session being created. * */ - Session getReadSession() throws SpannerException { + PooledSession getReadSession() throws SpannerException { Span span = Tracing.getTracer().getCurrentSpan(); span.addAnnotation("Acquiring session"); Waiter waiter = null; @@ -879,7 +1205,7 @@ Session getReadSession() throws SpannerException { * to the pool which is then write prepared. * */ - Session getReadWriteSession() { + PooledSession getReadWriteSession() { Span span = Tracing.getTracer().getCurrentSpan(); span.addAnnotation("Acquiring read write session"); Waiter waiter = null; @@ -919,6 +1245,24 @@ Session getReadWriteSession() { return sess; } + PooledSession replaceReadSession(SessionNotFoundException e, PooledSession session) { + if (!options.isFailIfSessionNotFound() && session.allowReplacing) { + closeSessionAsync(session); + return getReadSession(); + } else { + throw e; + } + } + + PooledSession replaceReadWriteSession(SessionNotFoundException e, PooledSession session) { + if (!options.isFailIfSessionNotFound() && session.allowReplacing) { + closeSessionAsync(session); + return getReadWriteSession(); + } else { + throw e; + } + } + private Annotation sessionAnnotation(Session session) { AttributeValue sessionId = AttributeValue.stringAttributeValue(session.getName()); return Annotation.fromDescriptionAndAttributes( @@ -1178,7 +1522,7 @@ private void createSession() { new Runnable() { @Override public void run() { - Session session = null; + SessionImpl session = null; try { session = spanner.createSession(db); logger.log(Level.FINE, "Session created"); diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java index 0ae3b8aa7b03..cda7341b6e5a 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPoolOptions.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; /** Options for the session pool used by {@code DatabaseClient}. */ @@ -29,6 +30,7 @@ public class SessionPoolOptions { private final float writeSessionsFraction; private final ActionOnExhaustion actionOnExhaustion; private final int keepAliveIntervalMinutes; + private final ActionOnSessionNotFound actionOnSessionNotFound; private SessionPoolOptions(Builder builder) { this.minSessions = builder.minSessions; @@ -36,6 +38,7 @@ private SessionPoolOptions(Builder builder) { this.maxIdleSessions = builder.maxIdleSessions; this.writeSessionsFraction = builder.writeSessionsFraction; this.actionOnExhaustion = builder.actionOnExhaustion; + this.actionOnSessionNotFound = builder.actionOnSessionNotFound; this.keepAliveIntervalMinutes = builder.keepAliveIntervalMinutes; } @@ -67,6 +70,11 @@ public boolean isBlockIfPoolExhausted() { return actionOnExhaustion == ActionOnExhaustion.BLOCK; } + @VisibleForTesting + boolean isFailIfSessionNotFound() { + return actionOnSessionNotFound == ActionOnSessionNotFound.FAIL; + } + public static Builder newBuilder() { return new Builder(); } @@ -76,6 +84,11 @@ private static enum ActionOnExhaustion { FAIL, } + private static enum ActionOnSessionNotFound { + RETRY, + FAIL; + } + /** Builder for creating SessionPoolOptions. */ public static class Builder { private int minSessions; @@ -83,6 +96,7 @@ public static class Builder { private int maxIdleSessions; private float writeSessionsFraction = 0.2f; private ActionOnExhaustion actionOnExhaustion = DEFAULT_ACTION; + private ActionOnSessionNotFound actionOnSessionNotFound = ActionOnSessionNotFound.RETRY; private int keepAliveIntervalMinutes = 30; /** @@ -146,6 +160,16 @@ public Builder setBlockIfPoolExhausted() { return this; } + /** + * If a session has been invalidated by the server, the {@link SessionPool} will by default + * retry the session. Set this option to throw an exception instead of retrying. + */ + @VisibleForTesting + Builder setFailIfSessionNotFound() { + this.actionOnSessionNotFound = ActionOnSessionNotFound.FAIL; + return this; + } + /** * Fraction of sessions to be kept prepared for write transactions. This is an optimisation to * avoid the cost of sending a BeginTransaction() rpc. If all such sessions are in use and a diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java index 7577a06bfbce..6a34b0a86082 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java @@ -147,6 +147,11 @@ private static SpannerException newSpannerExceptionPreformatted( switch (code) { case ABORTED: return new AbortedException(token, message, cause); + case NOT_FOUND: + if (message != null && message.contains("Session not found")) { + return new SessionNotFoundException(token, message, cause); + } + // Fall through to the default. default: return new SpannerException(token, code, isRetryable(code, cause), message, cause); } diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 712917fc1a18..5e60686c7e5b 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -29,6 +29,7 @@ import com.google.cloud.PageImpl.NextPageFetcher; import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.cloud.spanner.spi.v1.SpannerRpc.Paginated; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; @@ -203,8 +204,7 @@ int getDefaultPrefetchChunks() { return defaultPrefetchChunks; } - // TODO(user): change this to return SessionImpl and modify all corresponding references. - Session createSession(final DatabaseId db) throws SpannerException { + SessionImpl createSession(final DatabaseId db) throws SpannerException { final Map options = optionMap(SessionOption.channelHint(random.nextLong())); Span span = tracer.spanBuilder(CREATE_SESSION).startSpan(); @@ -250,13 +250,18 @@ public DatabaseClient getDatabaseClient(DatabaseId db) { return dbClients.get(db); } else { SessionPool pool = SessionPool.createPool(getOptions(), db, SpannerImpl.this); - DatabaseClientImpl dbClient = new DatabaseClientImpl(pool); + DatabaseClientImpl dbClient = createDatabaseClient(pool); dbClients.put(db, dbClient); return dbClient; } } } + @VisibleForTesting + DatabaseClientImpl createDatabaseClient(SessionPool pool) { + return new DatabaseClientImpl(pool); + } + @Override public BatchClient getBatchClient(DatabaseId db) { return new BatchClientImpl(db, SpannerImpl.this); diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/testing/RemoteSpannerHelper.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/testing/RemoteSpannerHelper.java index 09a314464295..9ab9e19435eb 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/testing/RemoteSpannerHelper.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/testing/RemoteSpannerHelper.java @@ -47,7 +47,7 @@ public class RemoteSpannerHelper { private static int dbPrefix = new Random().nextInt(Integer.MAX_VALUE); private final List dbs = new ArrayList<>(); - private RemoteSpannerHelper(SpannerOptions options, InstanceId instanceId, Spanner client) { + protected RemoteSpannerHelper(SpannerOptions options, InstanceId instanceId, Spanner client) { this.options = options; this.instanceId = instanceId; this.client = client; diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java index 8209fe2fd3a9..340327336373 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java @@ -56,8 +56,8 @@ public void release(ScheduledExecutorService executor) { } } - Session mockSession() { - Session session = mock(Session.class); + SessionImpl mockSession() { + SessionImpl session = mock(SessionImpl.class); when(session.getName()) .thenReturn( "projects/dummy/instances/dummy/database/dummy/sessions/session" + sessionIndex); diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/IntegrationTestEnv.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/IntegrationTestEnv.java index fa5ee4257683..63741cd2f7a9 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/IntegrationTestEnv.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/IntegrationTestEnv.java @@ -85,7 +85,7 @@ protected void before() throws Throwable { instanceId = InstanceId.of(config.spannerOptions().getProjectId(), "test-instance"); isOwnedInstance = true; } - testHelper = RemoteSpannerHelper.create(options, instanceId); + testHelper = createTestHelper(options, instanceId); instanceAdminClient = testHelper.getClient().getInstanceAdminClient(); logger.log(Level.FINE, "Test env endpoint is {0}", options.getHost()); if (isOwnedInstance) { @@ -93,6 +93,11 @@ protected void before() throws Throwable { } } + RemoteSpannerHelper createTestHelper(SpannerOptions options, InstanceId instanceId) + throws Throwable { + return RemoteSpannerHelper.create(options, instanceId); + } + @Override protected void after() { cleanUpInstance(); @@ -138,7 +143,7 @@ private void cleanUpInstance() { } } - private void checkInitialized() { + void checkInitialized() { checkState(testHelper != null, "Setup has not completed successfully"); } } diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/IntegrationTestWithClosedSessionsEnv.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/IntegrationTestWithClosedSessionsEnv.java new file mode 100644 index 000000000000..66e0893be7b0 --- /dev/null +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/IntegrationTestWithClosedSessionsEnv.java @@ -0,0 +1,128 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.cloud.spanner.SessionPool.PooledSession; +import com.google.cloud.spanner.testing.RemoteSpannerHelper; + +/** + * Subclass of {@link IntegrationTestEnv} that allows the user to specify when the underlying + * session of a {@link PooledSession} should be closed. This can be used to ensure that the + * recreation of sessions that have been invalidated by the server works. + */ +public class IntegrationTestWithClosedSessionsEnv extends IntegrationTestEnv { + private static class RemoteSpannerHelperWithClosedSessions extends RemoteSpannerHelper { + private RemoteSpannerHelperWithClosedSessions( + SpannerOptions options, InstanceId instanceId, Spanner client) { + super(options, instanceId, client); + } + } + + @Override + RemoteSpannerHelper createTestHelper(SpannerOptions options, InstanceId instanceId) + throws Throwable { + SpannerWithClosedSessionsImpl spanner = new SpannerWithClosedSessionsImpl(options); + return new RemoteSpannerHelperWithClosedSessions(options, instanceId, spanner); + } + + private static class SpannerWithClosedSessionsImpl extends SpannerImpl { + SpannerWithClosedSessionsImpl(SpannerOptions options) { + super(options); + } + + @Override + DatabaseClientImpl createDatabaseClient(SessionPool pool) { + return new DatabaseClientWithClosedSessionImpl(pool); + } + } + + /** + * {@link DatabaseClient} that allows the user to specify when an underlying session of a {@link + * PooledSession} should be closed. + */ + public static class DatabaseClientWithClosedSessionImpl extends DatabaseClientImpl { + private boolean invalidateNextSession = false; + private boolean allowReplacing = true; + + DatabaseClientWithClosedSessionImpl(SessionPool pool) { + super(pool); + } + + /** Invalidate the next session that is checked out from the pool. */ + public void invalidateNextSession() { + invalidateNextSession = true; + } + + /** Sets whether invalidated sessions should be replaced or not. */ + public void setAllowSessionReplacing(boolean allow) { + this.allowReplacing = allow; + } + + @Override + PooledSession getReadSession() { + PooledSession session = super.getReadSession(); + if (invalidateNextSession) { + session.delegate.close(); + session.setAllowReplacing(false); + awaitDeleted(session.delegate); + session.setAllowReplacing(allowReplacing); + invalidateNextSession = false; + } + session.setAllowReplacing(allowReplacing); + return session; + } + + @Override + PooledSession getReadWriteSession() { + PooledSession session = super.getReadWriteSession(); + if (invalidateNextSession) { + session.delegate.close(); + session.setAllowReplacing(false); + awaitDeleted(session.delegate); + session.setAllowReplacing(allowReplacing); + invalidateNextSession = false; + } + session.setAllowReplacing(allowReplacing); + return session; + } + + /** + * Deleting a session server side takes some time. This method checks and waits until the + * session really has been deleted. + */ + private void awaitDeleted(Session session) { + // Wait until the session has actually been deleted. + while (true) { + try (ResultSet rs = session.singleUse().executeQuery(Statement.of("SELECT 1"))) { + while (rs.next()) { + // Do nothing. + } + Thread.sleep(500L); + } catch (SpannerException e) { + if (e.getErrorCode() == ErrorCode.NOT_FOUND + && e.getMessage().contains("Session not found")) { + break; + } else { + throw e; + } + } catch (InterruptedException e) { + break; + } + } + } + } +} diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index 8564b51b66cd..8cd369c6c687 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -1248,7 +1248,6 @@ public void commit(CommitRequest request, StreamObserver respons .asRuntimeException()); return; } - if (transaction == null) { setTransactionNotFound(request.getTransactionId(), responseObserver); return; diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/RetryOnInvalidatedSessionTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/RetryOnInvalidatedSessionTest.java new file mode 100644 index 000000000000..ac14701e84d0 --- /dev/null +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/RetryOnInvalidatedSessionTest.java @@ -0,0 +1,1411 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.google.api.gax.core.NoCredentialsProvider; +import com.google.api.gax.grpc.testing.LocalChannelProvider; +import com.google.cloud.NoCredentials; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.TransactionRunner.TransactionCallable; +import com.google.cloud.spanner.v1.SpannerClient; +import com.google.cloud.spanner.v1.SpannerClient.ListSessionsPagedResponse; +import com.google.cloud.spanner.v1.SpannerSettings; +import com.google.common.base.Stopwatch; +import com.google.protobuf.ListValue; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.TypeCode; +import io.grpc.Server; +import io.grpc.inprocess.InProcessServerBuilder; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class RetryOnInvalidatedSessionTest { + @Rule public ExpectedException expected = ExpectedException.none(); + + @Parameter(0) + public boolean failOnInvalidatedSession; + + @Parameters(name = "fail on invalidated session = {0}") + public static Collection data() { + List params = new ArrayList<>(); + params.add(new Object[] {false}); + params.add(new Object[] {true}); + return params; + } + + private static final ResultSetMetadata READ_METADATA = + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("BAR") + .setType( + com.google.spanner.v1.Type.newBuilder() + .setCode(TypeCode.INT64) + .build()) + .build()) + .build()) + .build(); + private static final com.google.spanner.v1.ResultSet READ_RESULTSET = + com.google.spanner.v1.ResultSet.newBuilder() + .addRows( + ListValue.newBuilder() + .addValues(com.google.protobuf.Value.newBuilder().setStringValue("1").build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(com.google.protobuf.Value.newBuilder().setStringValue("2").build()) + .build()) + .setMetadata(READ_METADATA) + .build(); + private static final com.google.spanner.v1.ResultSet READ_ROW_RESULTSET = + com.google.spanner.v1.ResultSet.newBuilder() + .addRows( + ListValue.newBuilder() + .addValues(com.google.protobuf.Value.newBuilder().setStringValue("1").build()) + .build()) + .setMetadata(READ_METADATA) + .build(); + private static final Statement SELECT1AND2 = + Statement.of("SELECT 1 AS COL1 UNION ALL SELECT 2 AS COL1"); + private static final ResultSetMetadata SELECT1AND2_METADATA = + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("COL1") + .setType( + com.google.spanner.v1.Type.newBuilder() + .setCode(TypeCode.INT64) + .build()) + .build()) + .build()) + .build(); + private static final com.google.spanner.v1.ResultSet SELECT1_RESULTSET = + com.google.spanner.v1.ResultSet.newBuilder() + .addRows( + ListValue.newBuilder() + .addValues(com.google.protobuf.Value.newBuilder().setStringValue("1").build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(com.google.protobuf.Value.newBuilder().setStringValue("2").build()) + .build()) + .setMetadata(SELECT1AND2_METADATA) + .build(); + private static final Statement UPDATE_STATEMENT = + Statement.of("UPDATE FOO SET BAR=1 WHERE BAZ=2"); + private static final long UPDATE_COUNT = 1L; + private static final int MAX_SESSIONS = 10; + private static final float WRITE_SESSIONS_FRACTION = 0.5f; + private static MockSpannerServiceImpl mockSpanner; + private static Server server; + private static LocalChannelProvider channelProvider; + private static SpannerClient spannerClient; + private static Spanner spanner; + private static DatabaseClient client; + + @BeforeClass + public static void startStaticServer() throws IOException { + mockSpanner = new MockSpannerServiceImpl(); + mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions. + mockSpanner.putStatementResult( + StatementResult.read("FOO", KeySet.all(), Arrays.asList("BAR"), READ_RESULTSET)); + mockSpanner.putStatementResult( + StatementResult.read( + "FOO", KeySet.singleKey(Key.of()), Arrays.asList("BAR"), READ_ROW_RESULTSET)); + mockSpanner.putStatementResult(StatementResult.query(SELECT1AND2, SELECT1_RESULTSET)); + mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + + String uniqueName = InProcessServerBuilder.generateName(); + server = + InProcessServerBuilder.forName(uniqueName) + .directExecutor() + .addService(mockSpanner) + .build() + .start(); + channelProvider = LocalChannelProvider.create(uniqueName); + + SpannerSettings settings = + SpannerSettings.newBuilder() + .setTransportChannelProvider(channelProvider) + .setCredentialsProvider(NoCredentialsProvider.create()) + .build(); + spannerClient = SpannerClient.create(settings); + } + + @AfterClass + public static void stopServer() { + spannerClient.close(); + server.shutdown(); + } + + @Before + public void setUp() throws IOException { + mockSpanner.reset(); + SessionPoolOptions.Builder builder = + SessionPoolOptions.newBuilder() + .setMaxSessions(MAX_SESSIONS) + .setWriteSessionsFraction(WRITE_SESSIONS_FRACTION); + if (failOnInvalidatedSession) { + builder.setFailIfSessionNotFound(); + } + spanner = + SpannerOptions.newBuilder() + .setProjectId("[PROJECT]") + .setChannelProvider(channelProvider) + .setSessionPoolOption(builder.build()) + .setCredentials(NoCredentials.getInstance()) + .build() + .getService(); + client = spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]")); + } + + @After + public void tearDown() throws Exception { + spanner.close(); + } + + private static void initReadOnlySessionPool() { + // Do a simple query in order to make sure there is one read-only session in the pool. + try (ReadContext context = client.singleUse()) { + try (ResultSet rs = context.executeQuery(SELECT1AND2)) { + while (rs.next()) { + // do nothing. + } + } + } + } + + private static void initReadWriteSessionPool() throws InterruptedException { + // Do enough queries to ensure that a read/write session will be prepared. + ExecutorService service = Executors.newFixedThreadPool(MAX_SESSIONS); + for (int i = 0; i < MAX_SESSIONS; i++) { + service.submit( + new Callable() { + @Override + public Void call() throws Exception { + try (ReadContext context = client.singleUse()) { + try (ResultSet rs = context.executeQuery(SELECT1AND2)) { + while (rs.next()) { + // Make sure the transactions are actually running simultaneously to ensure that + // there are multiple sessions being created. + Thread.sleep(20L); + } + } + } + return null; + } + }); + } + service.shutdown(); + service.awaitTermination(10L, TimeUnit.SECONDS); + Stopwatch watch = Stopwatch.createStarted(); + while (((DatabaseClientImpl) client).pool.getNumberOfAvailableWritePreparedSessions() == 0) { + if (watch.elapsed(TimeUnit.MILLISECONDS) > 1000L) { + fail("No read/write sessions prepared"); + } + Thread.sleep(5L); + } + } + + private static void invalidateSessionPool() throws InterruptedException { + ListSessionsPagedResponse response = + spannerClient.listSessions("projects/[PROJECT]/instances/[INSTANCE]/databases/[DATABASE]"); + for (com.google.spanner.v1.Session session : response.iterateAll()) { + spannerClient.deleteSession(session.getName()); + } + } + + @Test + public void singleUseSelect() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + // This call will receive an invalidated session that will be replaced on the first call to + // rs.next(). + int count = 0; + try (ReadContext context = client.singleUse()) { + try (ResultSet rs = context.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + } + assertThat(count, is(equalTo(2))); + } + + @Test + public void singleUseRead() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + int count = 0; + try (ReadContext context = client.singleUse()) { + try (ResultSet rs = context.read("FOO", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + } + } + + @Test + public void singleUseReadUsingIndex() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + int count = 0; + try (ReadContext context = client.singleUse()) { + try (ResultSet rs = + context.readUsingIndex("FOO", "IDX", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + } + } + + @Test + public void singleUseReadRow() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + try (ReadContext context = client.singleUse()) { + Struct row = context.readRow("FOO", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + } + ; + } + + @Test + public void singleUseReadRowUsingIndex() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + try (ReadContext context = client.singleUse()) { + Struct row = context.readRowUsingIndex("FOO", "IDX", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + } + ; + } + + @Test + public void singleUseReadOnlyTransactionSelect() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + int count = 0; + try (ReadContext context = client.singleUseReadOnlyTransaction()) { + try (ResultSet rs = context.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + } + assertThat(count, is(equalTo(2))); + } + + @Test + public void singleUseReadOnlyTransactionRead() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + int count = 0; + try (ReadContext context = client.singleUseReadOnlyTransaction()) { + try (ResultSet rs = context.read("FOO", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + } + } + + @Test + public void singlUseReadOnlyTransactionReadUsingIndex() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + int count = 0; + try (ReadContext context = client.singleUseReadOnlyTransaction()) { + try (ResultSet rs = + context.readUsingIndex("FOO", "IDX", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + } + } + + @Test + public void singleUseReadOnlyTransactionReadRow() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + try (ReadContext context = client.singleUseReadOnlyTransaction()) { + Struct row = context.readRow("FOO", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + } + ; + } + + @Test + public void singleUseReadOnlyTransactionReadRowUsingIndex() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + try (ReadContext context = client.singleUseReadOnlyTransaction()) { + Struct row = context.readRowUsingIndex("FOO", "IDX", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + } + ; + } + + @Test + public void readOnlyTransactionSelect() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + int count = 0; + try (ReadContext context = client.readOnlyTransaction()) { + try (ResultSet rs = context.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + } + } + + @Test + public void readOnlyTransactionRead() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + int count = 0; + try (ReadContext context = client.readOnlyTransaction()) { + try (ResultSet rs = context.read("FOO", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + } + } + + @Test + public void readOnlyTransactionReadUsingIndex() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + int count = 0; + try (ReadContext context = client.readOnlyTransaction()) { + try (ResultSet rs = + context.readUsingIndex("FOO", "IDX", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + } + } + + @Test + public void readOnlyTransactionReadRow() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + try (ReadContext context = client.readOnlyTransaction()) { + Struct row = context.readRow("FOO", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + } + ; + } + + @Test + public void readOnlyTransactionReadRowUsingIndex() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadOnlySessionPool(); + invalidateSessionPool(); + try (ReadContext context = client.readOnlyTransaction()) { + Struct row = context.readRowUsingIndex("FOO", "IDX", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + } + ; + } + + @Test(expected = SessionNotFoundException.class) + public void readOnlyTransactionSelectNonRecoverable() throws InterruptedException { + int count = 0; + try (ReadContext context = client.readOnlyTransaction()) { + try (ResultSet rs = context.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + // Invalidate the session pool while in a transaction. This is not recoverable. + invalidateSessionPool(); + try (ResultSet rs = context.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + } + } + + @Test(expected = SessionNotFoundException.class) + public void readOnlyTransactionReadNonRecoverable() throws InterruptedException { + int count = 0; + try (ReadContext context = client.readOnlyTransaction()) { + try (ResultSet rs = context.read("FOO", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + invalidateSessionPool(); + try (ResultSet rs = context.read("FOO", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + } + } + + @Test(expected = SessionNotFoundException.class) + public void readOnlyTransactionReadUsingIndexNonRecoverable() throws InterruptedException { + int count = 0; + try (ReadContext context = client.readOnlyTransaction()) { + try (ResultSet rs = + context.readUsingIndex("FOO", "IDX", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + invalidateSessionPool(); + try (ResultSet rs = + context.readUsingIndex("FOO", "IDX", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + } + } + + @Test(expected = SessionNotFoundException.class) + public void readOnlyTransactionReadRowNonRecoverable() throws InterruptedException { + try (ReadContext context = client.readOnlyTransaction()) { + Struct row = context.readRow("FOO", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + invalidateSessionPool(); + row = context.readRow("FOO", Key.of(), Arrays.asList("BAR")); + } + ; + } + + @Test(expected = SessionNotFoundException.class) + public void readOnlyTransactionReadRowUsingIndexNonRecoverable() throws InterruptedException { + try (ReadContext context = client.readOnlyTransaction()) { + Struct row = context.readRowUsingIndex("FOO", "IDX", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + invalidateSessionPool(); + row = context.readRowUsingIndex("FOO", "IDX", Key.of(), Arrays.asList("BAR")); + } + ; + } + + /** + * Test with one read-only session in the pool that is invalidated. The session pool will try to + * prepare this session for read/write, which will fail with a {@link SessionNotFoundException}. + * That again will trigger the creation of a new session. This will always succeed. + */ + @Test + public void readWriteTransactionReadOnlySessionInPool() throws InterruptedException { + initReadOnlySessionPool(); + invalidateSessionPool(); + TransactionRunner runner = client.readWriteTransaction(); + int count = + runner.run( + new TransactionCallable() { + @Override + public Integer run(TransactionContext transaction) throws Exception { + int count = 0; + try (ResultSet rs = transaction.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + return count; + } + }); + assertThat(count, is(equalTo(2))); + } + + @Test + public void readWriteTransactionSelect() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + TransactionRunner runner = client.readWriteTransaction(); + int count = + runner.run( + new TransactionCallable() { + @Override + public Integer run(TransactionContext transaction) throws Exception { + int count = 0; + try (ResultSet rs = transaction.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + return count; + } + }); + assertThat(count, is(equalTo(2))); + } + + @Test + public void readWriteTransactionRead() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + TransactionRunner runner = client.readWriteTransaction(); + int count = + runner.run( + new TransactionCallable() { + @Override + public Integer run(TransactionContext transaction) throws Exception { + int count = 0; + try (ResultSet rs = transaction.read("FOO", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + return count; + } + }); + assertThat(count, is(equalTo(2))); + } + + @Test + public void readWriteTransactionReadUsingIndex() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + TransactionRunner runner = client.readWriteTransaction(); + int count = + runner.run( + new TransactionCallable() { + @Override + public Integer run(TransactionContext transaction) throws Exception { + int count = 0; + try (ResultSet rs = + transaction.readUsingIndex("FOO", "IDX", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + return count; + } + }); + assertThat(count, is(equalTo(2))); + } + + @Test + public void readWriteTransactionReadRow() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + TransactionRunner runner = client.readWriteTransaction(); + Struct row = + runner.run( + new TransactionCallable() { + @Override + public Struct run(TransactionContext transaction) throws Exception { + return transaction.readRow("FOO", Key.of(), Arrays.asList("BAR")); + } + }); + assertThat(row.getLong(0), is(equalTo(1L))); + } + + @Test + public void readWriteTransactionReadRowUsingIndex() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + TransactionRunner runner = client.readWriteTransaction(); + Struct row = + runner.run( + new TransactionCallable() { + @Override + public Struct run(TransactionContext transaction) throws Exception { + return transaction.readRowUsingIndex("FOO", "IDX", Key.of(), Arrays.asList("BAR")); + } + }); + assertThat(row.getLong(0), is(equalTo(1L))); + } + + @Test + public void readWriteTransactionUpdate() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + TransactionRunner runner = client.readWriteTransaction(); + long count = + runner.run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + return transaction.executeUpdate(UPDATE_STATEMENT); + } + }); + assertThat(count, is(equalTo(UPDATE_COUNT))); + } + + @Test + public void readWriteTransactionBatchUpdate() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + TransactionRunner runner = client.readWriteTransaction(); + long[] count = + runner.run( + new TransactionCallable() { + @Override + public long[] run(TransactionContext transaction) throws Exception { + return transaction.batchUpdate(Arrays.asList(UPDATE_STATEMENT)); + } + }); + assertThat(count.length, is(equalTo(1))); + assertThat(count[0], is(equalTo(UPDATE_COUNT))); + } + + @Test + public void readWriteTransactionBuffer() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + TransactionRunner runner = client.readWriteTransaction(); + runner.run( + new TransactionCallable() { + @Override + public Void run(TransactionContext transaction) throws Exception { + transaction.buffer(Mutation.newInsertBuilder("FOO").set("BAR").to(1L).build()); + return null; + } + }); + assertThat(runner.getCommitTimestamp(), is(notNullValue())); + } + + @Test + public void readWriteTransactionSelectInvalidatedDuringTransaction() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + TransactionRunner runner = client.readWriteTransaction(); + int attempts = + runner.run( + new TransactionCallable() { + private int attempt = 0; + + @Override + public Integer run(TransactionContext transaction) throws Exception { + attempt++; + int count = 0; + try (ResultSet rs = transaction.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + if (attempt == 1) { + invalidateSessionPool(); + } + try (ResultSet rs = transaction.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + return attempt; + } + }); + assertThat(attempts, is(equalTo(2))); + } + + @Test + public void readWriteTransactionReadInvalidatedDuringTransaction() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + TransactionRunner runner = client.readWriteTransaction(); + int attempts = + runner.run( + new TransactionCallable() { + private int attempt = 0; + + @Override + public Integer run(TransactionContext transaction) throws Exception { + attempt++; + int count = 0; + try (ResultSet rs = transaction.read("FOO", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + if (attempt == 1) { + invalidateSessionPool(); + } + try (ResultSet rs = transaction.read("FOO", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + return attempt; + } + }); + assertThat(attempts, is(equalTo(2))); + } + + @Test + public void readWriteTransactionReadUsingIndexInvalidatedDuringTransaction() + throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + TransactionRunner runner = client.readWriteTransaction(); + int attempts = + runner.run( + new TransactionCallable() { + private int attempt = 0; + + @Override + public Integer run(TransactionContext transaction) throws Exception { + attempt++; + int count = 0; + try (ResultSet rs = + transaction.readUsingIndex("FOO", "IDX", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + if (attempt == 1) { + invalidateSessionPool(); + } + try (ResultSet rs = + transaction.readUsingIndex("FOO", "IDX", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + return attempt; + } + }); + assertThat(attempts, is(equalTo(2))); + } + + @Test + public void readWriteTransactionReadRowInvalidatedDuringTransaction() + throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + TransactionRunner runner = client.readWriteTransaction(); + int attempts = + runner.run( + new TransactionCallable() { + private int attempt = 0; + + @Override + public Integer run(TransactionContext transaction) throws Exception { + attempt++; + Struct row = transaction.readRow("FOO", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + if (attempt == 1) { + invalidateSessionPool(); + } + row = transaction.readRow("FOO", Key.of(), Arrays.asList("BAR")); + return attempt; + } + }); + assertThat(attempts, is(equalTo(2))); + } + + @Test + public void readWriteTransactionReadRowUsingIndexInvalidatedDuringTransaction() + throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + TransactionRunner runner = client.readWriteTransaction(); + int attempts = + runner.run( + new TransactionCallable() { + private int attempt = 0; + + @Override + public Integer run(TransactionContext transaction) throws Exception { + attempt++; + Struct row = + transaction.readRowUsingIndex("FOO", "IDX", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + if (attempt == 1) { + invalidateSessionPool(); + } + row = transaction.readRowUsingIndex("FOO", "IDX", Key.of(), Arrays.asList("BAR")); + return attempt; + } + }); + assertThat(attempts, is(equalTo(2))); + } + + /** + * Test with one read-only session in the pool that is invalidated. The session pool will try to + * prepare this session for read/write, which will fail with a {@link SessionNotFoundException}. + * That again will trigger the creation of a new session. This will always succeed. + */ + @SuppressWarnings("resource") + @Test + public void transactionManagerReadOnlySessionInPool() throws InterruptedException { + initReadOnlySessionPool(); + invalidateSessionPool(); + int count = 0; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + try { + try (ResultSet rs = transaction.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(count, is(equalTo(2))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerSelect() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + int count = 0; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + try { + try (ResultSet rs = transaction.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(count, is(equalTo(2))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerRead() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + int count = 0; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + try { + try (ResultSet rs = transaction.read("FOO", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(count, is(equalTo(2))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerReadUsingIndex() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + int count = 0; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + try { + try (ResultSet rs = + transaction.readUsingIndex("FOO", "IDX", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(count, is(equalTo(2))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerReadRow() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + Struct row; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + try { + row = transaction.readRow("FOO", Key.of(), Arrays.asList("BAR")); + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(row.getLong(0), is(equalTo(1L))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerReadRowUsingIndex() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + Struct row; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + try { + row = transaction.readRowUsingIndex("FOO", "IDX", Key.of(), Arrays.asList("BAR")); + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(row.getLong(0), is(equalTo(1L))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerUpdate() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + long count; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + try { + count = transaction.executeUpdate(UPDATE_STATEMENT); + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(count, is(equalTo(UPDATE_COUNT))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerBatchUpdate() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + long[] count; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + try { + count = transaction.batchUpdate(Arrays.asList(UPDATE_STATEMENT)); + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(count.length, is(equalTo(1))); + assertThat(count[0], is(equalTo(UPDATE_COUNT))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerBuffer() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + transaction.buffer(Mutation.newInsertBuilder("FOO").set("BAR").to(1L).build()); + try { + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + assertThat(manager.getCommitTimestamp(), is(notNullValue())); + } + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerSelectInvalidatedDuringTransaction() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + int attempts = 0; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + attempts++; + int count = 0; + try { + try (ResultSet rs = transaction.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + if (attempts == 1) { + invalidateSessionPool(); + } + try (ResultSet rs = transaction.executeQuery(SELECT1AND2)) { + while (rs.next()) { + count++; + } + } + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(attempts, is(equalTo(2))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerReadInvalidatedDuringTransaction() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + int attempts = 0; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + attempts++; + int count = 0; + try { + try (ResultSet rs = transaction.read("FOO", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + if (attempts == 1) { + invalidateSessionPool(); + } + try (ResultSet rs = transaction.read("FOO", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(attempts, is(equalTo(2))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerReadUsingIndexInvalidatedDuringTransaction() + throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + int attempts = 0; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + attempts++; + int count = 0; + try { + try (ResultSet rs = + transaction.readUsingIndex("FOO", "IDX", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + assertThat(count, is(equalTo(2))); + if (attempts == 1) { + invalidateSessionPool(); + } + try (ResultSet rs = + transaction.readUsingIndex("FOO", "IDX", KeySet.all(), Arrays.asList("BAR"))) { + while (rs.next()) { + count++; + } + } + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(attempts, is(equalTo(2))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerReadRowInvalidatedDuringTransaction() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + int attempts = 0; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + attempts++; + try { + Struct row = transaction.readRow("FOO", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + if (attempts == 1) { + invalidateSessionPool(); + } + row = transaction.readRow("FOO", Key.of(), Arrays.asList("BAR")); + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(attempts, is(equalTo(2))); + } + + @SuppressWarnings("resource") + @Test + public void transactionManagerReadRowUsingIndexInvalidatedDuringTransaction() + throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + int attempts = 0; + try (TransactionManager manager = client.transactionManager()) { + TransactionContext transaction = manager.begin(); + while (true) { + attempts++; + try { + Struct row = transaction.readRowUsingIndex("FOO", "IDX", Key.of(), Arrays.asList("BAR")); + assertThat(row.getLong(0), is(equalTo(1L))); + if (attempts == 1) { + invalidateSessionPool(); + } + row = transaction.readRowUsingIndex("FOO", "IDX", Key.of(), Arrays.asList("BAR")); + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + transaction = manager.resetForRetry(); + } + } + } + assertThat(attempts, is(equalTo(2))); + } + + @Test + public void partitionedDml() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + assertThat(client.executePartitionedUpdate(UPDATE_STATEMENT), is(equalTo(UPDATE_COUNT))); + } + + @Test + public void write() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + Timestamp timestamp = client.write(Arrays.asList(Mutation.delete("FOO", KeySet.all()))); + assertThat(timestamp, is(notNullValue())); + } + + @Test + public void writeAtLeastOnce() throws InterruptedException { + if (failOnInvalidatedSession) { + expected.expect(SessionNotFoundException.class); + } + initReadWriteSessionPool(); + invalidateSessionPool(); + Timestamp timestamp = + client.writeAtLeastOnce(Arrays.asList(Mutation.delete("FOO", KeySet.all()))); + assertThat(timestamp, is(notNullValue())); + } +} diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java index 6c95cfcf8476..897430ac7978 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java @@ -18,6 +18,7 @@ import static com.google.cloud.spanner.SpannerMatchers.isSpannerException; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.doAnswer; @@ -29,13 +30,27 @@ import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.initMocks; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; import com.google.cloud.spanner.SessionPool.Clock; import com.google.cloud.spanner.SessionPool.PooledSession; +import com.google.cloud.spanner.TransactionRunner.TransactionCallable; +import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl; +import com.google.cloud.spanner.spi.v1.SpannerRpc; +import com.google.cloud.spanner.spi.v1.SpannerRpc.ResultStreamConsumer; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.Uninterruptibles; +import com.google.protobuf.ByteString; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.RollbackRequest; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -49,6 +64,7 @@ import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -118,8 +134,8 @@ public void poolClosure() throws Exception { @Test public void poolClosureClosesLeakedSessions() throws Exception { - Session mockSession1 = mockSession(); - Session mockSession2 = mockSession(); + SessionImpl mockSession1 = mockSession(); + SessionImpl mockSession2 = mockSession(); when(client.createSession(db)).thenReturn(mockSession1).thenReturn(mockSession2); pool = createPool(); Session session1 = pool.getReadSession(); @@ -157,7 +173,7 @@ public void run() { public void poolClosureFailsPendingReadWaiters() throws Exception { final CountDownLatch insideCreation = new CountDownLatch(1); final CountDownLatch releaseCreation = new CountDownLatch(1); - Session session1 = mockSession(); + SessionImpl session1 = mockSession(); final Session session2 = mockSession(); when(client.createSession(db)) .thenReturn(session1) @@ -186,7 +202,7 @@ public Session answer(InvocationOnMock invocation) throws Throwable { public void poolClosureFailsPendingWriteWaiters() throws Exception { final CountDownLatch insideCreation = new CountDownLatch(1); final CountDownLatch releaseCreation = new CountDownLatch(1); - Session session1 = mockSession(); + SessionImpl session1 = mockSession(); final Session session2 = mockSession(); when(client.createSession(db)) .thenReturn(session1) @@ -238,7 +254,7 @@ public Session answer(InvocationOnMock invocation) throws Throwable { @Test public void poolClosesEvenIfPreparationFails() throws Exception { - Session session = mockSession(); + SessionImpl session = mockSession(); when(client.createSession(db)).thenReturn(session); final CountDownLatch insidePrepare = new CountDownLatch(1); final CountDownLatch releasePrepare = new CountDownLatch(1); @@ -266,7 +282,7 @@ public Session answer(InvocationOnMock invocation) throws Throwable { @Test public void poolClosureFailsNewRequests() throws Exception { - Session session = mockSession(); + SessionImpl session = mockSession(); when(client.createSession(db)).thenReturn(session); pool = createPool(); pool.getReadSession(); @@ -310,7 +326,7 @@ public void creationExceptionPropagatesToReadWriteSession() { @Test public void prepareExceptionPropagatesToReadWriteSession() { - Session session = mockSession(); + SessionImpl session = mockSession(); when(client.createSession(db)).thenReturn(session); doThrow(SpannerExceptionFactory.newSpannerException(ErrorCode.INTERNAL, "")) .when(session) @@ -322,7 +338,7 @@ public void prepareExceptionPropagatesToReadWriteSession() { @Test public void getReadWriteSession() { - Session mockSession = mockSession(); + SessionImpl mockSession = mockSession(); when(client.createSession(db)).thenReturn(mockSession); pool = createPool(); try (Session session = pool.getReadWriteSession()) { @@ -333,8 +349,8 @@ public void getReadWriteSession() { @Test public void getMultipleReadWriteSessions() { - Session mockSession1 = mockSession(); - Session mockSession2 = mockSession(); + SessionImpl mockSession1 = mockSession(); + SessionImpl mockSession2 = mockSession(); when(client.createSession(db)).thenReturn(mockSession1).thenReturn(mockSession2); pool = createPool(); Session session1 = pool.getReadWriteSession(); @@ -348,7 +364,7 @@ public void getMultipleReadWriteSessions() { @Test public void getMultipleConcurrentReadWriteSessions() { AtomicBoolean failed = new AtomicBoolean(false); - Session session = mockSession(); + SessionImpl session = mockSession(); when(client.createSession(db)).thenReturn(session); pool = createPool(); int numSessions = 5; @@ -361,8 +377,8 @@ public void getMultipleConcurrentReadWriteSessions() { @Test public void sessionIsPrePrepared() { - Session mockSession1 = mockSession(); - Session mockSession2 = mockSession(); + SessionImpl mockSession1 = mockSession(); + SessionImpl mockSession2 = mockSession(); final CountDownLatch prepareLatch = new CountDownLatch(1); doAnswer( new Answer() { @@ -397,8 +413,8 @@ public Void answer(InvocationOnMock arg0) throws Throwable { pool = createPool(); // One of the sessions would be pre prepared. Uninterruptibles.awaitUninterruptibly(prepareLatch); - PooledSession readSession = (PooledSession) pool.getReadSession(); - PooledSession writeSession = (PooledSession) pool.getReadWriteSession(); + PooledSession readSession = pool.getReadSession(); + PooledSession writeSession = pool.getReadWriteSession(); verify(writeSession.delegate, times(1)).prepareReadWriteTransaction(); verify(readSession.delegate, never()).prepareReadWriteTransaction(); readSession.close(); @@ -407,7 +423,7 @@ public Void answer(InvocationOnMock arg0) throws Throwable { @Test public void getReadSessionFallsBackToWritePreparedSession() throws Exception { - Session mockSession1 = mockSession(); + SessionImpl mockSession1 = mockSession(); final CountDownLatch prepareLatch = new CountDownLatch(2); doAnswer( new Answer() { @@ -430,7 +446,7 @@ public Void answer(InvocationOnMock arg0) throws Throwable { pool.getReadWriteSession().close(); prepareLatch.await(); // This session should also be write prepared. - PooledSession readSession = (PooledSession) pool.getReadSession(); + PooledSession readSession = pool.getReadSession(); verify(readSession.delegate, times(2)).prepareReadWriteTransaction(); } @@ -442,7 +458,7 @@ public void failOnPoolExhaustion() { .setMaxSessions(1) .setFailIfPoolExhausted() .build(); - Session mockSession = mockSession(); + SessionImpl mockSession = mockSession(); when(client.createSession(db)).thenReturn(mockSession); pool = createPool(); Session session1 = pool.getReadSession(); @@ -456,14 +472,14 @@ public void failOnPoolExhaustion() { @Test public void poolWorksWhenSessionNotFound() { - Session mockSession1 = mockSession(); - Session mockSession2 = mockSession(); + SessionImpl mockSession1 = mockSession(); + SessionImpl mockSession2 = mockSession(); doThrow(SpannerExceptionFactory.newSpannerException(ErrorCode.NOT_FOUND, "Session not found")) .when(mockSession1) .prepareReadWriteTransaction(); when(client.createSession(db)).thenReturn(mockSession1).thenReturn(mockSession2); pool = createPool(); - assertThat(((PooledSession) pool.getReadWriteSession()).delegate).isEqualTo(mockSession2); + assertThat(pool.getReadWriteSession().delegate).isEqualTo(mockSession2); } @Test @@ -474,9 +490,9 @@ public void idleSessionCleanup() throws Exception { .setMaxSessions(3) .setMaxIdleSessions(0) .build(); - Session session1 = mockSession(); - Session session2 = mockSession(); - Session session3 = mockSession(); + SessionImpl session1 = mockSession(); + SessionImpl session2 = mockSession(); + SessionImpl session3 = mockSession(); final AtomicInteger numSessionClosed = new AtomicInteger(); when(client.createSession(db)).thenReturn(session1).thenReturn(session2).thenReturn(session3); for (Session session : new Session[] {session1, session2, session3}) { @@ -523,7 +539,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { @Test public void keepAlive() throws Exception { options = SessionPoolOptions.newBuilder().setMinSessions(2).setMaxSessions(3).build(); - Session session = mockSession(); + SessionImpl session = mockSession(); mockKeepAlive(session); // This is cheating as we are returning the same session each but it makes the verification // easier. @@ -548,6 +564,285 @@ public void keepAlive() throws Exception { pool.closeAsync().get(); } + @Test + public void testSessionNotFoundSingleUse() { + Statement statement = Statement.of("SELECT 1"); + SessionImpl closedSession = mockSession(); + ReadContext closedContext = mock(ReadContext.class); + ResultSet closedResultSet = mock(ResultSet.class); + when(closedResultSet.next()) + .thenThrow( + SpannerExceptionFactory.newSpannerException(ErrorCode.NOT_FOUND, "Session not found")); + when(closedContext.executeQuery(statement)).thenReturn(closedResultSet); + when(closedSession.singleUse()).thenReturn(closedContext); + + SessionImpl openSession = mockSession(); + ReadContext openContext = mock(ReadContext.class); + ResultSet openResultSet = mock(ResultSet.class); + when(openResultSet.next()).thenReturn(true, false); + when(openContext.executeQuery(statement)).thenReturn(openResultSet); + when(openSession.singleUse()).thenReturn(openContext); + + when(client.createSession(db)).thenReturn(closedSession, openSession); + FakeClock clock = new FakeClock(); + clock.currentTimeMillis = System.currentTimeMillis(); + pool = createPool(clock); + ReadContext context = pool.getReadSession().singleUse(); + ResultSet resultSet = context.executeQuery(statement); + assertThat(resultSet.next()).isTrue(); + } + + @Test + public void testSessionNotFoundReadOnlyTransaction() { + Statement statement = Statement.of("SELECT 1"); + SessionImpl closedSession = mockSession(); + when(closedSession.readOnlyTransaction()) + .thenThrow( + SpannerExceptionFactory.newSpannerException(ErrorCode.NOT_FOUND, "Session not found")); + + SessionImpl openSession = mockSession(); + ReadOnlyTransaction openTransaction = mock(ReadOnlyTransaction.class); + ResultSet openResultSet = mock(ResultSet.class); + when(openResultSet.next()).thenReturn(true, false); + when(openTransaction.executeQuery(statement)).thenReturn(openResultSet); + when(openSession.readOnlyTransaction()).thenReturn(openTransaction); + + when(client.createSession(db)).thenReturn(closedSession, openSession); + FakeClock clock = new FakeClock(); + clock.currentTimeMillis = System.currentTimeMillis(); + pool = createPool(clock); + ReadOnlyTransaction transaction = pool.getReadSession().readOnlyTransaction(); + ResultSet resultSet = transaction.executeQuery(statement); + assertThat(resultSet.next()).isTrue(); + } + + private enum ReadWriteTransactionTestStatementType { + QUERY, + ANALYZE, + UPDATE, + BATCH_UPDATE, + WRITE, + EXCEPTION; + } + + @SuppressWarnings("unchecked") + @Test + public void testSessionNotFoundReadWriteTransaction() { + final Statement queryStatement = Statement.of("SELECT 1"); + final Statement updateStatement = Statement.of("UPDATE FOO SET BAR=1 WHERE ID=2"); + final SpannerException sessionNotFound = + SpannerExceptionFactory.newSpannerException(ErrorCode.NOT_FOUND, "Session not found"); + for (ReadWriteTransactionTestStatementType statementType : + ReadWriteTransactionTestStatementType.values()) { + final ReadWriteTransactionTestStatementType executeStatementType = statementType; + for (boolean prepared : new boolean[] {true, false}) { + final boolean hasPreparedTransaction = prepared; + SpannerRpc.StreamingCall closedStreamingCall = mock(SpannerRpc.StreamingCall.class); + doThrow(sessionNotFound).when(closedStreamingCall).request(Mockito.anyInt()); + SpannerRpc rpc = mock(SpannerRpc.class); + when(rpc.executeQuery( + any(ExecuteSqlRequest.class), any(ResultStreamConsumer.class), any(Map.class))) + .thenReturn(closedStreamingCall); + when(rpc.executeQuery(any(ExecuteSqlRequest.class), any(Map.class))) + .thenThrow(sessionNotFound); + when(rpc.executeBatchDml(any(ExecuteBatchDmlRequest.class), any(Map.class))) + .thenThrow(sessionNotFound); + when(rpc.commit(any(CommitRequest.class), any(Map.class))).thenThrow(sessionNotFound); + doThrow(sessionNotFound).when(rpc).rollback(any(RollbackRequest.class), any(Map.class)); + SessionImpl closedSession = mock(SessionImpl.class); + when(closedSession.getName()) + .thenReturn("projects/dummy/instances/dummy/database/dummy/sessions/session-closed"); + ByteString preparedTransactionId = + hasPreparedTransaction ? ByteString.copyFromUtf8("test-txn") : null; + final TransactionContextImpl closedTransactionContext = + new TransactionContextImpl(closedSession, preparedTransactionId, rpc, 10); + when(closedSession.newTransaction()).thenReturn(closedTransactionContext); + when(closedSession.beginTransaction()).thenThrow(sessionNotFound); + TransactionRunnerImpl closedTransactionRunner = + new TransactionRunnerImpl(closedSession, rpc, 10); + when(closedSession.readWriteTransaction()).thenReturn(closedTransactionRunner); + + SessionImpl openSession = mock(SessionImpl.class); + when(openSession.getName()) + .thenReturn("projects/dummy/instances/dummy/database/dummy/sessions/session-open"); + final TransactionContextImpl openTransactionContext = mock(TransactionContextImpl.class); + when(openSession.newTransaction()).thenReturn(openTransactionContext); + when(openSession.beginTransaction()).thenReturn(ByteString.copyFromUtf8("open-txn")); + TransactionRunnerImpl openTransactionRunner = + new TransactionRunnerImpl(openSession, mock(SpannerRpc.class), 10); + when(openSession.readWriteTransaction()).thenReturn(openTransactionRunner); + + ResultSet openResultSet = mock(ResultSet.class); + when(openResultSet.next()).thenReturn(true, false); + ResultSet planResultSet = mock(ResultSet.class); + when(planResultSet.getStats()).thenReturn(ResultSetStats.getDefaultInstance()); + when(openTransactionContext.executeQuery(queryStatement)).thenReturn(openResultSet); + when(openTransactionContext.analyzeQuery(queryStatement, QueryAnalyzeMode.PLAN)) + .thenReturn(planResultSet); + when(openTransactionContext.executeUpdate(updateStatement)).thenReturn(1L); + when(openTransactionContext.batchUpdate(Arrays.asList(updateStatement, updateStatement))) + .thenReturn(new long[] {1L, 1L}); + + when(client.createSession(db)).thenReturn(closedSession, openSession); + FakeClock clock = new FakeClock(); + clock.currentTimeMillis = System.currentTimeMillis(); + SessionPoolOptions options = + SessionPoolOptions.newBuilder() + .setMinSessions(0) // The pool should not auto-create any sessions + .setMaxSessions(2) + .setBlockIfPoolExhausted() + .build(); + SessionPool pool = + SessionPool.createPool(options, new TestExecutorFactory(), db, client, clock); + TransactionRunner runner = pool.getReadWriteSession().readWriteTransaction(); + try { + runner.run( + new TransactionCallable() { + private int callNumber = 0; + + @Override + public Integer run(TransactionContext transaction) throws Exception { + callNumber++; + if (hasPreparedTransaction) { + // If the session had a prepared read/write transaction, that transaction will + // be given to the runner in the first place and the SessionNotFoundException + // will occur on the first query / update statement. + if (callNumber == 1) { + assertThat(transaction).isEqualTo(closedTransactionContext); + } else { + assertThat(transaction).isEqualTo(openTransactionContext); + } + } else { + // If the session did not have a prepared read/write transaction, the library + // tried to create a new transaction before handing it to the transaction + // runner. + // The creation of the new transaction failed with a SessionNotFoundException, + // and the session was re-created before the run method was called. + assertThat(transaction).isEqualTo(openTransactionContext); + } + switch (executeStatementType) { + case QUERY: + ResultSet resultSet = transaction.executeQuery(queryStatement); + assertThat(resultSet.next()).isTrue(); + break; + case ANALYZE: + ResultSet planResultSet = + transaction.analyzeQuery(queryStatement, QueryAnalyzeMode.PLAN); + assertThat(planResultSet.next()).isFalse(); + assertThat(planResultSet.getStats()).isNotNull(); + break; + case UPDATE: + long updateCount = transaction.executeUpdate(updateStatement); + assertThat(updateCount).isEqualTo(1L); + break; + case BATCH_UPDATE: + long[] updateCounts = + transaction.batchUpdate(Arrays.asList(updateStatement, updateStatement)); + assertThat(updateCounts).isEqualTo(new long[] {1L, 1L}); + break; + case WRITE: + transaction.buffer(Mutation.delete("FOO", Key.of(1L))); + break; + case EXCEPTION: + throw new RuntimeException("rollback at call " + callNumber); + default: + fail("Unknown statement type: " + executeStatementType); + } + return callNumber; + } + }); + } catch (Exception e) { + // The rollback will also cause a SessionNotFoundException, but this is caught, logged and + // further ignored by the library, meaning that the session will not be re-created for + // retry. Hence rollback at call 1. + assertThat( + executeStatementType == ReadWriteTransactionTestStatementType.EXCEPTION + && e.getMessage().contains("rollback at call 1")) + .isTrue(); + } + } + } + } + + @Test + public void testSessionNotFoundOnPrepareTransaction() { + final SpannerException sessionNotFound = + SpannerExceptionFactory.newSpannerException(ErrorCode.NOT_FOUND, "Session not found"); + SessionImpl closedSession = mock(SessionImpl.class); + when(closedSession.getName()) + .thenReturn("projects/dummy/instances/dummy/database/dummy/sessions/session-closed"); + when(closedSession.beginTransaction()).thenThrow(sessionNotFound); + doThrow(sessionNotFound).when(closedSession).prepareReadWriteTransaction(); + + SessionImpl openSession = mock(SessionImpl.class); + when(openSession.getName()) + .thenReturn("projects/dummy/instances/dummy/database/dummy/sessions/session-open"); + + when(client.createSession(db)).thenReturn(closedSession, openSession); + FakeClock clock = new FakeClock(); + clock.currentTimeMillis = System.currentTimeMillis(); + pool = createPool(clock); + PooledSession session = pool.getReadWriteSession(); + assertThat(session.delegate).isEqualTo(openSession); + } + + @Test + public void testSessionNotFoundWrite() { + SpannerException sessionNotFound = + SpannerExceptionFactory.newSpannerException(ErrorCode.NOT_FOUND, "Session not found"); + List mutations = Arrays.asList(Mutation.newInsertBuilder("FOO").build()); + SessionImpl closedSession = mockSession(); + when(closedSession.write(mutations)).thenThrow(sessionNotFound); + + SessionImpl openSession = mockSession(); + when(openSession.write(mutations)).thenReturn(Timestamp.now()); + + when(client.createSession(db)).thenReturn(closedSession, openSession); + FakeClock clock = new FakeClock(); + clock.currentTimeMillis = System.currentTimeMillis(); + pool = createPool(clock); + DatabaseClientImpl impl = new DatabaseClientImpl(pool); + assertThat(impl.write(mutations)).isNotNull(); + } + + @Test + public void testSessionNotFoundWriteAtLeastOnce() { + SpannerException sessionNotFound = + SpannerExceptionFactory.newSpannerException(ErrorCode.NOT_FOUND, "Session not found"); + List mutations = Arrays.asList(Mutation.newInsertBuilder("FOO").build()); + SessionImpl closedSession = mockSession(); + when(closedSession.writeAtLeastOnce(mutations)).thenThrow(sessionNotFound); + + SessionImpl openSession = mockSession(); + when(openSession.writeAtLeastOnce(mutations)).thenReturn(Timestamp.now()); + + when(client.createSession(db)).thenReturn(closedSession, openSession); + FakeClock clock = new FakeClock(); + clock.currentTimeMillis = System.currentTimeMillis(); + pool = createPool(clock); + DatabaseClientImpl impl = new DatabaseClientImpl(pool); + assertThat(impl.writeAtLeastOnce(mutations)).isNotNull(); + } + + @Test + public void testSessionNotFoundPartitionedUpdate() { + SpannerException sessionNotFound = + SpannerExceptionFactory.newSpannerException(ErrorCode.NOT_FOUND, "Session not found"); + Statement statement = Statement.of("UPDATE FOO SET BAR=1 WHERE 1=1"); + SessionImpl closedSession = mockSession(); + when(closedSession.executePartitionedUpdate(statement)).thenThrow(sessionNotFound); + + SessionImpl openSession = mockSession(); + when(openSession.executePartitionedUpdate(statement)).thenReturn(1L); + + when(client.createSession(db)).thenReturn(closedSession, openSession); + FakeClock clock = new FakeClock(); + clock.currentTimeMillis = System.currentTimeMillis(); + pool = createPool(clock); + DatabaseClientImpl impl = new DatabaseClientImpl(pool); + assertThat(impl.executePartitionedUpdate(statement)).isEqualTo(1L); + } + private void mockKeepAlive(Session session) { ReadContext context = mock(ReadContext.class); ResultSet resultSet = mock(ResultSet.class); diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITClosedSessionTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITClosedSessionTest.java new file mode 100644 index 000000000000..043430a2cc4e --- /dev/null +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITClosedSessionTest.java @@ -0,0 +1,273 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.it; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.cloud.spanner.AbortedException; +import com.google.cloud.spanner.Database; +import com.google.cloud.spanner.IntegrationTest; +import com.google.cloud.spanner.IntegrationTestWithClosedSessionsEnv; +import com.google.cloud.spanner.IntegrationTestWithClosedSessionsEnv.DatabaseClientWithClosedSessionImpl; +import com.google.cloud.spanner.ReadOnlyTransaction; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.SessionNotFoundException; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.TimestampBound; +import com.google.cloud.spanner.TransactionContext; +import com.google.cloud.spanner.TransactionManager; +import com.google.cloud.spanner.TransactionRunner; +import com.google.cloud.spanner.TransactionRunner.TransactionCallable; +import java.util.concurrent.TimeUnit; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test the automatic re-creation of sessions that have been invalidated by the server. */ +@Category(IntegrationTest.class) +@RunWith(JUnit4.class) +public class ITClosedSessionTest { + // Run each test case twice to ensure that a retried session does not affect subsequent + // transactions. + private static final int RUNS_PER_TEST_CASE = 2; + + @ClassRule + public static IntegrationTestWithClosedSessionsEnv env = + new IntegrationTestWithClosedSessionsEnv(); + + private static Database db; + @Rule public ExpectedException expectedException = ExpectedException.none(); + private static DatabaseClientWithClosedSessionImpl client; + + @BeforeClass + public static void setUpDatabase() { + // Empty database. + db = env.getTestHelper().createTestDatabase(); + client = (DatabaseClientWithClosedSessionImpl) env.getTestHelper().getDatabaseClient(db); + } + + @Before + public void setup() { + client.setAllowSessionReplacing(true); + } + + @Test + public void testSingleUse() { + // This should trigger an exception with code NOT_FOUND and the text 'Session not found'. + client.invalidateNextSession(); + for (int run = 0; run < RUNS_PER_TEST_CASE; run++) { + try (ResultSet rs = Statement.of("SELECT 1").executeQuery(client.singleUse())) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + } + + @Test + public void testSingleUseNoRecreation() { + // This should trigger an exception with code NOT_FOUND and the text 'Session not found'. + client.setAllowSessionReplacing(false); + client.invalidateNextSession(); + expectedException.expect(SessionNotFoundException.class); + try (ResultSet rs = Statement.of("SELECT 1").executeQuery(client.singleUse())) { + rs.next(); + } + } + + @Test + public void testSingleUseBound() { + // This should trigger an exception with code NOT_FOUND and the text 'Session not found'. + client.invalidateNextSession(); + for (int run = 0; run < RUNS_PER_TEST_CASE; run++) { + try (ResultSet rs = + Statement.of("SELECT 1") + .executeQuery( + client.singleUse(TimestampBound.ofExactStaleness(10L, TimeUnit.SECONDS)))) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + } + + @Test + public void testSingleUseReadOnlyTransaction() { + client.invalidateNextSession(); + for (int run = 0; run < RUNS_PER_TEST_CASE; run++) { + try (ReadOnlyTransaction txn = client.singleUseReadOnlyTransaction()) { + try (ResultSet rs = txn.executeQuery(Statement.of("SELECT 1"))) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + assertThat(txn.getReadTimestamp()).isNotNull(); + } + } + } + + @Test + public void testSingleUseReadOnlyTransactionBound() { + client.invalidateNextSession(); + for (int run = 0; run < RUNS_PER_TEST_CASE; run++) { + try (ReadOnlyTransaction txn = + client.singleUseReadOnlyTransaction( + TimestampBound.ofMaxStaleness(10L, TimeUnit.SECONDS))) { + try (ResultSet rs = txn.executeQuery(Statement.of("SELECT 1"))) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + assertThat(txn.getReadTimestamp()).isNotNull(); + } + } + } + + @Test + public void testReadOnlyTransaction() { + client.invalidateNextSession(); + for (int run = 0; run < RUNS_PER_TEST_CASE; run++) { + try (ReadOnlyTransaction txn = client.readOnlyTransaction()) { + for (int i = 0; i < 2; i++) { + try (ResultSet rs = txn.executeQuery(Statement.of("SELECT 1"))) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + assertThat(txn.getReadTimestamp()).isNotNull(); + } + } + } + + @Test + public void testReadOnlyTransactionNoRecreation() { + client.setAllowSessionReplacing(false); + client.invalidateNextSession(); + expectedException.expect(SessionNotFoundException.class); + try (ReadOnlyTransaction txn = client.readOnlyTransaction()) { + try (ResultSet rs = txn.executeQuery(Statement.of("SELECT 1"))) { + rs.next(); + } + } + } + + @Test + public void testReadOnlyTransactionBound() { + client.invalidateNextSession(); + for (int run = 0; run < RUNS_PER_TEST_CASE; run++) { + try (ReadOnlyTransaction txn = + client.readOnlyTransaction(TimestampBound.ofExactStaleness(10L, TimeUnit.SECONDS))) { + for (int i = 0; i < 2; i++) { + try (ResultSet rs = txn.executeQuery(Statement.of("SELECT 1"))) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + assertThat(txn.getReadTimestamp()).isNotNull(); + } + } + } + + @Test + public void testReadWriteTransaction() { + client.invalidateNextSession(); + for (int run = 0; run < RUNS_PER_TEST_CASE; run++) { + TransactionRunner txn = client.readWriteTransaction(); + txn.run( + new TransactionCallable() { + @Override + public Void run(TransactionContext transaction) throws Exception { + for (int i = 0; i < 2; i++) { + try (ResultSet rs = transaction.executeQuery(Statement.of("SELECT 1"))) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + return null; + } + }); + } + } + + @Test + public void testReadWriteTransactionNoRecreation() { + client.setAllowSessionReplacing(false); + client.invalidateNextSession(); + expectedException.expect(SessionNotFoundException.class); + TransactionRunner txn = client.readWriteTransaction(); + txn.run( + new TransactionCallable() { + @Override + public Void run(TransactionContext transaction) throws Exception { + try (ResultSet rs = transaction.executeQuery(Statement.of("SELECT 1"))) { + rs.next(); + } + return null; + } + }); + } + + @Test + public void testTransactionManager() throws InterruptedException { + client.invalidateNextSession(); + for (int run = 0; run < 2; run++) { + try (TransactionManager manager = client.transactionManager()) { + TransactionContext txn = manager.begin(); + while (true) { + for (int i = 0; i < 2; i++) { + try (ResultSet rs = txn.executeQuery(Statement.of("SELECT 1"))) { + assertThat(rs.next()).isTrue(); + assertThat(rs.getLong(0)).isEqualTo(1L); + assertThat(rs.next()).isFalse(); + } + } + try { + manager.commit(); + break; + } catch (AbortedException e) { + Thread.sleep(e.getRetryDelayInMillis() / 1000); + txn = manager.resetForRetry(); + } + } + } + } + } + + @Test + public void testTransactionManagerNoRecreation() throws InterruptedException { + client.setAllowSessionReplacing(false); + client.invalidateNextSession(); + expectedException.expect(SessionNotFoundException.class); + try (TransactionManager manager = client.transactionManager()) { + TransactionContext txn = manager.begin(); + while (true) { + try (ResultSet rs = txn.executeQuery(Statement.of("SELECT 1"))) { + rs.next(); + } + } + } + } +}