diff --git a/client/trino-client/src/main/java/io/trino/client/ResultRowsDecoder.java b/client/trino-client/src/main/java/io/trino/client/ResultRowsDecoder.java index 50b0529eefb4..2d8100840be8 100644 --- a/client/trino-client/src/main/java/io/trino/client/ResultRowsDecoder.java +++ b/client/trino-client/src/main/java/io/trino/client/ResultRowsDecoder.java @@ -13,7 +13,6 @@ */ package io.trino.client; -import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterators; import io.trino.client.spooling.DataAttributes; import io.trino.client.spooling.EncodedQueryData; @@ -30,11 +29,11 @@ import java.util.List; import java.util.Optional; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.client.FixJsonDataUtils.fixData; import static io.trino.client.ResultRows.NULL_ROWS; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; /** @@ -44,8 +43,7 @@ public class ResultRowsDecoder implements AutoCloseable { private final SegmentLoader loader; - private QueryDataDecoder.Factory decoderFactory; - private List columns = emptyList(); + private QueryDataDecoder decoder; public ResultRowsDecoder() { @@ -57,38 +55,35 @@ public ResultRowsDecoder(SegmentLoader loader) this.loader = requireNonNull(loader, "loader is null"); } - public ResultRowsDecoder withEncoding(String encoding) + private void setEncoding(List columns, String encoding) { - if (decoderFactory != null) { - if (!encoding.equals(decoderFactory.encoding())) { - throw new IllegalStateException("Already set encoding " + encoding + " is not equal to " + decoderFactory.encoding()); - } + if (decoder != null) { + checkState(decoder.encoding().equals(encoding), "Decoder is configured for encoding %s but got %s", decoder.encoding(), encoding); } else { - this.decoderFactory = QueryDataDecoders.get(encoding); + checkState(!columns.isEmpty(), "Columns must be set when decoding data"); + this.decoder = QueryDataDecoders.get(encoding) + // we don't use query-level attributes for now + .create(columns, DataAttributes.empty()); } - return this; } - public ResultRowsDecoder withColumns(List columns) + public ResultRows toRows(QueryResults results) { - if (this.columns.isEmpty()) { - this.columns = ImmutableList.copyOf(columns); - } - else if (!columns.equals(this.columns)) { - throw new IllegalStateException("Already set columns " + columns + " are not equal to " + this.columns); + if (results == null || results.getData() == null) { + return NULL_ROWS; } - return this; + return toRows(results.getColumns(), results.getData()); } - public ResultRows toRows(QueryData data) + public ResultRows toRows(List columns, QueryData data) { - if (data == null) { + if (data == null || data.isNull()) { return NULL_ROWS; // for backward compatibility instead of null } - verify(!columns.isEmpty(), "columns must be set"); + verify(columns != null && !columns.isEmpty(), "Columns must be set when decoding data"); if (data instanceof RawQueryData) { RawQueryData rawData = (RawQueryData) data; if (rawData.isNull()) { @@ -98,15 +93,12 @@ public ResultRows toRows(QueryData data) } if (data instanceof EncodedQueryData) { - verify(decoderFactory != null, "decoderFactory must be set"); - // we don't need query-level attributes for now - QueryDataDecoder decoder = decoderFactory.create(columns, DataAttributes.empty()); EncodedQueryData encodedData = (EncodedQueryData) data; - verify(decoder.encoding().equals(encodedData.getEncoding()), "encoding %s is not equal to %s", encodedData.getEncoding(), decoder.encoding()); + setEncoding(columns, encodedData.getEncoding()); List resultRows = encodedData.getSegments() .stream() - .map(segment -> segmentToRows(decoder, segment)) + .map(this::segmentToRows) .collect(toImmutableList()); return concat(resultRows); @@ -115,7 +107,7 @@ public ResultRows toRows(QueryData data) throw new UnsupportedOperationException("Unsupported data type: " + data.getClass().getName()); } - private ResultRows segmentToRows(QueryDataDecoder decoder, Segment segment) + private ResultRows segmentToRows(Segment segment) { if (segment instanceof InlineSegment) { InlineSegment inlineSegment = (InlineSegment) segment; @@ -131,6 +123,7 @@ private ResultRows segmentToRows(QueryDataDecoder decoder, Segment segment) SpooledSegment spooledSegment = (SpooledSegment) segment; try { + // The returned rows are lazy which means that decoder is responsible for closing input stream InputStream stream = loader.load(spooledSegment); return decoder.decode(stream, spooledSegment.getMetadata()); } @@ -144,8 +137,8 @@ private ResultRows segmentToRows(QueryDataDecoder decoder, Segment segment) public Optional getEncoding() { - return Optional.ofNullable(decoderFactory) - .map(QueryDataDecoder.Factory::encoding); + return Optional.ofNullable(decoder) + .map(QueryDataDecoder::encoding); } @Override diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java index 19a5de427544..c5e8972e5b66 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java @@ -258,7 +258,8 @@ public QueryStatusInfo currentStatusInfo() @Nonnull public ResultRows currentRows() { - return resultRowsDecoder.toRows(currentData()); + checkState(isRunning(), "current position is not valid (cursor past end)"); + return resultRowsDecoder.toRows(currentResults.get()); } @Override @@ -465,16 +466,6 @@ private void processResponse(Headers headers, QueryResults results) setCatalog.set(headers.get(TRINO_HEADERS.responseSetCatalog())); setSchema.set(headers.get(TRINO_HEADERS.responseSetSchema())); setPath.set(safeSplitToList(headers.get(TRINO_HEADERS.responseSetPath()))); - - String responseEncoding = headers.get(TRINO_HEADERS.responseQueryDataEncoding()); - if (responseEncoding != null) { - resultRowsDecoder.withEncoding(responseEncoding); - } - - if (results.getColumns() != null) { - resultRowsDecoder.withColumns(results.getColumns()); - } - String setAuthorizationUser = headers.get(TRINO_HEADERS.responseSetAuthorizationUser()); if (setAuthorizationUser != null) { this.setAuthorizationUser.set(setAuthorizationUser); diff --git a/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java b/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java index ba6ace7b31f1..8fb55c7bf563 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java +++ b/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java @@ -265,9 +265,9 @@ private void assertDataEquals(List columns, QueryData left, QueryData ri return; } - try (ResultRowsDecoder decoder = new ResultRowsDecoder().withColumns(columns)) { - assertThat(decoder.toRows(left)) - .containsAll(decoder.toRows(right)); + try (ResultRowsDecoder decoder = new ResultRowsDecoder()) { + assertThat(decoder.toRows(columns, left)) + .containsAll(decoder.toRows(columns, right)); } catch (Exception e) { fail(e); diff --git a/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryDataSerialization.java b/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryDataSerialization.java index bfd242159943..3df70f2113da 100644 --- a/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryDataSerialization.java +++ b/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryDataSerialization.java @@ -256,11 +256,8 @@ private static void assertEquals(QueryData left, QueryData right) private static Iterable> decodeData(QueryData data) { - try (ResultRowsDecoder decoder = new ResultRowsDecoder().withColumns(COLUMNS_LIST)) { - if (data instanceof EncodedQueryData encodedQueryData) { - return decoder.withEncoding(encodedQueryData.getEncoding()).toRows(data); - } - return decoder.toRows(data); + try (ResultRowsDecoder decoder = new ResultRowsDecoder()) { + return decoder.toRows(COLUMNS_LIST, data); } catch (Exception e) { return fail(e); diff --git a/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultsSerialization.java b/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultsSerialization.java index 7c59f593c407..fbfd2b11e9cd 100644 --- a/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultsSerialization.java +++ b/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultsSerialization.java @@ -44,9 +44,6 @@ public class TestQueryResultsSerialization { private static final List COLUMNS = ImmutableList.of(new Column("_col0", BIGINT, new ClientTypeSignature("bigint"))); - private static final ResultRowsDecoder DATA_DECODER = new ResultRowsDecoder() - .withColumns(COLUMNS); - // As close as possible to the server mapper (client mapper differs) private static final io.airlift.json.JsonCodec SERVER_CODEC = new JsonCodecFactory(new ObjectMapperProvider() .withModules(Set.of(new QueryDataJacksonModule()))) @@ -97,6 +94,7 @@ public void testNullDataSerialization() @Test public void testEmptyArraySerialization() + throws Exception { testRoundTrip(RawQueryData.of(ImmutableList.of()), "[]"); @@ -107,20 +105,22 @@ public void testEmptyArraySerialization() @Test public void testSerialization() + throws Exception { QueryData values = RawQueryData.of(ImmutableList.of(ImmutableList.of(1L), ImmutableList.of(5L))); testRoundTrip(values, "[[1],[5]]"); } private void testRoundTrip(QueryData results, String expectedDataRepresentation) + throws Exception { assertThat(serialize(results)) .isEqualToIgnoringWhitespace(queryResultsJson(expectedDataRepresentation)); String serialized = serialize(results); - try { - assertThat(DATA_DECODER.toRows(CLIENT_CODEC.fromJson(serialized).getData())) - .containsAll(DATA_DECODER.toRows(results)); + try (ResultRowsDecoder decoder = new ResultRowsDecoder()) { + assertThat(decoder.toRows(COLUMNS, CLIENT_CODEC.fromJson(serialized).getData())) + .containsAll(decoder.toRows(COLUMNS, results)); } catch (JsonProcessingException e) { throw new UncheckedIOException(e); diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java index b50157ff509a..bd18e0e0e1c8 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java @@ -177,8 +177,8 @@ public void testFirstResponseColumns() QueryResults results = data.orElseThrow(); - try (ResultRowsDecoder decoder = new ResultRowsDecoder().withColumns(results.getColumns())) { - assertThat(decoder.toRows(results.getData())).containsOnly(ImmutableList.of("memory"), ImmutableList.of("system")); + try (ResultRowsDecoder decoder = new ResultRowsDecoder()) { + assertThat(decoder.toRows(results)).containsOnly(ImmutableList.of("memory"), ImmutableList.of("system")); } } @@ -209,8 +209,8 @@ public void testQuery() .peek(result -> assertThat(result.getError()).isNull()) .peek(results -> { if (results.getData() != null) { - try (ResultRowsDecoder decoder = new ResultRowsDecoder().withColumns(results.getColumns())) { - data.addAll(decoder.toRows(results.getData())); + try (ResultRowsDecoder decoder = new ResultRowsDecoder()) { + data.addAll(decoder.toRows(results)); } catch (Exception e) { throw new RuntimeException(e);