From cdaa7da10b37e07f507119bc3a19abac15c47c2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Sat, 27 Jan 2024 10:09:07 +0100 Subject: [PATCH 1/3] refactor: move inner classes to top level Move the gRPC-related inner classes from AbstractResultSet to top-level classes, so they are easier to modify and maintain. This change only contains modifications that are needed to move these inner classes. There are no functional changes. --- .../cloud/spanner/AbstractReadContext.java | 3 - .../cloud/spanner/AbstractResultSet.java | 1283 +---------------- .../google/cloud/spanner/GrpcResultSet.java | 113 ++ .../cloud/spanner/GrpcStreamIterator.java | 174 +++ .../com/google/cloud/spanner/GrpcStruct.java | 641 ++++++++ .../cloud/spanner/GrpcValueIterator.java | 211 +++ .../spanner/ResumableStreamIterator.java | 289 ++++ .../cloud/spanner/DatabaseClientImplTest.java | 1 - .../cloud/spanner/GrpcResultSetTest.java | 16 +- .../cloud/spanner/MockSpannerServiceImpl.java | 1 - .../cloud/spanner/ReadFormatTestRunner.java | 11 +- .../cloud/spanner/ResultSetsHelper.java | 1 - .../spanner/ResumableStreamIteratorTest.java | 7 +- 13 files changed, 1445 insertions(+), 1306 deletions(-) create mode 100644 google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java create mode 100644 google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java create mode 100644 google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java create mode 100644 google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java create mode 100644 google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 0f4310f9b4d..1006314198e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -28,9 +28,6 @@ import com.google.api.gax.core.ExecutorProvider; import com.google.cloud.Timestamp; import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; -import com.google.cloud.spanner.AbstractResultSet.GrpcResultSet; -import com.google.cloud.spanner.AbstractResultSet.GrpcStreamIterator; -import com.google.cloud.spanner.AbstractResultSet.ResumableStreamIterator; import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; import com.google.cloud.spanner.Options.QueryOption; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java index c18e64165bc..6cce03e72cb 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java @@ -17,79 +17,32 @@ package com.google.cloud.spanner; import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; -import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerExceptionForCancellation; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; - -import com.google.api.client.util.BackOff; -import com.google.api.client.util.ExponentialBackOff; -import com.google.api.gax.grpc.GrpcStatusCode; -import com.google.api.gax.retrying.RetrySettings; -import com.google.api.gax.rpc.ApiCallContext; -import com.google.api.gax.rpc.StatusCode.Code; + import com.google.cloud.ByteArray; import com.google.cloud.Date; import com.google.cloud.Timestamp; -import com.google.cloud.spanner.Type.StructField; -import com.google.cloud.spanner.spi.v1.SpannerRpc; -import com.google.cloud.spanner.v1.stub.SpannerStubSettings; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import com.google.common.collect.AbstractIterator; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; -import com.google.common.io.CharSource; -import com.google.common.util.concurrent.Uninterruptibles; import com.google.protobuf.AbstractMessage; -import com.google.protobuf.ByteString; import com.google.protobuf.ListValue; -import com.google.protobuf.NullValue; import com.google.protobuf.ProtocolMessageEnum; import com.google.protobuf.Value.KindCase; -import com.google.spanner.v1.PartialResultSet; -import com.google.spanner.v1.ResultSetMetadata; -import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.Transaction; -import com.google.spanner.v1.TypeCode; -import io.grpc.Context; -import io.opencensus.common.Scope; -import io.opencensus.trace.AttributeValue; -import io.opencensus.trace.Span; -import io.opencensus.trace.Tracer; -import io.opencensus.trace.Tracing; import java.io.IOException; import java.io.Serializable; import java.math.BigDecimal; -import java.nio.charset.StandardCharsets; import java.util.AbstractList; -import java.util.ArrayList; import java.util.Base64; import java.util.BitSet; -import java.util.Collections; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; import java.util.Objects; -import java.util.Set; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; import java.util.function.Function; -import java.util.logging.Level; -import java.util.logging.Logger; -import java.util.stream.Collectors; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import org.threeten.bp.Duration; /** Implementation of {@link ResultSet}. */ abstract class AbstractResultSet extends AbstractStructReader implements ResultSet { - private static final Tracer tracer = Tracing.getTracer(); - private static final com.google.protobuf.Value NULL_VALUE = - com.google.protobuf.Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build(); interface Listener { /** @@ -106,271 +59,6 @@ void onTransactionMetadata(Transaction transaction, boolean shouldIncludeId) void onDone(boolean withBeginTransaction); } - @VisibleForTesting - static class GrpcResultSet extends AbstractResultSet> { - private final GrpcValueIterator iterator; - private final Listener listener; - private ResultSetMetadata metadata; - private GrpcStruct currRow; - private SpannerException error; - private ResultSetStats statistics; - private boolean closed; - - GrpcResultSet(CloseableIterator iterator, Listener listener) { - this.iterator = new GrpcValueIterator(iterator); - this.listener = listener; - } - - @Override - protected GrpcStruct currRow() { - checkState(!closed, "ResultSet is closed"); - checkState(currRow != null, "next() call required"); - return currRow; - } - - @Override - public boolean next() throws SpannerException { - if (error != null) { - throw newSpannerException(error); - } - try { - if (currRow == null) { - metadata = iterator.getMetadata(); - if (metadata.hasTransaction()) { - listener.onTransactionMetadata( - metadata.getTransaction(), iterator.isWithBeginTransaction()); - } else if (iterator.isWithBeginTransaction()) { - // The query should have returned a transaction. - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.FAILED_PRECONDITION, AbstractReadContext.NO_TRANSACTION_RETURNED_MSG); - } - currRow = new GrpcStruct(iterator.type(), new ArrayList<>()); - } - boolean hasNext = currRow.consumeRow(iterator); - if (!hasNext) { - statistics = iterator.getStats(); - } - return hasNext; - } catch (Throwable t) { - throw yieldError( - SpannerExceptionFactory.asSpannerException(t), - iterator.isWithBeginTransaction() && currRow == null); - } - } - - @Override - @Nullable - public ResultSetStats getStats() { - return statistics; - } - - @Override - public ResultSetMetadata getMetadata() { - checkState(metadata != null, "next() call required"); - return metadata; - } - - @Override - public void close() { - listener.onDone(iterator.isWithBeginTransaction()); - iterator.close("ResultSet closed"); - closed = true; - } - - @Override - public Type getType() { - checkState(currRow != null, "next() call required"); - return currRow.getType(); - } - - private SpannerException yieldError(SpannerException e, boolean beginTransaction) { - SpannerException toThrow = listener.onError(e, beginTransaction); - close(); - throw toThrow; - } - } - /** - * Adapts a stream of {@code PartialResultSet} messages into a stream of {@code Value} messages. - */ - private static class GrpcValueIterator extends AbstractIterator { - private enum StreamValue { - METADATA, - RESULT, - } - - private final CloseableIterator stream; - private ResultSetMetadata metadata; - private Type type; - private PartialResultSet current; - private int pos; - private ResultSetStats statistics; - - GrpcValueIterator(CloseableIterator stream) { - this.stream = stream; - } - - @SuppressWarnings("unchecked") - @Override - protected com.google.protobuf.Value computeNext() { - if (!ensureReady(StreamValue.RESULT)) { - endOfData(); - return null; - } - com.google.protobuf.Value value = current.getValues(pos++); - KindCase kind = value.getKindCase(); - - if (!isMergeable(kind)) { - if (pos == current.getValuesCount() && current.getChunkedValue()) { - throw newSpannerException(ErrorCode.INTERNAL, "Unexpected chunked PartialResultSet."); - } else { - return value; - } - } - if (!current.getChunkedValue() || pos != current.getValuesCount()) { - return value; - } - - Object merged = - kind == KindCase.STRING_VALUE - ? value.getStringValue() - : new ArrayList<>(value.getListValue().getValuesList()); - while (current.getChunkedValue() && pos == current.getValuesCount()) { - if (!ensureReady(StreamValue.RESULT)) { - throw newSpannerException( - ErrorCode.INTERNAL, "Stream closed in the middle of chunked value"); - } - com.google.protobuf.Value newValue = current.getValues(pos++); - if (newValue.getKindCase() != kind) { - throw newSpannerException( - ErrorCode.INTERNAL, - "Unexpected type in middle of chunked value. Expected: " - + kind - + " but got: " - + newValue.getKindCase()); - } - if (kind == KindCase.STRING_VALUE) { - merged = merged + newValue.getStringValue(); - } else { - concatLists( - (List) merged, newValue.getListValue().getValuesList()); - } - } - if (kind == KindCase.STRING_VALUE) { - return com.google.protobuf.Value.newBuilder().setStringValue((String) merged).build(); - } else { - return com.google.protobuf.Value.newBuilder() - .setListValue( - ListValue.newBuilder().addAllValues((List) merged)) - .build(); - } - } - - ResultSetMetadata getMetadata() throws SpannerException { - if (metadata == null) { - if (!ensureReady(StreamValue.METADATA)) { - throw newSpannerException(ErrorCode.INTERNAL, "Stream closed without sending metadata"); - } - } - return metadata; - } - - /** - * Get the query statistics. Query statistics are delivered with the last PartialResultSet in - * the stream. Any attempt to call this method before the caller has finished consuming the - * results will return null. - */ - @Nullable - ResultSetStats getStats() { - return statistics; - } - - Type type() { - checkState(type != null, "metadata has not been received"); - return type; - } - - private boolean ensureReady(StreamValue requiredValue) throws SpannerException { - while (current == null || pos >= current.getValuesCount()) { - if (!stream.hasNext()) { - return false; - } - current = stream.next(); - pos = 0; - if (type == null) { - // This is the first message on the stream. - if (!current.hasMetadata() || !current.getMetadata().hasRowType()) { - throw newSpannerException(ErrorCode.INTERNAL, "Missing type metadata in first message"); - } - metadata = current.getMetadata(); - com.google.spanner.v1.Type typeProto = - com.google.spanner.v1.Type.newBuilder() - .setCode(TypeCode.STRUCT) - .setStructType(metadata.getRowType()) - .build(); - try { - type = Type.fromProto(typeProto); - } catch (IllegalArgumentException e) { - throw newSpannerException( - ErrorCode.INTERNAL, "Invalid type metadata: " + e.getMessage(), e); - } - } - if (current.hasStats()) { - statistics = current.getStats(); - } - if (requiredValue == StreamValue.METADATA) { - return true; - } - } - return true; - } - - void close(@Nullable String message) { - stream.close(message); - } - - boolean isWithBeginTransaction() { - return stream.isWithBeginTransaction(); - } - - /** @param a is a mutable list and b will be concatenated into a. */ - private void concatLists(List a, List b) { - if (a.size() == 0 || b.size() == 0) { - a.addAll(b); - return; - } else { - com.google.protobuf.Value last = a.get(a.size() - 1); - com.google.protobuf.Value first = b.get(0); - KindCase lastKind = last.getKindCase(); - KindCase firstKind = first.getKindCase(); - if (isMergeable(lastKind) && lastKind == firstKind) { - com.google.protobuf.Value merged; - if (lastKind == KindCase.STRING_VALUE) { - String lastStr = last.getStringValue(); - String firstStr = first.getStringValue(); - merged = - com.google.protobuf.Value.newBuilder().setStringValue(lastStr + firstStr).build(); - } else { // List - List mergedList = new ArrayList<>(); - mergedList.addAll(last.getListValue().getValuesList()); - concatLists(mergedList, first.getListValue().getValuesList()); - merged = - com.google.protobuf.Value.newBuilder() - .setListValue(ListValue.newBuilder().addAllValues(mergedList)) - .build(); - } - a.set(a.size() - 1, merged); - a.addAll(b.subList(1, b.size())); - } else { - a.addAll(b); - } - } - } - - private boolean isMergeable(KindCase kind) { - return kind == KindCase.STRING_VALUE || kind == KindCase.LIST_VALUE; - } - } - static final class LazyByteArray implements Serializable { private static final Base64.Encoder ENCODER = Base64.getEncoder(); private static final Base64.Decoder DECODER = Base64.getDecoder(); @@ -444,598 +132,6 @@ private boolean lazyByteArraysEqual(LazyByteArray other) { } } - static class GrpcStruct extends Struct implements Serializable { - private final Type type; - private final List rowData; - - /** - * Builds an immutable version of this struct using {@link Struct#newBuilder()} which is used as - * a serialization proxy. - */ - private Object writeReplace() { - Builder builder = Struct.newBuilder(); - List structFields = getType().getStructFields(); - for (int i = 0; i < structFields.size(); i++) { - Type.StructField field = structFields.get(i); - String fieldName = field.getName(); - Object value = rowData.get(i); - Type fieldType = field.getType(); - switch (fieldType.getCode()) { - case BOOL: - builder.set(fieldName).to((Boolean) value); - break; - case INT64: - builder.set(fieldName).to((Long) value); - break; - case FLOAT64: - builder.set(fieldName).to((Double) value); - break; - case NUMERIC: - builder.set(fieldName).to((BigDecimal) value); - break; - case PG_NUMERIC: - builder.set(fieldName).to((String) value); - break; - case STRING: - builder.set(fieldName).to((String) value); - break; - case JSON: - builder.set(fieldName).to(Value.json((String) value)); - break; - case PROTO: - builder - .set(fieldName) - .to(Value.protoMessage((ByteArray) value, fieldType.getProtoTypeFqn())); - break; - case ENUM: - builder.set(fieldName).to(Value.protoEnum((Long) value, fieldType.getProtoTypeFqn())); - break; - case PG_JSONB: - builder.set(fieldName).to(Value.pgJsonb((String) value)); - break; - case BYTES: - builder - .set(fieldName) - .to( - Value.bytesFromBase64( - value == null ? null : ((LazyByteArray) value).getBase64String())); - break; - case TIMESTAMP: - builder.set(fieldName).to((Timestamp) value); - break; - case DATE: - builder.set(fieldName).to((Date) value); - break; - case ARRAY: - final Type elementType = fieldType.getArrayElementType(); - switch (elementType.getCode()) { - case BOOL: - builder.set(fieldName).toBoolArray((Iterable) value); - break; - case INT64: - case ENUM: - builder.set(fieldName).toInt64Array((Iterable) value); - break; - case FLOAT64: - builder.set(fieldName).toFloat64Array((Iterable) value); - break; - case NUMERIC: - builder.set(fieldName).toNumericArray((Iterable) value); - break; - case PG_NUMERIC: - builder.set(fieldName).toPgNumericArray((Iterable) value); - break; - case STRING: - builder.set(fieldName).toStringArray((Iterable) value); - break; - case JSON: - builder.set(fieldName).toJsonArray((Iterable) value); - break; - case PG_JSONB: - builder.set(fieldName).toPgJsonbArray((Iterable) value); - break; - case BYTES: - case PROTO: - builder - .set(fieldName) - .toBytesArrayFromBase64( - value == null - ? null - : ((List) value) - .stream() - .map( - element -> - element == null ? null : element.getBase64String()) - .collect(Collectors.toList())); - break; - case TIMESTAMP: - builder.set(fieldName).toTimestampArray((Iterable) value); - break; - case DATE: - builder.set(fieldName).toDateArray((Iterable) value); - break; - case STRUCT: - builder.set(fieldName).toStructArray(elementType, (Iterable) value); - break; - default: - throw new AssertionError("Unhandled array type code: " + elementType); - } - break; - case STRUCT: - if (value == null) { - builder.set(fieldName).to(fieldType, null); - } else { - builder.set(fieldName).to((Struct) value); - } - break; - default: - throw new AssertionError("Unhandled type code: " + fieldType.getCode()); - } - } - return builder.build(); - } - - GrpcStruct(Type type, List rowData) { - this.type = type; - this.rowData = rowData; - } - - @Override - public String toString() { - return this.rowData.toString(); - } - - boolean consumeRow(Iterator iterator) { - rowData.clear(); - if (!iterator.hasNext()) { - return false; - } - for (Type.StructField fieldType : getType().getStructFields()) { - if (!iterator.hasNext()) { - throw newSpannerException( - ErrorCode.INTERNAL, - "Invalid value stream: end of stream reached before row is complete"); - } - com.google.protobuf.Value value = iterator.next(); - rowData.add(decodeValue(fieldType.getType(), value)); - } - return true; - } - - private static Object decodeValue(Type fieldType, com.google.protobuf.Value proto) { - if (proto.getKindCase() == KindCase.NULL_VALUE) { - return null; - } - switch (fieldType.getCode()) { - case BOOL: - checkType(fieldType, proto, KindCase.BOOL_VALUE); - return proto.getBoolValue(); - case INT64: - case ENUM: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return Long.parseLong(proto.getStringValue()); - case FLOAT64: - return valueProtoToFloat64(proto); - case NUMERIC: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return new BigDecimal(proto.getStringValue()); - case PG_NUMERIC: - case STRING: - case JSON: - case PG_JSONB: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return proto.getStringValue(); - case BYTES: - case PROTO: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return new LazyByteArray(proto.getStringValue()); - case TIMESTAMP: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return Timestamp.parseTimestamp(proto.getStringValue()); - case DATE: - checkType(fieldType, proto, KindCase.STRING_VALUE); - return Date.parseDate(proto.getStringValue()); - case ARRAY: - checkType(fieldType, proto, KindCase.LIST_VALUE); - ListValue listValue = proto.getListValue(); - return decodeArrayValue(fieldType.getArrayElementType(), listValue); - case STRUCT: - checkType(fieldType, proto, KindCase.LIST_VALUE); - ListValue structValue = proto.getListValue(); - return decodeStructValue(fieldType, structValue); - case UNRECOGNIZED: - return proto; - default: - throw new AssertionError("Unhandled type code: " + fieldType.getCode()); - } - } - - private static Struct decodeStructValue(Type structType, ListValue structValue) { - List fieldTypes = structType.getStructFields(); - checkArgument( - structValue.getValuesCount() == fieldTypes.size(), - "Size mismatch between type descriptor and actual values."); - List fields = new ArrayList<>(fieldTypes.size()); - List fieldValues = structValue.getValuesList(); - for (int i = 0; i < fieldTypes.size(); ++i) { - fields.add(decodeValue(fieldTypes.get(i).getType(), fieldValues.get(i))); - } - return new GrpcStruct(structType, fields); - } - - static Object decodeArrayValue(Type elementType, ListValue listValue) { - switch (elementType.getCode()) { - case INT64: - case ENUM: - // For int64/float64/enum types, use custom containers. These avoid wrapper object - // creation for non-null arrays. - return new Int64Array(listValue); - case FLOAT64: - return new Float64Array(listValue); - case BOOL: - case NUMERIC: - case PG_NUMERIC: - case STRING: - case JSON: - case PG_JSONB: - case BYTES: - case TIMESTAMP: - case DATE: - case STRUCT: - case PROTO: - return Lists.transform( - listValue.getValuesList(), input -> decodeValue(elementType, input)); - default: - throw new AssertionError("Unhandled type code: " + elementType.getCode()); - } - } - - private static void checkType( - Type fieldType, com.google.protobuf.Value proto, KindCase expected) { - if (proto.getKindCase() != expected) { - throw newSpannerException( - ErrorCode.INTERNAL, - "Invalid value for column type " - + fieldType - + " expected " - + expected - + " but was " - + proto.getKindCase()); - } - } - - Struct immutableCopy() { - return new GrpcStruct(type, new ArrayList<>(rowData)); - } - - @Override - public Type getType() { - return type; - } - - @Override - public boolean isNull(int columnIndex) { - return rowData.get(columnIndex) == null; - } - - @Override - protected T getProtoMessageInternal(int columnIndex, T message) { - Preconditions.checkNotNull( - message, - "Proto message may not be null. Use MyProtoClass.getDefaultInstance() as a parameter value."); - try { - return (T) - message - .toBuilder() - .mergeFrom( - Base64.getDecoder() - .wrap( - CharSource.wrap(((LazyByteArray) rowData.get(columnIndex)).base64String) - .asByteSource(StandardCharsets.UTF_8) - .openStream())) - .build(); - } catch (IOException ioException) { - throw SpannerExceptionFactory.asSpannerException(ioException); - } - } - - @Override - protected T getProtoEnumInternal( - int columnIndex, Function method) { - Preconditions.checkNotNull( - method, "Method may not be null. Use 'MyProtoEnum::forNumber' as a parameter value."); - return (T) method.apply((int) getLongInternal(columnIndex)); - } - - @Override - protected boolean getBooleanInternal(int columnIndex) { - return (Boolean) rowData.get(columnIndex); - } - - @Override - protected long getLongInternal(int columnIndex) { - return (Long) rowData.get(columnIndex); - } - - @Override - protected double getDoubleInternal(int columnIndex) { - return (Double) rowData.get(columnIndex); - } - - @Override - protected BigDecimal getBigDecimalInternal(int columnIndex) { - return (BigDecimal) rowData.get(columnIndex); - } - - @Override - protected String getStringInternal(int columnIndex) { - return (String) rowData.get(columnIndex); - } - - @Override - protected String getJsonInternal(int columnIndex) { - return (String) rowData.get(columnIndex); - } - - @Override - protected String getPgJsonbInternal(int columnIndex) { - return (String) rowData.get(columnIndex); - } - - @Override - protected ByteArray getBytesInternal(int columnIndex) { - return getLazyBytesInternal(columnIndex).getByteArray(); - } - - LazyByteArray getLazyBytesInternal(int columnIndex) { - return (LazyByteArray) rowData.get(columnIndex); - } - - @Override - protected Timestamp getTimestampInternal(int columnIndex) { - return (Timestamp) rowData.get(columnIndex); - } - - @Override - protected Date getDateInternal(int columnIndex) { - return (Date) rowData.get(columnIndex); - } - - protected com.google.protobuf.Value getProtoValueInternal(int columnIndex) { - return (com.google.protobuf.Value) rowData.get(columnIndex); - } - - @Override - protected Value getValueInternal(int columnIndex) { - final List structFields = getType().getStructFields(); - final StructField structField = structFields.get(columnIndex); - final Type columnType = structField.getType(); - final boolean isNull = rowData.get(columnIndex) == null; - switch (columnType.getCode()) { - case BOOL: - return Value.bool(isNull ? null : getBooleanInternal(columnIndex)); - case INT64: - return Value.int64(isNull ? null : getLongInternal(columnIndex)); - case ENUM: - return Value.protoEnum(getLongInternal(columnIndex), columnType.getProtoTypeFqn()); - case NUMERIC: - return Value.numeric(isNull ? null : getBigDecimalInternal(columnIndex)); - case PG_NUMERIC: - return Value.pgNumeric(isNull ? null : getStringInternal(columnIndex)); - case FLOAT64: - return Value.float64(isNull ? null : getDoubleInternal(columnIndex)); - case STRING: - return Value.string(isNull ? null : getStringInternal(columnIndex)); - case JSON: - return Value.json(isNull ? null : getJsonInternal(columnIndex)); - case PG_JSONB: - return Value.pgJsonb(isNull ? null : getPgJsonbInternal(columnIndex)); - case BYTES: - return Value.internalBytes(isNull ? null : getLazyBytesInternal(columnIndex)); - case PROTO: - return Value.protoMessage(getBytesInternal(columnIndex), columnType.getProtoTypeFqn()); - case TIMESTAMP: - return Value.timestamp(isNull ? null : getTimestampInternal(columnIndex)); - case DATE: - return Value.date(isNull ? null : getDateInternal(columnIndex)); - case STRUCT: - return Value.struct(isNull ? null : getStructInternal(columnIndex)); - case UNRECOGNIZED: - return Value.unrecognized( - isNull ? NULL_VALUE : getProtoValueInternal(columnIndex), columnType); - case ARRAY: - final Type elementType = columnType.getArrayElementType(); - switch (elementType.getCode()) { - case BOOL: - return Value.boolArray(isNull ? null : getBooleanListInternal(columnIndex)); - case INT64: - return Value.int64Array(isNull ? null : getLongListInternal(columnIndex)); - case NUMERIC: - return Value.numericArray(isNull ? null : getBigDecimalListInternal(columnIndex)); - case PG_NUMERIC: - return Value.pgNumericArray(isNull ? null : getStringListInternal(columnIndex)); - case FLOAT64: - return Value.float64Array(isNull ? null : getDoubleListInternal(columnIndex)); - case STRING: - return Value.stringArray(isNull ? null : getStringListInternal(columnIndex)); - case JSON: - return Value.jsonArray(isNull ? null : getJsonListInternal(columnIndex)); - case PG_JSONB: - return Value.pgJsonbArray(isNull ? null : getPgJsonbListInternal(columnIndex)); - case BYTES: - return Value.bytesArray(isNull ? null : getBytesListInternal(columnIndex)); - case PROTO: - return Value.protoMessageArray( - isNull ? null : getBytesListInternal(columnIndex), elementType.getProtoTypeFqn()); - case ENUM: - return Value.protoEnumArray( - isNull ? null : getLongListInternal(columnIndex), elementType.getProtoTypeFqn()); - case TIMESTAMP: - return Value.timestampArray(isNull ? null : getTimestampListInternal(columnIndex)); - case DATE: - return Value.dateArray(isNull ? null : getDateListInternal(columnIndex)); - case STRUCT: - return Value.structArray( - elementType, isNull ? null : getStructListInternal(columnIndex)); - default: - throw new IllegalArgumentException( - "Invalid array value type " + this.type.getArrayElementType()); - } - default: - throw new IllegalArgumentException("Invalid value type " + this.type); - } - } - - @Override - protected Struct getStructInternal(int columnIndex) { - return (Struct) rowData.get(columnIndex); - } - - @Override - protected boolean[] getBooleanArrayInternal(int columnIndex) { - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - List values = (List) rowData.get(columnIndex); - boolean[] r = new boolean[values.size()]; - for (int i = 0; i < values.size(); ++i) { - if (values.get(i) == null) { - throw throwNotNull(columnIndex); - } - r[i] = values.get(i); - } - return r; - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getBooleanListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - protected long[] getLongArrayInternal(int columnIndex) { - return getLongListInternal(columnIndex).toPrimitiveArray(columnIndex); - } - - @Override - protected Int64Array getLongListInternal(int columnIndex) { - return (Int64Array) rowData.get(columnIndex); - } - - @Override - protected double[] getDoubleArrayInternal(int columnIndex) { - return getDoubleListInternal(columnIndex).toPrimitiveArray(columnIndex); - } - - @Override - protected Float64Array getDoubleListInternal(int columnIndex) { - return (Float64Array) rowData.get(columnIndex); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getBigDecimalListInternal(int columnIndex) { - return (List) rowData.get(columnIndex); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getStringListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getJsonListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getProtoMessageListInternal( - int columnIndex, T message) { - Preconditions.checkNotNull( - message, - "Proto message may not be null. Use MyProtoClass.getDefaultInstance() as a parameter value."); - - List bytesArray = (List) rowData.get(columnIndex); - - try { - List protoMessagesList = new ArrayList<>(bytesArray.size()); - for (LazyByteArray protoMessageBytes : bytesArray) { - if (protoMessageBytes == null) { - protoMessagesList.add(null); - } else { - protoMessagesList.add( - (T) - message - .toBuilder() - .mergeFrom( - Base64.getDecoder() - .wrap( - CharSource.wrap(protoMessageBytes.base64String) - .asByteSource(StandardCharsets.UTF_8) - .openStream())) - .build()); - } - } - return protoMessagesList; - } catch (IOException ioException) { - throw SpannerExceptionFactory.asSpannerException(ioException); - } - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getProtoEnumListInternal( - int columnIndex, Function method) { - Preconditions.checkNotNull( - method, "Method may not be null. Use 'MyProtoEnum::forNumber' as a parameter value."); - - List enumIntArray = (List) rowData.get(columnIndex); - List protoEnumList = new ArrayList<>(enumIntArray.size()); - for (Long enumIntValue : enumIntArray) { - if (enumIntValue == null) { - protoEnumList.add(null); - } else { - protoEnumList.add((T) method.apply(enumIntValue.intValue())); - } - } - - return protoEnumList; - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getPgJsonbListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getBytesListInternal(int columnIndex) { - return Lists.transform( - (List) rowData.get(columnIndex), l -> l == null ? null : l.getByteArray()); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getTimestampListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY produces a List. - protected List getDateListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - - @Override - @SuppressWarnings("unchecked") // We know ARRAY> produces a List. - protected List getStructListInternal(int columnIndex) { - return Collections.unmodifiableList((List) rowData.get(columnIndex)); - } - } - @VisibleForTesting interface CloseableIterator extends Iterator { @@ -1049,383 +145,6 @@ interface CloseableIterator extends Iterator { boolean isWithBeginTransaction(); } - /** Adapts a streaming read/query call into an iterator over partial result sets. */ - @VisibleForTesting - static class GrpcStreamIterator extends AbstractIterator - implements CloseableIterator { - private static final Logger logger = Logger.getLogger(GrpcStreamIterator.class.getName()); - private static final PartialResultSet END_OF_STREAM = PartialResultSet.newBuilder().build(); - - private final ConsumerImpl consumer = new ConsumerImpl(); - private final BlockingQueue stream; - private final Statement statement; - - private SpannerRpc.StreamingCall call; - private volatile boolean withBeginTransaction; - private TimeUnit streamWaitTimeoutUnit; - private long streamWaitTimeoutValue; - private SpannerException error; - - @VisibleForTesting - GrpcStreamIterator(int prefetchChunks) { - this(null, prefetchChunks); - } - - @VisibleForTesting - GrpcStreamIterator(Statement statement, int prefetchChunks) { - this.statement = statement; - // One extra to allow for END_OF_STREAM message. - this.stream = new LinkedBlockingQueue<>(prefetchChunks + 1); - } - - protected final SpannerRpc.ResultStreamConsumer consumer() { - return consumer; - } - - public void setCall(SpannerRpc.StreamingCall call, boolean withBeginTransaction) { - this.call = call; - this.withBeginTransaction = withBeginTransaction; - ApiCallContext callContext = call.getCallContext(); - Duration streamWaitTimeout = callContext == null ? null : callContext.getStreamWaitTimeout(); - if (streamWaitTimeout != null) { - // Determine the timeout unit to use. This reduces the precision to seconds if the timeout - // value is more than 1 second, which is lower than the precision that would normally be - // used by the stream watchdog (which uses a precision of 10 seconds by default). - if (streamWaitTimeout.getSeconds() > 0L) { - streamWaitTimeoutValue = streamWaitTimeout.getSeconds(); - streamWaitTimeoutUnit = TimeUnit.SECONDS; - } else if (streamWaitTimeout.getNano() > 0) { - streamWaitTimeoutValue = streamWaitTimeout.getNano(); - streamWaitTimeoutUnit = TimeUnit.NANOSECONDS; - } - // Note that if the stream-wait-timeout is zero, we won't set a timeout at all. - // That is consistent with ApiCallContext#withStreamWaitTimeout(Duration.ZERO). - } - } - - @Override - public void close(@Nullable String message) { - if (call != null) { - call.cancel(message); - } - } - - @Override - public boolean isWithBeginTransaction() { - return withBeginTransaction; - } - - @Override - protected final PartialResultSet computeNext() { - PartialResultSet next; - try { - if (streamWaitTimeoutUnit != null) { - next = stream.poll(streamWaitTimeoutValue, streamWaitTimeoutUnit); - if (next == null) { - throw SpannerExceptionFactory.newSpannerException( - ErrorCode.DEADLINE_EXCEEDED, "stream wait timeout"); - } - } else { - next = stream.take(); - } - } catch (InterruptedException e) { - // Treat interrupt as a request to cancel the read. - throw SpannerExceptionFactory.propagateInterrupt(e); - } - if (next != END_OF_STREAM) { - call.request(1); - return next; - } - - // All done - close() no longer needs to cancel the call. - call = null; - - if (error != null) { - throw SpannerExceptionFactory.newSpannerException(error); - } - - endOfData(); - return null; - } - - private void addToStream(PartialResultSet results) { - // We assume that nothing from the user will interrupt gRPC event threads. - Uninterruptibles.putUninterruptibly(stream, results); - } - - private class ConsumerImpl implements SpannerRpc.ResultStreamConsumer { - @Override - public void onPartialResultSet(PartialResultSet results) { - addToStream(results); - } - - @Override - public void onCompleted() { - addToStream(END_OF_STREAM); - } - - @Override - public void onError(SpannerException e) { - if (statement != null) { - if (logger.isLoggable(Level.FINEST)) { - // Include parameter values if logging level is set to FINEST or higher. - e = - SpannerExceptionFactory.newSpannerExceptionPreformatted( - e.getErrorCode(), - String.format("%s - Statement: '%s'", e.getMessage(), statement.toString()), - e); - logger.log(Level.FINEST, "Error executing statement", e); - } else { - e = - SpannerExceptionFactory.newSpannerExceptionPreformatted( - e.getErrorCode(), - String.format("%s - Statement: '%s'", e.getMessage(), statement.getSql()), - e); - } - } - error = e; - addToStream(END_OF_STREAM); - } - } - } - - /** - * Wraps an iterator over partial result sets, supporting resuming RPCs on error. This class keeps - * track of the most recent resume token seen, and will buffer partial result set chunks that do - * not have a resume token until one is seen or buffer space is exceeded, which reduces the chance - * of yielding data to the caller that cannot be resumed. - */ - @VisibleForTesting - abstract static class ResumableStreamIterator extends AbstractIterator - implements CloseableIterator { - private static final RetrySettings DEFAULT_STREAMING_RETRY_SETTINGS = - SpannerStubSettings.newBuilder().executeStreamingSqlSettings().getRetrySettings(); - private final RetrySettings streamingRetrySettings; - private final Set retryableCodes; - private static final Logger logger = Logger.getLogger(ResumableStreamIterator.class.getName()); - private final BackOff backOff; - private final LinkedList buffer = new LinkedList<>(); - private final int maxBufferSize; - private final Span span; - private CloseableIterator stream; - private ByteString resumeToken; - private boolean finished; - /** - * Indicates whether it is currently safe to retry RPCs. This will be {@code false} if we have - * reached the maximum buffer size without seeing a restart token; in this case, we will drain - * the buffer and remain in this state until we see a new restart token. - */ - private boolean safeToRetry = true; - - protected ResumableStreamIterator( - int maxBufferSize, - String streamName, - Span parent, - RetrySettings streamingRetrySettings, - Set retryableCodes) { - checkArgument(maxBufferSize >= 0); - this.maxBufferSize = maxBufferSize; - this.span = tracer.spanBuilderWithExplicitParent(streamName, parent).startSpan(); - this.streamingRetrySettings = Preconditions.checkNotNull(streamingRetrySettings); - this.retryableCodes = Preconditions.checkNotNull(retryableCodes); - this.backOff = newBackOff(); - } - - private ExponentialBackOff newBackOff() { - if (Objects.equals(streamingRetrySettings, DEFAULT_STREAMING_RETRY_SETTINGS)) { - return new ExponentialBackOff.Builder() - .setMultiplier(streamingRetrySettings.getRetryDelayMultiplier()) - .setInitialIntervalMillis( - Math.max(10, (int) streamingRetrySettings.getInitialRetryDelay().toMillis())) - .setMaxIntervalMillis( - Math.max(1000, (int) streamingRetrySettings.getMaxRetryDelay().toMillis())) - .setMaxElapsedTimeMillis( - Integer.MAX_VALUE) // Prevent Backoff.STOP from getting returned. - .build(); - } - return new ExponentialBackOff.Builder() - .setMultiplier(streamingRetrySettings.getRetryDelayMultiplier()) - // All of these values must be > 0. - .setInitialIntervalMillis( - Math.max( - 1, - (int) - Math.min( - streamingRetrySettings.getInitialRetryDelay().toMillis(), - Integer.MAX_VALUE))) - .setMaxIntervalMillis( - Math.max( - 1, - (int) - Math.min( - streamingRetrySettings.getMaxRetryDelay().toMillis(), Integer.MAX_VALUE))) - .setMaxElapsedTimeMillis( - Math.max( - 1, - (int) - Math.min( - streamingRetrySettings.getTotalTimeout().toMillis(), Integer.MAX_VALUE))) - .build(); - } - - private void backoffSleep(Context context, BackOff backoff) throws SpannerException { - backoffSleep(context, nextBackOffMillis(backoff)); - } - - private static long nextBackOffMillis(BackOff backoff) throws SpannerException { - try { - return backoff.nextBackOffMillis(); - } catch (IOException e) { - throw newSpannerException(ErrorCode.INTERNAL, e.getMessage(), e); - } - } - - private void backoffSleep(Context context, long backoffMillis) throws SpannerException { - tracer - .getCurrentSpan() - .addAnnotation( - "Backing off", - ImmutableMap.of("Delay", AttributeValue.longAttributeValue(backoffMillis))); - final CountDownLatch latch = new CountDownLatch(1); - final Context.CancellationListener listener = - ignored -> { - // Wakeup on cancellation / DEADLINE_EXCEEDED. - latch.countDown(); - }; - - context.addListener(listener, DirectExecutor.INSTANCE); - try { - if (backoffMillis == BackOff.STOP) { - // Highly unlikely but we handle it just in case. - backoffMillis = streamingRetrySettings.getMaxRetryDelay().toMillis(); - } - if (latch.await(backoffMillis, TimeUnit.MILLISECONDS)) { - // Woken by context cancellation. - throw newSpannerExceptionForCancellation(context, null); - } - } catch (InterruptedException interruptExcept) { - throw newSpannerExceptionForCancellation(context, interruptExcept); - } finally { - context.removeListener(listener); - } - } - - private enum DirectExecutor implements Executor { - INSTANCE; - - @Override - public void execute(Runnable command) { - command.run(); - } - } - - abstract CloseableIterator startStream(@Nullable ByteString resumeToken); - - @Override - public void close(@Nullable String message) { - if (stream != null) { - stream.close(message); - span.end(TraceUtil.END_SPAN_OPTIONS); - stream = null; - } - } - - @Override - public boolean isWithBeginTransaction() { - return stream != null && stream.isWithBeginTransaction(); - } - - @Override - protected PartialResultSet computeNext() { - Context context = Context.current(); - while (true) { - // Eagerly start stream before consuming any buffered items. - if (stream == null) { - span.addAnnotation( - "Starting/Resuming stream", - ImmutableMap.of( - "ResumeToken", - AttributeValue.stringAttributeValue( - resumeToken == null ? "null" : resumeToken.toStringUtf8()))); - try (Scope s = tracer.withSpan(span)) { - // When start a new stream set the Span as current to make the gRPC Span a child of - // this Span. - stream = checkNotNull(startStream(resumeToken)); - } - } - // Buffer contains items up to a resume token or has reached capacity: flush. - if (!buffer.isEmpty() - && (finished || !safeToRetry || !buffer.getLast().getResumeToken().isEmpty())) { - return buffer.pop(); - } - try { - if (stream.hasNext()) { - PartialResultSet next = stream.next(); - boolean hasResumeToken = !next.getResumeToken().isEmpty(); - if (hasResumeToken) { - resumeToken = next.getResumeToken(); - safeToRetry = true; - } - // If the buffer is empty and this chunk has a resume token or we cannot resume safely - // anyway, we can yield it immediately rather than placing it in the buffer to be - // returned on the next iteration. - if ((hasResumeToken || !safeToRetry) && buffer.isEmpty()) { - return next; - } - buffer.add(next); - if (buffer.size() > maxBufferSize && buffer.getLast().getResumeToken().isEmpty()) { - // We need to flush without a restart token. Errors encountered until we see - // such a token will fail the read. - safeToRetry = false; - } - } else { - finished = true; - if (buffer.isEmpty()) { - endOfData(); - return null; - } - } - } catch (SpannerException spannerException) { - if (safeToRetry && isRetryable(spannerException)) { - span.addAnnotation( - "Stream broken. Safe to retry", - TraceUtil.getExceptionAnnotations(spannerException)); - logger.log(Level.FINE, "Retryable exception, will sleep and retry", spannerException); - // Truncate any items in the buffer before the last retry token. - while (!buffer.isEmpty() && buffer.getLast().getResumeToken().isEmpty()) { - buffer.removeLast(); - } - assert buffer.isEmpty() || buffer.getLast().getResumeToken().equals(resumeToken); - stream = null; - try (Scope s = tracer.withSpan(span)) { - long delay = spannerException.getRetryDelayInMillis(); - if (delay != -1) { - backoffSleep(context, delay); - } else { - backoffSleep(context, backOff); - } - } - - continue; - } - span.addAnnotation("Stream broken. Not safe to retry"); - TraceUtil.setWithFailure(span, spannerException); - throw spannerException; - } catch (RuntimeException e) { - span.addAnnotation("Stream broken. Not safe to retry"); - TraceUtil.setWithFailure(span, e); - throw e; - } - } - } - - boolean isRetryable(SpannerException spannerException) { - return spannerException.isRetryable() - || retryableCodes.contains( - GrpcStatusCode.of(spannerException.getErrorCode().getGrpcStatusCode()).getCode()); - } - } - static double valueProtoToFloat64(com.google.protobuf.Value proto) { if (proto.getKindCase() == KindCase.STRING_VALUE) { switch (proto.getStringValue()) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java new file mode 100644 index 00000000000..6100f580404 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java @@ -0,0 +1,113 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; +import com.google.spanner.v1.PartialResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; + +@VisibleForTesting +class GrpcResultSet extends AbstractResultSet> { + + private final GrpcValueIterator iterator; + private final Listener listener; + private ResultSetMetadata metadata; + private GrpcStruct currRow; + private SpannerException error; + private ResultSetStats statistics; + private boolean closed; + + GrpcResultSet(CloseableIterator iterator, Listener listener) { + this.iterator = new GrpcValueIterator(iterator); + this.listener = listener; + } + + @Override + protected GrpcStruct currRow() { + checkState(!closed, "ResultSet is closed"); + checkState(currRow != null, "next() call required"); + return currRow; + } + + @Override + public boolean next() throws SpannerException { + if (error != null) { + throw newSpannerException(error); + } + try { + if (currRow == null) { + metadata = iterator.getMetadata(); + if (metadata.hasTransaction()) { + listener.onTransactionMetadata( + metadata.getTransaction(), iterator.isWithBeginTransaction()); + } else if (iterator.isWithBeginTransaction()) { + // The query should have returned a transaction. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, AbstractReadContext.NO_TRANSACTION_RETURNED_MSG); + } + currRow = new GrpcStruct(iterator.type(), new ArrayList<>()); + } + boolean hasNext = currRow.consumeRow(iterator); + if (!hasNext) { + statistics = iterator.getStats(); + } + return hasNext; + } catch (Throwable t) { + throw yieldError( + SpannerExceptionFactory.asSpannerException(t), + iterator.isWithBeginTransaction() && currRow == null); + } + } + + @Override + @Nullable + public ResultSetStats getStats() { + return statistics; + } + + @Override + public ResultSetMetadata getMetadata() { + checkState(metadata != null, "next() call required"); + return metadata; + } + + @Override + public void close() { + listener.onDone(iterator.isWithBeginTransaction()); + iterator.close("ResultSet closed"); + closed = true; + } + + @Override + public Type getType() { + checkState(currRow != null, "next() call required"); + return currRow.getType(); + } + + private SpannerException yieldError(SpannerException e, boolean beginTransaction) { + SpannerException toThrow = listener.onError(e, beginTransaction); + close(); + throw toThrow; + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java new file mode 100644 index 00000000000..11b645a3ff9 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java @@ -0,0 +1,174 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.gax.rpc.ApiCallContext; +import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; +import com.google.cloud.spanner.spi.v1.SpannerRpc; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.AbstractIterator; +import com.google.common.util.concurrent.Uninterruptibles; +import com.google.spanner.v1.PartialResultSet; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; +import org.threeten.bp.Duration; + +/** Adapts a streaming read/query call into an iterator over partial result sets. */ +@VisibleForTesting +class GrpcStreamIterator extends AbstractIterator + implements CloseableIterator { + + private static final Logger logger = Logger.getLogger(GrpcStreamIterator.class.getName()); + private static final PartialResultSet END_OF_STREAM = PartialResultSet.newBuilder().build(); + + private final ConsumerImpl consumer = new ConsumerImpl(); + private final BlockingQueue stream; + private final Statement statement; + + private SpannerRpc.StreamingCall call; + private volatile boolean withBeginTransaction; + private TimeUnit streamWaitTimeoutUnit; + private long streamWaitTimeoutValue; + private SpannerException error; + + @VisibleForTesting + GrpcStreamIterator(int prefetchChunks) { + this(null, prefetchChunks); + } + + @VisibleForTesting + GrpcStreamIterator(Statement statement, int prefetchChunks) { + this.statement = statement; + // One extra to allow for END_OF_STREAM message. + this.stream = new LinkedBlockingQueue<>(prefetchChunks + 1); + } + + protected final SpannerRpc.ResultStreamConsumer consumer() { + return consumer; + } + + public void setCall(SpannerRpc.StreamingCall call, boolean withBeginTransaction) { + this.call = call; + this.withBeginTransaction = withBeginTransaction; + ApiCallContext callContext = call.getCallContext(); + Duration streamWaitTimeout = callContext == null ? null : callContext.getStreamWaitTimeout(); + if (streamWaitTimeout != null) { + // Determine the timeout unit to use. This reduces the precision to seconds if the timeout + // value is more than 1 second, which is lower than the precision that would normally be + // used by the stream watchdog (which uses a precision of 10 seconds by default). + if (streamWaitTimeout.getSeconds() > 0L) { + streamWaitTimeoutValue = streamWaitTimeout.getSeconds(); + streamWaitTimeoutUnit = TimeUnit.SECONDS; + } else if (streamWaitTimeout.getNano() > 0) { + streamWaitTimeoutValue = streamWaitTimeout.getNano(); + streamWaitTimeoutUnit = TimeUnit.NANOSECONDS; + } + // Note that if the stream-wait-timeout is zero, we won't set a timeout at all. + // That is consistent with ApiCallContext#withStreamWaitTimeout(Duration.ZERO). + } + } + + @Override + public void close(@Nullable String message) { + if (call != null) { + call.cancel(message); + } + } + + @Override + public boolean isWithBeginTransaction() { + return withBeginTransaction; + } + + @Override + protected final PartialResultSet computeNext() { + PartialResultSet next; + try { + if (streamWaitTimeoutUnit != null) { + next = stream.poll(streamWaitTimeoutValue, streamWaitTimeoutUnit); + if (next == null) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.DEADLINE_EXCEEDED, "stream wait timeout"); + } + } else { + next = stream.take(); + } + } catch (InterruptedException e) { + // Treat interrupt as a request to cancel the read. + throw SpannerExceptionFactory.propagateInterrupt(e); + } + if (next != END_OF_STREAM) { + call.request(1); + return next; + } + + // All done - close() no longer needs to cancel the call. + call = null; + + if (error != null) { + throw SpannerExceptionFactory.newSpannerException(error); + } + + endOfData(); + return null; + } + + private void addToStream(PartialResultSet results) { + // We assume that nothing from the user will interrupt gRPC event threads. + Uninterruptibles.putUninterruptibly(stream, results); + } + + private class ConsumerImpl implements SpannerRpc.ResultStreamConsumer { + + @Override + public void onPartialResultSet(PartialResultSet results) { + addToStream(results); + } + + @Override + public void onCompleted() { + addToStream(END_OF_STREAM); + } + + @Override + public void onError(SpannerException e) { + if (statement != null) { + if (logger.isLoggable(Level.FINEST)) { + // Include parameter values if logging level is set to FINEST or higher. + e = + SpannerExceptionFactory.newSpannerExceptionPreformatted( + e.getErrorCode(), + String.format("%s - Statement: '%s'", e.getMessage(), statement.toString()), + e); + logger.log(Level.FINEST, "Error executing statement", e); + } else { + e = + SpannerExceptionFactory.newSpannerExceptionPreformatted( + e.getErrorCode(), + String.format("%s - Statement: '%s'", e.getMessage(), statement.getSql()), + e); + } + } + error = e; + addToStream(END_OF_STREAM); + } + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java new file mode 100644 index 00000000000..e1b347a26f1 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java @@ -0,0 +1,641 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.ByteArray; +import com.google.cloud.Date; +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.AbstractResultSet.Float64Array; +import com.google.cloud.spanner.AbstractResultSet.Int64Array; +import com.google.cloud.spanner.AbstractResultSet.LazyByteArray; +import com.google.cloud.spanner.Type.StructField; +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.io.CharSource; +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.ListValue; +import com.google.protobuf.NullValue; +import com.google.protobuf.ProtocolMessageEnum; +import com.google.protobuf.Value.KindCase; +import java.io.IOException; +import java.io.Serializable; +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +class GrpcStruct extends Struct implements Serializable { + private static final com.google.protobuf.Value NULL_VALUE = + com.google.protobuf.Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build(); + + private final Type type; + private final List rowData; + + /** + * Builds an immutable version of this struct using {@link Struct#newBuilder()} which is used as a + * serialization proxy. + */ + private Object writeReplace() { + Builder builder = Struct.newBuilder(); + List structFields = getType().getStructFields(); + for (int i = 0; i < structFields.size(); i++) { + Type.StructField field = structFields.get(i); + String fieldName = field.getName(); + Object value = rowData.get(i); + Type fieldType = field.getType(); + switch (fieldType.getCode()) { + case BOOL: + builder.set(fieldName).to((Boolean) value); + break; + case INT64: + builder.set(fieldName).to((Long) value); + break; + case FLOAT64: + builder.set(fieldName).to((Double) value); + break; + case NUMERIC: + builder.set(fieldName).to((BigDecimal) value); + break; + case PG_NUMERIC: + builder.set(fieldName).to((String) value); + break; + case STRING: + builder.set(fieldName).to((String) value); + break; + case JSON: + builder.set(fieldName).to(Value.json((String) value)); + break; + case PROTO: + builder + .set(fieldName) + .to(Value.protoMessage((ByteArray) value, fieldType.getProtoTypeFqn())); + break; + case ENUM: + builder.set(fieldName).to(Value.protoEnum((Long) value, fieldType.getProtoTypeFqn())); + break; + case PG_JSONB: + builder.set(fieldName).to(Value.pgJsonb((String) value)); + break; + case BYTES: + builder + .set(fieldName) + .to( + Value.bytesFromBase64( + value == null ? null : ((LazyByteArray) value).getBase64String())); + break; + case TIMESTAMP: + builder.set(fieldName).to((Timestamp) value); + break; + case DATE: + builder.set(fieldName).to((Date) value); + break; + case ARRAY: + final Type elementType = fieldType.getArrayElementType(); + switch (elementType.getCode()) { + case BOOL: + builder.set(fieldName).toBoolArray((Iterable) value); + break; + case INT64: + case ENUM: + builder.set(fieldName).toInt64Array((Iterable) value); + break; + case FLOAT64: + builder.set(fieldName).toFloat64Array((Iterable) value); + break; + case NUMERIC: + builder.set(fieldName).toNumericArray((Iterable) value); + break; + case PG_NUMERIC: + builder.set(fieldName).toPgNumericArray((Iterable) value); + break; + case STRING: + builder.set(fieldName).toStringArray((Iterable) value); + break; + case JSON: + builder.set(fieldName).toJsonArray((Iterable) value); + break; + case PG_JSONB: + builder.set(fieldName).toPgJsonbArray((Iterable) value); + break; + case BYTES: + case PROTO: + builder + .set(fieldName) + .toBytesArrayFromBase64( + value == null + ? null + : ((List) value) + .stream() + .map( + element -> element == null ? null : element.getBase64String()) + .collect(Collectors.toList())); + break; + case TIMESTAMP: + builder.set(fieldName).toTimestampArray((Iterable) value); + break; + case DATE: + builder.set(fieldName).toDateArray((Iterable) value); + break; + case STRUCT: + builder.set(fieldName).toStructArray(elementType, (Iterable) value); + break; + default: + throw new AssertionError("Unhandled array type code: " + elementType); + } + break; + case STRUCT: + if (value == null) { + builder.set(fieldName).to(fieldType, null); + } else { + builder.set(fieldName).to((Struct) value); + } + break; + default: + throw new AssertionError("Unhandled type code: " + fieldType.getCode()); + } + } + return builder.build(); + } + + GrpcStruct(Type type, List rowData) { + this.type = type; + this.rowData = rowData; + } + + @Override + public String toString() { + return this.rowData.toString(); + } + + boolean consumeRow(Iterator iterator) { + rowData.clear(); + if (!iterator.hasNext()) { + return false; + } + for (Type.StructField fieldType : getType().getStructFields()) { + if (!iterator.hasNext()) { + throw newSpannerException( + ErrorCode.INTERNAL, + "Invalid value stream: end of stream reached before row is complete"); + } + com.google.protobuf.Value value = iterator.next(); + rowData.add(decodeValue(fieldType.getType(), value)); + } + return true; + } + + private static Object decodeValue(Type fieldType, com.google.protobuf.Value proto) { + if (proto.getKindCase() == KindCase.NULL_VALUE) { + return null; + } + switch (fieldType.getCode()) { + case BOOL: + checkType(fieldType, proto, KindCase.BOOL_VALUE); + return proto.getBoolValue(); + case INT64: + case ENUM: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return Long.parseLong(proto.getStringValue()); + case FLOAT64: + return AbstractResultSet.valueProtoToFloat64(proto); + case NUMERIC: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return new BigDecimal(proto.getStringValue()); + case PG_NUMERIC: + case STRING: + case JSON: + case PG_JSONB: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return proto.getStringValue(); + case BYTES: + case PROTO: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return new LazyByteArray(proto.getStringValue()); + case TIMESTAMP: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return Timestamp.parseTimestamp(proto.getStringValue()); + case DATE: + checkType(fieldType, proto, KindCase.STRING_VALUE); + return Date.parseDate(proto.getStringValue()); + case ARRAY: + checkType(fieldType, proto, KindCase.LIST_VALUE); + ListValue listValue = proto.getListValue(); + return decodeArrayValue(fieldType.getArrayElementType(), listValue); + case STRUCT: + checkType(fieldType, proto, KindCase.LIST_VALUE); + ListValue structValue = proto.getListValue(); + return decodeStructValue(fieldType, structValue); + case UNRECOGNIZED: + return proto; + default: + throw new AssertionError("Unhandled type code: " + fieldType.getCode()); + } + } + + private static Struct decodeStructValue(Type structType, ListValue structValue) { + List fieldTypes = structType.getStructFields(); + checkArgument( + structValue.getValuesCount() == fieldTypes.size(), + "Size mismatch between type descriptor and actual values."); + List fields = new ArrayList<>(fieldTypes.size()); + List fieldValues = structValue.getValuesList(); + for (int i = 0; i < fieldTypes.size(); ++i) { + fields.add(decodeValue(fieldTypes.get(i).getType(), fieldValues.get(i))); + } + return new GrpcStruct(structType, fields); + } + + static Object decodeArrayValue(Type elementType, ListValue listValue) { + switch (elementType.getCode()) { + case INT64: + case ENUM: + // For int64/float64/enum types, use custom containers. These avoid wrapper object + // creation for non-null arrays. + return new Int64Array(listValue); + case FLOAT64: + return new Float64Array(listValue); + case BOOL: + case NUMERIC: + case PG_NUMERIC: + case STRING: + case JSON: + case PG_JSONB: + case BYTES: + case TIMESTAMP: + case DATE: + case STRUCT: + case PROTO: + return Lists.transform(listValue.getValuesList(), input -> decodeValue(elementType, input)); + default: + throw new AssertionError("Unhandled type code: " + elementType.getCode()); + } + } + + private static void checkType( + Type fieldType, com.google.protobuf.Value proto, KindCase expected) { + if (proto.getKindCase() != expected) { + throw newSpannerException( + ErrorCode.INTERNAL, + "Invalid value for column type " + + fieldType + + " expected " + + expected + + " but was " + + proto.getKindCase()); + } + } + + Struct immutableCopy() { + return new GrpcStruct(type, new ArrayList<>(rowData)); + } + + @Override + public Type getType() { + return type; + } + + @Override + public boolean isNull(int columnIndex) { + return rowData.get(columnIndex) == null; + } + + @Override + protected T getProtoMessageInternal(int columnIndex, T message) { + Preconditions.checkNotNull( + message, + "Proto message may not be null. Use MyProtoClass.getDefaultInstance() as a parameter value."); + try { + return (T) + message + .toBuilder() + .mergeFrom( + Base64.getDecoder() + .wrap( + CharSource.wrap( + ((LazyByteArray) rowData.get(columnIndex)).getBase64String()) + .asByteSource(StandardCharsets.UTF_8) + .openStream())) + .build(); + } catch (IOException ioException) { + throw SpannerExceptionFactory.asSpannerException(ioException); + } + } + + @Override + protected T getProtoEnumInternal( + int columnIndex, Function method) { + Preconditions.checkNotNull( + method, "Method may not be null. Use 'MyProtoEnum::forNumber' as a parameter value."); + return (T) method.apply((int) getLongInternal(columnIndex)); + } + + @Override + protected boolean getBooleanInternal(int columnIndex) { + return (Boolean) rowData.get(columnIndex); + } + + @Override + protected long getLongInternal(int columnIndex) { + return (Long) rowData.get(columnIndex); + } + + @Override + protected double getDoubleInternal(int columnIndex) { + return (Double) rowData.get(columnIndex); + } + + @Override + protected BigDecimal getBigDecimalInternal(int columnIndex) { + return (BigDecimal) rowData.get(columnIndex); + } + + @Override + protected String getStringInternal(int columnIndex) { + return (String) rowData.get(columnIndex); + } + + @Override + protected String getJsonInternal(int columnIndex) { + return (String) rowData.get(columnIndex); + } + + @Override + protected String getPgJsonbInternal(int columnIndex) { + return (String) rowData.get(columnIndex); + } + + @Override + protected ByteArray getBytesInternal(int columnIndex) { + return getLazyBytesInternal(columnIndex).getByteArray(); + } + + LazyByteArray getLazyBytesInternal(int columnIndex) { + return (LazyByteArray) rowData.get(columnIndex); + } + + @Override + protected Timestamp getTimestampInternal(int columnIndex) { + return (Timestamp) rowData.get(columnIndex); + } + + @Override + protected Date getDateInternal(int columnIndex) { + return (Date) rowData.get(columnIndex); + } + + protected com.google.protobuf.Value getProtoValueInternal(int columnIndex) { + return (com.google.protobuf.Value) rowData.get(columnIndex); + } + + @Override + protected Value getValueInternal(int columnIndex) { + final List structFields = getType().getStructFields(); + final StructField structField = structFields.get(columnIndex); + final Type columnType = structField.getType(); + final boolean isNull = rowData.get(columnIndex) == null; + switch (columnType.getCode()) { + case BOOL: + return Value.bool(isNull ? null : getBooleanInternal(columnIndex)); + case INT64: + return Value.int64(isNull ? null : getLongInternal(columnIndex)); + case ENUM: + return Value.protoEnum(getLongInternal(columnIndex), columnType.getProtoTypeFqn()); + case NUMERIC: + return Value.numeric(isNull ? null : getBigDecimalInternal(columnIndex)); + case PG_NUMERIC: + return Value.pgNumeric(isNull ? null : getStringInternal(columnIndex)); + case FLOAT64: + return Value.float64(isNull ? null : getDoubleInternal(columnIndex)); + case STRING: + return Value.string(isNull ? null : getStringInternal(columnIndex)); + case JSON: + return Value.json(isNull ? null : getJsonInternal(columnIndex)); + case PG_JSONB: + return Value.pgJsonb(isNull ? null : getPgJsonbInternal(columnIndex)); + case BYTES: + return Value.internalBytes(isNull ? null : getLazyBytesInternal(columnIndex)); + case PROTO: + return Value.protoMessage(getBytesInternal(columnIndex), columnType.getProtoTypeFqn()); + case TIMESTAMP: + return Value.timestamp(isNull ? null : getTimestampInternal(columnIndex)); + case DATE: + return Value.date(isNull ? null : getDateInternal(columnIndex)); + case STRUCT: + return Value.struct(isNull ? null : getStructInternal(columnIndex)); + case UNRECOGNIZED: + return Value.unrecognized( + isNull ? NULL_VALUE : getProtoValueInternal(columnIndex), columnType); + case ARRAY: + final Type elementType = columnType.getArrayElementType(); + switch (elementType.getCode()) { + case BOOL: + return Value.boolArray(isNull ? null : getBooleanListInternal(columnIndex)); + case INT64: + return Value.int64Array(isNull ? null : getLongListInternal(columnIndex)); + case NUMERIC: + return Value.numericArray(isNull ? null : getBigDecimalListInternal(columnIndex)); + case PG_NUMERIC: + return Value.pgNumericArray(isNull ? null : getStringListInternal(columnIndex)); + case FLOAT64: + return Value.float64Array(isNull ? null : getDoubleListInternal(columnIndex)); + case STRING: + return Value.stringArray(isNull ? null : getStringListInternal(columnIndex)); + case JSON: + return Value.jsonArray(isNull ? null : getJsonListInternal(columnIndex)); + case PG_JSONB: + return Value.pgJsonbArray(isNull ? null : getPgJsonbListInternal(columnIndex)); + case BYTES: + return Value.bytesArray(isNull ? null : getBytesListInternal(columnIndex)); + case PROTO: + return Value.protoMessageArray( + isNull ? null : getBytesListInternal(columnIndex), elementType.getProtoTypeFqn()); + case ENUM: + return Value.protoEnumArray( + isNull ? null : getLongListInternal(columnIndex), elementType.getProtoTypeFqn()); + case TIMESTAMP: + return Value.timestampArray(isNull ? null : getTimestampListInternal(columnIndex)); + case DATE: + return Value.dateArray(isNull ? null : getDateListInternal(columnIndex)); + case STRUCT: + return Value.structArray( + elementType, isNull ? null : getStructListInternal(columnIndex)); + default: + throw new IllegalArgumentException( + "Invalid array value type " + this.type.getArrayElementType()); + } + default: + throw new IllegalArgumentException("Invalid value type " + this.type); + } + } + + @Override + protected Struct getStructInternal(int columnIndex) { + return (Struct) rowData.get(columnIndex); + } + + @Override + protected boolean[] getBooleanArrayInternal(int columnIndex) { + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + List values = (List) rowData.get(columnIndex); + boolean[] r = new boolean[values.size()]; + for (int i = 0; i < values.size(); ++i) { + if (values.get(i) == null) { + throw AbstractResultSet.throwNotNull(columnIndex); + } + r[i] = values.get(i); + } + return r; + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getBooleanListInternal(int columnIndex) { + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + protected long[] getLongArrayInternal(int columnIndex) { + return getLongListInternal(columnIndex).toPrimitiveArray(columnIndex); + } + + @Override + protected Int64Array getLongListInternal(int columnIndex) { + return (Int64Array) rowData.get(columnIndex); + } + + @Override + protected double[] getDoubleArrayInternal(int columnIndex) { + return getDoubleListInternal(columnIndex).toPrimitiveArray(columnIndex); + } + + @Override + protected Float64Array getDoubleListInternal(int columnIndex) { + return (Float64Array) rowData.get(columnIndex); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getBigDecimalListInternal(int columnIndex) { + return (List) rowData.get(columnIndex); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getStringListInternal(int columnIndex) { + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getJsonListInternal(int columnIndex) { + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getProtoMessageListInternal( + int columnIndex, T message) { + Preconditions.checkNotNull( + message, + "Proto message may not be null. Use MyProtoClass.getDefaultInstance() as a parameter value."); + + List bytesArray = (List) rowData.get(columnIndex); + + try { + List protoMessagesList = new ArrayList<>(bytesArray.size()); + for (LazyByteArray protoMessageBytes : bytesArray) { + if (protoMessageBytes == null) { + protoMessagesList.add(null); + } else { + protoMessagesList.add( + (T) + message + .toBuilder() + .mergeFrom( + Base64.getDecoder() + .wrap( + CharSource.wrap(protoMessageBytes.getBase64String()) + .asByteSource(StandardCharsets.UTF_8) + .openStream())) + .build()); + } + } + return protoMessagesList; + } catch (IOException ioException) { + throw SpannerExceptionFactory.asSpannerException(ioException); + } + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getProtoEnumListInternal( + int columnIndex, Function method) { + Preconditions.checkNotNull( + method, "Method may not be null. Use 'MyProtoEnum::forNumber' as a parameter value."); + + List enumIntArray = (List) rowData.get(columnIndex); + List protoEnumList = new ArrayList<>(enumIntArray.size()); + for (Long enumIntValue : enumIntArray) { + if (enumIntValue == null) { + protoEnumList.add(null); + } else { + protoEnumList.add((T) method.apply(enumIntValue.intValue())); + } + } + + return protoEnumList; + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getPgJsonbListInternal(int columnIndex) { + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getBytesListInternal(int columnIndex) { + return Lists.transform( + (List) rowData.get(columnIndex), l -> l == null ? null : l.getByteArray()); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getTimestampListInternal(int columnIndex) { + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY produces a List. + protected List getDateListInternal(int columnIndex) { + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } + + @Override + @SuppressWarnings("unchecked") // We know ARRAY> produces a List. + protected List getStructListInternal(int columnIndex) { + return Collections.unmodifiableList((List) rowData.get(columnIndex)); + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java new file mode 100644 index 00000000000..b7d9328577b --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java @@ -0,0 +1,211 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; +import static com.google.common.base.Preconditions.checkState; + +import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; +import com.google.common.collect.AbstractIterator; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.protobuf.Value.KindCase; +import com.google.spanner.v1.PartialResultSet; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.ResultSetStats; +import com.google.spanner.v1.TypeCode; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; + +/** Adapts a stream of {@code PartialResultSet} messages into a stream of {@code Value} messages. */ +class GrpcValueIterator extends AbstractIterator { + + private enum StreamValue { + METADATA, + RESULT, + } + + private final CloseableIterator stream; + private ResultSetMetadata metadata; + private Type type; + private PartialResultSet current; + private int pos; + private ResultSetStats statistics; + + GrpcValueIterator(CloseableIterator stream) { + this.stream = stream; + } + + @SuppressWarnings("unchecked") + @Override + protected Value computeNext() { + if (!ensureReady(StreamValue.RESULT)) { + endOfData(); + return null; + } + Value value = current.getValues(pos++); + KindCase kind = value.getKindCase(); + + if (!isMergeable(kind)) { + if (pos == current.getValuesCount() && current.getChunkedValue()) { + throw newSpannerException(ErrorCode.INTERNAL, "Unexpected chunked PartialResultSet."); + } else { + return value; + } + } + if (!current.getChunkedValue() || pos != current.getValuesCount()) { + return value; + } + + Object merged = + kind == KindCase.STRING_VALUE + ? value.getStringValue() + : new ArrayList<>(value.getListValue().getValuesList()); + while (current.getChunkedValue() && pos == current.getValuesCount()) { + if (!ensureReady(StreamValue.RESULT)) { + throw newSpannerException( + ErrorCode.INTERNAL, "Stream closed in the middle of chunked value"); + } + Value newValue = current.getValues(pos++); + if (newValue.getKindCase() != kind) { + throw newSpannerException( + ErrorCode.INTERNAL, + "Unexpected type in middle of chunked value. Expected: " + + kind + + " but got: " + + newValue.getKindCase()); + } + if (kind == KindCase.STRING_VALUE) { + merged = merged + newValue.getStringValue(); + } else { + concatLists((List) merged, newValue.getListValue().getValuesList()); + } + } + if (kind == KindCase.STRING_VALUE) { + return Value.newBuilder().setStringValue((String) merged).build(); + } else { + return Value.newBuilder() + .setListValue(ListValue.newBuilder().addAllValues((List) merged)) + .build(); + } + } + + ResultSetMetadata getMetadata() throws SpannerException { + if (metadata == null) { + if (!ensureReady(StreamValue.METADATA)) { + throw newSpannerException(ErrorCode.INTERNAL, "Stream closed without sending metadata"); + } + } + return metadata; + } + + /** + * Get the query statistics. Query statistics are delivered with the last PartialResultSet in the + * stream. Any attempt to call this method before the caller has finished consuming the results + * will return null. + */ + @Nullable + ResultSetStats getStats() { + return statistics; + } + + Type type() { + checkState(type != null, "metadata has not been received"); + return type; + } + + private boolean ensureReady(StreamValue requiredValue) throws SpannerException { + while (current == null || pos >= current.getValuesCount()) { + if (!stream.hasNext()) { + return false; + } + current = stream.next(); + pos = 0; + if (type == null) { + // This is the first message on the stream. + if (!current.hasMetadata() || !current.getMetadata().hasRowType()) { + throw newSpannerException(ErrorCode.INTERNAL, "Missing type metadata in first message"); + } + metadata = current.getMetadata(); + com.google.spanner.v1.Type typeProto = + com.google.spanner.v1.Type.newBuilder() + .setCode(TypeCode.STRUCT) + .setStructType(metadata.getRowType()) + .build(); + try { + type = Type.fromProto(typeProto); + } catch (IllegalArgumentException e) { + throw newSpannerException( + ErrorCode.INTERNAL, "Invalid type metadata: " + e.getMessage(), e); + } + } + if (current.hasStats()) { + statistics = current.getStats(); + } + if (requiredValue == StreamValue.METADATA) { + return true; + } + } + return true; + } + + void close(@Nullable String message) { + stream.close(message); + } + + boolean isWithBeginTransaction() { + return stream.isWithBeginTransaction(); + } + + /** @param a is a mutable list and b will be concatenated into a. */ + private void concatLists(List a, List b) { + if (a.size() == 0 || b.size() == 0) { + a.addAll(b); + return; + } else { + Value last = a.get(a.size() - 1); + Value first = b.get(0); + KindCase lastKind = last.getKindCase(); + KindCase firstKind = first.getKindCase(); + if (isMergeable(lastKind) && lastKind == firstKind) { + Value merged; + if (lastKind == KindCase.STRING_VALUE) { + String lastStr = last.getStringValue(); + String firstStr = first.getStringValue(); + merged = Value.newBuilder().setStringValue(lastStr + firstStr).build(); + } else { // List + List mergedList = new ArrayList<>(); + mergedList.addAll(last.getListValue().getValuesList()); + concatLists(mergedList, first.getListValue().getValuesList()); + merged = + Value.newBuilder() + .setListValue(ListValue.newBuilder().addAllValues(mergedList)) + .build(); + } + a.set(a.size() - 1, merged); + a.addAll(b.subList(1, b.size())); + } else { + a.addAll(b); + } + } + } + + private boolean isMergeable(KindCase kind) { + return kind == KindCase.STRING_VALUE || kind == KindCase.LIST_VALUE; + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java new file mode 100644 index 00000000000..da065d7ce7a --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java @@ -0,0 +1,289 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerExceptionForCancellation; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.api.client.util.BackOff; +import com.google.api.client.util.ExponentialBackOff; +import com.google.api.gax.grpc.GrpcStatusCode; +import com.google.api.gax.retrying.RetrySettings; +import com.google.api.gax.rpc.StatusCode.Code; +import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; +import com.google.cloud.spanner.v1.stub.SpannerStubSettings; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import com.google.spanner.v1.PartialResultSet; +import io.grpc.Context; +import io.opencensus.common.Scope; +import io.opencensus.trace.AttributeValue; +import io.opencensus.trace.Span; +import io.opencensus.trace.Tracer; +import io.opencensus.trace.Tracing; +import java.io.IOException; +import java.util.LinkedList; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * Wraps an iterator over partial result sets, supporting resuming RPCs on error. This class keeps + * track of the most recent resume token seen, and will buffer partial result set chunks that do not + * have a resume token until one is seen or buffer space is exceeded, which reduces the chance of + * yielding data to the caller that cannot be resumed. + */ +@VisibleForTesting +abstract class ResumableStreamIterator extends AbstractIterator + implements CloseableIterator { + private static final Tracer tracer = Tracing.getTracer(); + + private static final RetrySettings DEFAULT_STREAMING_RETRY_SETTINGS = + SpannerStubSettings.newBuilder().executeStreamingSqlSettings().getRetrySettings(); + private final RetrySettings streamingRetrySettings; + private final Set retryableCodes; + private static final Logger logger = Logger.getLogger(ResumableStreamIterator.class.getName()); + private final BackOff backOff; + private final LinkedList buffer = new LinkedList<>(); + private final int maxBufferSize; + private final Span span; + private CloseableIterator stream; + private ByteString resumeToken; + private boolean finished; + /** + * Indicates whether it is currently safe to retry RPCs. This will be {@code false} if we have + * reached the maximum buffer size without seeing a restart token; in this case, we will drain the + * buffer and remain in this state until we see a new restart token. + */ + private boolean safeToRetry = true; + + protected ResumableStreamIterator( + int maxBufferSize, + String streamName, + Span parent, + RetrySettings streamingRetrySettings, + Set retryableCodes) { + checkArgument(maxBufferSize >= 0); + this.maxBufferSize = maxBufferSize; + this.span = tracer.spanBuilderWithExplicitParent(streamName, parent).startSpan(); + this.streamingRetrySettings = Preconditions.checkNotNull(streamingRetrySettings); + this.retryableCodes = Preconditions.checkNotNull(retryableCodes); + this.backOff = newBackOff(); + } + + private ExponentialBackOff newBackOff() { + if (Objects.equals(streamingRetrySettings, DEFAULT_STREAMING_RETRY_SETTINGS)) { + return new ExponentialBackOff.Builder() + .setMultiplier(streamingRetrySettings.getRetryDelayMultiplier()) + .setInitialIntervalMillis( + Math.max(10, (int) streamingRetrySettings.getInitialRetryDelay().toMillis())) + .setMaxIntervalMillis( + Math.max(1000, (int) streamingRetrySettings.getMaxRetryDelay().toMillis())) + .setMaxElapsedTimeMillis(Integer.MAX_VALUE) // Prevent Backoff.STOP from getting returned. + .build(); + } + return new ExponentialBackOff.Builder() + .setMultiplier(streamingRetrySettings.getRetryDelayMultiplier()) + // All of these values must be > 0. + .setInitialIntervalMillis( + Math.max( + 1, + (int) + Math.min( + streamingRetrySettings.getInitialRetryDelay().toMillis(), + Integer.MAX_VALUE))) + .setMaxIntervalMillis( + Math.max( + 1, + (int) + Math.min( + streamingRetrySettings.getMaxRetryDelay().toMillis(), Integer.MAX_VALUE))) + .setMaxElapsedTimeMillis( + Math.max( + 1, + (int) + Math.min( + streamingRetrySettings.getTotalTimeout().toMillis(), Integer.MAX_VALUE))) + .build(); + } + + private void backoffSleep(Context context, BackOff backoff) throws SpannerException { + backoffSleep(context, nextBackOffMillis(backoff)); + } + + private static long nextBackOffMillis(BackOff backoff) throws SpannerException { + try { + return backoff.nextBackOffMillis(); + } catch (IOException e) { + throw newSpannerException(ErrorCode.INTERNAL, e.getMessage(), e); + } + } + + private void backoffSleep(Context context, long backoffMillis) throws SpannerException { + tracer + .getCurrentSpan() + .addAnnotation( + "Backing off", + ImmutableMap.of("Delay", AttributeValue.longAttributeValue(backoffMillis))); + final CountDownLatch latch = new CountDownLatch(1); + final Context.CancellationListener listener = + ignored -> { + // Wakeup on cancellation / DEADLINE_EXCEEDED. + latch.countDown(); + }; + + context.addListener(listener, DirectExecutor.INSTANCE); + try { + if (backoffMillis == BackOff.STOP) { + // Highly unlikely but we handle it just in case. + backoffMillis = streamingRetrySettings.getMaxRetryDelay().toMillis(); + } + if (latch.await(backoffMillis, TimeUnit.MILLISECONDS)) { + // Woken by context cancellation. + throw newSpannerExceptionForCancellation(context, null); + } + } catch (InterruptedException interruptExcept) { + throw newSpannerExceptionForCancellation(context, interruptExcept); + } finally { + context.removeListener(listener); + } + } + + private enum DirectExecutor implements Executor { + INSTANCE; + + @Override + public void execute(Runnable command) { + command.run(); + } + } + + abstract CloseableIterator startStream(@Nullable ByteString resumeToken); + + @Override + public void close(@Nullable String message) { + if (stream != null) { + stream.close(message); + span.end(TraceUtil.END_SPAN_OPTIONS); + stream = null; + } + } + + @Override + public boolean isWithBeginTransaction() { + return stream != null && stream.isWithBeginTransaction(); + } + + @Override + protected PartialResultSet computeNext() { + Context context = Context.current(); + while (true) { + // Eagerly start stream before consuming any buffered items. + if (stream == null) { + span.addAnnotation( + "Starting/Resuming stream", + ImmutableMap.of( + "ResumeToken", + AttributeValue.stringAttributeValue( + resumeToken == null ? "null" : resumeToken.toStringUtf8()))); + try (Scope s = tracer.withSpan(span)) { + // When start a new stream set the Span as current to make the gRPC Span a child of + // this Span. + stream = checkNotNull(startStream(resumeToken)); + } + } + // Buffer contains items up to a resume token or has reached capacity: flush. + if (!buffer.isEmpty() + && (finished || !safeToRetry || !buffer.getLast().getResumeToken().isEmpty())) { + return buffer.pop(); + } + try { + if (stream.hasNext()) { + PartialResultSet next = stream.next(); + boolean hasResumeToken = !next.getResumeToken().isEmpty(); + if (hasResumeToken) { + resumeToken = next.getResumeToken(); + safeToRetry = true; + } + // If the buffer is empty and this chunk has a resume token or we cannot resume safely + // anyway, we can yield it immediately rather than placing it in the buffer to be + // returned on the next iteration. + if ((hasResumeToken || !safeToRetry) && buffer.isEmpty()) { + return next; + } + buffer.add(next); + if (buffer.size() > maxBufferSize && buffer.getLast().getResumeToken().isEmpty()) { + // We need to flush without a restart token. Errors encountered until we see + // such a token will fail the read. + safeToRetry = false; + } + } else { + finished = true; + if (buffer.isEmpty()) { + endOfData(); + return null; + } + } + } catch (SpannerException spannerException) { + if (safeToRetry && isRetryable(spannerException)) { + span.addAnnotation( + "Stream broken. Safe to retry", TraceUtil.getExceptionAnnotations(spannerException)); + logger.log(Level.FINE, "Retryable exception, will sleep and retry", spannerException); + // Truncate any items in the buffer before the last retry token. + while (!buffer.isEmpty() && buffer.getLast().getResumeToken().isEmpty()) { + buffer.removeLast(); + } + assert buffer.isEmpty() || buffer.getLast().getResumeToken().equals(resumeToken); + stream = null; + try (Scope s = tracer.withSpan(span)) { + long delay = spannerException.getRetryDelayInMillis(); + if (delay != -1) { + backoffSleep(context, delay); + } else { + backoffSleep(context, backOff); + } + } + + continue; + } + span.addAnnotation("Stream broken. Not safe to retry"); + TraceUtil.setWithFailure(span, spannerException); + throw spannerException; + } catch (RuntimeException e) { + span.addAnnotation("Stream broken. Not safe to retry"); + TraceUtil.setWithFailure(span, e); + throw e; + } + } + } + + boolean isRetryable(SpannerException spannerException) { + return spannerException.isRetryable() + || retryableCodes.contains( + GrpcStatusCode.of(spannerException.getErrorCode().getGrpcStatusCode()).getCode()); + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java index 24c7e3fd95d..c40e98fe9ae 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java @@ -46,7 +46,6 @@ import com.google.cloud.ByteArray; import com.google.cloud.NoCredentials; import com.google.cloud.Timestamp; -import com.google.cloud.spanner.AbstractResultSet.GrpcStreamIterator; import com.google.cloud.spanner.AsyncResultSet.CallbackResponse; import com.google.cloud.spanner.AsyncTransactionManager.TransactionContextFuture; import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java index 914ce391f4a..cb73618d998 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java @@ -56,13 +56,13 @@ import org.junit.runners.JUnit4; import org.threeten.bp.Duration; -/** Unit tests for {@link com.google.cloud.spanner.AbstractResultSet.GrpcResultSet}. */ +/** Unit tests for {@link GrpcResultSet}. */ @RunWith(JUnit4.class) public class GrpcResultSetTest { - private AbstractResultSet.GrpcResultSet resultSet; + private GrpcResultSet resultSet; private SpannerRpc.ResultStreamConsumer consumer; - private AbstractResultSet.GrpcStreamIterator stream; + private GrpcStreamIterator stream; private final Duration streamWaitTimeout = Duration.ofNanos(1L); private static class NoOpListener implements AbstractResultSet.Listener { @@ -81,7 +81,7 @@ public void onDone(boolean withBeginTransaction) {} @Before public void setUp() { - stream = new AbstractResultSet.GrpcStreamIterator(10); + stream = new GrpcStreamIterator(10); stream.setCall( new SpannerRpc.StreamingCall() { @Override @@ -97,11 +97,11 @@ public void request(int numMessages) {} }, false); consumer = stream.consumer(); - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + resultSet = new GrpcResultSet(stream, new NoOpListener()); } - public AbstractResultSet.GrpcResultSet resultSetWithMode(QueryMode queryMode) { - return new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + public GrpcResultSet resultSetWithMode(QueryMode queryMode) { + return new GrpcResultSet(stream, new NoOpListener()); } @Test @@ -609,7 +609,7 @@ public com.google.protobuf.Value apply(@Nullable Value input) { private void verifySerialization( Function protoFn, Value... values) { - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + resultSet = new GrpcResultSet(stream, new NoOpListener()); PartialResultSet.Builder builder = PartialResultSet.newBuilder(); List types = new ArrayList<>(); for (Value value : values) { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index 7bf9f51a4ea..f27aa405aaa 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -19,7 +19,6 @@ import com.google.api.gax.grpc.testing.MockGrpcService; import com.google.cloud.ByteArray; import com.google.cloud.Date; -import com.google.cloud.spanner.AbstractResultSet.GrpcStruct; import com.google.cloud.spanner.AbstractResultSet.LazyByteArray; import com.google.cloud.spanner.SessionPool.SessionPoolTransactionContext; import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java index af558d14dd4..a72c9872faf 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java @@ -104,9 +104,9 @@ protected List getChildren() { } private static class TestCaseRunner { - private AbstractResultSet.GrpcResultSet resultSet; + private GrpcResultSet resultSet; private SpannerRpc.ResultStreamConsumer consumer; - private AbstractResultSet.GrpcStreamIterator stream; + private GrpcStreamIterator stream; private JSONObject testCase; TestCaseRunner(JSONObject testCase) { @@ -114,7 +114,7 @@ private static class TestCaseRunner { } private void run() throws Exception { - stream = new AbstractResultSet.GrpcStreamIterator(10); + stream = new GrpcStreamIterator(10); stream.setCall( new SpannerRpc.StreamingCall() { @Override @@ -130,7 +130,7 @@ public void request(int numMessages) {} }, false); consumer = stream.consumer(); - resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener()); + resultSet = new GrpcResultSet(stream, new NoOpListener()); JSONArray chunks = testCase.getJSONArray("chunks"); JSONObject expectedResult = testCase.getJSONObject("result"); @@ -143,8 +143,7 @@ public void request(int numMessages) {} assertResultSet(resultSet, expectedResult.getJSONArray("value")); } - private void assertResultSet(AbstractResultSet.GrpcResultSet actual, JSONArray expected) - throws Exception { + private void assertResultSet(GrpcResultSet actual, JSONArray expected) throws Exception { int i = 0; while (actual.next()) { Struct actualRow = actual.getCurrentRowAsStruct(); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResultSetsHelper.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResultSetsHelper.java index 51cca1bc684..fc494c6f3ff 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResultSetsHelper.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResultSetsHelper.java @@ -17,7 +17,6 @@ package com.google.cloud.spanner; import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; -import com.google.cloud.spanner.AbstractResultSet.GrpcResultSet; import com.google.cloud.spanner.AbstractResultSet.Listener; import com.google.protobuf.ListValue; import com.google.spanner.v1.PartialResultSet; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java index 06c1725d76a..00916414a85 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java @@ -23,7 +23,6 @@ import static org.mockito.Mockito.verify; import com.google.api.client.util.BackOff; -import com.google.cloud.spanner.AbstractResultSet.ResumableStreamIterator; import com.google.cloud.spanner.v1.stub.SpannerStubSettings; import com.google.common.collect.AbstractIterator; import com.google.common.collect.Lists; @@ -53,7 +52,7 @@ import org.junit.runners.JUnit4; import org.mockito.Mockito; -/** Unit tests for {@link AbstractResultSet.ResumableStreamIterator}. */ +/** Unit tests for {@link ResumableStreamIterator}. */ @RunWith(JUnit4.class) public class ResumableStreamIteratorTest { interface Starter { @@ -128,7 +127,7 @@ public boolean isWithBeginTransaction() { } Starter starter = Mockito.mock(Starter.class); - AbstractResultSet.ResumableStreamIterator resumableStreamIterator; + ResumableStreamIterator resumableStreamIterator; @Before public void setUp() { @@ -137,7 +136,7 @@ public void setUp() { private void initWithLimit(int maxBufferSize) { resumableStreamIterator = - new AbstractResultSet.ResumableStreamIterator( + new ResumableStreamIterator( maxBufferSize, "", null, From 6fc57e47c86566f600c56e96d0f2e1721baa6af6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 9 Feb 2024 16:01:09 +0100 Subject: [PATCH 2/3] feat: support lazy decoding of query results (#2847) * feat: support lazy decoding of query results Adds an option for lazy decoding of query results. Currently, all values in a query result row are decoded from protobuf values to plain Java objects at the moment that the result set is advanced to the next row. This means that all values are decoded, regardless whether the application actually fetches these or not. Lazy decoding also enables the possibility for (internal) consumers of a result set to access the protobuf value before it is converted to a plain Java object. This for example allows ChecksumResultSet to calculate the checksum based on the protobuf value, instead of a Java object, which can be more efficient. * fix: add null check * perf: calculate checksum using protobuf values (#2848) * perf: calculate checksum using protobuf values * chore: cleanup * test: remove unrelated test * fix: undo change to public API * chore: cleanup| --- .../cloud/spanner/AbstractReadContext.java | 15 +- .../google/cloud/spanner/BatchClientImpl.java | 2 + .../com/google/cloud/spanner/DecodeMode.java | 35 ++ .../cloud/spanner/ForwardingResultSet.java | 18 +- .../google/cloud/spanner/GrpcResultSet.java | 24 +- .../com/google/cloud/spanner/GrpcStruct.java | 132 +++++- .../com/google/cloud/spanner/Options.java | 32 ++ .../cloud/spanner/ProtobufResultSet.java | 42 ++ .../com/google/cloud/spanner/ResultSets.java | 15 +- .../com/google/cloud/spanner/SessionImpl.java | 4 + .../com/google/cloud/spanner/SpannerImpl.java | 4 + .../google/cloud/spanner/SpannerOptions.java | 20 + .../java/com/google/cloud/spanner/Value.java | 4 + .../spanner/connection/ChecksumResultSet.java | 375 +++++++-------- .../connection/DirectExecuteResultSet.java | 18 +- .../connection/ReadWriteTransaction.java | 7 +- .../ReplaceableForwardingResultSet.java | 22 +- .../cloud/spanner/connection/SpannerPool.java | 4 + .../spanner/connection/DecodeModeTest.java | 128 ++++++ .../DirectExecuteResultSetTest.java | 3 + .../connection/RandomResultSetGenerator.java | 6 +- .../connection/ReadWriteTransactionTest.java | 435 +++++++++--------- .../ReplaceableForwardingResultSetTest.java | 11 +- 23 files changed, 910 insertions(+), 446 deletions(-) create mode 100644 google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java create mode 100644 google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java create mode 100644 google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 2505b8b4985..ae18a1e4f5f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -70,6 +70,7 @@ abstract static class Builder, T extends AbstractReadCon private TraceWrapper tracer; private int defaultPrefetchChunks = SpannerOptions.Builder.DEFAULT_PREFETCH_CHUNKS; private QueryOptions defaultQueryOptions = SpannerOptions.Builder.DEFAULT_QUERY_OPTIONS; + private DecodeMode defaultDecodeMode = SpannerOptions.Builder.DEFAULT_DECODE_MODE; private DirectedReadOptions defaultDirectedReadOption; private ExecutorProvider executorProvider; private Clock clock = new Clock(); @@ -111,6 +112,11 @@ B setDefaultQueryOptions(QueryOptions defaultQueryOptions) { return self(); } + B setDefaultDecodeMode(DecodeMode defaultDecodeMode) { + this.defaultDecodeMode = defaultDecodeMode; + return self(); + } + B setExecutorProvider(ExecutorProvider executorProvider) { this.executorProvider = executorProvider; return self(); @@ -411,8 +417,8 @@ void initTransaction() { TraceWrapper tracer; private final int defaultPrefetchChunks; private final QueryOptions defaultQueryOptions; - private final DirectedReadOptions defaultDirectedReadOptions; + private final DecodeMode defaultDecodeMode; private final Clock clock; @GuardedBy("lock") @@ -438,6 +444,7 @@ void initTransaction() { this.defaultPrefetchChunks = builder.defaultPrefetchChunks; this.defaultQueryOptions = builder.defaultQueryOptions; this.defaultDirectedReadOptions = builder.defaultDirectedReadOption; + this.defaultDecodeMode = builder.defaultDecodeMode; this.span = builder.span; this.executorProvider = builder.executorProvider; this.clock = builder.clock; @@ -727,7 +734,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken return stream; } }; - return new GrpcResultSet(stream, this); + return new GrpcResultSet( + stream, this, options.hasDecodeMode() ? options.decodeMode() : defaultDecodeMode); } /** @@ -871,7 +879,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken return stream; } }; - return new GrpcResultSet(stream, this); + return new GrpcResultSet( + stream, this, readOptions.hasDecodeMode() ? readOptions.decodeMode() : defaultDecodeMode); } private Struct consumeSingleRow(ResultSet resultSet) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java index 664cde1edbb..22fb9f710c1 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java @@ -60,6 +60,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(TimestampBound bound) { sessionClient.getSpanner().getDefaultQueryOptions(sessionClient.getDatabaseId())) .setExecutorProvider(sessionClient.getSpanner().getAsyncExecutorProvider()) .setDefaultPrefetchChunks(sessionClient.getSpanner().getDefaultPrefetchChunks()) + .setDefaultDecodeMode(sessionClient.getSpanner().getDefaultDecodeMode()) .setDefaultDirectedReadOptions( sessionClient.getSpanner().getOptions().getDirectedReadOptions()) .setSpan(sessionClient.getSpanner().getTracer().getCurrentSpan()) @@ -81,6 +82,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(BatchTransactionId batc sessionClient.getSpanner().getDefaultQueryOptions(sessionClient.getDatabaseId())) .setExecutorProvider(sessionClient.getSpanner().getAsyncExecutorProvider()) .setDefaultPrefetchChunks(sessionClient.getSpanner().getDefaultPrefetchChunks()) + .setDefaultDecodeMode(sessionClient.getSpanner().getDefaultDecodeMode()) .setDefaultDirectedReadOptions( sessionClient.getSpanner().getOptions().getDirectedReadOptions()) .setSpan(sessionClient.getSpanner().getTracer().getCurrentSpan()) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java new file mode 100644 index 00000000000..c1bea9a3ce1 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DecodeMode.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +/** Specifies how and when to decode a value from protobuf to a plain Java object. */ +public enum DecodeMode { + /** + * Decodes all columns of a row directly when a {@link ResultSet} is advanced to the next row with + * {@link ResultSet#next()} + */ + DIRECT, + /** + * Decodes all columns of a row the first time a {@link ResultSet} value is retrieved from the + * row. + */ + LAZY_PER_ROW, + /** + * Decodes a columns of a row the first time the value of that column is retrieved from the row. + */ + LAZY_PER_COL, +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java index c29282879ed..18ecbeceb0f 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java @@ -23,7 +23,7 @@ import com.google.spanner.v1.ResultSetStats; /** Forwarding implementation of ResultSet that forwards all calls to a delegate. */ -public class ForwardingResultSet extends ForwardingStructReader implements ResultSet { +public class ForwardingResultSet extends ForwardingStructReader implements ProtobufResultSet { private Supplier delegate; @@ -55,6 +55,22 @@ public boolean next() throws SpannerException { return delegate.get().next(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + ResultSet resultSetDelegate = delegate.get(); + return (resultSetDelegate instanceof ProtobufResultSet) + && ((ProtobufResultSet) resultSetDelegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + ResultSet resultSetDelegate = delegate.get(); + Preconditions.checkState( + resultSetDelegate instanceof ProtobufResultSet, + "The result set does not support protobuf values"); + return ((ProtobufResultSet) resultSetDelegate).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { return delegate.get().getCurrentRowAsStruct(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java index b5d64bce3bb..37a4792ad87 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.Value; import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.ResultSetStats; @@ -28,9 +29,10 @@ import javax.annotation.Nullable; @VisibleForTesting -class GrpcResultSet extends AbstractResultSet> { +class GrpcResultSet extends AbstractResultSet> implements ProtobufResultSet { private final GrpcValueIterator iterator; private final Listener listener; + private final DecodeMode decodeMode; private ResultSetMetadata metadata; private GrpcStruct currRow; private SpannerException error; @@ -38,8 +40,26 @@ class GrpcResultSet extends AbstractResultSet> { private boolean closed; GrpcResultSet(CloseableIterator iterator, Listener listener) { + this(iterator, listener, DecodeMode.DIRECT); + } + + GrpcResultSet( + CloseableIterator iterator, Listener listener, DecodeMode decodeMode) { this.iterator = new GrpcValueIterator(iterator); this.listener = listener; + this.decodeMode = decodeMode; + } + + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed && currRow != null && currRow.canGetProtoValue(columnIndex); + } + + @Override + public Value getProtobufValue(int columnIndex) { + checkState(!closed, "ResultSet is closed"); + checkState(currRow != null, "next() call required"); + return currRow.getProtoValueInternal(columnIndex); } @Override @@ -65,7 +85,7 @@ public boolean next() throws SpannerException { throw SpannerExceptionFactory.newSpannerException( ErrorCode.FAILED_PRECONDITION, AbstractReadContext.NO_TRANSACTION_RETURNED_MSG); } - currRow = new GrpcStruct(iterator.type(), new ArrayList<>()); + currRow = new GrpcStruct(iterator.type(), new ArrayList<>(), decodeMode); } boolean hasNext = currRow.consumeRow(iterator); if (!hasNext) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java index 0d8a5545a90..152c82e9ca9 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStruct.java @@ -27,6 +27,7 @@ import com.google.cloud.spanner.AbstractResultSet.Float64Array; import com.google.cloud.spanner.AbstractResultSet.Int64Array; import com.google.cloud.spanner.AbstractResultSet.LazyByteArray; +import com.google.cloud.spanner.Type.Code; import com.google.cloud.spanner.Type.StructField; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; @@ -42,6 +43,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Base64; +import java.util.BitSet; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -54,6 +56,9 @@ class GrpcStruct extends Struct implements Serializable { private final Type type; private final List rowData; + private final DecodeMode decodeMode; + private final BitSet colDecoded; + private boolean rowDecoded; /** * Builds an immutable version of this struct using {@link Struct#newBuilder()} which is used as a @@ -181,9 +186,28 @@ private Object writeReplace() { return builder.build(); } - GrpcStruct(Type type, List rowData) { + GrpcStruct(Type type, List rowData, DecodeMode decodeMode) { + this( + type, + rowData, + decodeMode, + /* rowDecoded = */ false, + /* colDecoded = */ decodeMode == DecodeMode.LAZY_PER_COL + ? new BitSet(type.getStructFields().size()) + : null); + } + + private GrpcStruct( + Type type, + List rowData, + DecodeMode decodeMode, + boolean rowDecoded, + BitSet colDecoded) { this.type = type; this.rowData = rowData; + this.decodeMode = decodeMode; + this.rowDecoded = rowDecoded; + this.colDecoded = colDecoded; } @Override @@ -193,6 +217,11 @@ public String toString() { boolean consumeRow(Iterator iterator) { rowData.clear(); + if (decodeMode == DecodeMode.LAZY_PER_ROW) { + rowDecoded = false; + } else if (decodeMode == DecodeMode.LAZY_PER_COL) { + colDecoded.clear(); + } if (!iterator.hasNext()) { return false; } @@ -203,7 +232,11 @@ boolean consumeRow(Iterator iterator) { "Invalid value stream: end of stream reached before row is complete"); } com.google.protobuf.Value value = iterator.next(); - rowData.add(decodeValue(fieldType.getType(), value)); + if (decodeMode == DecodeMode.DIRECT) { + rowData.add(decodeValue(fieldType.getType(), value)); + } else { + rowData.add(value); + } } return true; } @@ -266,7 +299,7 @@ private static Struct decodeStructValue(Type structType, ListValue structValue) for (int i = 0; i < fieldTypes.size(); ++i) { fields.add(decodeValue(fieldTypes.get(i).getType(), fieldValues.get(i))); } - return new GrpcStruct(structType, fields); + return new GrpcStruct(structType, fields, DecodeMode.DIRECT); } static Object decodeArrayValue(Type elementType, ListValue listValue) { @@ -310,7 +343,12 @@ private static void checkType( } Struct immutableCopy() { - return new GrpcStruct(type, new ArrayList<>(rowData)); + return new GrpcStruct( + type, + new ArrayList<>(rowData), + this.decodeMode, + this.rowDecoded, + this.colDecoded == null ? null : (BitSet) this.colDecoded.clone()); } @Override @@ -320,6 +358,10 @@ public Type getType() { @Override public boolean isNull(int columnIndex) { + if ((decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) + || (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex))) { + return ((com.google.protobuf.Value) rowData.get(columnIndex)).hasNullValue(); + } return rowData.get(columnIndex) == null; } @@ -355,64 +397,123 @@ protected T getProtoEnumInternal( @Override protected boolean getBooleanInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Boolean) rowData.get(columnIndex); } @Override protected long getLongInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Long) rowData.get(columnIndex); } @Override protected double getDoubleInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Double) rowData.get(columnIndex); } @Override protected BigDecimal getBigDecimalInternal(int columnIndex) { + ensureDecoded(columnIndex); return (BigDecimal) rowData.get(columnIndex); } @Override protected String getStringInternal(int columnIndex) { + ensureDecoded(columnIndex); return (String) rowData.get(columnIndex); } @Override protected String getJsonInternal(int columnIndex) { + ensureDecoded(columnIndex); return (String) rowData.get(columnIndex); } @Override protected String getPgJsonbInternal(int columnIndex) { + ensureDecoded(columnIndex); return (String) rowData.get(columnIndex); } @Override protected ByteArray getBytesInternal(int columnIndex) { + ensureDecoded(columnIndex); return getLazyBytesInternal(columnIndex).getByteArray(); } LazyByteArray getLazyBytesInternal(int columnIndex) { + ensureDecoded(columnIndex); return (LazyByteArray) rowData.get(columnIndex); } @Override protected Timestamp getTimestampInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Timestamp) rowData.get(columnIndex); } @Override protected Date getDateInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Date) rowData.get(columnIndex); } + private boolean isUnrecognizedType(int columnIndex) { + return type.getStructFields().get(columnIndex).getType().getCode() == Code.UNRECOGNIZED; + } + + boolean canGetProtoValue(int columnIndex) { + return isUnrecognizedType(columnIndex) + || (decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) + || (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex)); + } + protected com.google.protobuf.Value getProtoValueInternal(int columnIndex) { + checkProtoValueSupported(columnIndex); return (com.google.protobuf.Value) rowData.get(columnIndex); } + private void checkProtoValueSupported(int columnIndex) { + // Unrecognized types are returned as protobuf values. + if (isUnrecognizedType(columnIndex)) { + return; + } + Preconditions.checkState( + decodeMode != DecodeMode.DIRECT, + "Getting proto value is not supported when DecodeMode#DIRECT is used."); + Preconditions.checkState( + !(decodeMode == DecodeMode.LAZY_PER_ROW && rowDecoded), + "Getting proto value after the row has been decoded is not supported."); + Preconditions.checkState( + !(decodeMode == DecodeMode.LAZY_PER_COL && colDecoded.get(columnIndex)), + "Getting proto value after the column has been decoded is not supported."); + } + + private void ensureDecoded(int columnIndex) { + if (decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) { + for (int i = 0; i < rowData.size(); i++) { + rowData.set( + i, + decodeValue( + type.getStructFields().get(i).getType(), + (com.google.protobuf.Value) rowData.get(i))); + } + rowDecoded = true; + } else if (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex)) { + rowData.set( + columnIndex, + decodeValue( + type.getStructFields().get(columnIndex).getType(), + (com.google.protobuf.Value) rowData.get(columnIndex))); + colDecoded.set(columnIndex); + } + } + @Override protected Value getValueInternal(int columnIndex) { + ensureDecoded(columnIndex); final List structFields = getType().getStructFields(); final StructField structField = structFields.get(columnIndex); final Type columnType = structField.getType(); @@ -423,7 +524,8 @@ protected Value getValueInternal(int columnIndex) { case INT64: return Value.int64(isNull ? null : getLongInternal(columnIndex)); case ENUM: - return Value.protoEnum(getLongInternal(columnIndex), columnType.getProtoTypeFqn()); + return Value.protoEnum( + isNull ? null : getLongInternal(columnIndex), columnType.getProtoTypeFqn()); case NUMERIC: return Value.numeric(isNull ? null : getBigDecimalInternal(columnIndex)); case PG_NUMERIC: @@ -439,7 +541,8 @@ protected Value getValueInternal(int columnIndex) { case BYTES: return Value.internalBytes(isNull ? null : getLazyBytesInternal(columnIndex)); case PROTO: - return Value.protoMessage(getBytesInternal(columnIndex), columnType.getProtoTypeFqn()); + return Value.protoMessage( + isNull ? null : getBytesInternal(columnIndex), columnType.getProtoTypeFqn()); case TIMESTAMP: return Value.timestamp(isNull ? null : getTimestampInternal(columnIndex)); case DATE: @@ -494,11 +597,13 @@ protected Value getValueInternal(int columnIndex) { @Override protected Struct getStructInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Struct) rowData.get(columnIndex); } @Override protected boolean[] getBooleanArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); @SuppressWarnings("unchecked") // We know ARRAY produces a List. List values = (List) rowData.get(columnIndex); boolean[] r = new boolean[values.size()]; @@ -514,44 +619,52 @@ protected boolean[] getBooleanArrayInternal(int columnIndex) { @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getBooleanListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override protected long[] getLongArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); return getLongListInternal(columnIndex).toPrimitiveArray(columnIndex); } @Override protected Int64Array getLongListInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Int64Array) rowData.get(columnIndex); } @Override protected double[] getDoubleArrayInternal(int columnIndex) { + ensureDecoded(columnIndex); return getDoubleListInternal(columnIndex).toPrimitiveArray(columnIndex); } @Override protected Float64Array getDoubleListInternal(int columnIndex) { + ensureDecoded(columnIndex); return (Float64Array) rowData.get(columnIndex); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getBigDecimalListInternal(int columnIndex) { + ensureDecoded(columnIndex); return (List) rowData.get(columnIndex); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getStringListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getJsonListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @@ -562,6 +675,7 @@ protected List getProtoMessageListInternal( Preconditions.checkNotNull( message, "Proto message may not be null. Use MyProtoClass.getDefaultInstance() as a parameter value."); + ensureDecoded(columnIndex); List bytesArray = (List) rowData.get(columnIndex); @@ -596,6 +710,7 @@ protected List getProtoEnumListInternal( int columnIndex, Function method) { Preconditions.checkNotNull( method, "Method may not be null. Use 'MyProtoEnum::forNumber' as a parameter value."); + ensureDecoded(columnIndex); List enumIntArray = (List) rowData.get(columnIndex); List protoEnumList = new ArrayList<>(enumIntArray.size()); @@ -613,12 +728,14 @@ protected List getProtoEnumListInternal( @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getPgJsonbListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getBytesListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Lists.transform( (List) rowData.get(columnIndex), l -> l == null ? null : l.getByteArray()); } @@ -626,18 +743,21 @@ protected List getBytesListInternal(int columnIndex) { @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getTimestampListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY produces a List. protected List getDateListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } @Override @SuppressWarnings("unchecked") // We know ARRAY> produces a List. protected List getStructListInternal(int columnIndex) { + ensureDecoded(columnIndex); return Collections.unmodifiableList((List) rowData.get(columnIndex)); } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java index 57feabbfcca..76d0f24225a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java @@ -243,6 +243,10 @@ public static ReadAndQueryOption directedRead(DirectedReadOptions directedReadOp return new DirectedReadOption(directedReadOptions); } + public static ReadAndQueryOption decodeMode(DecodeMode decodeMode) { + return new DecodeOption(decodeMode); + } + /** Option to request {@link CommitStats} for read/write transactions. */ static final class CommitStatsOption extends InternalOption implements TransactionOption { @Override @@ -374,6 +378,19 @@ void appendToOptions(Options options) { } } + static final class DecodeOption extends InternalOption implements ReadAndQueryOption { + private final DecodeMode decodeMode; + + DecodeOption(DecodeMode decodeMode) { + this.decodeMode = Preconditions.checkNotNull(decodeMode, "DecodeMode cannot be null"); + } + + @Override + void appendToOptions(Options options) { + options.decodeMode = decodeMode; + } + } + private boolean withCommitStats; private Duration maxCommitDelay; @@ -391,6 +408,7 @@ void appendToOptions(Options options) { private Boolean withOptimisticLock; private Boolean dataBoostEnabled; private DirectedReadOptions directedReadOptions; + private DecodeMode decodeMode; // Construction is via factory methods below. private Options() {} @@ -507,6 +525,14 @@ DirectedReadOptions directedReadOptions() { return directedReadOptions; } + boolean hasDecodeMode() { + return decodeMode != null; + } + + DecodeMode decodeMode() { + return decodeMode; + } + @Override public String toString() { StringBuilder b = new StringBuilder(); @@ -552,6 +578,9 @@ public String toString() { if (directedReadOptions != null) { b.append("directedReadOptions: ").append(directedReadOptions).append(' '); } + if (decodeMode != null) { + b.append("decodeMode: ").append(decodeMode).append(' '); + } return b.toString(); } @@ -640,6 +669,9 @@ public int hashCode() { if (directedReadOptions != null) { result = 31 * result + directedReadOptions.hashCode(); } + if (decodeMode != null) { + result = 31 * result + decodeMode.hashCode(); + } return result; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java new file mode 100644 index 00000000000..bbd8c41291f --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ProtobufResultSet.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.InternalApi; +import com.google.protobuf.Value; + +/** Interface for {@link ResultSet}s that can return a protobuf value. */ +@InternalApi +public interface ProtobufResultSet extends ResultSet { + + /** Returns true if the protobuf value for the given column is still available. */ + boolean canGetProtobufValue(int columnIndex); + + /** + * Returns the column value as a protobuf value. + * + *

This is an internal method not intended for external usage. + * + *

This method may only be called before the column value has been decoded to a plain Java + * object. This means that the {@link DecodeMode} that is used for the {@link ResultSet} must be + * one of {@link DecodeMode#LAZY_PER_ROW} and {@link DecodeMode#LAZY_PER_COL}, and that the + * corresponding {@link ResultSet#getValue(int)}, {@link ResultSet#getBoolean(int)}, ... method + * may not yet have been called for the column. + */ + @InternalApi + Value getProtobufValue(int columnIndex); +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java index d55d4091b9f..a6cc7c729e5 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSets.java @@ -109,7 +109,7 @@ public ResultSet get() { } } - private static class PrePopulatedResultSet implements ResultSet { + private static class PrePopulatedResultSet implements ProtobufResultSet { private final List rows; private final Type type; private int index = -1; @@ -137,6 +137,19 @@ public boolean next() throws SpannerException { return ++index < rows.size(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed && index >= 0 && index < rows.size(); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + Preconditions.checkState(!closed, "ResultSet is closed"); + Preconditions.checkState(index >= 0, "Must be preceded by a next() call"); + Preconditions.checkElementIndex(index, rows.size(), "All rows have been yielded"); + return getValue(columnIndex).toProto(); + } + @Override public Struct getCurrentRowAsStruct() { Preconditions.checkState(!closed, "ResultSet is closed"); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index 29928f61cec..81b00001105 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -263,6 +263,7 @@ public ReadContext singleUse(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -284,6 +285,7 @@ public ReadOnlyTransaction singleUseReadOnlyTransaction(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -305,6 +307,7 @@ public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setDefaultDirectedReadOptions(spanner.getOptions().getDirectedReadOptions()) .setSpan(currentSpan) .setTracer(tracer) @@ -423,6 +426,7 @@ TransactionContextImpl newTransaction(Options options) { .setRpc(spanner.getRpc()) .setDefaultQueryOptions(spanner.getDefaultQueryOptions(databaseId)) .setDefaultPrefetchChunks(spanner.getDefaultPrefetchChunks()) + .setDefaultDecodeMode(spanner.getDefaultDecodeMode()) .setSpan(currentSpan) .setTracer(tracer) .setExecutorProvider(spanner.getAsyncExecutorProvider()) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 326a51d803e..8fe06f76cc8 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -151,6 +151,10 @@ int getDefaultPrefetchChunks() { return getOptions().getPrefetchChunks(); } + DecodeMode getDefaultDecodeMode() { + return getOptions().getDecodeMode(); + } + /** Returns the default query options that should be used for the specified database. */ QueryOptions getDefaultQueryOptions(DatabaseId databaseId) { return getOptions().getDefaultQueryOptions(databaseId); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java index 9c6044aa938..a16be179ce3 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerOptions.java @@ -113,6 +113,7 @@ public class SpannerOptions extends ServiceOptions { private final GrpcInterceptorProvider interceptorProvider; private final SessionPoolOptions sessionPoolOptions; private final int prefetchChunks; + private final DecodeMode decodeMode; private final int numChannels; private final String transportChannelExecutorThreadNameFormat; private final String databaseRole; @@ -616,6 +617,7 @@ protected SpannerOptions(Builder builder) { ? builder.sessionPoolOptions : SessionPoolOptions.newBuilder().build(); prefetchChunks = builder.prefetchChunks; + decodeMode = builder.decodeMode; databaseRole = builder.databaseRole; sessionLabels = builder.sessionLabels; try { @@ -704,6 +706,9 @@ public static class Builder extends ServiceOptions.Builder { static final int DEFAULT_PREFETCH_CHUNKS = 4; static final QueryOptions DEFAULT_QUERY_OPTIONS = QueryOptions.getDefaultInstance(); + // TODO: Set the default to DecodeMode.DIRECT before merging to keep the current default. + // It is currently set to LAZY_PER_COL so it is used in all tests. + static final DecodeMode DEFAULT_DECODE_MODE = DecodeMode.LAZY_PER_COL; static final RetrySettings DEFAULT_ADMIN_REQUESTS_LIMIT_EXCEEDED_RETRY_SETTINGS = RetrySettings.newBuilder() .setInitialRetryDelay(Duration.ofSeconds(5L)) @@ -730,6 +735,7 @@ public static class Builder private String transportChannelExecutorThreadNameFormat = "Cloud-Spanner-TransportChannel-%d"; private int prefetchChunks = DEFAULT_PREFETCH_CHUNKS; + private DecodeMode decodeMode = DEFAULT_DECODE_MODE; private SessionPoolOptions sessionPoolOptions; private String databaseRole; private ImmutableMap sessionLabels; @@ -797,6 +803,7 @@ protected Builder() { options.transportChannelExecutorThreadNameFormat; this.sessionPoolOptions = options.sessionPoolOptions; this.prefetchChunks = options.prefetchChunks; + this.decodeMode = options.decodeMode; this.databaseRole = options.databaseRole; this.sessionLabels = options.sessionLabels; this.spannerStubSettingsBuilder = options.spannerStubSettings.toBuilder(); @@ -1224,6 +1231,15 @@ public Builder setPrefetchChunks(int prefetchChunks) { return this; } + /** + * Specifies how values that are returned from a query should be decoded and converted from + * protobuf values into plain Java objects. + */ + public Builder setDecodeMode(DecodeMode decodeMode) { + this.decodeMode = decodeMode; + return this; + } + @Override public Builder setHost(String host) { super.setHost(host); @@ -1568,6 +1584,10 @@ public int getPrefetchChunks() { return prefetchChunks; } + public DecodeMode getDecodeMode() { + return decodeMode; + } + public static GrpcTransportOptions getDefaultGrpcTransportOptions() { return GrpcTransportOptions.newBuilder().build(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java index a845eb118bf..3f0155e4a5e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Value.java @@ -1518,6 +1518,10 @@ void valueToString(StringBuilder b) { @Override boolean valueEquals(Value v) { + // NaN == NaN always returns false, so we need a custom check. + if (Double.isNaN(this.value)) { + return Double.isNaN(((Float64Impl) v).value); + } return ((Float64Impl) v).value == value; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java index dc373cf03bd..c642d7e505a 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ChecksumResultSet.java @@ -16,28 +16,28 @@ package com.google.cloud.spanner.connection; -import com.google.cloud.ByteArray; -import com.google.cloud.Date; -import com.google.cloud.Timestamp; import com.google.cloud.spanner.AbortedException; +import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.Options.QueryOption; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; -import com.google.cloud.spanner.Struct; +import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Type.Code; +import com.google.cloud.spanner.Type.StructField; import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement; import com.google.cloud.spanner.connection.ReadWriteTransaction.RetriableStatement; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import com.google.common.hash.Funnel; import com.google.common.hash.HashCode; -import com.google.common.hash.HashFunction; -import com.google.common.hash.Hasher; -import com.google.common.hash.Hashing; -import com.google.common.hash.PrimitiveSink; -import java.math.BigDecimal; -import java.util.Objects; +import com.google.protobuf.Value; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.util.Arrays; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicLong; @@ -71,11 +71,11 @@ class ChecksumResultSet extends ReplaceableForwardingResultSet implements Retria private final ParsedStatement statement; private final AnalyzeMode analyzeMode; private final QueryOption[] options; - private final ChecksumResultSet.ChecksumCalculator checksumCalculator = new ChecksumCalculator(); + private final ChecksumCalculator checksumCalculator = new ChecksumCalculator(); ChecksumResultSet( ReadWriteTransaction transaction, - ResultSet delegate, + ProtobufResultSet delegate, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { @@ -91,6 +91,13 @@ class ChecksumResultSet extends ReplaceableForwardingResultSet implements Retria this.options = options; } + @Override + public Value getProtobufValue(int columnIndex) { + // We can safely cast to ProtobufResultSet here, as the constructor of this class only accepts + // instances of ProtobufResultSet. + return ((ProtobufResultSet) getDelegate()).getProtobufValue(columnIndex); + } + /** Simple {@link Callable} for calling {@link ResultSet#next()} */ private final class NextCallable implements Callable { @Override @@ -102,7 +109,7 @@ public Boolean call() { boolean res = ChecksumResultSet.super.next(); // Only update the checksum if there was another row to be consumed. if (res) { - checksumCalculator.calculateNextChecksum(getCurrentRowAsStruct()); + checksumCalculator.calculateNextChecksum(ChecksumResultSet.this); } numberOfNextCalls.incrementAndGet(); return res; @@ -118,8 +125,9 @@ public boolean next() { } @VisibleForTesting - HashCode getChecksum() { - // HashCode is immutable and can be safely returned. + byte[] getChecksum() { + // Getting the checksum from the checksumCalculator will create a clone of the current digest + // and return the checksum from the clone, so it is safe to return this value. return checksumCalculator.getChecksum(); } @@ -132,8 +140,8 @@ HashCode getChecksum() { @Override public void retry(AbortedException aborted) throws AbortedException { // Execute the same query and consume the result set to the same point as the original. - ChecksumResultSet.ChecksumCalculator newChecksumCalculator = new ChecksumCalculator(); - ResultSet resultSet = null; + ChecksumCalculator newChecksumCalculator = new ChecksumCalculator(); + ProtobufResultSet resultSet = null; long counter = 0L; try { transaction @@ -150,7 +158,7 @@ public void retry(AbortedException aborted) throws AbortedException { statement, StatementExecutionStep.RETRY_NEXT_ON_RESULT_SET, transaction); next = resultSet.next(); if (next) { - newChecksumCalculator.calculateNextChecksum(resultSet.getCurrentRowAsStruct()); + newChecksumCalculator.calculateNextChecksum(resultSet); } counter++; } @@ -168,9 +176,9 @@ public void retry(AbortedException aborted) throws AbortedException { throw e; } // Check that we have the same number of rows and the same checksum. - HashCode newChecksum = newChecksumCalculator.getChecksum(); - HashCode currentChecksum = checksumCalculator.getChecksum(); - if (counter == numberOfNextCalls.get() && Objects.equals(newChecksum, currentChecksum)) { + byte[] newChecksum = newChecksumCalculator.getChecksum(); + byte[] currentChecksum = checksumCalculator.getChecksum(); + if (counter == numberOfNextCalls.get() && Arrays.equals(newChecksum, currentChecksum)) { // Checksum is ok, we only need to replace the delegate result set if it's still open. if (isClosed()) { resultSet.close(); @@ -184,222 +192,165 @@ public void retry(AbortedException aborted) throws AbortedException { } } - /** Calculates and keeps the current checksum of a {@link ChecksumResultSet} */ + /** + * Calculates a running checksum for all the data that has been seen sofar in this result set. The + * calculation is performed on the protobuf values that were returned by Cloud Spanner, which + * means that no decoding of the results is needed (or allowed!) before calculating the checksum. + * This is more efficient, both in terms of CPU usage and memory consumption, especially if the + * consumer of the result set does not read all values, or is only reading the underlying protobuf + * values. + */ private static final class ChecksumCalculator { - private static final HashFunction SHA256_FUNCTION = Hashing.sha256(); - private HashCode currentChecksum; + // Use a buffer of max 1Mb to hash string data. This means that strings of up to 1Mb in size + // will be hashed in one go, while strings larger than 1Mb will be chunked into pieces of at + // most 1Mb and then fed into the digest. The digest internally creates a copy of the string + // that is being hashed, so chunking large strings prevents them from being loaded into memory + // twice. + private static final int MAX_BUFFER_SIZE = 1 << 20; - private void calculateNextChecksum(Struct row) { - Hasher hasher = SHA256_FUNCTION.newHasher(); - if (currentChecksum != null) { - hasher.putBytes(currentChecksum.asBytes()); + private boolean firstRow = true; + private final MessageDigest digest; + private ByteBuffer buffer; + private ByteBuffer float64Buffer; + + ChecksumCalculator() { + try { + // This is safe, as all Java implementations are required to have MD5 implemented. + // See https://docs.oracle.com/javase/8/docs/api/java/security/MessageDigest.html + // MD5 requires less CPU power than SHA-256, and still offers a low enough collision + // probability for the use case at hand here. + digest = MessageDigest.getInstance("MD5"); + } catch (Throwable t) { + throw SpannerExceptionFactory.asSpannerException(t); } - hasher.putObject(row, StructFunnel.INSTANCE); - currentChecksum = hasher.hash(); } - private HashCode getChecksum() { - return currentChecksum; + private byte[] getChecksum() { + try { + // This is safe, as the MD5 MessageDigest is known to be cloneable. + MessageDigest clone = (MessageDigest) digest.clone(); + return clone.digest(); + } catch (CloneNotSupportedException e) { + throw SpannerExceptionFactory.asSpannerException(e); + } } - } - /** - * A {@link Funnel} implementation for calculating a {@link HashCode} for each row in a {@link - * ResultSet}. - */ - private enum StructFunnel implements Funnel { - INSTANCE; - private static final String NULL = "null"; - - @Override - public void funnel(Struct row, PrimitiveSink into) { - for (int i = 0; i < row.getColumnCount(); i++) { - if (row.isNull(i)) { - funnelValue(Code.STRING, null, into); + private void calculateNextChecksum(ProtobufResultSet resultSet) { + if (firstRow) { + for (StructField field : resultSet.getType().getStructFields()) { + digest.update(field.getType().toString().getBytes(StandardCharsets.UTF_8)); + } + } + for (int col = 0; col < resultSet.getColumnCount(); col++) { + Type type = resultSet.getColumnType(col); + if (resultSet.canGetProtobufValue(col)) { + Value value = resultSet.getProtobufValue(col); + digest.update((byte) value.getKindCase().getNumber()); + pushValue(type, value); } else { - Code type = row.getColumnType(i).getCode(); - switch (type) { - case ARRAY: - funnelArray(row.getColumnType(i).getArrayElementType().getCode(), row, i, into); - break; - case BOOL: - funnelValue(type, row.getBoolean(i), into); - break; - case BYTES: - case PROTO: - funnelValue(type, row.getBytes(i), into); - break; - case DATE: - funnelValue(type, row.getDate(i), into); - break; - case FLOAT64: - funnelValue(type, row.getDouble(i), into); - break; - case NUMERIC: - funnelValue(type, row.getBigDecimal(i), into); - break; - case PG_NUMERIC: - funnelValue(type, row.getString(i), into); - break; - case INT64: - case ENUM: - funnelValue(type, row.getLong(i), into); - break; - case STRING: - funnelValue(type, row.getString(i), into); - break; - case JSON: - funnelValue(type, row.getJson(i), into); - break; - case PG_JSONB: - funnelValue(type, row.getPgJsonb(i), into); - break; - case TIMESTAMP: - funnelValue(type, row.getTimestamp(i), into); - break; - - case STRUCT: - default: - throw new IllegalArgumentException("unsupported row type"); - } + // This will normally not happen, unless the user explicitly sets the decoding mode to + // DIRECT for a query in a read/write transaction. The default decoding mode in the + // Connection API is set to LAZY_PER_COL. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "Failed to get the underlying protobuf value for the column " + + resultSet.getMetadata().getRowType().getFields(col).getName() + + ". " + + "Executing queries with DecodeMode#DIRECT is not supported in read/write transactions."); } } + firstRow = false; } - private void funnelArray( - Code arrayElementType, Struct row, int columnIndex, PrimitiveSink into) { - funnelValue(Code.STRING, "BeginArray", into); - switch (arrayElementType) { - case BOOL: - into.putInt(row.getBooleanList(columnIndex).size()); - for (Boolean value : row.getBooleanList(columnIndex)) { - funnelValue(Code.BOOL, value, into); - } + private void pushValue(Type type, Value value) { + // Protobuf Value has a very limited set of possible types of values. All Cloud Spanner types + // are mapped to one of the protobuf values listed here, meaning that no changes are needed to + // this calculation when a new type is added to Cloud Spanner. + switch (value.getKindCase()) { + case NULL_VALUE: + // nothing needed, writing the KindCase is enough. break; - case BYTES: - case PROTO: - into.putInt(row.getBytesList(columnIndex).size()); - for (ByteArray value : row.getBytesList(columnIndex)) { - funnelValue(Code.BYTES, value, into); - } + case BOOL_VALUE: + digest.update(value.getBoolValue() ? (byte) 1 : 0); break; - case DATE: - into.putInt(row.getDateList(columnIndex).size()); - for (Date value : row.getDateList(columnIndex)) { - funnelValue(Code.DATE, value, into); - } + case STRING_VALUE: + putString(value.getStringValue()); break; - case FLOAT64: - into.putInt(row.getDoubleList(columnIndex).size()); - for (Double value : row.getDoubleList(columnIndex)) { - funnelValue(Code.FLOAT64, value, into); + case NUMBER_VALUE: + if (float64Buffer == null) { + // Create an 8-byte buffer that can be re-used for all float64 values in this result + // set. + float64Buffer = ByteBuffer.allocate(Double.BYTES); + } else { + float64Buffer.clear(); } + float64Buffer.putDouble(value.getNumberValue()); + float64Buffer.flip(); + digest.update(float64Buffer); break; - case NUMERIC: - into.putInt(row.getBigDecimalList(columnIndex).size()); - for (BigDecimal value : row.getBigDecimalList(columnIndex)) { - funnelValue(Code.NUMERIC, value, into); + case LIST_VALUE: + if (type.getCode() == Code.ARRAY) { + for (Value item : value.getListValue().getValuesList()) { + digest.update((byte) item.getKindCase().getNumber()); + pushValue(type.getArrayElementType(), item); + } + } else { + // This should not be possible. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "List values that are not an ARRAY are not supported"); } break; - case PG_NUMERIC: - into.putInt(row.getStringList(columnIndex).size()); - for (String value : row.getStringList(columnIndex)) { - funnelValue(Code.STRING, value, into); + case STRUCT_VALUE: + if (type.getCode() == Code.STRUCT) { + for (int col = 0; col < type.getStructFields().size(); col++) { + String name = type.getStructFields().get(col).getName(); + putString(name); + Value item = value.getStructValue().getFieldsMap().get(name); + digest.update((byte) item.getKindCase().getNumber()); + pushValue(type.getStructFields().get(col).getType(), item); + } + } else { + // This should not be possible. + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.FAILED_PRECONDITION, + "Struct values without a struct type are not supported"); } break; - case INT64: - case ENUM: - into.putInt(row.getLongList(columnIndex).size()); - for (Long value : row.getLongList(columnIndex)) { - funnelValue(Code.INT64, value, into); - } - break; - case STRING: - into.putInt(row.getStringList(columnIndex).size()); - for (String value : row.getStringList(columnIndex)) { - funnelValue(Code.STRING, value, into); - } - break; - case JSON: - into.putInt(row.getJsonList(columnIndex).size()); - for (String value : row.getJsonList(columnIndex)) { - funnelValue(Code.JSON, value, into); - } - break; - case PG_JSONB: - into.putInt(row.getPgJsonbList(columnIndex).size()); - for (String value : row.getPgJsonbList(columnIndex)) { - funnelValue(Code.PG_JSONB, value, into); - } - break; - case TIMESTAMP: - into.putInt(row.getTimestampList(columnIndex).size()); - for (Timestamp value : row.getTimestampList(columnIndex)) { - funnelValue(Code.TIMESTAMP, value, into); - } - break; - - case ARRAY: - case STRUCT: default: - throw new IllegalArgumentException("unsupported array element type"); + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.UNIMPLEMENTED, "Unsupported protobuf value: " + value.getKindCase()); } - funnelValue(Code.STRING, "EndArray", into); } - private void funnelValue(Code type, T value, PrimitiveSink into) { - // Include the type name in case the type of a column has changed. - into.putUnencodedChars(type.name()); - if (value == null) { - if (type == Code.BYTES || type == Code.STRING) { - // Put length -1 to distinguish from the string value 'null'. - into.putInt(-1); - } - into.putUnencodedChars(NULL); + /** Hashes a string value in blocks of max MAX_BUFFER_SIZE. */ + private void putString(String stringValue) { + int length = stringValue.length(); + if (buffer == null || (buffer.capacity() < MAX_BUFFER_SIZE && buffer.capacity() < length)) { + // Create a ByteBuffer with a maximum buffer size. + // This buffer is re-used for all string values in the result set. + buffer = ByteBuffer.allocate(Math.min(MAX_BUFFER_SIZE, length)); } else { - switch (type) { - case BOOL: - into.putBoolean((Boolean) value); - break; - case BYTES: - case PROTO: - ByteArray byteArray = (ByteArray) value; - into.putInt(byteArray.length()); - into.putBytes(byteArray.toByteArray()); - break; - case DATE: - Date date = (Date) value; - into.putInt(date.getYear()).putInt(date.getMonth()).putInt(date.getDayOfMonth()); - break; - case FLOAT64: - into.putDouble((Double) value); - break; - case NUMERIC: - String stringRepresentation = value.toString(); - into.putInt(stringRepresentation.length()); - into.putUnencodedChars(stringRepresentation); - break; - case INT64: - case ENUM: - into.putLong((Long) value); - break; - case PG_NUMERIC: - case STRING: - case JSON: - case PG_JSONB: - String stringValue = (String) value; - into.putInt(stringValue.length()); - into.putUnencodedChars(stringValue); - break; - case TIMESTAMP: - Timestamp timestamp = (Timestamp) value; - into.putLong(timestamp.getSeconds()).putInt(timestamp.getNanos()); - break; - case ARRAY: - case STRUCT: - default: - throw new IllegalArgumentException("invalid type for single value"); - } + buffer.clear(); + } + + // Wrap the string in a CharBuffer. This allows us to read from the string in sections without + // creating a new copy of (a part of) the string. E.g. using something like substring(..) + // would create a copy of that part of the string, using CharBuffer.wrap(..) does not. + CharBuffer source = CharBuffer.wrap(stringValue); + CharsetEncoder encoder = StandardCharsets.UTF_8.newEncoder(); + // source.hasRemaining() returns false when all the characters in the string have been + // processed. + while (source.hasRemaining()) { + // Encode the string into bytes and write them into the byte buffer. + // At most MAX_BUFFER_SIZE bytes will be written. + encoder.encode(source, buffer, false); + // Flip the buffer so we can read from the start. + buffer.flip(); + // Put the bytes from the buffer into the digest. + digest.update(buffer); + // Flip the buffer again, so we can repeat and write to the start of the buffer again. + buffer.flip(); } } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java index dff915e2cce..1b15ec50822 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/DirectExecuteResultSet.java @@ -19,6 +19,7 @@ import com.google.cloud.ByteArray; import com.google.cloud.Date; import com.google.cloud.Timestamp; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.Struct; @@ -40,7 +41,7 @@ * to the actual query execution. It also ensures that any invalid query will throw an exception at * execution instead of the first next() call by a client. */ -class DirectExecuteResultSet implements ResultSet { +class DirectExecuteResultSet implements ProtobufResultSet { private static final String MISSING_NEXT_CALL = "Must be preceded by a next() call"; private final ResultSet delegate; private boolean nextCalledByClient = false; @@ -79,6 +80,21 @@ public boolean next() throws SpannerException { return initialNextResult; } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return nextCalledByClient + && delegate instanceof ProtobufResultSet + && ((ProtobufResultSet) delegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + Preconditions.checkState(nextCalledByClient, MISSING_NEXT_CALL); + Preconditions.checkState( + delegate instanceof ProtobufResultSet, "The result set does not support protobuf values"); + return ((ProtobufResultSet) delegate).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { Preconditions.checkState(nextCalledByClient, MISSING_NEXT_CALL); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java index e1fb87e4ade..6c4290c3b18 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReadWriteTransaction.java @@ -39,6 +39,7 @@ import com.google.cloud.spanner.Options.QueryOption; import com.google.cloud.spanner.Options.TransactionOption; import com.google.cloud.spanner.Options.UpdateOption; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ReadContext; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; @@ -427,7 +428,7 @@ public ApiFuture executeQueryAsync( statement, StatementExecutionStep.EXECUTE_STATEMENT, ReadWriteTransaction.this); - ResultSet delegate = + DirectExecuteResultSet delegate = DirectExecuteResultSet.ofResultSet( internalExecuteQuery(statement, analyzeMode, options)); return createAndAddRetryResultSet( @@ -797,7 +798,7 @@ T runWithRetry(Callable callable) throws SpannerException { * returns a retryable {@link ResultSet}. */ private ResultSet createAndAddRetryResultSet( - ResultSet resultSet, + ProtobufResultSet resultSet, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { @@ -1091,7 +1092,7 @@ interface RetriableStatement { /** Creates a {@link ChecksumResultSet} for this {@link ReadWriteTransaction}. */ @VisibleForTesting ChecksumResultSet createChecksumResultSet( - ResultSet delegate, + ProtobufResultSet delegate, ParsedStatement statement, AnalyzeMode analyzeMode, QueryOption... options) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java index 7370551a46f..a8de14e5121 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSet.java @@ -20,6 +20,7 @@ import com.google.cloud.Date; import com.google.cloud.Timestamp; import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; @@ -42,7 +43,7 @@ * that is fetched using the new transaction. This is achieved by wrapping the returned result sets * in a {@link ReplaceableForwardingResultSet} that replaces its delegate after a transaction retry. */ -class ReplaceableForwardingResultSet implements ResultSet { +class ReplaceableForwardingResultSet implements ProtobufResultSet { private ResultSet delegate; private boolean closed; @@ -60,6 +61,10 @@ void replaceDelegate(ResultSet delegate) { this.delegate = delegate; } + protected ResultSet getDelegate() { + return this.delegate; + } + private void checkClosed() { if (closed) { throw SpannerExceptionFactory.newSpannerException( @@ -77,6 +82,21 @@ public boolean next() throws SpannerException { return delegate.next(); } + @Override + public boolean canGetProtobufValue(int columnIndex) { + return !closed + && delegate instanceof ProtobufResultSet + && ((ProtobufResultSet) delegate).canGetProtobufValue(columnIndex); + } + + @Override + public com.google.protobuf.Value getProtobufValue(int columnIndex) { + checkClosed(); + Preconditions.checkState( + delegate instanceof ProtobufResultSet, "The result set does not support protobuf values"); + return ((ProtobufResultSet) getDelegate()).getProtobufValue(columnIndex); + } + @Override public Struct getCurrentRowAsStruct() { checkClosed(); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java index 2a5a805c2c7..da8da78d92e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerPool.java @@ -17,6 +17,7 @@ package com.google.cloud.spanner.connection; import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.DecodeMode; import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SessionPoolOptions; import com.google.cloud.spanner.Spanner; @@ -342,6 +343,9 @@ Spanner createSpanner(SpannerPoolKey key, ConnectionOptions options) { .setClientLibToken(MoreObjects.firstNonNull(key.userAgent, CONNECTION_API_CLIENT_LIB_TOKEN)) .setHost(key.host) .setProjectId(key.projectId) + // Use lazy decoding, so we can use the protobuf values for calculating the checksum that is + // needed for read/write transactions. + .setDecodeMode(DecodeMode.LAZY_PER_COL) .setDatabaseRole(options.getDatabaseRole()) .setCredentials(options.getCredentials()); builder.setSessionPoolOption(key.sessionPoolOptions); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java new file mode 100644 index 00000000000..6a6125e1dda --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DecodeModeTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.connection; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.DecodeMode; +import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.MockSpannerServiceImpl; +import com.google.cloud.spanner.Options; +import com.google.cloud.spanner.ResultSet; +import com.google.cloud.spanner.SpannerException; +import com.google.cloud.spanner.Statement; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class DecodeModeTest extends AbstractMockServerTest { + + @After + public void clearRequests() { + mockSpanner.clearRequests(); + } + + @Test + public void testAllDecodeModes() { + int numRows = 10; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + String sql = "select * from random"; + Statement statement = Statement.of(sql); + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.query(statement, generator.generate())); + + try (Connection connection = createConnection()) { + for (boolean readonly : new boolean[] {true, false}) { + for (boolean autocommit : new boolean[] {true, false}) { + connection.setReadOnly(readonly); + connection.setAutocommit(autocommit); + + int receivedRows = 0; + // DecodeMode#DIRECT is not supported in read/write transactions, as the protobuf value is + // used for checksum calculation. + try (ResultSet direct = + connection.executeQuery( + statement, + !readonly && !autocommit + ? Options.decodeMode(DecodeMode.LAZY_PER_ROW) + : Options.decodeMode(DecodeMode.DIRECT)); + ResultSet lazyPerRow = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.LAZY_PER_ROW)); + ResultSet lazyPerCol = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.LAZY_PER_COL))) { + while (direct.next() && lazyPerRow.next() && lazyPerCol.next()) { + assertEquals(direct.getColumnCount(), lazyPerRow.getColumnCount()); + assertEquals(direct.getColumnCount(), lazyPerCol.getColumnCount()); + for (int col = 0; col < direct.getColumnCount(); col++) { + // Test getting the entire row as a struct both as the first thing we do, and as the + // last thing we do. This ensures that the method works as expected both when a row + // is lazily decoded by this method, and when it has already been decoded by another + // method. + if (col % 2 == 0) { + assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct()); + assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct()); + } + assertEquals(direct.isNull(col), lazyPerRow.isNull(col)); + assertEquals(direct.isNull(col), lazyPerCol.isNull(col)); + assertEquals(direct.getValue(col), lazyPerRow.getValue(col)); + assertEquals(direct.getValue(col), lazyPerCol.getValue(col)); + if (col % 2 == 1) { + assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct()); + assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct()); + } + } + receivedRows++; + } + assertEquals(numRows, receivedRows); + } + if (!autocommit) { + connection.commit(); + } + } + } + } + } + + @Test + public void testDecodeModeDirect_failsInReadWriteTransaction() { + int numRows = 1; + RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows); + String sql = "select * from random"; + Statement statement = Statement.of(sql); + mockSpanner.putStatementResult( + MockSpannerServiceImpl.StatementResult.query(statement, generator.generate())); + + try (Connection connection = createConnection()) { + connection.setAutocommit(false); + try (ResultSet resultSet = + connection.executeQuery(statement, Options.decodeMode(DecodeMode.DIRECT))) { + SpannerException exception = assertThrows(SpannerException.class, resultSet::next); + assertEquals(ErrorCode.FAILED_PRECONDITION, exception.getErrorCode()); + assertTrue( + exception.getMessage(), + exception + .getMessage() + .contains( + "Executing queries with DecodeMode#DIRECT is not supported in read/write transactions.")); + } + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java index 1e4f96d1568..b14f837ff7b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/DirectExecuteResultSetTest.java @@ -59,6 +59,7 @@ public void testMethodCallBeforeNext() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -79,6 +80,7 @@ public void testMethodCallAfterClose() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -101,6 +103,7 @@ public void testMethodCallAfterNextHasReturnedFalse() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java index 3091364e17a..2067d36b5ea 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/RandomResultSetGenerator.java @@ -34,6 +34,7 @@ import com.google.spanner.v1.TypeAnnotationCode; import com.google.spanner.v1.TypeCode; import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -239,7 +240,10 @@ private void setRandomValue(Value.Builder builder, Type type) { if (dialect == Dialect.POSTGRESQL && randomNaN()) { builder.setStringValue("NaN"); } else { - builder.setStringValue(BigDecimal.valueOf(random.nextDouble()).toString()); + builder.setStringValue( + BigDecimal.valueOf(random.nextDouble()) + .setScale(9, RoundingMode.HALF_UP) + .toString()); } break; case INT64: diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java index 0f083fd1e50..8e643cf6e24 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReadWriteTransactionTest.java @@ -37,6 +37,7 @@ import com.google.cloud.spanner.CommitResponse; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.ProtobufResultSet; import com.google.cloud.spanner.ReadContext.QueryAnalyzeMode; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.ResultSets; @@ -518,193 +519,197 @@ public void testChecksumResultSet() { .setGenre(Genre.FOLK) .build(); ProtocolMessageEnum protoEnumVal = Genre.ROCK; - ResultSet delegate1 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(BigDecimal.valueOf(550, 2)) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(BigDecimal.valueOf(750, 2)) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build())); + ProtobufResultSet delegate1 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(BigDecimal.valueOf(550, 2)) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(BigDecimal.valueOf(750, 2)) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build())); ChecksumResultSet rs1 = transaction.createChecksumResultSet(delegate1, parsedStatement, AnalyzeMode.NONE); - ResultSet delegate2 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build())); + ProtobufResultSet delegate2 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build())); ChecksumResultSet rs2 = transaction.createChecksumResultSet(delegate2, parsedStatement, AnalyzeMode.NONE); // rs1 and rs2 are equal, rs3 contains the same rows, but in a different order - ResultSet delegate3 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build(), - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build())); + ProtobufResultSet delegate3 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build(), + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build())); ChecksumResultSet rs3 = transaction.createChecksumResultSet(delegate3, parsedStatement, AnalyzeMode.NONE); // rs4 contains the same rows as rs1 and rs2, but also an additional row - ResultSet delegate4 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("NAME", Type.string()), - StructField.of("AMOUNT", Type.numeric()), - StructField.of("JSON", Type.json()), - StructField.of( - "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), - StructField.of( - "PROTOENUM", - Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("NAME") - .to("TEST 1") - .set("AMOUNT") - .to(new BigDecimal("5.50")) - .set("JSON") - .to(Value.json(simpleJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(protoEnumVal) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("NAME") - .to("TEST 2") - .set("AMOUNT") - .to(new BigDecimal("7.50")) - .set("JSON") - .to(Value.json(arrayJson)) - .set("PROTO") - .to(protoMessageVal) - .set("PROTOENUM") - .to(Genre.JAZZ) - .build(), - Struct.newBuilder() - .set("ID") - .to(3L) - .set("NAME") - .to("TEST 3") - .set("AMOUNT") - .to(new BigDecimal("9.99")) - .set("JSON") - .to(Value.json(emptyArrayJson)) - .set("PROTO") - .to(null, SingerInfo.getDescriptor()) - .set("PROTOENUM") - .to(Genre.POP) - .build())); + ProtobufResultSet delegate4 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("NAME", Type.string()), + StructField.of("AMOUNT", Type.numeric()), + StructField.of("JSON", Type.json()), + StructField.of( + "PROTO", Type.proto(protoMessageVal.getDescriptorForType().getFullName())), + StructField.of( + "PROTOENUM", + Type.protoEnum(protoEnumVal.getDescriptorForType().getFullName()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("NAME") + .to("TEST 1") + .set("AMOUNT") + .to(new BigDecimal("5.50")) + .set("JSON") + .to(Value.json(simpleJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(protoEnumVal) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("NAME") + .to("TEST 2") + .set("AMOUNT") + .to(new BigDecimal("7.50")) + .set("JSON") + .to(Value.json(arrayJson)) + .set("PROTO") + .to(protoMessageVal) + .set("PROTOENUM") + .to(Genre.JAZZ) + .build(), + Struct.newBuilder() + .set("ID") + .to(3L) + .set("NAME") + .to("TEST 3") + .set("AMOUNT") + .to(new BigDecimal("9.99")) + .set("JSON") + .to(Value.json(emptyArrayJson)) + .set("PROTO") + .to(null, SingerInfo.getDescriptor()) + .set("PROTOENUM") + .to(Genre.POP) + .build())); ChecksumResultSet rs4 = transaction.createChecksumResultSet(delegate4, parsedStatement, AnalyzeMode.NONE); @@ -736,44 +741,46 @@ public void testChecksumResultSetWithArray() { ParsedStatement parsedStatement = mock(ParsedStatement.class); Statement statement = Statement.of("SELECT * FROM FOO"); when(parsedStatement.getStatement()).thenReturn(statement); - ResultSet delegate1 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("PRICES", Type.array(Type.int64()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("PRICES") - .toInt64Array(new long[] {1L, 2L}) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("PRICES") - .toInt64Array(new long[] {3L, 4L}) - .build())); + ProtobufResultSet delegate1 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("PRICES", Type.array(Type.int64()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("PRICES") + .toInt64Array(new long[] {1L, 2L}) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("PRICES") + .toInt64Array(new long[] {3L, 4L}) + .build())); ChecksumResultSet rs1 = transaction.createChecksumResultSet(delegate1, parsedStatement, AnalyzeMode.NONE); - ResultSet delegate2 = - ResultSets.forRows( - Type.struct( - StructField.of("ID", Type.int64()), - StructField.of("PRICES", Type.array(Type.int64()))), - Arrays.asList( - Struct.newBuilder() - .set("ID") - .to(1L) - .set("PRICES") - .toInt64Array(new long[] {1L, 2L}) - .build(), - Struct.newBuilder() - .set("ID") - .to(2L) - .set("PRICES") - .toInt64Array(new long[] {3L, 5L}) - .build())); + ProtobufResultSet delegate2 = + (ProtobufResultSet) + ResultSets.forRows( + Type.struct( + StructField.of("ID", Type.int64()), + StructField.of("PRICES", Type.array(Type.int64()))), + Arrays.asList( + Struct.newBuilder() + .set("ID") + .to(1L) + .set("PRICES") + .toInt64Array(new long[] {1L, 2L}) + .build(), + Struct.newBuilder() + .set("ID") + .to(2L) + .set("PRICES") + .toInt64Array(new long[] {3L, 5L}) + .build())); ChecksumResultSet rs2 = transaction.createChecksumResultSet(delegate2, parsedStatement, AnalyzeMode.NONE); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java index bbb34675147..4617c47bc6b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ReplaceableForwardingResultSetTest.java @@ -104,7 +104,14 @@ public void testReplace() { public void testMethodCallBeforeNext() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = - Arrays.asList("getStats", "getMetadata", "next", "close", "equals", "hashCode"); + Arrays.asList( + "canGetProtobufValue", + "getStats", + "getMetadata", + "next", + "close", + "equals", + "hashCode"); ReplaceableForwardingResultSet subject = createSubject(); // Test that all methods throw an IllegalStateException except the excluded methods when called // before a call to ResultSet#next(). @@ -116,6 +123,7 @@ public void testMethodCallAfterClose() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", @@ -140,6 +148,7 @@ public void testMethodCallAfterNextHasReturnedFalse() throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { List excludedMethods = Arrays.asList( + "canGetProtobufValue", "getStats", "getMetadata", "next", From b1d977078cfdeab40c79660623df3014bd837bb6 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Fri, 9 Feb 2024 15:03:05 +0000 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20po?= =?UTF-8?q?st-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 288afeef263..56d668b0a65 100644 --- a/README.md +++ b/README.md @@ -57,13 +57,13 @@ implementation 'com.google.cloud:google-cloud-spanner' If you are using Gradle without BOM, add this to your dependencies: ```Groovy -implementation 'com.google.cloud:google-cloud-spanner:6.57.0' +implementation 'com.google.cloud:google-cloud-spanner:6.58.0' ``` If you are using SBT, add this to your dependencies: ```Scala -libraryDependencies += "com.google.cloud" % "google-cloud-spanner" % "6.57.0" +libraryDependencies += "com.google.cloud" % "google-cloud-spanner" % "6.58.0" ``` @@ -444,7 +444,7 @@ Java is a registered trademark of Oracle and/or its affiliates. [kokoro-badge-link-5]: http://storage.googleapis.com/cloud-devrel-public/java/badges/java-spanner/java11.html [stability-image]: https://img.shields.io/badge/stability-stable-green [maven-version-image]: https://img.shields.io/maven-central/v/com.google.cloud/google-cloud-spanner.svg -[maven-version-link]: https://central.sonatype.com/artifact/com.google.cloud/google-cloud-spanner/6.57.0 +[maven-version-link]: https://central.sonatype.com/artifact/com.google.cloud/google-cloud-spanner/6.58.0 [authentication]: https://github.com/googleapis/google-cloud-java#authentication [auth-scopes]: https://developers.google.com/identity/protocols/oauth2/scopes [predefined-iam-roles]: https://cloud.google.com/iam/docs/understanding-roles#predefined_roles