diff --git a/README.md b/README.md index 288afeef263..56d668b0a65 100644 --- a/README.md +++ b/README.md @@ -57,13 +57,13 @@ implementation 'com.google.cloud:google-cloud-spanner' If you are using Gradle without BOM, add this to your dependencies: ```Groovy -implementation 'com.google.cloud:google-cloud-spanner:6.57.0' +implementation 'com.google.cloud:google-cloud-spanner:6.58.0' ``` If you are using SBT, add this to your dependencies: ```Scala -libraryDependencies += "com.google.cloud" % "google-cloud-spanner" % "6.57.0" +libraryDependencies += "com.google.cloud" % "google-cloud-spanner" % "6.58.0" ``` @@ -444,7 +444,7 @@ Java is a registered trademark of Oracle and/or its affiliates. [kokoro-badge-link-5]: http://storage.googleapis.com/cloud-devrel-public/java/badges/java-spanner/java11.html [stability-image]: https://img.shields.io/badge/stability-stable-green [maven-version-image]: https://img.shields.io/maven-central/v/com.google.cloud/google-cloud-spanner.svg -[maven-version-link]: https://central.sonatype.com/artifact/com.google.cloud/google-cloud-spanner/6.57.0 +[maven-version-link]: https://central.sonatype.com/artifact/com.google.cloud/google-cloud-spanner/6.58.0 [authentication]: https://github.com/googleapis/google-cloud-java#authentication [auth-scopes]: https://developers.google.com/identity/protocols/oauth2/scopes [predefined-iam-roles]: https://cloud.google.com/iam/docs/understanding-roles#predefined_roles diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 0714b7651cc..ae18a1e4f5f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -28,9 +28,6 @@ import com.google.api.gax.core.ExecutorProvider; import com.google.cloud.Timestamp; import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; -import com.google.cloud.spanner.AbstractResultSet.GrpcResultSet; -import com.google.cloud.spanner.AbstractResultSet.GrpcStreamIterator; -import com.google.cloud.spanner.AbstractResultSet.ResumableStreamIterator; import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; import com.google.cloud.spanner.Options.QueryOption; @@ -73,6 +70,7 @@ abstract static class Builder, T extends AbstractReadCon private TraceWrapper tracer; private int defaultPrefetchChunks = SpannerOptions.Builder.DEFAULT_PREFETCH_CHUNKS; private QueryOptions defaultQueryOptions = SpannerOptions.Builder.DEFAULT_QUERY_OPTIONS; + private DecodeMode defaultDecodeMode = SpannerOptions.Builder.DEFAULT_DECODE_MODE; private DirectedReadOptions defaultDirectedReadOption; private ExecutorProvider executorProvider; private Clock clock = new Clock(); @@ -114,6 +112,11 @@ B setDefaultQueryOptions(QueryOptions defaultQueryOptions) { return self(); } + B setDefaultDecodeMode(DecodeMode defaultDecodeMode) { + this.defaultDecodeMode = defaultDecodeMode; + return self(); + } + B setExecutorProvider(ExecutorProvider executorProvider) { this.executorProvider = executorProvider; return self(); @@ -414,8 +417,8 @@ void initTransaction() { TraceWrapper tracer; private final int defaultPrefetchChunks; private final QueryOptions defaultQueryOptions; - private final DirectedReadOptions defaultDirectedReadOptions; + private final DecodeMode defaultDecodeMode; private final Clock clock; @GuardedBy("lock") @@ -441,6 +444,7 @@ void initTransaction() { this.defaultPrefetchChunks = builder.defaultPrefetchChunks; this.defaultQueryOptions = builder.defaultQueryOptions; this.defaultDirectedReadOptions = builder.defaultDirectedReadOption; + this.defaultDecodeMode = builder.defaultDecodeMode; this.span = builder.span; this.executorProvider = builder.executorProvider; this.clock = builder.clock; @@ -730,7 +734,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken return stream; } }; - return new GrpcResultSet(stream, this); + return new GrpcResultSet( + stream, this, options.hasDecodeMode() ? options.decodeMode() : defaultDecodeMode); } /** @@ -874,7 +879,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken return stream; } }; - return new GrpcResultSet(stream, this); + return new GrpcResultSet( + stream, this, readOptions.hasDecodeMode() ? readOptions.decodeMode() : defaultDecodeMode); } private Struct consumeSingleRow(ResultSet resultSet) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java index 4904a382f92..6cce03e72cb 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java @@ -17,72 +17,32 @@ package com.google.cloud.spanner; import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; -import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerExceptionForCancellation; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; - -import com.google.api.client.util.BackOff; -import com.google.api.client.util.ExponentialBackOff; -import com.google.api.gax.grpc.GrpcStatusCode; -import com.google.api.gax.retrying.RetrySettings; -import com.google.api.gax.rpc.ApiCallContext; -import com.google.api.gax.rpc.StatusCode.Code; + import com.google.cloud.ByteArray; import com.google.cloud.Date; import com.google.cloud.Timestamp; -import com.google.cloud.spanner.Type.StructField; -import com.google.cloud.spanner.spi.v1.SpannerRpc; -import com.google.cloud.spanner.v1.stub.SpannerStubSettings; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import com.google.common.collect.AbstractIterator; -import com.google.common.collect.Lists; -import com.google.common.io.CharSource; -import com.google.common.util.concurrent.Uninterruptibles; import com.google.protobuf.AbstractMessage; -import com.google.protobuf.ByteString; import com.google.protobuf.ListValue; -import com.google.protobuf.NullValue; import com.google.protobuf.ProtocolMessageEnum; import com.google.protobuf.Value.KindCase; -import com.google.spanner.v1.PartialResultSet; -import com.google.spanner.v1.ResultSetMetadata; -import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.Transaction; -import com.google.spanner.v1.TypeCode; -import io.grpc.Context; import java.io.IOException; import java.io.Serializable; import java.math.BigDecimal; -import java.nio.charset.StandardCharsets; import java.util.AbstractList; -import java.util.ArrayList; import java.util.Base64; import java.util.BitSet; -import java.util.Collections; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; import java.util.Objects; -import java.util.Set; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; import java.util.function.Function; -import java.util.logging.Level; -import java.util.logging.Logger; -import java.util.stream.Collectors; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import org.threeten.bp.Duration; /** Implementation of {@link ResultSet}. */ abstract class AbstractResultSet extends AbstractStructReader implements ResultSet { - private static final com.google.protobuf.Value NULL_VALUE = - com.google.protobuf.Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build(); interface Listener { /** @@ -99,271 +59,6 @@ void onTransactionMetadata(Transaction transaction, boolean shouldIncludeId) void onDone(boolean withBeginTransaction); } - @VisibleForTesting - static class GrpcResultSet extends AbstractResultSet> { - private final GrpcValueIterator iterator; - private final Listener listener; - private ResultSetMetadata metadata; - private GrpcStruct currRow; - private SpannerException error; - private ResultSetStats statistics; - private boolean closed; - - GrpcResultSet(CloseableIterator iterator, Listener listener) { - this.iterator = new GrpcValueIterator(iterator); - this.listener = listener; - } - - @Override - protected GrpcStruct currRow() { - checkState(!closed, "ResultSet is closed"); - checkState(currRow != null, "next() call required"); - return currRow; - } - - @Override - public boolean next() throws SpannerException { - if (error != null) { - throw newSpannerException(error); - } - try { - if (currRow == null) { - metadata = iterator.getMetadata(); - if (metadata.hasTransaction()) { - listener.onTransactionMetadata( - metadata.getTransaction(), iterator.isWithBeginTransaction()); - } else if (iterator.isWithBeginTransaction()) { - // The query should have returned a transaction. - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.FAILED_PRECONDITION, AbstractReadContext.NO_TRANSACTION_RETURNED_MSG); - } - currRow = new GrpcStruct(iterator.type(), new ArrayList<>()); - } - boolean hasNext = currRow.consumeRow(iterator); - if (!hasNext) { - statistics = iterator.getStats(); - } - return hasNext; - } catch (Throwable t) { - throw yieldError( - SpannerExceptionFactory.asSpannerException(t), - iterator.isWithBeginTransaction() && currRow == null); - } - } - - @Override - @Nullable - public ResultSetStats getStats() { - return statistics; - } - - @Override - public ResultSetMetadata getMetadata() { - checkState(metadata != null, "next() call required"); - return metadata; - } - - @Override - public void close() { - listener.onDone(iterator.isWithBeginTransaction()); - iterator.close("ResultSet closed"); - closed = true; - } - - @Override - public Type getType() { - checkState(currRow != null, "next() call required"); - return currRow.getType(); - } - - private SpannerException yieldError(SpannerException e, boolean beginTransaction) { - SpannerException toThrow = listener.onError(e, beginTransaction); - close(); - throw toThrow; - } - } - /** - * Adapts a stream of {@code PartialResultSet} messages into a stream of {@code Value} messages. - */ - private static class GrpcValueIterator extends AbstractIterator { - private enum StreamValue { - METADATA, - RESULT, - } - - private final CloseableIterator stream; - private ResultSetMetadata metadata; - private Type type; - private PartialResultSet current; - private int pos; - private ResultSetStats statistics; - - GrpcValueIterator(CloseableIterator stream) { - this.stream = stream; - } - - @SuppressWarnings("unchecked") - @Override - protected com.google.protobuf.Value computeNext() { - if (!ensureReady(StreamValue.RESULT)) { - endOfData(); - return null; - } - com.google.protobuf.Value value = current.getValues(pos++); - KindCase kind = value.getKindCase(); - - if (!isMergeable(kind)) { - if (pos == current.getValuesCount() && current.getChunkedValue()) { - throw newSpannerException(ErrorCode.INTERNAL, "Unexpected chunked PartialResultSet."); - } else { - return value; - } - } - if (!current.getChunkedValue() || pos != current.getValuesCount()) { - return value; - } - - Object merged = - kind == KindCase.STRING_VALUE - ? value.getStringValue() - : new ArrayList<>(value.getListValue().getValuesList()); - while (current.getChunkedValue() && pos == current.getValuesCount()) { - if (!ensureReady(StreamValue.RESULT)) { - throw newSpannerException( - ErrorCode.INTERNAL, "Stream closed in the middle of chunked value"); - } - com.google.protobuf.Value newValue = current.getValues(pos++); - if (newValue.getKindCase() != kind) { - throw newSpannerException( - ErrorCode.INTERNAL, - "Unexpected type in middle of chunked value. Expected: " - + kind - + " but got: " - + newValue.getKindCase()); - } - if (kind == KindCase.STRING_VALUE) { - merged = merged + newValue.getStringValue(); - } else { - concatLists( - (List) merged, newValue.getListValue().getValuesList()); - } - } - if (kind == KindCase.STRING_VALUE) { - return com.google.protobuf.Value.newBuilder().setStringValue((String) merged).build(); - } else { - return com.google.protobuf.Value.newBuilder() - .setListValue( - ListValue.newBuilder().addAllValues((List) merged)) - .build(); - } - } - - ResultSetMetadata getMetadata() throws SpannerException { - if (metadata == null) { - if (!ensureReady(StreamValue.METADATA)) { - throw newSpannerException(ErrorCode.INTERNAL, "Stream closed without sending metadata"); - } - } - return metadata; - } - - /** - * Get the query statistics. Query statistics are delivered with the last PartialResultSet in - * the stream. Any attempt to call this method before the caller has finished consuming the - * results will return null. - */ - @Nullable - ResultSetStats getStats() { - return statistics; - } - - Type type() { - checkState(type != null, "metadata has not been received"); - return type; - } - - private boolean ensureReady(StreamValue requiredValue) throws SpannerException { - while (current == null || pos >= current.getValuesCount()) { - if (!stream.hasNext()) { - return false; - } - current = stream.next(); - pos = 0; - if (type == null) { - // This is the first message on the stream. - if (!current.hasMetadata() || !current.getMetadata().hasRowType()) { - throw newSpannerException(ErrorCode.INTERNAL, "Missing type metadata in first message"); - } - metadata = current.getMetadata(); - com.google.spanner.v1.Type typeProto = - com.google.spanner.v1.Type.newBuilder() - .setCode(TypeCode.STRUCT) - .setStructType(metadata.getRowType()) - .build(); - try { - type = Type.fromProto(typeProto); - } catch (IllegalArgumentException e) { - throw newSpannerException( - ErrorCode.INTERNAL, "Invalid type metadata: " + e.getMessage(), e); - } - } - if (current.hasStats()) { - statistics = current.getStats(); - } - if (requiredValue == StreamValue.METADATA) { - return true; - } - } - return true; - } - - void close(@Nullable String message) { - stream.close(message); - } - - boolean isWithBeginTransaction() { - return stream.isWithBeginTransaction(); - } - - /** @param a is a mutable list and b will be concatenated into a. */ - private void concatLists(List a, List b) { - if (a.size() == 0 || b.size() == 0) { - a.addAll(b); - return; - } else { - com.google.protobuf.Value last = a.get(a.size() - 1); - com.google.protobuf.Value first = b.get(0); - KindCase lastKind = last.getKindCase(); - KindCase firstKind = first.getKindCase(); - if (isMergeable(lastKind) && lastKind == firstKind) { - com.google.protobuf.Value merged; - if (lastKind == KindCase.STRING_VALUE) { - String lastStr = last.getStringValue(); - String firstStr = first.getStringValue(); - merged = - com.google.protobuf.Value.newBuilder().setStringValue(lastStr + firstStr).build(); - } else { // List - List mergedList = new ArrayList<>(); - mergedList.addAll(last.getListValue().getValuesList()); - concatLists(mergedList, first.getListValue().getValuesList()); - merged = - com.google.protobuf.Value.newBuilder() - .setListValue(ListValue.newBuilder().addAllValues(mergedList)) - .build(); - } - a.set(a.size() - 1, merged); - a.addAll(b.subList(1, b.size())); - } else { - a.addAll(b); - } - } - } - - private boolean isMergeable(KindCase kind) { - return kind == KindCase.STRING_VALUE || kind == KindCase.LIST_VALUE; - } - } - static final class LazyByteArray implements Serializable { private static final Base64.Encoder ENCODER = Base64.getEncoder(); private static final Base64.Decoder DECODER = Base64.getDecoder(); @@ -437,603 +132,6 @@ private boolean lazyByteArraysEqual(LazyByteArray other) { } } - static class GrpcStruct extends Struct implements Serializable { - private final Type type; - private final List rowData; - - /** - * Builds an immutable version of this struct using {@link Struct#newBuilder()} which is used as - * a serialization proxy. - */ - private Object writeReplace() { - Builder builder = Struct.newBuilder(); - List structFields = getType().getStructFields(); - for (int i = 0; i < structFields.size(); i++) { - Type.StructField field = structFields.get(i); - String fieldName = field.getName(); - Object value = rowData.get(i); - Type fieldType = field.getType(); - switch (fieldType.getCode()) { - case BOOL: - builder.set(fieldName).to((Boolean) value); - break; - case INT64: - builder.set(fieldName).to((Long) value); - break; - case FLOAT64: - builder.set(fieldName).to((Double) value); - break; - case NUMERIC: - builder.set(fieldName).to((BigDecimal) value); - break; - case PG_NUMERIC: - builder.set(fieldName).to((String) value); - break; - case STRING: - builder.set(fieldName).to((String) value); - break; - case JSON: - builder.set(fieldName).to(Value.json((String) value)); - break; - case PROTO: - builder - .set(fieldName) - .to( - Value.protoMessage( - value == null ? null : ((LazyByteArray) value).getByteArray(), - fieldType.getProtoTypeFqn())); - break; - case ENUM: - builder.set(fieldName).to(Value.protoEnum((Long) value, fieldType.getProtoTypeFqn())); - break; - case PG_JSONB: - builder.set(fieldName).to(Value.pgJsonb((String) value)); - break; - case BYTES: - builder - .set(fieldName) - .to( - Value.bytesFromBase64( - value == null ? null : ((LazyByteArray) value).getBase64String())); - break; - case TIMESTAMP: - builder.set(fieldName).to((Timestamp) value); - break; - case DATE: - builder.set(fieldName).to((Date) value); - break; - case ARRAY: - final Type elementType = fieldType.getArrayElementType(); - switch (elementType.getCode()) { - case BOOL: - builder.set(fieldName).toBoolArray((Iterable) value); - break; - case INT64: - case ENUM: - builder.set(fieldName).toInt64Array((Iterable) value); - break; - case FLOAT64: - builder.set(fieldName).toFloat64Array((Iterable) value); - break; - case NUMERIC: - builder.set(fieldName).toNumericArray((Iterable) value); - break; - case PG_NUMERIC: - builder.set(fieldName).toPgNumericArray((Iterable) value); - break; - case STRING: - builder.set(fieldName).toStringArray((Iterable) value); - break; - case JSON: - builder.set(fieldName).toJsonArray((Iterable) value); - break; - case PG_JSONB: - builder.set(fieldName).toPgJsonbArray((Iterable) value); - break; - case BYTES: - case PROTO: - builder - .set(fieldName) - .toBytesArrayFromBase64( - value == null - ? null - : ((List) value) - .stream() - .map( - element -> - element == null ? null : element.getBase64String()) - .collect(Collectors.toList())); - break; - case TIMESTAMP: - builder.set(fieldName).toTimestampArray((Iterable) value); - break; - case DATE: - builder.set(fieldName).toDateArray((Iterable) value); - break; - case STRUCT: - builder.set(fieldName).toStructArray(elementType, (Iterable) value); - break; - default: - throw new AssertionError("Unhandled array type code: " + elementType); - } - break; - case STRUCT: - if (value == null) { - builder.set(fieldName).to(fieldType, null); - } else { - builder.set(fieldName).to((Struct) value); - } - break; - default: - throw new AssertionError("Unhandled type code: " + fieldType.getCode()); - } - } - return builder.build(); - } - - GrpcStruct(Type type, List rowData) { - this.type = type; - this.rowData = rowData; - } - - @Override - public String toString() { - return this.rowData.toString(); - } - - boolean consumeRow(Iterator iterator) { - rowData.clear(); - if (!iterator.hasNext()) { - return false; - } - for (Type.StructField fieldType : getType().getStructFields()) { - if (!iterator.hasNext()) { - throw newSpannerException( - ErrorCode.INTERNAL, - "Invalid value stream: end of stream reached before row is complete"); - } - com.google.protobuf.Value value = iterator.next(); - rowData.add(decodeValue(fieldType.getType(), value)); - } - return true; - } - - private static Object decodeValue(Type fieldType, com.google.protobuf.Value proto) { - if (proto.getKindCase() == KindCase.NULL_VALUE) { - return null; - } - switch (fieldType.getCode()) { - case BOOL: - checkType(fieldType, proto, KindCase.BOOL_VALUE); - return proto.getBoolValue(); - case INT64: - case ENUM: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return Long.parseLong(proto.getStringValue()); - case FLOAT64: - return valueProtoToFloat64(proto); - case NUMERIC: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return new BigDecimal(proto.getStringValue()); - case PG_NUMERIC: - case STRING: - case JSON: - case PG_JSONB: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return proto.getStringValue(); - case BYTES: - case PROTO: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return new LazyByteArray(proto.getStringValue()); - case TIMESTAMP: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return Timestamp.parseTimestamp(proto.getStringValue()); - case DATE: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return Date.parseDate(proto.getStringValue()); - case ARRAY: - checkType(fieldType, proto, KindCase.LIST_VALUE); - ListValue listValue = proto.getListValue(); - return decodeArrayValue(fieldType.getArrayElementType(), listValue); - case STRUCT: - checkType(fieldType, proto, KindCase.LIST_VALUE); - ListValue structValue = proto.getListValue(); - return decodeStructValue(fieldType, structValue); - case UNRECOGNIZED: - return proto; - default: - throw new AssertionError("Unhandled type code: " + fieldType.getCode()); - } - } - - private static Struct decodeStructValue(Type structType, ListValue structValue) { - List fieldTypes = structType.getStructFields(); - checkArgument( - structValue.getValuesCount() == fieldTypes.size(), - "Size mismatch between type descriptor and actual values."); - List fields = new ArrayList<>(fieldTypes.size()); - List fieldValues = structValue.getValuesList(); - for (int i = 0; i < fieldTypes.size(); ++i) { - fields.add(decodeValue(fieldTypes.get(i).getType(), fieldValues.get(i))); - } - return new GrpcStruct(structType, fields); - } - - static Object decodeArrayValue(Type elementType, ListValue listValue) { - switch (elementType.getCode()) { - case INT64: - case ENUM: - // For int64/float64/enum types, use custom containers. These avoid wrapper object - // creation for non-null arrays. - return new Int64Array(listValue); - case FLOAT64: - return new Float64Array(listValue); - case BOOL: - case NUMERIC: - case PG_NUMERIC: - case STRING: - case JSON: - case PG_JSONB: - case BYTES: - case TIMESTAMP: - case DATE: - case STRUCT: - case PROTO: - return Lists.transform( - listValue.getValuesList(), input -> decodeValue(elementType, input)); - default: - throw new AssertionError("Unhandled type code: " + elementType.getCode()); - } - } - - private static void checkType( - Type fieldType, com.google.protobuf.Value proto, KindCase expected) { - if (proto.getKindCase() != expected) { - throw newSpannerException( - ErrorCode.INTERNAL, - "Invalid value for column type " - + fieldType - + " expected " - + expected - + " but was " - + proto.getKindCase()); - } - } - - Struct immutableCopy() { - return new GrpcStruct(type, new ArrayList<>(rowData)); - } - - @Override - public Type getType() { - return type; - } - - @Override - public boolean isNull(int columnIndex) { - return rowData.get(columnIndex) == null; - } - - @Override - protected T getProtoMessageInternal(int columnIndex, T message) { - Preconditions.checkNotNull( - message, - "Proto message may not be null. Use MyProtoClass.getDefaultInstance() as a parameter value."); - try { - return (T) - message - .toBuilder() - .mergeFrom( - Base64.getDecoder() - .wrap( - CharSource.wrap(((LazyByteArray) rowData.get(columnIndex)).base64String) - .asByteSource(StandardCharsets.UTF_8) - .openStream())) - .build(); - } catch (IOException ioException) { - throw SpannerExceptionFactory.asSpannerException(ioException); - } - } - - @Override - protected T getProtoEnumInternal( - int columnIndex, Function method) { - Preconditions.checkNotNull( - method, "Method may not be null. Use 'MyProtoEnum::forNumber' as a parameter value."); - return (T) method.apply((int) getLongInternal(columnIndex)); - } - - @Override - protected boolean getBooleanInternal(int columnIndex) { - return (Boolean) rowData.get(columnIndex); - } - - @Override - protected long getLongInternal(int columnIndex) { - return (Long) rowData.get(columnIndex); - } - - @Override - protected double getDoubleInternal(int columnIndex) { - return (Double) rowData.get(columnIndex); - } - - @Override - protected BigDecimal getBigDecimalInternal(int columnIndex) { - return (BigDecimal) rowData.get(columnIndex); - } - - @Override - protected String getStringInternal(int columnIndex) { - return (String) rowData.get(columnIndex); - } - - @Override - protected String getJsonInternal(int columnIndex) { - return (String) rowData.get(columnIndex); - } - - @Override - protected String getPgJsonbInternal(int columnIndex) { - return (String) rowData.get(columnIndex); - } - - @Override - protected ByteArray getBytesInternal(int columnIndex) { - return getLazyBytesInternal(columnIndex).getByteArray(); - } - - LazyByteArray getLazyBytesInternal(int columnIndex) { - return (LazyByteArray) rowData.get(columnIndex); - } - - @Override - protected Timestamp getTimestampInternal(int columnIndex) { - return (Timestamp) rowData.get(columnIndex); - } - - @Override - protected Date getDateInternal(int columnIndex) { - return (Date) rowData.get(columnIndex); - } - - protected com.google.protobuf.Value getProtoValueInternal(int columnIndex) { - return (com.google.protobuf.Value) rowData.get(columnIndex); - } - - @Override - protected Value getValueInternal(int columnIndex) { - final List structFields = getType().getStructFields(); - final StructField structField = structFields.get(columnIndex); - final Type columnType = structField.getType(); - final boolean isNull = rowData.get(columnIndex) == null; - switch (columnType.getCode()) { - case BOOL: - return Value.bool(isNull ? null : getBooleanInternal(columnIndex)); - case INT64: - return Value.int64(isNull ? null : getLongInternal(columnIndex)); - case ENUM: - return Value.protoEnum( - isNull ? null : getLongInternal(columnIndex), columnType.getProtoTypeFqn()); - case NUMERIC: - return Value.numeric(isNull ? null : getBigDecimalInternal(columnIndex)); - case PG_NUMERIC: - return Value.pgNumeric(isNull ? null : getStringInternal(columnIndex)); - case FLOAT64: - return Value.float64(isNull ? null : getDoubleInternal(columnIndex)); - case STRING: - return Value.string(isNull ? null : getStringInternal(columnIndex)); - case JSON: - return Value.json(isNull ? null : getJsonInternal(columnIndex)); - case PG_JSONB: - return Value.pgJsonb(isNull ? null : getPgJsonbInternal(columnIndex)); - case BYTES: - return Value.internalBytes(isNull ? null : getLazyBytesInternal(columnIndex)); - case PROTO: - return Value.protoMessage( - isNull ? null : getBytesInternal(columnIndex), columnType.getProtoTypeFqn()); - case TIMESTAMP: - return Value.timestamp(isNull ? null : getTimestampInternal(columnIndex)); - case DATE: - return Value.date(isNull ? null : getDateInternal(columnIndex)); - case STRUCT: - return Value.struct(isNull ? null : getStructInternal(columnIndex)); - case UNRECOGNIZED: - return Value.unrecognized( - isNull ? NULL_VALUE : getProtoValueInternal(columnIndex), columnType); - case ARRAY: - final Type elementType = columnType.getArrayElementType(); - switch (elementType.getCode()) { - case BOOL: - return Value.boolArray(isNull ? null : getBooleanListInternal(columnIndex)); - case INT64: - return Value.int64Array(isNull ? null : getLongListInternal(columnIndex)); - case NUMERIC: - return Value.numericArray(isNull ? null : getBigDecimalListInternal(columnIndex)); - case PG_NUMERIC: - return Value.pgNumericArray(isNull ? null : getStringListInternal(columnIndex)); - case FLOAT64: - return Value.float64Array(isNull ? null : getDoubleListInternal(columnIndex)); - case STRING: - return Value.stringArray(isNull ? null : getStringListInternal(columnIndex)); - case JSON: - return Value.jsonArray(isNull ? null : getJsonListInternal(columnIndex)); - case PG_JSONB: - return Value.pgJsonbArray(isNull ? null : getPgJsonbListInternal(columnIndex)); - case BYTES: - return Value.bytesArray(isNull ? null : getBytesListInternal(columnIndex)); - case PROTO: - return Value.protoMessageArray( - isNull ? null : getBytesListInternal(columnIndex), elementType.getProtoTypeFqn()); - case ENUM: - return Value.protoEnumArray( - isNull ? null : getLongListInternal(columnIndex), elementType.getProtoTypeFqn()); - case TIMESTAMP: - return Value.timestampArray(isNull ? null : getTimestampListInternal(columnIndex)); - case DATE: - return Value.dateArray(isNull ? null : getDateListInternal(columnIndex)); - case STRUCT: - return Value.structArray( - elementType, isNull ? null : getStructListInternal(columnIndex)); - default: - throw new IllegalArgumentException( - "Invalid array value type " + this.type.getArrayElementType()); - } - default: - throw new IllegalArgumentException("Invalid value type " + this.type); - } - } - - @Override - protected Struct getStructInternal(int columnIndex) { - return (Struct) rowData.get(columnIndex); - } - - @Override - protected boolean[] getBooleanArrayInternal(int columnIndex) { - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - List values = (List) rowData.get(columnIndex); - boolean[] r = new boolean[values.size()]; - for (int i = 0; i < values.size(); ++i) { - if (values.get(i) == null) { - throw throwNotNull(columnIndex); - } - r[i] = values.get(i); - } - return r; - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getBooleanListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - protected long[] getLongArrayInternal(int columnIndex) { - return getLongListInternal(columnIndex).toPrimitiveArray(columnIndex); - } - - @Override - protected Int64Array getLongListInternal(int columnIndex) { - return (Int64Array) rowData.get(columnIndex); - } - - @Override - protected double[] getDoubleArrayInternal(int columnIndex) { - return getDoubleListInternal(columnIndex).toPrimitiveArray(columnIndex); - } - - @Override - protected Float64Array getDoubleListInternal(int columnIndex) { - return (Float64Array) rowData.get(columnIndex); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getBigDecimalListInternal(int columnIndex) { - return (List) rowData.get(columnIndex); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getStringListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getJsonListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getProtoMessageListInternal( - int columnIndex, T message) { - Preconditions.checkNotNull( - message, - "Proto message may not be null. Use MyProtoClass.getDefaultInstance() as a parameter value."); - - List bytesArray = (List) rowData.get(columnIndex); - - try { - List protoMessagesList = new ArrayList<>(bytesArray.size()); - for (LazyByteArray protoMessageBytes : bytesArray) { - if (protoMessageBytes == null) { - protoMessagesList.add(null); - } else { - protoMessagesList.add( - (T) - message - .toBuilder() - .mergeFrom( - Base64.getDecoder() - .wrap( - CharSource.wrap(protoMessageBytes.base64String) - .asByteSource(StandardCharsets.UTF_8) - .openStream())) - .build()); - } - } - return protoMessagesList; - } catch (IOException ioException) { - throw SpannerExceptionFactory.asSpannerException(ioException); - } - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getProtoEnumListInternal( - int columnIndex, Function method) { - Preconditions.checkNotNull( - method, "Method may not be null. Use 'MyProtoEnum::forNumber' as a parameter value."); - - List enumIntArray = (List) rowData.get(columnIndex); - List protoEnumList = new ArrayList<>(enumIntArray.size()); - for (Long enumIntValue : enumIntArray) { - if (enumIntValue == null) { - protoEnumList.add(null); - } else { - protoEnumList.add((T) method.apply(enumIntValue.intValue())); - } - } - - return protoEnumList; - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getPgJsonbListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getBytesListInternal(int columnIndex) { - return Lists.transform( - (List) rowData.get(columnIndex), l -> l == null ? null : l.getByteArray()); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getTimestampListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getDateListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY> produces a List. - protected List getStructListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - } - @VisibleForTesting interface CloseableIterator extends Iterator { @@ -1047,378 +145,6 @@ interface CloseableIterator extends Iterator { boolean isWithBeginTransaction(); } - /** Adapts a streaming read/query call into an iterator over partial result sets. */ - @VisibleForTesting - static class GrpcStreamIterator extends AbstractIterator - implements CloseableIterator { - private static final Logger logger = Logger.getLogger(GrpcStreamIterator.class.getName()); - private static final PartialResultSet END_OF_STREAM = PartialResultSet.newBuilder().build(); - - private final ConsumerImpl consumer = new ConsumerImpl(); - private final BlockingQueue stream; - private final Statement statement; - - private SpannerRpc.StreamingCall call; - private volatile boolean withBeginTransaction; - private TimeUnit streamWaitTimeoutUnit; - private long streamWaitTimeoutValue; - private SpannerException error; - - @VisibleForTesting - GrpcStreamIterator(int prefetchChunks) { - this(null, prefetchChunks); - } - - @VisibleForTesting - GrpcStreamIterator(Statement statement, int prefetchChunks) { - this.statement = statement; - // One extra to allow for END_OF_STREAM message. - this.stream = new LinkedBlockingQueue<>(prefetchChunks + 1); - } - - protected final SpannerRpc.ResultStreamConsumer consumer() { - return consumer; - } - - public void setCall(SpannerRpc.StreamingCall call, boolean withBeginTransaction) { - this.call = call; - this.withBeginTransaction = withBeginTransaction; - ApiCallContext callContext = call.getCallContext(); - Duration streamWaitTimeout = callContext == null ? null : callContext.getStreamWaitTimeout(); - if (streamWaitTimeout != null) { - // Determine the timeout unit to use. This reduces the precision to seconds if the timeout - // value is more than 1 second, which is lower than the precision that would normally be - // used by the stream watchdog (which uses a precision of 10 seconds by default). - if (streamWaitTimeout.getSeconds() > 0L) { - streamWaitTimeoutValue = streamWaitTimeout.getSeconds(); - streamWaitTimeoutUnit = TimeUnit.SECONDS; - } else if (streamWaitTimeout.getNano() > 0) { - streamWaitTimeoutValue = streamWaitTimeout.getNano(); - streamWaitTimeoutUnit = TimeUnit.NANOSECONDS; - } - // Note that if the stream-wait-timeout is zero, we won't set a timeout at all. - // That is consistent with ApiCallContext#withStreamWaitTimeout(Duration.ZERO). - } - } - - @Override - public void close(@Nullable String message) { - if (call != null) { - call.cancel(message); - } - } - - @Override - public boolean isWithBeginTransaction() { - return withBeginTransaction; - } - - @Override - protected final PartialResultSet computeNext() { - PartialResultSet next; - try { - if (streamWaitTimeoutUnit != null) { - next = stream.poll(streamWaitTimeoutValue, streamWaitTimeoutUnit); - if (next == null) { - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.DEADLINE_EXCEEDED, "stream wait timeout"); - } - } else { - next = stream.take(); - } - } catch (InterruptedException e) { - // Treat interrupt as a request to cancel the read. - throw SpannerExceptionFactory.propagateInterrupt(e); - } - if (next != END_OF_STREAM) { - call.request(1); - return next; - } - - // All done - close() no longer needs to cancel the call. - call = null; - - if (error != null) { - throw SpannerExceptionFactory.newSpannerException(error); - } - - endOfData(); - return null; - } - - private void addToStream(PartialResultSet results) { - // We assume that nothing from the user will interrupt gRPC event threads. - Uninterruptibles.putUninterruptibly(stream, results); - } - - private class ConsumerImpl implements SpannerRpc.ResultStreamConsumer { - @Override - public void onPartialResultSet(PartialResultSet results) { - addToStream(results); - } - - @Override - public void onCompleted() { - addToStream(END_OF_STREAM); - } - - @Override - public void onError(SpannerException e) { - if (statement != null) { - if (logger.isLoggable(Level.FINEST)) { - // Include parameter values if logging level is set to FINEST or higher. - e = - SpannerExceptionFactory.newSpannerExceptionPreformatted( - e.getErrorCode(), - String.format("%s - Statement: '%s'", e.getMessage(), statement.toString()), - e); - logger.log(Level.FINEST, "Error executing statement", e); - } else { - e = - SpannerExceptionFactory.newSpannerExceptionPreformatted( - e.getErrorCode(), - String.format("%s - Statement: '%s'", e.getMessage(), statement.getSql()), - e); - } - } - error = e; - addToStream(END_OF_STREAM); - } - } - } - - /** - * Wraps an iterator over partial result sets, supporting resuming RPCs on error. This class keeps - * track of the most recent resume token seen, and will buffer partial result set chunks that do - * not have a resume token until one is seen or buffer space is exceeded, which reduces the chance - * of yielding data to the caller that cannot be resumed. - */ - @VisibleForTesting - abstract static class ResumableStreamIterator extends AbstractIterator - implements CloseableIterator { - private static final RetrySettings DEFAULT_STREAMING_RETRY_SETTINGS = - SpannerStubSettings.newBuilder().executeStreamingSqlSettings().getRetrySettings(); - private final RetrySettings streamingRetrySettings; - private final Set retryableCodes; - private static final Logger logger = Logger.getLogger(ResumableStreamIterator.class.getName()); - private final BackOff backOff; - private final LinkedList buffer = new LinkedList<>(); - private final int maxBufferSize; - private final ISpan span; - private final TraceWrapper tracer; - private CloseableIterator stream; - private ByteString resumeToken; - private boolean finished; - /** - * Indicates whether it is currently safe to retry RPCs. This will be {@code false} if we have - * reached the maximum buffer size without seeing a restart token; in this case, we will drain - * the buffer and remain in this state until we see a new restart token. - */ - private boolean safeToRetry = true; - - protected ResumableStreamIterator( - int maxBufferSize, - String streamName, - ISpan parent, - TraceWrapper tracer, - RetrySettings streamingRetrySettings, - Set retryableCodes) { - checkArgument(maxBufferSize >= 0); - this.maxBufferSize = maxBufferSize; - this.tracer = tracer; - this.span = tracer.spanBuilderWithExplicitParent(streamName, parent); - this.streamingRetrySettings = Preconditions.checkNotNull(streamingRetrySettings); - this.retryableCodes = Preconditions.checkNotNull(retryableCodes); - this.backOff = newBackOff(); - } - - private ExponentialBackOff newBackOff() { - if (Objects.equals(streamingRetrySettings, DEFAULT_STREAMING_RETRY_SETTINGS)) { - return new ExponentialBackOff.Builder() - .setMultiplier(streamingRetrySettings.getRetryDelayMultiplier()) - .setInitialIntervalMillis( - Math.max(10, (int) streamingRetrySettings.getInitialRetryDelay().toMillis())) - .setMaxIntervalMillis( - Math.max(1000, (int) streamingRetrySettings.getMaxRetryDelay().toMillis())) - .setMaxElapsedTimeMillis( - Integer.MAX_VALUE) // Prevent Backoff.STOP from getting returned. - .build(); - } - return new ExponentialBackOff.Builder() - .setMultiplier(streamingRetrySettings.getRetryDelayMultiplier()) - // All of these values must be > 0. - .setInitialIntervalMillis( - Math.max( - 1, - (int) - Math.min( - streamingRetrySettings.getInitialRetryDelay().toMillis(), - Integer.MAX_VALUE))) - .setMaxIntervalMillis( - Math.max( - 1, - (int) - Math.min( - streamingRetrySettings.getMaxRetryDelay().toMillis(), Integer.MAX_VALUE))) - .setMaxElapsedTimeMillis( - Math.max( - 1, - (int) - Math.min( - streamingRetrySettings.getTotalTimeout().toMillis(), Integer.MAX_VALUE))) - .build(); - } - - private void backoffSleep(Context context, BackOff backoff) throws SpannerException { - backoffSleep(context, nextBackOffMillis(backoff)); - } - - private static long nextBackOffMillis(BackOff backoff) throws SpannerException { - try { - return backoff.nextBackOffMillis(); - } catch (IOException e) { - throw newSpannerException(ErrorCode.INTERNAL, e.getMessage(), e); - } - } - - private void backoffSleep(Context context, long backoffMillis) throws SpannerException { - tracer.getCurrentSpan().addAnnotation("Backing off", "Delay", backoffMillis); - final CountDownLatch latch = new CountDownLatch(1); - final Context.CancellationListener listener = - ignored -> { - // Wakeup on cancellation / DEADLINE_EXCEEDED. - latch.countDown(); - }; - - context.addListener(listener, DirectExecutor.INSTANCE); - try { - if (backoffMillis == BackOff.STOP) { - // Highly unlikely but we handle it just in case. - backoffMillis = streamingRetrySettings.getMaxRetryDelay().toMillis(); - } - if (latch.await(backoffMillis, TimeUnit.MILLISECONDS)) { - // Woken by context cancellation. - throw newSpannerExceptionForCancellation(context, null); - } - } catch (InterruptedException interruptExcept) { - throw newSpannerExceptionForCancellation(context, interruptExcept); - } finally { - context.removeListener(listener); - } - } - - private enum DirectExecutor implements Executor { - INSTANCE; - - @Override - public void execute(Runnable command) { - command.run(); - } - } - - abstract CloseableIterator startStream(@Nullable ByteString resumeToken); - - @Override - public void close(@Nullable String message) { - if (stream != null) { - stream.close(message); - span.end(); - stream = null; - } - } - - @Override - public boolean isWithBeginTransaction() { - return stream != null && stream.isWithBeginTransaction(); - } - - @Override - protected PartialResultSet computeNext() { - Context context = Context.current(); - while (true) { - // Eagerly start stream before consuming any buffered items. - if (stream == null) { - span.addAnnotation( - "Starting/Resuming stream", - "ResumeToken", - resumeToken == null ? "null" : resumeToken.toStringUtf8()); - try (IScope scope = tracer.withSpan(span)) { - // When start a new stream set the Span as current to make the gRPC Span a child of - // this Span. - stream = checkNotNull(startStream(resumeToken)); - } - } - // Buffer contains items up to a resume token or has reached capacity: flush. - if (!buffer.isEmpty() - && (finished || !safeToRetry || !buffer.getLast().getResumeToken().isEmpty())) { - return buffer.pop(); - } - try { - if (stream.hasNext()) { - PartialResultSet next = stream.next(); - boolean hasResumeToken = !next.getResumeToken().isEmpty(); - if (hasResumeToken) { - resumeToken = next.getResumeToken(); - safeToRetry = true; - } - // If the buffer is empty and this chunk has a resume token or we cannot resume safely - // anyway, we can yield it immediately rather than placing it in the buffer to be - // returned on the next iteration. - if ((hasResumeToken || !safeToRetry) && buffer.isEmpty()) { - return next; - } - buffer.add(next); - if (buffer.size() > maxBufferSize && buffer.getLast().getResumeToken().isEmpty()) { - // We need to flush without a restart token. Errors encountered until we see - // such a token will fail the read. - safeToRetry = false; - } - } else { - finished = true; - if (buffer.isEmpty()) { - endOfData(); - return null; - } - } - } catch (SpannerException spannerException) { - if (safeToRetry && isRetryable(spannerException)) { - span.addAnnotation("Stream broken. Safe to retry", spannerException); - logger.log(Level.FINE, "Retryable exception, will sleep and retry", spannerException); - // Truncate any items in the buffer before the last retry token. - while (!buffer.isEmpty() && buffer.getLast().getResumeToken().isEmpty()) { - buffer.removeLast(); - } - assert buffer.isEmpty() || buffer.getLast().getResumeToken().equals(resumeToken); - stream = null; - try (IScope s = tracer.withSpan(span)) { - long delay = spannerException.getRetryDelayInMillis(); - if (delay != -1) { - backoffSleep(context, delay); - } else { - backoffSleep(context, backOff); - } - } - - continue; - } - span.addAnnotation("Stream broken. Not safe to retry", spannerException); - span.setStatus(spannerException); - throw spannerException; - } catch (RuntimeException e) { - span.addAnnotation("Stream broken. Not safe to retry", e); - span.setStatus(e); - throw e; - } - } - } - - boolean isRetryable(SpannerException spannerException) { - return spannerException.isRetryable() - || retryableCodes.contains( - GrpcStatusCode.of(spannerException.getErrorCode().getGrpcStatusCode()).getCode()); - } - } - static double valueProtoToFloat64(com.google.protobuf.Value proto) { if (proto.getKindCase() == KindCase.STRING_VALUE) { switch (proto.getStringValue()) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java index 664cde1edbb..22fb9f710c1 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java @@ -60,6 +60,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(TimestampBound bound) { sessionClient.getSpanner().getDefaultQueryOptions(sessionClient.getDatabaseId())) .setExecutorProvider(sessionClient.getSpanner().getAsyncExecutorProvider()) .setDefaultPrefetchChunks(sessionClient.getSpanner().getDefaultPrefetchChunks()) + .setDefaultDecodeMode(sessionClient.getSpanner().getDefaultDecodeMode()) .setDefaultDirectedReadOptions( sessionClient.getSpanner().getOptions().getDirectedReadOptions()) .setSpan(sessionClient.getSpanner().getTracer().getCurrentSpan()) @@ -81,6 +82,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(BatchTransactionId batc sessionClient.getSpanner().getDefaultQueryOptions(sessionClient.getDatabaseId())) .setExecutorProvider(sessionClient.getSpanner().getAsyncExecutorProvider()) .setDefaultPrefetchChunks(sessionClient.getSpanner().getDefaultPrefetchChunks()) + .setDefaultDecodeMode(sessionClient.getSpanner().getDefaultDecodeMode()) .setDefaultDirectedReadOptions( sessionClient.getSpanner().getOptions().getDirectedReadOptions()) .setSpan(sessionClient.getSpanner().getTracer().getCurrentSpan()) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java new file mode 100644 index 00000000000..c1bea9a3ce1 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +/** Specifies how and when to decode a value from protobuf to a plain Java object. */ +public enum DecodeMode { + /** + * Decodes all columns of a row directly when a {@link ResultSet} is advanced to the next row with + * {@link ResultSet#next()} + */ + DIRECT, + /** + * Decodes all columns of a row the first time a {@link ResultSet} value is retrieved from the + * row. + */ + LAZY_PER_ROW, + /** + * Decodes a columns of a row the first time the value of that column is retrieved from the row. + */ + LAZY_PER_COL, +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java index c29282879ed..18ecbeceb0f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java @@ -23,7 +23,7 @@ import com.google.spanner.v1.ResultSetStats; /** Forwarding implementation of ResultSet that forwards all calls to a delegate. */ -public class ForwardingResultSet extends ForwardingStructReader implements ResultSet { +public class ForwardingResultSet extends ForwardingStructReader implements ProtobufResultSet { private Supplier delegate; @@ -55,6 +55,22 @@ public boolean next() throws SpannerException { return delegate.get().next(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + ResultSet resultSetDelegate = delegate.get(); + return (resultSetDelegate instanceof ProtobufResultSet) + && ((ProtobufResultSet) resultSetDelegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + ResultSet resultSetDelegate = delegate.get(); + Preconditions.checkState( + resultSetDelegate instanceof ProtobufResultSet, + "The result set does not support protobuf values"); + return ((ProtobufResultSet) resultSetDelegate).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { return delegate.get().getCurrentRowAsStruct(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java new file mode 100644 index 00000000000..37a4792ad87 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java @@ -0,0 +1,132 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.Value; +import com.google.spanner.v1.PartialResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; + +@VisibleForTesting +class GrpcResultSet extends AbstractResultSet> implements ProtobufResultSet { + private final GrpcValueIterator iterator; + private final Listener listener; + private final DecodeMode decodeMode; + private ResultSetMetadata metadata; + private GrpcStruct currRow; + private SpannerException error; + private ResultSetStats statistics; + private boolean closed; + + GrpcResultSet(CloseableIterator iterator, Listener listener) { + this(iterator, listener, DecodeMode.DIRECT); + } + + GrpcResultSet( + CloseableIterator iterator, Listener listener, DecodeMode decodeMode) { + this.iterator = new GrpcValueIterator(iterator); + this.listener = listener; + this.decodeMode = decodeMode; + } + + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed && currRow != null && currRow.canGetProtoValue(columnIndex); + } + + @Override + public Value getProtobufValue(int columnIndex) { + checkState(!closed, "ResultSet is closed"); + checkState(currRow != null, "next() call required"); + return currRow.getProtoValueInternal(columnIndex); + } + + @Override + protected GrpcStruct currRow() { + checkState(!closed, "ResultSet is closed"); + checkState(currRow != null, "next() call required"); + return currRow; + } + + @Override + public boolean next() throws SpannerException { + if (error != null) { + throw newSpannerException(error); + } + try { + if (currRow == null) { + metadata = iterator.getMetadata(); + if (metadata.hasTransaction()) { + listener.onTransactionMetadata( + metadata.getTransaction(), iterator.isWithBeginTransaction()); + } else if (iterator.isWithBeginTransaction()) { + // The query should have returned a transaction. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, AbstractReadContext.NO_TRANSACTION_RETURNED_MSG); + } + currRow = new GrpcStruct(iterator.type(), new ArrayList<>(), decodeMode); + } + boolean hasNext = currRow.consumeRow(iterator); + if (!hasNext) { + statistics = iterator.getStats(); + } + return hasNext; + } catch (Throwable t) { + throw yieldError( + SpannerExceptionFactory.asSpannerException(t), + iterator.isWithBeginTransaction() && currRow == null); + } + } + + @Override + @Nullable + public ResultSetStats getStats() { + return statistics; + } + + @Override + public ResultSetMetadata getMetadata() { + checkState(metadata != null, "next() call required"); + return metadata; + } + + @Override + public void close() { + listener.onDone(iterator.isWithBeginTransaction()); + iterator.close("ResultSet closed"); + closed = true; + } + + @Override + public Type getType() { + checkState(currRow != null, "next() call required"); + return currRow.getType(); + } + + private SpannerException yieldError(SpannerException e, boolean beginTransaction) { + SpannerException toThrow = listener.onError(e, beginTransaction); + close(); + throw toThrow; + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java new file mode 100644 index 00000000000..dde6b69c461 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java @@ -0,0 +1,172 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.gax.rpc.ApiCallContext; +import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; +import com.google.cloud.spanner.spi.v1.SpannerRpc; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.AbstractIterator; +import com.google.common.util.concurrent.Uninterruptibles; +import com.google.spanner.v1.PartialResultSet; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; +import org.threeten.bp.Duration; + +/** Adapts a streaming read/query call into an iterator over partial result sets. */ +@VisibleForTesting +class GrpcStreamIterator extends AbstractIterator + implements CloseableIterator { + private static final Logger logger = Logger.getLogger(GrpcStreamIterator.class.getName()); + private static final PartialResultSet END_OF_STREAM = PartialResultSet.newBuilder().build(); + + private final ConsumerImpl consumer = new ConsumerImpl(); + private final BlockingQueue stream; + private final Statement statement; + + private SpannerRpc.StreamingCall call; + private volatile boolean withBeginTransaction; + private TimeUnit streamWaitTimeoutUnit; + private long streamWaitTimeoutValue; + private SpannerException error; + + @VisibleForTesting + GrpcStreamIterator(int prefetchChunks) { + this(null, prefetchChunks); + } + + @VisibleForTesting + GrpcStreamIterator(Statement statement, int prefetchChunks) { + this.statement = statement; + // One extra to allow for END_OF_STREAM message. + this.stream = new LinkedBlockingQueue<>(prefetchChunks + 1); + } + + protected final SpannerRpc.ResultStreamConsumer consumer() { + return consumer; + } + + public void setCall(SpannerRpc.StreamingCall call, boolean withBeginTransaction) { + this.call = call; + this.withBeginTransaction = withBeginTransaction; + ApiCallContext callContext = call.getCallContext(); + Duration streamWaitTimeout = callContext == null ? null : callContext.getStreamWaitTimeout(); + if (streamWaitTimeout != null) { + // Determine the timeout unit to use. This reduces the precision to seconds if the timeout + // value is more than 1 second, which is lower than the precision that would normally be + // used by the stream watchdog (which uses a precision of 10 seconds by default). + if (streamWaitTimeout.getSeconds() > 0L) { + streamWaitTimeoutValue = streamWaitTimeout.getSeconds(); + streamWaitTimeoutUnit = TimeUnit.SECONDS; + } else if (streamWaitTimeout.getNano() > 0) { + streamWaitTimeoutValue = streamWaitTimeout.getNano(); + streamWaitTimeoutUnit = TimeUnit.NANOSECONDS; + } + // Note that if the stream-wait-timeout is zero, we won't set a timeout at all. + // That is consistent with ApiCallContext#withStreamWaitTimeout(Duration.ZERO). + } + } + + @Override + public void close(@Nullable String message) { + if (call != null) { + call.cancel(message); + } + } + + @Override + public boolean isWithBeginTransaction() { + return withBeginTransaction; + } + + @Override + protected final PartialResultSet computeNext() { + PartialResultSet next; + try { + if (streamWaitTimeoutUnit != null) { + next = stream.poll(streamWaitTimeoutValue, streamWaitTimeoutUnit); + if (next == null) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.DEADLINE_EXCEEDED, "stream wait timeout"); + } + } else { + next = stream.take(); + } + } catch (InterruptedException e) { + // Treat interrupt as a request to cancel the read. + throw SpannerExceptionFactory.propagateInterrupt(e); + } + if (next != END_OF_STREAM) { + call.request(1); + return next; + } + + // All done - close() no longer needs to cancel the call. + call = null; + + if (error != null) { + throw SpannerExceptionFactory.newSpannerException(error); + } + + endOfData(); + return null; + } + + private void addToStream(PartialResultSet results) { + // We assume that nothing from the user will interrupt gRPC event threads. + Uninterruptibles.putUninterruptibly(stream, results); + } + + private class ConsumerImpl implements SpannerRpc.ResultStreamConsumer { + @Override + public void onPartialResultSet(PartialResultSet results) { + addToStream(results); + } + + @Override + public void onCompleted() { + addToStream(END_OF_STREAM); + } + + @Override + public void onError(SpannerException e) { + if (statement != null) { + if (logger.isLoggable(Level.FINEST)) { + // Include parameter values if logging level is set to FINEST or higher. + e = + SpannerExceptionFactory.newSpannerExceptionPreformatted( + e.getErrorCode(), + String.format("%s - Statement: '%s'", e.getMessage(), statement.toString()), + e); + logger.log(Level.FINEST, "Error executing statement", e); + } else { + e = + SpannerExceptionFactory.newSpannerExceptionPreformatted( + e.getErrorCode(), + String.format("%s - Statement: '%s'", e.getMessage(), statement.getSql()), + e); + } + } + error = e; + addToStream(END_OF_STREAM); + } + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java new file mode 100644 index 00000000000..e4951d7bee4 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java @@ -0,0 +1,766 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.AbstractResultSet.throwNotNull; +import static com.google.cloud.spanner.AbstractResultSet.valueProtoToFloat64; +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.ByteArray; +import com.google.cloud.Date; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.AbstractResultSet.Float64Array; +import com.google.cloud.spanner.AbstractResultSet.Int64Array; +import com.google.cloud.spanner.AbstractResultSet.LazyByteArray; +import com.google.cloud.spanner.Type.Code; +import com.google.cloud.spanner.Type.StructField; +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.io.CharSource; +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.ListValue; +import com.google.protobuf.NullValue; +import com.google.protobuf.ProtocolMessageEnum; +import com.google.protobuf.Value.KindCase; +import java.io.IOException; +import java.io.Serializable; +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Base64; +import java.util.BitSet; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +class GrpcStruct extends Struct implements Serializable { + private static final com.google.protobuf.Value NULL_VALUE = + com.google.protobuf.Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build(); + + private final Type type; + private final List rowData; + private final DecodeMode decodeMode; + private final BitSet colDecoded; + private boolean rowDecoded; + + /** + * Builds an immutable version of this struct using {@link Struct#newBuilder()} which is used as a + * serialization proxy. + */ + private Object writeReplace() { + Builder builder = Struct.newBuilder(); + List structFields = getType().getStructFields(); + for (int i = 0; i < structFields.size(); i++) { + Type.StructField field = structFields.get(i); + String fieldName = field.getName(); + Object value = rowData.get(i); + Type fieldType = field.getType(); + switch (fieldType.getCode()) { + case BOOL: + builder.set(fieldName).to((Boolean) value); + break; + case INT64: + builder.set(fieldName).to((Long) value); + break; + case FLOAT64: + builder.set(fieldName).to((Double) value); + break; + case NUMERIC: + builder.set(fieldName).to((BigDecimal) value); + break; + case PG_NUMERIC: + builder.set(fieldName).to((String) value); + break; + case STRING: + builder.set(fieldName).to((String) value); + break; + case JSON: + builder.set(fieldName).to(Value.json((String) value)); + break; + case PROTO: + builder + .set(fieldName) + .to( + Value.protoMessage( + value == null ? null : ((LazyByteArray) value).getByteArray(), + fieldType.getProtoTypeFqn())); + break; + case ENUM: + builder.set(fieldName).to(Value.protoEnum((Long) value, fieldType.getProtoTypeFqn())); + break; + case PG_JSONB: + builder.set(fieldName).to(Value.pgJsonb((String) value)); + break; + case BYTES: + builder + .set(fieldName) + .to( + Value.bytesFromBase64( + value == null ? null : ((LazyByteArray) value).getBase64String())); + break; + case TIMESTAMP: + builder.set(fieldName).to((Timestamp) value); + break; + case DATE: + builder.set(fieldName).to((Date) value); + break; + case ARRAY: + final Type elementType = fieldType.getArrayElementType(); + switch (elementType.getCode()) { + case BOOL: + builder.set(fieldName).toBoolArray((Iterable) value); + break; + case INT64: + case ENUM: + builder.set(fieldName).toInt64Array((Iterable) value); + break; + case FLOAT64: + builder.set(fieldName).toFloat64Array((Iterable) value); + break; + case NUMERIC: + builder.set(fieldName).toNumericArray((Iterable) value); + break; + case PG_NUMERIC: + builder.set(fieldName).toPgNumericArray((Iterable) value); + break; + case STRING: + builder.set(fieldName).toStringArray((Iterable) value); + break; + case JSON: + builder.set(fieldName).toJsonArray((Iterable) value); + break; + case PG_JSONB: + builder.set(fieldName).toPgJsonbArray((Iterable) value); + break; + case BYTES: + case PROTO: + builder + .set(fieldName) + .toBytesArrayFromBase64( + value == null + ? null + : ((List) value) + .stream() + .map( + element -> element == null ? null : element.getBase64String()) + .collect(Collectors.toList())); + break; + case TIMESTAMP: + builder.set(fieldName).toTimestampArray((Iterable) value); + break; + case DATE: + builder.set(fieldName).toDateArray((Iterable) value); + break; + case STRUCT: + builder.set(fieldName).toStructArray(elementType, (Iterable) value); + break; + default: + throw new AssertionError("Unhandled array type code: " + elementType); + } + break; + case STRUCT: + if (value == null) { + builder.set(fieldName).to(fieldType, null); + } else { + builder.set(fieldName).to((Struct) value); + } + break; + default: + throw new AssertionError("Unhandled type code: " + fieldType.getCode()); + } + } + return builder.build(); + } + + GrpcStruct(Type type, List rowData, DecodeMode decodeMode) { + this( + type, + rowData, + decodeMode, + /* rowDecoded = */ false, + /* colDecoded = */ decodeMode == DecodeMode.LAZY_PER_COL + ? new BitSet(type.getStructFields().size()) + : null); + } + + private GrpcStruct( + Type type, + List rowData, + DecodeMode decodeMode, + boolean rowDecoded, + BitSet colDecoded) { + this.type = type; + this.rowData = rowData; + this.decodeMode = decodeMode; + this.rowDecoded = rowDecoded; + this.colDecoded = colDecoded; + } + + @Override + public String toString() { + return this.rowData.toString(); + } + + boolean consumeRow(Iterator iterator) { + rowData.clear(); + if (decodeMode == DecodeMode.LAZY_PER_ROW) { + rowDecoded = false; + } else if (decodeMode == DecodeMode.LAZY_PER_COL) { + colDecoded.clear(); + } + if (!iterator.hasNext()) { + return false; + } + for (Type.StructField fieldType : getType().getStructFields()) { + if (!iterator.hasNext()) { + throw newSpannerException( + ErrorCode.INTERNAL, + "Invalid value stream: end of stream reached before row is complete"); + } + com.google.protobuf.Value value = iterator.next(); + if (decodeMode == DecodeMode.DIRECT) { + rowData.add(decodeValue(fieldType.getType(), value)); + } else { + rowData.add(value); + } + } + return true; + } + + private static Object decodeValue(Type fieldType, com.google.protobuf.Value proto) { + if (proto.getKindCase() == KindCase.NULL_VALUE) { + return null; + } + switch (fieldType.getCode()) { + case BOOL: + checkType(fieldType, proto, KindCase.BOOL_VALUE); + return proto.getBoolValue(); + case INT64: + case ENUM: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return Long.parseLong(proto.getStringValue()); + case FLOAT64: + return valueProtoToFloat64(proto); + case NUMERIC: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return new BigDecimal(proto.getStringValue()); + case PG_NUMERIC: + case STRING: + case JSON: + case PG_JSONB: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return proto.getStringValue(); + case BYTES: + case PROTO: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return new LazyByteArray(proto.getStringValue()); + case TIMESTAMP: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return Timestamp.parseTimestamp(proto.getStringValue()); + case DATE: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return Date.parseDate(proto.getStringValue()); + case ARRAY: + checkType(fieldType, proto, KindCase.LIST_VALUE); + ListValue listValue = proto.getListValue(); + return decodeArrayValue(fieldType.getArrayElementType(), listValue); + case STRUCT: + checkType(fieldType, proto, KindCase.LIST_VALUE); + ListValue structValue = proto.getListValue(); + return decodeStructValue(fieldType, structValue); + case UNRECOGNIZED: + return proto; + default: + throw new AssertionError("Unhandled type code: " + fieldType.getCode()); + } + } + + private static Struct decodeStructValue(Type structType, ListValue structValue) { + List fieldTypes = structType.getStructFields(); + checkArgument( + structValue.getValuesCount() == fieldTypes.size(), + "Size mismatch between type descriptor and actual values."); + List fields = new ArrayList<>(fieldTypes.size()); + List fieldValues = structValue.getValuesList(); + for (int i = 0; i < fieldTypes.size(); ++i) { + fields.add(decodeValue(fieldTypes.get(i).getType(), fieldValues.get(i))); + } + return new GrpcStruct(structType, fields, DecodeMode.DIRECT); + } + + static Object decodeArrayValue(Type elementType, ListValue listValue) { + switch (elementType.getCode()) { + case INT64: + case ENUM: + // For int64/float64/enum types, use custom containers. These avoid wrapper object + // creation for non-null arrays. + return new Int64Array(listValue); + case FLOAT64: + return new Float64Array(listValue); + case BOOL: + case NUMERIC: + case PG_NUMERIC: + case STRING: + case JSON: + case PG_JSONB: + case BYTES: + case TIMESTAMP: + case DATE: + case STRUCT: + case PROTO: + return Lists.transform(listValue.getValuesList(), input -> decodeValue(elementType, input)); + default: + throw new AssertionError("Unhandled type code: " + elementType.getCode()); + } + } + + private static void checkType( + Type fieldType, com.google.protobuf.Value proto, KindCase expected) { + if (proto.getKindCase() != expected) { + throw newSpannerException( + ErrorCode.INTERNAL, + "Invalid value for column type " + + fieldType + + " expected " + + expected + + " but was " + + proto.getKindCase()); + } + } + + Struct immutableCopy() { + return new GrpcStruct( + type, + new ArrayList<>(rowData), + this.decodeMode, + this.rowDecoded, + this.colDecoded == null ? null : (BitSet) this.colDecoded.clone()); + } + + @Override + public Type getType() { + return type; + } + + @Override + public boolean isNull(int columnIndex) { + if ((decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) + || (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex))) { + return ((com.google.protobuf.Value) rowData.get(columnIndex)).hasNullValue(); + } + return rowData.get(columnIndex) == null; + } + + @Override + protected T getProtoMessageInternal(int columnIndex, T message) { + Preconditions.checkNotNull( + message, + "Proto message may not be null. Use MyProtoClass.getDefaultInstance() as a parameter value."); + try { + return (T) + message + .toBuilder() + .mergeFrom( + Base64.getDecoder() + .wrap( + CharSource.wrap( + ((LazyByteArray) rowData.get(columnIndex)).getBase64String()) + .asByteSource(StandardCharsets.UTF_8) + .openStream())) + .build(); + } catch (IOException ioException) { + throw SpannerExceptionFactory.asSpannerException(ioException); + } + } + + @Override + protected T getProtoEnumInternal( + int columnIndex, Function method) { + Preconditions.checkNotNull( + method, "Method may not be null. Use 'MyProtoEnum::forNumber' as a parameter value."); + return (T) method.apply((int) getLongInternal(columnIndex)); + } + + @Override + protected boolean getBooleanInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (Boolean) rowData.get(columnIndex); + } + + @Override + protected long getLongInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (Long) rowData.get(columnIndex); + } + + @Override + protected double getDoubleInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (Double) rowData.get(columnIndex); + } + + @Override + protected BigDecimal getBigDecimalInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (BigDecimal) rowData.get(columnIndex); + } + + @Override + protected String getStringInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (String) rowData.get(columnIndex); + } + + @Override + protected String getJsonInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (String) rowData.get(columnIndex); + } + + @Override + protected String getPgJsonbInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (String) rowData.get(columnIndex); + } + + @Override + protected ByteArray getBytesInternal(int columnIndex) { + ensureDecoded(columnIndex); + return getLazyBytesInternal(columnIndex).getByteArray(); + } + + LazyByteArray getLazyBytesInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (LazyByteArray) rowData.get(columnIndex); + } + + @Override + protected Timestamp getTimestampInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (Timestamp) rowData.get(columnIndex); + } + + @Override + protected Date getDateInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (Date) rowData.get(columnIndex); + } + + private boolean isUnrecognizedType(int columnIndex) { + return type.getStructFields().get(columnIndex).getType().getCode() == Code.UNRECOGNIZED; + } + + boolean canGetProtoValue(int columnIndex) { + return isUnrecognizedType(columnIndex) + || (decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) + || (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex)); + } + + protected com.google.protobuf.Value getProtoValueInternal(int columnIndex) { + checkProtoValueSupported(columnIndex); + return (com.google.protobuf.Value) rowData.get(columnIndex); + } + + private void checkProtoValueSupported(int columnIndex) { + // Unrecognized types are returned as protobuf values. + if (isUnrecognizedType(columnIndex)) { + return; + } + Preconditions.checkState( + decodeMode != DecodeMode.DIRECT, + "Getting proto value is not supported when DecodeMode#DIRECT is used."); + Preconditions.checkState( + !(decodeMode == DecodeMode.LAZY_PER_ROW && rowDecoded), + "Getting proto value after the row has been decoded is not supported."); + Preconditions.checkState( + !(decodeMode == DecodeMode.LAZY_PER_COL && colDecoded.get(columnIndex)), + "Getting proto value after the column has been decoded is not supported."); + } + + private void ensureDecoded(int columnIndex) { + if (decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) { + for (int i = 0; i < rowData.size(); i++) { + rowData.set( + i, + decodeValue( + type.getStructFields().get(i).getType(), + (com.google.protobuf.Value) rowData.get(i))); + } + rowDecoded = true; + } else if (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex)) { + rowData.set( + columnIndex, + decodeValue( + type.getStructFields().get(columnIndex).getType(), + (com.google.protobuf.Value) rowData.get(columnIndex))); + colDecoded.set(columnIndex); + } + } + + @Override + protected Value getValueInternal(int columnIndex) { + ensureDecoded(columnIndex); + final List structFields = getType().getStructFields(); + final StructField structField = structFields.get(columnIndex); + final Type columnType = structField.getType(); + final boolean isNull = rowData.get(columnIndex) == null; + switch (columnType.getCode()) { + case BOOL: + return Value.bool(isNull ? null : getBooleanInternal(columnIndex)); + case INT64: + return Value.int64(isNull ? null : getLongInternal(columnIndex)); + case ENUM: + return Value.protoEnum( + isNull ? null : getLongInternal(columnIndex), columnType.getProtoTypeFqn()); + case NUMERIC: + return Value.numeric(isNull ? null : getBigDecimalInternal(columnIndex)); + case PG_NUMERIC: + return Value.pgNumeric(isNull ? null : getStringInternal(columnIndex)); + case FLOAT64: + return Value.float64(isNull ? null : getDoubleInternal(columnIndex)); + case STRING: + return Value.string(isNull ? null : getStringInternal(columnIndex)); + case JSON: + return Value.json(isNull ? null : getJsonInternal(columnIndex)); + case PG_JSONB: + return Value.pgJsonb(isNull ? null : getPgJsonbInternal(columnIndex)); + case BYTES: + return Value.internalBytes(isNull ? null : getLazyBytesInternal(columnIndex)); + case PROTO: + return Value.protoMessage( + isNull ? null : getBytesInternal(columnIndex), columnType.getProtoTypeFqn()); + case TIMESTAMP: + return Value.timestamp(isNull ? null : getTimestampInternal(columnIndex)); + case DATE: + return Value.date(isNull ? null : getDateInternal(columnIndex)); + case STRUCT: + return Value.struct(isNull ? null : getStructInternal(columnIndex)); + case UNRECOGNIZED: + return Value.unrecognized( + isNull ? NULL_VALUE : getProtoValueInternal(columnIndex), columnType); + case ARRAY: + final Type elementType = columnType.getArrayElementType(); + switch (elementType.getCode()) { + case BOOL: + return Value.boolArray(isNull ? null : getBooleanListInternal(columnIndex)); + case INT64: + return Value.int64Array(isNull ? null : getLongListInternal(columnIndex)); + case NUMERIC: + return Value.numericArray(isNull ? null : getBigDecimalListInternal(columnIndex)); + case PG_NUMERIC: + return Value.pgNumericArray(isNull ? null : getStringListInternal(columnIndex)); + case FLOAT64: + return Value.float64Array(isNull ? null : getDoubleListInternal(columnIndex)); + case STRING: + return Value.stringArray(isNull ? null : getStringListInternal(columnIndex)); + case JSON: + return Value.jsonArray(isNull ? null : getJsonListInternal(columnIndex)); + case PG_JSONB: + return Value.pgJsonbArray(isNull ? null : getPgJsonbListInternal(columnIndex)); + case BYTES: + return Value.bytesArray(isNull ? null : getBytesListInternal(columnIndex)); + case PROTO: + return Value.protoMessageArray( + isNull ? null : getBytesListInternal(columnIndex), elementType.getProtoTypeFqn()); + case ENUM: + return Value.protoEnumArray( + isNull ? null : getLongListInternal(columnIndex), elementType.getProtoTypeFqn()); + case TIMESTAMP: + return Value.timestampArray(isNull ? null : getTimestampListInternal(columnIndex)); + case DATE: + return Value.dateArray(isNull ? null : getDateListInternal(columnIndex)); + case STRUCT: + return Value.structArray( + elementType, isNull ? null : getStructListInternal(columnIndex)); + default: + throw new IllegalArgumentException( + "Invalid array value type " + this.type.getArrayElementType()); + } + default: + throw new IllegalArgumentException("Invalid value type " + this.type); + } + } + + @Override + protected Struct getStructInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (Struct) rowData.get(columnIndex); + } + + @Override + protected boolean[] getBooleanArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + List values = (List) rowData.get(columnIndex); + boolean[] r = new boolean[values.size()]; + for (int i = 0; i < values.size(); ++i) { + if (values.get(i) == null) { + throw throwNotNull(columnIndex); + } + r[i] = values.get(i); + } + return r; + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getBooleanListInternal(int columnIndex) { + ensureDecoded(columnIndex); + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + protected long[] getLongArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); + return getLongListInternal(columnIndex).toPrimitiveArray(columnIndex); + } + + @Override + protected Int64Array getLongListInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (Int64Array) rowData.get(columnIndex); + } + + @Override + protected double[] getDoubleArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); + return getDoubleListInternal(columnIndex).toPrimitiveArray(columnIndex); + } + + @Override + protected Float64Array getDoubleListInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (Float64Array) rowData.get(columnIndex); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getBigDecimalListInternal(int columnIndex) { + ensureDecoded(columnIndex); + return (List) rowData.get(columnIndex); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getStringListInternal(int columnIndex) { + ensureDecoded(columnIndex); + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getJsonListInternal(int columnIndex) { + ensureDecoded(columnIndex); + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getProtoMessageListInternal( + int columnIndex, T message) { + Preconditions.checkNotNull( + message, + "Proto message may not be null. Use MyProtoClass.getDefaultInstance() as a parameter value."); + ensureDecoded(columnIndex); + + List bytesArray = (List) rowData.get(columnIndex); + + try { + List protoMessagesList = new ArrayList<>(bytesArray.size()); + for (LazyByteArray protoMessageBytes : bytesArray) { + if (protoMessageBytes == null) { + protoMessagesList.add(null); + } else { + protoMessagesList.add( + (T) + message + .toBuilder() + .mergeFrom( + Base64.getDecoder() + .wrap( + CharSource.wrap(protoMessageBytes.getBase64String()) + .asByteSource(StandardCharsets.UTF_8) + .openStream())) + .build()); + } + } + return protoMessagesList; + } catch (IOException ioException) { + throw SpannerExceptionFactory.asSpannerException(ioException); + } + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getProtoEnumListInternal( + int columnIndex, Function method) { + Preconditions.checkNotNull( + method, "Method may not be null. Use 'MyProtoEnum::forNumber' as a parameter value."); + ensureDecoded(columnIndex); + + List enumIntArray = (List) rowData.get(columnIndex); + List protoEnumList = new ArrayList<>(enumIntArray.size()); + for (Long enumIntValue : enumIntArray) { + if (enumIntValue == null) { + protoEnumList.add(null); + } else { + protoEnumList.add((T) method.apply(enumIntValue.intValue())); + } + } + + return protoEnumList; + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getPgJsonbListInternal(int columnIndex) { + ensureDecoded(columnIndex); + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getBytesListInternal(int columnIndex) { + ensureDecoded(columnIndex); + return Lists.transform( + (List) rowData.get(columnIndex), l -> l == null ? null : l.getByteArray()); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getTimestampListInternal(int columnIndex) { + ensureDecoded(columnIndex); + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getDateListInternal(int columnIndex) { + ensureDecoded(columnIndex); + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY> produces a List. + protected List getStructListInternal(int columnIndex) { + ensureDecoded(columnIndex); + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java new file mode 100644 index 00000000000..0a2e17bd2b5 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java @@ -0,0 +1,212 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; +import com.google.common.collect.AbstractIterator; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value.KindCase; +import com.google.spanner.v1.PartialResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.TypeCode; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; + +/** Adapts a stream of {@code PartialResultSet} messages into a stream of {@code Value} messages. */ +class GrpcValueIterator extends AbstractIterator { + private enum StreamValue { + METADATA, + RESULT, + } + + private final CloseableIterator stream; + private ResultSetMetadata metadata; + private Type type; + private PartialResultSet current; + private int pos; + private ResultSetStats statistics; + + GrpcValueIterator(CloseableIterator stream) { + this.stream = stream; + } + + @SuppressWarnings("unchecked") + @Override + protected com.google.protobuf.Value computeNext() { + if (!ensureReady(StreamValue.RESULT)) { + endOfData(); + return null; + } + com.google.protobuf.Value value = current.getValues(pos++); + KindCase kind = value.getKindCase(); + + if (!isMergeable(kind)) { + if (pos == current.getValuesCount() && current.getChunkedValue()) { + throw newSpannerException(ErrorCode.INTERNAL, "Unexpected chunked PartialResultSet."); + } else { + return value; + } + } + if (!current.getChunkedValue() || pos != current.getValuesCount()) { + return value; + } + + Object merged = + kind == KindCase.STRING_VALUE + ? value.getStringValue() + : new ArrayList<>(value.getListValue().getValuesList()); + while (current.getChunkedValue() && pos == current.getValuesCount()) { + if (!ensureReady(StreamValue.RESULT)) { + throw newSpannerException( + ErrorCode.INTERNAL, "Stream closed in the middle of chunked value"); + } + com.google.protobuf.Value newValue = current.getValues(pos++); + if (newValue.getKindCase() != kind) { + throw newSpannerException( + ErrorCode.INTERNAL, + "Unexpected type in middle of chunked value. Expected: " + + kind + + " but got: " + + newValue.getKindCase()); + } + if (kind == KindCase.STRING_VALUE) { + merged = merged + newValue.getStringValue(); + } else { + concatLists( + (List) merged, newValue.getListValue().getValuesList()); + } + } + if (kind == KindCase.STRING_VALUE) { + return com.google.protobuf.Value.newBuilder().setStringValue((String) merged).build(); + } else { + return com.google.protobuf.Value.newBuilder() + .setListValue( + ListValue.newBuilder().addAllValues((List) merged)) + .build(); + } + } + + ResultSetMetadata getMetadata() throws SpannerException { + if (metadata == null) { + if (!ensureReady(StreamValue.METADATA)) { + throw newSpannerException(ErrorCode.INTERNAL, "Stream closed without sending metadata"); + } + } + return metadata; + } + + /** + * Get the query statistics. Query statistics are delivered with the last PartialResultSet in the + * stream. Any attempt to call this method before the caller has finished consuming the results + * will return null. + */ + @Nullable + ResultSetStats getStats() { + return statistics; + } + + Type type() { + checkState(type != null, "metadata has not been received"); + return type; + } + + private boolean ensureReady(StreamValue requiredValue) throws SpannerException { + while (current == null || pos >= current.getValuesCount()) { + if (!stream.hasNext()) { + return false; + } + current = stream.next(); + pos = 0; + if (type == null) { + // This is the first message on the stream. + if (!current.hasMetadata() || !current.getMetadata().hasRowType()) { + throw newSpannerException(ErrorCode.INTERNAL, "Missing type metadata in first message"); + } + metadata = current.getMetadata(); + com.google.spanner.v1.Type typeProto = + com.google.spanner.v1.Type.newBuilder() + .setCode(TypeCode.STRUCT) + .setStructType(metadata.getRowType()) + .build(); + try { + type = Type.fromProto(typeProto); + } catch (IllegalArgumentException e) { + throw newSpannerException( + ErrorCode.INTERNAL, "Invalid type metadata: " + e.getMessage(), e); + } + } + if (current.hasStats()) { + statistics = current.getStats(); + } + if (requiredValue == StreamValue.METADATA) { + return true; + } + } + return true; + } + + void close(@Nullable String message) { + stream.close(message); + } + + boolean isWithBeginTransaction() { + return stream.isWithBeginTransaction(); + } + + /** @param a is a mutable list and b will be concatenated into a. */ + private void concatLists(List a, List b) { + if (a.size() == 0 || b.size() == 0) { + a.addAll(b); + return; + } else { + com.google.protobuf.Value last = a.get(a.size() - 1); + com.google.protobuf.Value first = b.get(0); + KindCase lastKind = last.getKindCase(); + KindCase firstKind = first.getKindCase(); + if (isMergeable(lastKind) && lastKind == firstKind) { + com.google.protobuf.Value merged; + if (lastKind == KindCase.STRING_VALUE) { + String lastStr = last.getStringValue(); + String firstStr = first.getStringValue(); + merged = + com.google.protobuf.Value.newBuilder().setStringValue(lastStr + firstStr).build(); + } else { // List + List mergedList = new ArrayList<>(); + mergedList.addAll(last.getListValue().getValuesList()); + concatLists(mergedList, first.getListValue().getValuesList()); + merged = + com.google.protobuf.Value.newBuilder() + .setListValue(ListValue.newBuilder().addAllValues(mergedList)) + .build(); + } + a.set(a.size() - 1, merged); + a.addAll(b.subList(1, b.size())); + } else { + a.addAll(b); + } + } + } + + private boolean isMergeable(KindCase kind) { + return kind == KindCase.STRING_VALUE || kind == KindCase.LIST_VALUE; + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java index 57feabbfcca..76d0f24225a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java @@ -243,6 +243,10 @@ public static ReadAndQueryOption directedRead(DirectedReadOptions directedReadOp return new DirectedReadOption(directedReadOptions); } + public static ReadAndQueryOption decodeMode(DecodeMode decodeMode) { + return new DecodeOption(decodeMode); + } + /** Option to request {@link CommitStats} for read/write transactions. */ static final class CommitStatsOption extends InternalOption implements TransactionOption { @Override @@ -374,6 +378,19 @@ void appendToOptions(Options options) { } } + static final class DecodeOption extends InternalOption implements ReadAndQueryOption { + private final DecodeMode decodeMode; + + DecodeOption(DecodeMode decodeMode) { + this.decodeMode = Preconditions.checkNotNull(decodeMode, "DecodeMode cannot be null"); + } + + @Override + void appendToOptions(Options options) { + options.decodeMode = decodeMode; + } + } + private boolean withCommitStats; private Duration maxCommitDelay; @@ -391,6 +408,7 @@ void appendToOptions(Options options) { private Boolean withOptimisticLock; private Boolean dataBoostEnabled; private DirectedReadOptions directedReadOptions; + private DecodeMode decodeMode; // Construction is via factory methods below. private Options() {} @@ -507,6 +525,14 @@ DirectedReadOptions directedReadOptions() { return directedReadOptions; } + boolean hasDecodeMode() { + return decodeMode != null; + } + + DecodeMode decodeMode() { + return decodeMode; + } + @Override public String toString() { StringBuilder b = new StringBuilder(); @@ -552,6 +578,9 @@ public String toString() { if (directedReadOptions != null) { b.append("directedReadOptions: ").append(directedReadOptions).append(' '); } + if (decodeMode != null) { + b.append("decodeMode: ").append(decodeMode).append(' '); + } return b.toString(); } @@ -640,6 +669,9 @@ public int hashCode() { if (directedReadOptions != null) { result = 31 * result + directedReadOptions.hashCode(); } + if (decodeMode != null) { + result = 31 * result + decodeMode.hashCode(); + } return result; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java new file mode 100644 index 00000000000..bbd8c41291f --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.InternalApi; +import com.google.protobuf.Value; + +/** Interface for {@link ResultSet}s that can return a protobuf value. */ +@InternalApi +public interface ProtobufResultSet extends ResultSet { + + /** Returns true if the protobuf value for the given column is still available. */ + boolean canGetProtobufValue(int columnIndex); + + /** + * Returns the column value as a protobuf value. + * + *

This is an internal method not intended for external usage. + * + *

This method may only be called before the column value has been decoded to a plain Java + * object. This means that the {@link DecodeMode} that is used for the {@link ResultSet} must be + * one of {@link DecodeMode#LAZY_PER_ROW} and {@link DecodeMode#LAZY_PER_COL}, and that the + * corresponding {@link ResultSet#getValue(int)}, {@link ResultSet#getBoolean(int)}, ... method + * may not yet have been called for the column. + */ + @InternalApi + Value getProtobufValue(int columnIndex); +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java index d55d4091b9f..a6cc7c729e5 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java @@ -109,7 +109,7 @@ public ResultSet get() { } } - private static class PrePopulatedResultSet implements ResultSet { + private static class PrePopulatedResultSet implements ProtobufResultSet { private final List rows; private final Type type; private int index = -1; @@ -137,6 +137,19 @@ public boolean next() throws SpannerException { return ++index < rows.size(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed && index >= 0 && index < rows.size(); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + Preconditions.checkState(!closed, "ResultSet is closed"); + Preconditions.checkState(index >= 0, "Must be preceded by a next() call"); + Preconditions.checkElementIndex(index, rows.size(), "All rows have been yielded"); + return getValue(columnIndex).toProto(); + } + @Override public Struct getCurrentRowAsStruct() { Preconditions.checkState(!closed, "ResultSet is closed"); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java new file mode 100644 index 00000000000..590797c0999 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java @@ -0,0 +1,277 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerExceptionForCancellation; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.api.client.util.BackOff; +import com.google.api.client.util.ExponentialBackOff; +import com.google.api.gax.grpc.GrpcStatusCode; +import com.google.api.gax.retrying.RetrySettings; +import com.google.api.gax.rpc.StatusCode.Code; +import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; +import com.google.cloud.spanner.v1.stub.SpannerStubSettings; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import com.google.protobuf.ByteString; +import com.google.spanner.v1.PartialResultSet; +import io.grpc.Context; +import java.io.IOException; +import java.util.LinkedList; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * Wraps an iterator over partial result sets, supporting resuming RPCs on error. This class keeps + * track of the most recent resume token seen, and will buffer partial result set chunks that do not + * have a resume token until one is seen or buffer space is exceeded, which reduces the chance of + * yielding data to the caller that cannot be resumed. + */ +@VisibleForTesting +abstract class ResumableStreamIterator extends AbstractIterator + implements CloseableIterator { + private static final RetrySettings DEFAULT_STREAMING_RETRY_SETTINGS = + SpannerStubSettings.newBuilder().executeStreamingSqlSettings().getRetrySettings(); + private final RetrySettings streamingRetrySettings; + private final Set retryableCodes; + private static final Logger logger = Logger.getLogger(ResumableStreamIterator.class.getName()); + private final BackOff backOff; + private final LinkedList buffer = new LinkedList<>(); + private final int maxBufferSize; + private final ISpan span; + private final TraceWrapper tracer; + private CloseableIterator stream; + private ByteString resumeToken; + private boolean finished; + /** + * Indicates whether it is currently safe to retry RPCs. This will be {@code false} if we have + * reached the maximum buffer size without seeing a restart token; in this case, we will drain the + * buffer and remain in this state until we see a new restart token. + */ + private boolean safeToRetry = true; + + protected ResumableStreamIterator( + int maxBufferSize, + String streamName, + ISpan parent, + TraceWrapper tracer, + RetrySettings streamingRetrySettings, + Set retryableCodes) { + checkArgument(maxBufferSize >= 0); + this.maxBufferSize = maxBufferSize; + this.tracer = tracer; + this.span = tracer.spanBuilderWithExplicitParent(streamName, parent); + this.streamingRetrySettings = Preconditions.checkNotNull(streamingRetrySettings); + this.retryableCodes = Preconditions.checkNotNull(retryableCodes); + this.backOff = newBackOff(); + } + + private ExponentialBackOff newBackOff() { + if (Objects.equals(streamingRetrySettings, DEFAULT_STREAMING_RETRY_SETTINGS)) { + return new ExponentialBackOff.Builder() + .setMultiplier(streamingRetrySettings.getRetryDelayMultiplier()) + .setInitialIntervalMillis( + Math.max(10, (int) streamingRetrySettings.getInitialRetryDelay().toMillis())) + .setMaxIntervalMillis( + Math.max(1000, (int) streamingRetrySettings.getMaxRetryDelay().toMillis())) + .setMaxElapsedTimeMillis(Integer.MAX_VALUE) // Prevent Backoff.STOP from getting returned. + .build(); + } + return new ExponentialBackOff.Builder() + .setMultiplier(streamingRetrySettings.getRetryDelayMultiplier()) + // All of these values must be > 0. + .setInitialIntervalMillis( + Math.max( + 1, + (int) + Math.min( + streamingRetrySettings.getInitialRetryDelay().toMillis(), + Integer.MAX_VALUE))) + .setMaxIntervalMillis( + Math.max( + 1, + (int) + Math.min( + streamingRetrySettings.getMaxRetryDelay().toMillis(), Integer.MAX_VALUE))) + .setMaxElapsedTimeMillis( + Math.max( + 1, + (int) + Math.min( + streamingRetrySettings.getTotalTimeout().toMillis(), Integer.MAX_VALUE))) + .build(); + } + + private void backoffSleep(Context context, BackOff backoff) throws SpannerException { + backoffSleep(context, nextBackOffMillis(backoff)); + } + + private static long nextBackOffMillis(BackOff backoff) throws SpannerException { + try { + return backoff.nextBackOffMillis(); + } catch (IOException e) { + throw newSpannerException(ErrorCode.INTERNAL, e.getMessage(), e); + } + } + + private void backoffSleep(Context context, long backoffMillis) throws SpannerException { + tracer.getCurrentSpan().addAnnotation("Backing off", "Delay", backoffMillis); + final CountDownLatch latch = new CountDownLatch(1); + final Context.CancellationListener listener = + ignored -> { + // Wakeup on cancellation / DEADLINE_EXCEEDED. + latch.countDown(); + }; + + context.addListener(listener, DirectExecutor.INSTANCE); + try { + if (backoffMillis == BackOff.STOP) { + // Highly unlikely but we handle it just in case. + backoffMillis = streamingRetrySettings.getMaxRetryDelay().toMillis(); + } + if (latch.await(backoffMillis, TimeUnit.MILLISECONDS)) { + // Woken by context cancellation. + throw newSpannerExceptionForCancellation(context, null); + } + } catch (InterruptedException interruptExcept) { + throw newSpannerExceptionForCancellation(context, interruptExcept); + } finally { + context.removeListener(listener); + } + } + + private enum DirectExecutor implements Executor { + INSTANCE; + + @Override + public void execute(Runnable command) { + command.run(); + } + } + + abstract CloseableIterator startStream(@Nullable ByteString resumeToken); + + @Override + public void close(@Nullable String message) { + if (stream != null) { + stream.close(message); + span.end(); + stream = null; + } + } + + @Override + public boolean isWithBeginTransaction() { + return stream != null && stream.isWithBeginTransaction(); + } + + @Override + protected PartialResultSet computeNext() { + Context context = Context.current(); + while (true) { + // Eagerly start stream before consuming any buffered items. + if (stream == null) { + span.addAnnotation( + "Starting/Resuming stream", + "ResumeToken", + resumeToken == null ? "null" : resumeToken.toStringUtf8()); + try (IScope scope = tracer.withSpan(span)) { + // When start a new stream set the Span as current to make the gRPC Span a child of + // this Span. + stream = checkNotNull(startStream(resumeToken)); + } + } + // Buffer contains items up to a resume token or has reached capacity: flush. + if (!buffer.isEmpty() + && (finished || !safeToRetry || !buffer.getLast().getResumeToken().isEmpty())) { + return buffer.pop(); + } + try { + if (stream.hasNext()) { + PartialResultSet next = stream.next(); + boolean hasResumeToken = !next.getResumeToken().isEmpty(); + if (hasResumeToken) { + resumeToken = next.getResumeToken(); + safeToRetry = true; + } + // If the buffer is empty and this chunk has a resume token or we cannot resume safely + // anyway, we can yield it immediately rather than placing it in the buffer to be + // returned on the next iteration. + if ((hasResumeToken || !safeToRetry) && buffer.isEmpty()) { + return next; + } + buffer.add(next); + if (buffer.size() > maxBufferSize && buffer.getLast().getResumeToken().isEmpty()) { + // We need to flush without a restart token. Errors encountered until we see + // such a token will fail the read. + safeToRetry = false; + } + } else { + finished = true; + if (buffer.isEmpty()) { + endOfData(); + return null; + } + } + } catch (SpannerException spannerException) { + if (safeToRetry && isRetryable(spannerException)) { + span.addAnnotation("Stream broken. Safe to retry", spannerException); + logger.log(Level.FINE, "Retryable exception, will sleep and retry", spannerException); + // Truncate any items in the buffer before the last retry token. + while (!buffer.isEmpty() && buffer.getLast().getResumeToken().isEmpty()) { + buffer.removeLast(); + } + assert buffer.isEmpty() || buffer.getLast().getResumeToken().equals(resumeToken); + stream = null; + try (IScope s = tracer.withSpan(span)) { + long delay = spannerException.getRetryDelayInMillis(); + if (delay != -1) { + backoffSleep(context, delay); + } else { + backoffSleep(context, backOff); + } + } + + continue; + } + span.addAnnotation("Stream broken. Not safe to retry", spannerException); + span.setStatus(spannerException); + throw spannerException; + } catch (RuntimeException e) { + span.addAnnotation("Stream broken. Not safe to retry", e); + span.setStatus(e); + throw e; + } + } + } + + boolean isRetryable(SpannerException spannerException) { + return spannerException.isRetryable() + || retryableCodes.contains( + GrpcStatusCode.of(spannerException.getErrorCode().getGrpcStatusCode()).getCode()); + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index 29928f61cec..81b00001105 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -263,6 +263,7 @@ public ReadContext singleUse(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -284,6 +285,7 @@ public ReadOnlyTransaction singleUseReadOnlyTransaction(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -305,6 +307,7 @@ public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -423,6 +426,7 @@ TransactionContextImpl newTransaction(Options options) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setSpan(currentSpan) .setTracer(tracer) .setExecutorProvider(spanner.getAsyncExecutorProvider()) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 326a51d803e..8fe06f76cc8 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -151,6 +151,10 @@ int getDefaultPrefetchChunks() { return getOptions().getPrefetchChunks(); } + DecodeMode getDefaultDecodeMode() { + return getOptions().getDecodeMode(); + } + /** Returns the default query options that should be used for the specified database. */ QueryOptions getDefaultQueryOptions(DatabaseId databaseId) { return getOptions().getDefaultQueryOptions(databaseId); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java index 9c6044aa938..a16be179ce3 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java @@ -113,6 +113,7 @@ public class SpannerOptions extends ServiceOptions { private final GrpcInterceptorProvider interceptorProvider; private final SessionPoolOptions sessionPoolOptions; private final int prefetchChunks; + private final DecodeMode decodeMode; private final int numChannels; private final String transportChannelExecutorThreadNameFormat; private final String databaseRole; @@ -616,6 +617,7 @@ protected SpannerOptions(Builder builder) { ? builder.sessionPoolOptions : SessionPoolOptions.newBuilder().build(); prefetchChunks = builder.prefetchChunks; + decodeMode = builder.decodeMode; databaseRole = builder.databaseRole; sessionLabels = builder.sessionLabels; try { @@ -704,6 +706,9 @@ public static class Builder extends ServiceOptions.Builder { static final int DEFAULT_PREFETCH_CHUNKS = 4; static final QueryOptions DEFAULT_QUERY_OPTIONS = QueryOptions.getDefaultInstance(); + // TODO: Set the default to DecodeMode.DIRECT before merging to keep the current default. + // It is currently set to LAZY_PER_COL so it is used in all tests. + static final DecodeMode DEFAULT_DECODE_MODE = DecodeMode.LAZY_PER_COL; static final RetrySettings DEFAULT_ADMIN_REQUESTS_LIMIT_EXCEEDED_RETRY_SETTINGS = RetrySettings.newBuilder() .setInitialRetryDelay(Duration.ofSeconds(5L)) @@ -730,6 +735,7 @@ public static class Builder private String transportChannelExecutorThreadNameFormat = "Cloud-Spanner-TransportChannel-%d"; private int prefetchChunks = DEFAULT_PREFETCH_CHUNKS; + private DecodeMode decodeMode = DEFAULT_DECODE_MODE; private SessionPoolOptions sessionPoolOptions; private String databaseRole; private ImmutableMap sessionLabels; @@ -797,6 +803,7 @@ protected Builder() { options.transportChannelExecutorThreadNameFormat; this.sessionPoolOptions = options.sessionPoolOptions; this.prefetchChunks = options.prefetchChunks; + this.decodeMode = options.decodeMode; this.databaseRole = options.databaseRole; this.sessionLabels = options.sessionLabels; this.spannerStubSettingsBuilder = options.spannerStubSettings.toBuilder(); @@ -1224,6 +1231,15 @@ public Builder setPrefetchChunks(int prefetchChunks) { return this; } + /** + * Specifies how values that are returned from a query should be decoded and converted from + * protobuf values into plain Java objects. + */ + public Builder setDecodeMode(DecodeMode decodeMode) { + this.decodeMode = decodeMode; + return this; + } + @Override public Builder setHost(String host) { super.setHost(host); @@ -1568,6 +1584,10 @@ public int getPrefetchChunks() { return prefetchChunks; } + public DecodeMode getDecodeMode() { + return decodeMode; + } + public static GrpcTransportOptions getDefaultGrpcTransportOptions() { return GrpcTransportOptions.newBuilder().build(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java index a845eb118bf..3f0155e4a5e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java @@ -1518,6 +1518,10 @@ void valueToString(StringBuilder b) { @Override boolean valueEquals(Value v) { + // NaN == NaN always returns false, so we need a custom check. + if (Double.isNaN(this.value)) { + return Double.isNaN(((Float64Impl) v).value); + } return ((Float64Impl) v).value == value; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java index dc373cf03bd..c642d7e505a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java @@ -16,28 +16,28 @@ package com.google.cloud.spanner.connection; -import com.google.cloud.ByteArray; -import com.google.cloud.Date; -import com.google.cloud.Timestamp; import com.google.cloud.spanner.AbortedException; +import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.Options.QueryOption; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; -import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.Code; +import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.ReadWriteTransaction.RetriableStatement; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import com.google.common.hash.Funnel; import com.google.common.hash.HashCode; -import com.google.common.hash.HashFunction; -import com.google.common.hash.Hasher; -import com.google.common.hash.Hashing; -import com.google.common.hash.PrimitiveSink; -import java.math.BigDecimal; -import java.util.Objects; +import com.google.protobuf.Value; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.util.Arrays; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicLong; @@ -71,11 +71,11 @@ class ChecksumResultSet extends ReplaceableForwardingResultSet implements Retria private final ParsedStatement statement; private final AnalyzeMode analyzeMode; private final QueryOption[] options; - private final ChecksumResultSet.ChecksumCalculator checksumCalculator = new ChecksumCalculator(); + private final ChecksumCalculator checksumCalculator = new ChecksumCalculator(); ChecksumResultSet( ReadWriteTransaction transaction, - ResultSet delegate, + ProtobufResultSet delegate, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { @@ -91,6 +91,13 @@ class ChecksumResultSet extends ReplaceableForwardingResultSet implements Retria this.options = options; } + @Override + public Value getProtobufValue(int columnIndex) { + // We can safely cast to ProtobufResultSet here, as the constructor of this class only accepts + // instances of ProtobufResultSet. + return ((ProtobufResultSet) getDelegate()).getProtobufValue(columnIndex); + } + /** Simple {@link Callable} for calling {@link ResultSet#next()} */ private final class NextCallable implements Callable { @Override @@ -102,7 +109,7 @@ public Boolean call() { boolean res = ChecksumResultSet.super.next(); // Only update the checksum if there was another row to be consumed. if (res) { - checksumCalculator.calculateNextChecksum(getCurrentRowAsStruct()); + checksumCalculator.calculateNextChecksum(ChecksumResultSet.this); } numberOfNextCalls.incrementAndGet(); return res; @@ -118,8 +125,9 @@ public boolean next() { } @VisibleForTesting - HashCode getChecksum() { - // HashCode is immutable and can be safely returned. + byte[] getChecksum() { + // Getting the checksum from the checksumCalculator will create a clone of the current digest + // and return the checksum from the clone, so it is safe to return this value. return checksumCalculator.getChecksum(); } @@ -132,8 +140,8 @@ HashCode getChecksum() { @Override public void retry(AbortedException aborted) throws AbortedException { // Execute the same query and consume the result set to the same point as the original. - ChecksumResultSet.ChecksumCalculator newChecksumCalculator = new ChecksumCalculator(); - ResultSet resultSet = null; + ChecksumCalculator newChecksumCalculator = new ChecksumCalculator(); + ProtobufResultSet resultSet = null; long counter = 0L; try { transaction @@ -150,7 +158,7 @@ public void retry(AbortedException aborted) throws AbortedException { statement, StatementExecutionStep.RETRY_NEXT_ON_RESULT_SET, transaction); next = resultSet.next(); if (next) { - newChecksumCalculator.calculateNextChecksum(resultSet.getCurrentRowAsStruct()); + newChecksumCalculator.calculateNextChecksum(resultSet); } counter++; } @@ -168,9 +176,9 @@ public void retry(AbortedException aborted) throws AbortedException { throw e; } // Check that we have the same number of rows and the same checksum. - HashCode newChecksum = newChecksumCalculator.getChecksum(); - HashCode currentChecksum = checksumCalculator.getChecksum(); - if (counter == numberOfNextCalls.get() && Objects.equals(newChecksum, currentChecksum)) { + byte[] newChecksum = newChecksumCalculator.getChecksum(); + byte[] currentChecksum = checksumCalculator.getChecksum(); + if (counter == numberOfNextCalls.get() && Arrays.equals(newChecksum, currentChecksum)) { // Checksum is ok, we only need to replace the delegate result set if it's still open. if (isClosed()) { resultSet.close(); @@ -184,222 +192,165 @@ public void retry(AbortedException aborted) throws AbortedException { } } - /** Calculates and keeps the current checksum of a {@link ChecksumResultSet} */ + /** + * Calculates a running checksum for all the data that has been seen sofar in this result set. The + * calculation is performed on the protobuf values that were returned by Cloud Spanner, which + * means that no decoding of the results is needed (or allowed!) before calculating the checksum. + * This is more efficient, both in terms of CPU usage and memory consumption, especially if the + * consumer of the result set does not read all values, or is only reading the underlying protobuf + * values. + */ private static final class ChecksumCalculator { - private static final HashFunction SHA256_FUNCTION = Hashing.sha256(); - private HashCode currentChecksum; + // Use a buffer of max 1Mb to hash string data. This means that strings of up to 1Mb in size + // will be hashed in one go, while strings larger than 1Mb will be chunked into pieces of at + // most 1Mb and then fed into the digest. The digest internally creates a copy of the string + // that is being hashed, so chunking large strings prevents them from being loaded into memory + // twice. + private static final int MAX_BUFFER_SIZE = 1 << 20; - private void calculateNextChecksum(Struct row) { - Hasher hasher = SHA256_FUNCTION.newHasher(); - if (currentChecksum != null) { - hasher.putBytes(currentChecksum.asBytes()); + private boolean firstRow = true; + private final MessageDigest digest; + private ByteBuffer buffer; + private ByteBuffer float64Buffer; + + ChecksumCalculator() { + try { + // This is safe, as all Java implementations are required to have MD5 implemented. + // See https://docs.oracle.com/javase/8/docs/api/java/security/MessageDigest.html + // MD5 requires less CPU power than SHA-256, and still offers a low enough collision + // probability for the use case at hand here. + digest = MessageDigest.getInstance("MD5"); + } catch (Throwable t) { + throw SpannerExceptionFactory.asSpannerException(t); } - hasher.putObject(row, StructFunnel.INSTANCE); - currentChecksum = hasher.hash(); } - private HashCode getChecksum() { - return currentChecksum; + private byte[] getChecksum() { + try { + // This is safe, as the MD5 MessageDigest is known to be cloneable. + MessageDigest clone = (MessageDigest) digest.clone(); + return clone.digest(); + } catch (CloneNotSupportedException e) { + throw SpannerExceptionFactory.asSpannerException(e); + } } - } - /** - * A {@link Funnel} implementation for calculating a {@link HashCode} for each row in a {@link - * ResultSet}. - */ - private enum StructFunnel implements Funnel { - INSTANCE; - private static final String NULL = "null"; - - @Override - public void funnel(Struct row, PrimitiveSink into) { - for (int i = 0; i < row.getColumnCount(); i++) { - if (row.isNull(i)) { - funnelValue(Code.STRING, null, into); + private void calculateNextChecksum(ProtobufResultSet resultSet) { + if (firstRow) { + for (StructField field : resultSet.getType().getStructFields()) { + digest.update(field.getType().toString().getBytes(StandardCharsets.UTF_8)); + } + } + for (int col = 0; col < resultSet.getColumnCount(); col++) { + Type type = resultSet.getColumnType(col); + if (resultSet.canGetProtobufValue(col)) { + Value value = resultSet.getProtobufValue(col); + digest.update((byte) value.getKindCase().getNumber()); + pushValue(type, value); } else { - Code type = row.getColumnType(i).getCode(); - switch (type) { - case ARRAY: - funnelArray(row.getColumnType(i).getArrayElementType().getCode(), row, i, into); - break; - case BOOL: - funnelValue(type, row.getBoolean(i), into); - break; - case BYTES: - case PROTO: - funnelValue(type, row.getBytes(i), into); - break; - case DATE: - funnelValue(type, row.getDate(i), into); - break; - case FLOAT64: - funnelValue(type, row.getDouble(i), into); - break; - case NUMERIC: - funnelValue(type, row.getBigDecimal(i), into); - break; - case PG_NUMERIC: - funnelValue(type, row.getString(i), into); - break; - case INT64: - case ENUM: - funnelValue(type, row.getLong(i), into); - break; - case STRING: - funnelValue(type, row.getString(i), into); - break; - case JSON: - funnelValue(type, row.getJson(i), into); - break; - case PG_JSONB: - funnelValue(type, row.getPgJsonb(i), into); - break; - case TIMESTAMP: - funnelValue(type, row.getTimestamp(i), into); - break; - - case STRUCT: - default: - throw new IllegalArgumentException("unsupported row type"); - } + // This will normally not happen, unless the user explicitly sets the decoding mode to + // DIRECT for a query in a read/write transaction. The default decoding mode in the + // Connection API is set to LAZY_PER_COL. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "Failed to get the underlying protobuf value for the column " + + resultSet.getMetadata().getRowType().getFields(col).getName() + + ". " + + "Executing queries with DecodeMode#DIRECT is not supported in read/write transactions."); } } + firstRow = false; } - private void funnelArray( - Code arrayElementType, Struct row, int columnIndex, PrimitiveSink into) { - funnelValue(Code.STRING, "BeginArray", into); - switch (arrayElementType) { - case BOOL: - into.putInt(row.getBooleanList(columnIndex).size()); - for (Boolean value : row.getBooleanList(columnIndex)) { - funnelValue(Code.BOOL, value, into); - } + private void pushValue(Type type, Value value) { + // Protobuf Value has a very limited set of possible types of values. All Cloud Spanner types + // are mapped to one of the protobuf values listed here, meaning that no changes are needed to + // this calculation when a new type is added to Cloud Spanner. + switch (value.getKindCase()) { + case NULL_VALUE: + // nothing needed, writing the KindCase is enough. break; - case BYTES: - case PROTO: - into.putInt(row.getBytesList(columnIndex).size()); - for (ByteArray value : row.getBytesList(columnIndex)) { - funnelValue(Code.BYTES, value, into); - } + case BOOL_VALUE: + digest.update(value.getBoolValue() ? (byte) 1 : 0); break; - case DATE: - into.putInt(row.getDateList(columnIndex).size()); - for (Date value : row.getDateList(columnIndex)) { - funnelValue(Code.DATE, value, into); - } + case STRING_VALUE: + putString(value.getStringValue()); break; - case FLOAT64: - into.putInt(row.getDoubleList(columnIndex).size()); - for (Double value : row.getDoubleList(columnIndex)) { - funnelValue(Code.FLOAT64, value, into); + case NUMBER_VALUE: + if (float64Buffer == null) { + // Create an 8-byte buffer that can be re-used for all float64 values in this result + // set. + float64Buffer = ByteBuffer.allocate(Double.BYTES); + } else { + float64Buffer.clear(); } + float64Buffer.putDouble(value.getNumberValue()); + float64Buffer.flip(); + digest.update(float64Buffer); break; - case NUMERIC: - into.putInt(row.getBigDecimalList(columnIndex).size()); - for (BigDecimal value : row.getBigDecimalList(columnIndex)) { - funnelValue(Code.NUMERIC, value, into); + case LIST_VALUE: + if (type.getCode() == Code.ARRAY) { + for (Value item : value.getListValue().getValuesList()) { + digest.update((byte) item.getKindCase().getNumber()); + pushValue(type.getArrayElementType(), item); + } + } else { + // This should not be possible. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "List values that are not an ARRAY are not supported"); } break; - case PG_NUMERIC: - into.putInt(row.getStringList(columnIndex).size()); - for (String value : row.getStringList(columnIndex)) { - funnelValue(Code.STRING, value, into); + case STRUCT_VALUE: + if (type.getCode() == Code.STRUCT) { + for (int col = 0; col < type.getStructFields().size(); col++) { + String name = type.getStructFields().get(col).getName(); + putString(name); + Value item = value.getStructValue().getFieldsMap().get(name); + digest.update((byte) item.getKindCase().getNumber()); + pushValue(type.getStructFields().get(col).getType(), item); + } + } else { + // This should not be possible. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "Struct values without a struct type are not supported"); } break; - case INT64: - case ENUM: - into.putInt(row.getLongList(columnIndex).size()); - for (Long value : row.getLongList(columnIndex)) { - funnelValue(Code.INT64, value, into); - } - break; - case STRING: - into.putInt(row.getStringList(columnIndex).size()); - for (String value : row.getStringList(columnIndex)) { - funnelValue(Code.STRING, value, into); - } - break; - case JSON: - into.putInt(row.getJsonList(columnIndex).size()); - for (String value : row.getJsonList(columnIndex)) { - funnelValue(Code.JSON, value, into); - } - break; - case PG_JSONB: - into.putInt(row.getPgJsonbList(columnIndex).size()); - for (String value : row.getPgJsonbList(columnIndex)) { - funnelValue(Code.PG_JSONB, value, into); - } - break; - case TIMESTAMP: - into.putInt(row.getTimestampList(columnIndex).size()); - for (Timestamp value : row.getTimestampList(columnIndex)) { - funnelValue(Code.TIMESTAMP, value, into); - } - break; - - case ARRAY: - case STRUCT: default: - throw new IllegalArgumentException("unsupported array element type"); + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.UNIMPLEMENTED, "Unsupported protobuf value: " + value.getKindCase()); } - funnelValue(Code.STRING, "EndArray", into); } - private void funnelValue(Code type, T value, PrimitiveSink into) { - // Include the type name in case the type of a column has changed. - into.putUnencodedChars(type.name()); - if (value == null) { - if (type == Code.BYTES || type == Code.STRING) { - // Put length -1 to distinguish from the string value 'null'. - into.putInt(-1); - } - into.putUnencodedChars(NULL); + /** Hashes a string value in blocks of max MAX_BUFFER_SIZE. */ + private void putString(String stringValue) { + int length = stringValue.length(); + if (buffer == null || (buffer.capacity() < MAX_BUFFER_SIZE && buffer.capacity() < length)) { + // Create a ByteBuffer with a maximum buffer size. + // This buffer is re-used for all string values in the result set. + buffer = ByteBuffer.allocate(Math.min(MAX_BUFFER_SIZE, length)); } else { - switch (type) { - case BOOL: - into.putBoolean((Boolean) value); - break; - case BYTES: - case PROTO: - ByteArray byteArray = (ByteArray) value; - into.putInt(byteArray.length()); - into.putBytes(byteArray.toByteArray()); - break; - case DATE: - Date date = (Date) value; - into.putInt(date.getYear()).putInt(date.getMonth()).putInt(date.getDayOfMonth()); - break; - case FLOAT64: - into.putDouble((Double) value); - break; - case NUMERIC: - String stringRepresentation = value.toString(); - into.putInt(stringRepresentation.length()); - into.putUnencodedChars(stringRepresentation); - break; - case INT64: - case ENUM: - into.putLong((Long) value); - break; - case PG_NUMERIC: - case STRING: - case JSON: - case PG_JSONB: - String stringValue = (String) value; - into.putInt(stringValue.length()); - into.putUnencodedChars(stringValue); - break; - case TIMESTAMP: - Timestamp timestamp = (Timestamp) value; - into.putLong(timestamp.getSeconds()).putInt(timestamp.getNanos()); - break; - case ARRAY: - case STRUCT: - default: - throw new IllegalArgumentException("invalid type for single value"); - } + buffer.clear(); + } + + // Wrap the string in a CharBuffer. This allows us to read from the string in sections without + // creating a new copy of (a part of) the string. E.g. using something like substring(..) + // would create a copy of that part of the string, using CharBuffer.wrap(..) does not. + CharBuffer source = CharBuffer.wrap(stringValue); + CharsetEncoder encoder = StandardCharsets.UTF_8.newEncoder(); + // source.hasRemaining() returns false when all the characters in the string have been + // processed. + while (source.hasRemaining()) { + // Encode the string into bytes and write them into the byte buffer. + // At most MAX_BUFFER_SIZE bytes will be written. + encoder.encode(source, buffer, false); + // Flip the buffer so we can read from the start. + buffer.flip(); + // Put the bytes from the buffer into the digest. + digest.update(buffer); + // Flip the buffer again, so we can repeat and write to the start of the buffer again. + buffer.flip(); } } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java index dff915e2cce..1b15ec50822 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java @@ -19,6 +19,7 @@ import com.google.cloud.ByteArray; import com.google.cloud.Date; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.Struct; @@ -40,7 +41,7 @@ * to the actual query execution. It also ensures that any invalid query will throw an exception at * execution instead of the first next() call by a client. */ -class DirectExecuteResultSet implements ResultSet { +class DirectExecuteResultSet implements ProtobufResultSet { private static final String MISSING_NEXT_CALL = "Must be preceded by a next() call"; private final ResultSet delegate; private boolean nextCalledByClient = false; @@ -79,6 +80,21 @@ public boolean next() throws SpannerException { return initialNextResult; } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return nextCalledByClient + && delegate instanceof ProtobufResultSet + && ((ProtobufResultSet) delegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + Preconditions.checkState(nextCalledByClient, MISSING_NEXT_CALL); + Preconditions.checkState( + delegate instanceof ProtobufResultSet, "The result set does not support protobuf values"); + return ((ProtobufResultSet) delegate).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { Preconditions.checkState(nextCalledByClient, MISSING_NEXT_CALL); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java index e1fb87e4ade..6c4290c3b18 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java @@ -39,6 +39,7 @@ import com.google.cloud.spanner.Options.QueryOption; import com.google.cloud.spanner.Options.TransactionOption; import com.google.cloud.spanner.Options.UpdateOption; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ReadContext; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; @@ -427,7 +428,7 @@ public ApiFuture executeQueryAsync( statement, StatementExecutionStep.EXECUTE_STATEMENT, ReadWriteTransaction.this); - ResultSet delegate = + DirectExecuteResultSet delegate = DirectExecuteResultSet.ofResultSet( internalExecuteQuery(statement, analyzeMode, options)); return createAndAddRetryResultSet( @@ -797,7 +798,7 @@ T runWithRetry(Callable callable) throws SpannerException { * returns a retryable {@link ResultSet}. */ private ResultSet createAndAddRetryResultSet( - ResultSet resultSet, + ProtobufResultSet resultSet, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { @@ -1091,7 +1092,7 @@ interface RetriableStatement { /** Creates a {@link ChecksumResultSet} for this {@link ReadWriteTransaction}. */ @VisibleForTesting ChecksumResultSet createChecksumResultSet( - ResultSet delegate, + ProtobufResultSet delegate, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java index 7370551a46f..a8de14e5121 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java @@ -20,6 +20,7 @@ import com.google.cloud.Date; import com.google.cloud.Timestamp; import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; @@ -42,7 +43,7 @@ * that is fetched using the new transaction. This is achieved by wrapping the returned result sets * in a {@link ReplaceableForwardingResultSet} that replaces its delegate after a transaction retry. */ -class ReplaceableForwardingResultSet implements ResultSet { +class ReplaceableForwardingResultSet implements ProtobufResultSet { private ResultSet delegate; private boolean closed; @@ -60,6 +61,10 @@ void replaceDelegate(ResultSet delegate) { this.delegate = delegate; } + protected ResultSet getDelegate() { + return this.delegate; + } + private void checkClosed() { if (closed) { throw SpannerExceptionFactory.newSpannerException( @@ -77,6 +82,21 @@ public boolean next() throws SpannerException { return delegate.next(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed + && delegate instanceof ProtobufResultSet + && ((ProtobufResultSet) delegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + checkClosed(); + Preconditions.checkState( + delegate instanceof ProtobufResultSet, "The result set does not support protobuf values"); + return ((ProtobufResultSet) getDelegate()).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { checkClosed(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java index 2a5a805c2c7..da8da78d92e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java @@ -17,6 +17,7 @@ package com.google.cloud.spanner.connection; import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.DecodeMode; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SessionPoolOptions; import com.google.cloud.spanner.Spanner; @@ -342,6 +343,9 @@ Spanner createSpanner(SpannerPoolKey key, ConnectionOptions options) { .setClientLibToken(MoreObjects.firstNonNull(key.userAgent, CONNECTION_API_CLIENT_LIB_TOKEN)) .setHost(key.host) .setProjectId(key.projectId) + // Use lazy decoding, so we can use the protobuf values for calculating the checksum that is + // needed for read/write transactions. + .setDecodeMode(DecodeMode.LAZY_PER_COL) .setDatabaseRole(options.getDatabaseRole()) .setCredentials(options.getCredentials()); builder.setSessionPoolOption(key.sessionPoolOptions); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java index ccbf3c0b2b9..8dfbb986eb8 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java @@ -46,7 +46,6 @@ import com.google.cloud.ByteArray; import com.google.cloud.NoCredentials; import com.google.cloud.Timestamp; -import com.google.cloud.spanner.AbstractResultSet.GrpcStreamIterator; import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture; import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java index 914ce391f4a..cb73618d998 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java @@ -56,13 +56,13 @@ import org.junit.runners.JUnit4; import org.threeten.bp.Duration; -/** Unit tests for {@link com.google.cloud.spanner.AbstractResultSet.GrpcResultSet}. */ +/** Unit tests for {@link GrpcResultSet}. */ @RunWith(JUnit4.class) public class GrpcResultSetTest { - private AbstractResultSet.GrpcResultSet resultSet; + private GrpcResultSet resultSet; private SpannerRpc.ResultStreamConsumer consumer; - private AbstractResultSet.GrpcStreamIterator stream; + private GrpcStreamIterator stream; private final Duration streamWaitTimeout = Duration.ofNanos(1L); private static class NoOpListener implements AbstractResultSet.Listener { @@ -81,7 +81,7 @@ public void onDone(boolean withBeginTransaction) {} @Before public void setUp() { - stream = new AbstractResultSet.GrpcStreamIterator(10); + stream = new GrpcStreamIterator(10); stream.setCall( new SpannerRpc.StreamingCall() { @Override @@ -97,11 +97,11 @@ public void request(int numMessages) {} }, false); consumer = stream.consumer(); - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + resultSet = new GrpcResultSet(stream, new NoOpListener()); } - public AbstractResultSet.GrpcResultSet resultSetWithMode(QueryMode queryMode) { - return new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + public GrpcResultSet resultSetWithMode(QueryMode queryMode) { + return new GrpcResultSet(stream, new NoOpListener()); } @Test @@ -609,7 +609,7 @@ public com.google.protobuf.Value apply(@Nullable Value input) { private void verifySerialization( Function protoFn, Value... values) { - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + resultSet = new GrpcResultSet(stream, new NoOpListener()); PartialResultSet.Builder builder = PartialResultSet.newBuilder(); List types = new ArrayList<>(); for (Value value : values) { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index 7bf9f51a4ea..f27aa405aaa 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -19,7 +19,6 @@ import com.google.api.gax.grpc.testing.MockGrpcService; import com.google.cloud.ByteArray; import com.google.cloud.Date; -import com.google.cloud.spanner.AbstractResultSet.GrpcStruct; import com.google.cloud.spanner.AbstractResultSet.LazyByteArray; import com.google.cloud.spanner.SessionPool.SessionPoolTransactionContext; import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java index af558d14dd4..a72c9872faf 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java @@ -104,9 +104,9 @@ protected List getChildren() { } private static class TestCaseRunner { - private AbstractResultSet.GrpcResultSet resultSet; + private GrpcResultSet resultSet; private SpannerRpc.ResultStreamConsumer consumer; - private AbstractResultSet.GrpcStreamIterator stream; + private GrpcStreamIterator stream; private JSONObject testCase; TestCaseRunner(JSONObject testCase) { @@ -114,7 +114,7 @@ private static class TestCaseRunner { } private void run() throws Exception { - stream = new AbstractResultSet.GrpcStreamIterator(10); + stream = new GrpcStreamIterator(10); stream.setCall( new SpannerRpc.StreamingCall() { @Override @@ -130,7 +130,7 @@ public void request(int numMessages) {} }, false); consumer = stream.consumer(); - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + resultSet = new GrpcResultSet(stream, new NoOpListener()); JSONArray chunks = testCase.getJSONArray("chunks"); JSONObject expectedResult = testCase.getJSONObject("result"); @@ -143,8 +143,7 @@ public void request(int numMessages) {} assertResultSet(resultSet, expectedResult.getJSONArray("value")); } - private void assertResultSet(AbstractResultSet.GrpcResultSet actual, JSONArray expected) - throws Exception { + private void assertResultSet(GrpcResultSet actual, JSONArray expected) throws Exception { int i = 0; while (actual.next()) { Struct actualRow = actual.getCurrentRowAsStruct(); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResultSetsHelper.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResultSetsHelper.java index 51cca1bc684..fc494c6f3ff 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResultSetsHelper.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResultSetsHelper.java @@ -17,7 +17,6 @@ package com.google.cloud.spanner; import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; -import com.google.cloud.spanner.AbstractResultSet.GrpcResultSet; import com.google.cloud.spanner.AbstractResultSet.Listener; import com.google.protobuf.ListValue; import com.google.spanner.v1.PartialResultSet; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java index 217e818d42c..d153696ab45 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java @@ -24,7 +24,6 @@ import static org.mockito.Mockito.when; import com.google.api.client.util.BackOff; -import com.google.cloud.spanner.AbstractResultSet.ResumableStreamIterator; import com.google.cloud.spanner.v1.stub.SpannerStubSettings; import com.google.common.collect.AbstractIterator; import com.google.common.collect.Lists; @@ -56,7 +55,7 @@ import org.junit.runners.JUnit4; import org.mockito.Mockito; -/** Unit tests for {@link AbstractResultSet.ResumableStreamIterator}. */ +/** Unit tests for {@link ResumableStreamIterator}. */ @RunWith(JUnit4.class) public class ResumableStreamIteratorTest { interface Starter { @@ -131,7 +130,7 @@ public boolean isWithBeginTransaction() { } Starter starter = Mockito.mock(Starter.class); - AbstractResultSet.ResumableStreamIterator resumableStreamIterator; + ResumableStreamIterator resumableStreamIterator; @Before public void setUp() { @@ -143,7 +142,7 @@ public void setUp() { private void initWithLimit(int maxBufferSize) { resumableStreamIterator = - new AbstractResultSet.ResumableStreamIterator( + new ResumableStreamIterator( maxBufferSize, "", new OpenTelemetrySpan(mock(io.opentelemetry.api.trace.Span.class)), diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java new file mode 100644 index 00000000000..6a6125e1dda --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.connection; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.DecodeMode; +import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.Options; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.SpannerException; +import com.google.cloud.spanner.Statement; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class DecodeModeTest extends AbstractMockServerTest { + + @After + public void clearRequests() { + mockSpanner.clearRequests(); + } + + @Test + public void testAllDecodeModes() { + int numRows = 10; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + String sql = "select * from random"; + Statement statement = Statement.of(sql); + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.query(statement, generator.generate())); + + try (Connection connection = createConnection()) { + for (boolean readonly : new boolean[] {true, false}) { + for (boolean autocommit : new boolean[] {true, false}) { + connection.setReadOnly(readonly); + connection.setAutocommit(autocommit); + + int receivedRows = 0; + // DecodeMode#DIRECT is not supported in read/write transactions, as the protobuf value is + // used for checksum calculation. + try (ResultSet direct = + connection.executeQuery( + statement, + !readonly && !autocommit + ? Options.decodeMode(DecodeMode.LAZY_PER_ROW) + : Options.decodeMode(DecodeMode.DIRECT)); + ResultSet lazyPerRow = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.LAZY_PER_ROW)); + ResultSet lazyPerCol = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.LAZY_PER_COL))) { + while (direct.next() && lazyPerRow.next() && lazyPerCol.next()) { + assertEquals(direct.getColumnCount(), lazyPerRow.getColumnCount()); + assertEquals(direct.getColumnCount(), lazyPerCol.getColumnCount()); + for (int col = 0; col < direct.getColumnCount(); col++) { + // Test getting the entire row as a struct both as the first thing we do, and as the + // last thing we do. This ensures that the method works as expected both when a row + // is lazily decoded by this method, and when it has already been decoded by another + // method. + if (col % 2 == 0) { + assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct()); + assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct()); + } + assertEquals(direct.isNull(col), lazyPerRow.isNull(col)); + assertEquals(direct.isNull(col), lazyPerCol.isNull(col)); + assertEquals(direct.getValue(col), lazyPerRow.getValue(col)); + assertEquals(direct.getValue(col), lazyPerCol.getValue(col)); + if (col % 2 == 1) { + assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct()); + assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct()); + } + } + receivedRows++; + } + assertEquals(numRows, receivedRows); + } + if (!autocommit) { + connection.commit(); + } + } + } + } + } + + @Test + public void testDecodeModeDirect_failsInReadWriteTransaction() { + int numRows = 1; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + String sql = "select * from random"; + Statement statement = Statement.of(sql); + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.query(statement, generator.generate())); + + try (Connection connection = createConnection()) { + connection.setAutocommit(false); + try (ResultSet resultSet = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.DIRECT))) { + SpannerException exception = assertThrows(SpannerException.class, resultSet::next); + assertEquals(ErrorCode.FAILED_PRECONDITION, exception.getErrorCode()); + assertTrue( + exception.getMessage(), + exception + .getMessage() + .contains( + "Executing queries with DecodeMode#DIRECT is not supported in read/write transactions.")); + } + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java index 1e4f96d1568..b14f837ff7b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java @@ -59,6 +59,7 @@ public void testMethodCallBeforeNext() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -79,6 +80,7 @@ public void testMethodCallAfterClose() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -101,6 +103,7 @@ public void testMethodCallAfterNextHasReturnedFalse() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java index 3091364e17a..2067d36b5ea 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java @@ -34,6 +34,7 @@ import com.google.spanner.v1.TypeAnnotationCode; import com.google.spanner.v1.TypeCode; import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -239,7 +240,10 @@ private void setRandomValue(Value.Builder builder, Type type) { if (dialect == Dialect.POSTGRESQL && randomNaN()) { builder.setStringValue("NaN"); } else { - builder.setStringValue(BigDecimal.valueOf(random.nextDouble()).toString()); + builder.setStringValue( + BigDecimal.valueOf(random.nextDouble()) + .setScale(9, RoundingMode.HALF_UP) + .toString()); } break; case INT64: diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java index 0f083fd1e50..8e643cf6e24 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java @@ -37,6 +37,7 @@ import com.google.cloud.spanner.CommitResponse; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.ResultSets; @@ -518,193 +519,197 @@ public void testChecksumResultSet() { .setGenre(Genre.FOLK) .build(); ProtocolMessageEnum protoEnumVal = Genre.ROCK; - ResultSet delegate1 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(BigDecimal.valueOf(550, 2)) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(BigDecimal.valueOf(750, 2)) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build())); + ProtobufResultSet delegate1 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(BigDecimal.valueOf(550, 2)) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(BigDecimal.valueOf(750, 2)) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build())); ChecksumResultSet rs1 = transaction.createChecksumResultSet(delegate1, parsedStatement, AnalyzeMode.NONE); - ResultSet delegate2 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build())); + ProtobufResultSet delegate2 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build())); ChecksumResultSet rs2 = transaction.createChecksumResultSet(delegate2, parsedStatement, AnalyzeMode.NONE); // rs1 and rs2 are equal, rs3 contains the same rows, but in a different order - ResultSet delegate3 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build(), - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build())); + ProtobufResultSet delegate3 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build(), + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build())); ChecksumResultSet rs3 = transaction.createChecksumResultSet(delegate3, parsedStatement, AnalyzeMode.NONE); // rs4 contains the same rows as rs1 and rs2, but also an additional row - ResultSet delegate4 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build(), - Struct.newBuilder() - .set("ID") - .to(3L) - .set("NAME") - .to("TEST 3") - .set("AMOUNT") - .to(new BigDecimal("9.99")) - .set("JSON") - .to(Value.json(emptyArrayJson)) - .set("PROTO") - .to(null, SingerInfo.getDescriptor()) - .set("PROTOENUM") - .to(Genre.POP) - .build())); + ProtobufResultSet delegate4 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build(), + Struct.newBuilder() + .set("ID") + .to(3L) + .set("NAME") + .to("TEST 3") + .set("AMOUNT") + .to(new BigDecimal("9.99")) + .set("JSON") + .to(Value.json(emptyArrayJson)) + .set("PROTO") + .to(null, SingerInfo.getDescriptor()) + .set("PROTOENUM") + .to(Genre.POP) + .build())); ChecksumResultSet rs4 = transaction.createChecksumResultSet(delegate4, parsedStatement, AnalyzeMode.NONE); @@ -736,44 +741,46 @@ public void testChecksumResultSetWithArray() { ParsedStatement parsedStatement = mock(ParsedStatement.class); Statement statement = Statement.of("SELECT * FROM FOO"); when(parsedStatement.getStatement()).thenReturn(statement); - ResultSet delegate1 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("PRICES", Type.array(Type.int64()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("PRICES") - .toInt64Array(new long[] {1L, 2L}) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("PRICES") - .toInt64Array(new long[] {3L, 4L}) - .build())); + ProtobufResultSet delegate1 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("PRICES", Type.array(Type.int64()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("PRICES") + .toInt64Array(new long[] {1L, 2L}) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("PRICES") + .toInt64Array(new long[] {3L, 4L}) + .build())); ChecksumResultSet rs1 = transaction.createChecksumResultSet(delegate1, parsedStatement, AnalyzeMode.NONE); - ResultSet delegate2 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("PRICES", Type.array(Type.int64()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("PRICES") - .toInt64Array(new long[] {1L, 2L}) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("PRICES") - .toInt64Array(new long[] {3L, 5L}) - .build())); + ProtobufResultSet delegate2 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("PRICES", Type.array(Type.int64()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("PRICES") + .toInt64Array(new long[] {1L, 2L}) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("PRICES") + .toInt64Array(new long[] {3L, 5L}) + .build())); ChecksumResultSet rs2 = transaction.createChecksumResultSet(delegate2, parsedStatement, AnalyzeMode.NONE); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java index bbb34675147..4617c47bc6b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java @@ -104,7 +104,14 @@ public void testReplace() { public void testMethodCallBeforeNext() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = - Arrays.asList("getStats", "getMetadata", "next", "close", "equals", "hashCode"); + Arrays.asList( + "canGetProtobufValue", + "getStats", + "getMetadata", + "next", + "close", + "equals", + "hashCode"); ReplaceableForwardingResultSet subject = createSubject(); // Test that all methods throw an IllegalStateException except the excluded methods when called // before a call to ResultSet#next(). @@ -116,6 +123,7 @@ public void testMethodCallAfterClose() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -140,6 +148,7 @@ public void testMethodCallAfterNextHasReturnedFalse() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next",