Skip to content

Commit

Permalink
Simplify ResultRowsDecoder
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Oct 11, 2024
1 parent 7d896cc commit 3f978dc
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.client;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterators;
import io.trino.client.spooling.DataAttributes;
import io.trino.client.spooling.EncodedQueryData;
Expand All @@ -30,11 +29,11 @@
import java.util.List;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.client.FixJsonDataUtils.fixData;
import static io.trino.client.ResultRows.NULL_ROWS;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;

/**
Expand All @@ -44,8 +43,7 @@ public class ResultRowsDecoder
implements AutoCloseable
{
private final SegmentLoader loader;
private QueryDataDecoder.Factory decoderFactory;
private List<Column> columns = emptyList();
private QueryDataDecoder decoder;

public ResultRowsDecoder()
{
Expand All @@ -57,38 +55,35 @@ public ResultRowsDecoder(SegmentLoader loader)
this.loader = requireNonNull(loader, "loader is null");
}

public ResultRowsDecoder withEncoding(String encoding)
private void setEncoding(List<Column> columns, String encoding)
{
if (decoderFactory != null) {
if (!encoding.equals(decoderFactory.encoding())) {
throw new IllegalStateException("Already set encoding " + encoding + " is not equal to " + decoderFactory.encoding());
}
if (decoder != null) {
checkState(decoder.encoding().equals(encoding), "Decoder is configured for encoding %s but got %s", decoder.encoding(), encoding);
}
else {
this.decoderFactory = QueryDataDecoders.get(encoding);
checkState(!columns.isEmpty(), "Columns must be set when decoding data");
this.decoder = QueryDataDecoders.get(encoding)
// we don't use query-level attributes for now
.create(columns, DataAttributes.empty());
}
return this;
}

public ResultRowsDecoder withColumns(List<Column> columns)
public ResultRows toRows(QueryResults results)
{
if (this.columns.isEmpty()) {
this.columns = ImmutableList.copyOf(columns);
}
else if (!columns.equals(this.columns)) {
throw new IllegalStateException("Already set columns " + columns + " are not equal to " + this.columns);
if (results == null || results.getData() == null) {
return NULL_ROWS;
}

return this;
return toRows(results.getColumns(), results.getData());
}

public ResultRows toRows(QueryData data)
public ResultRows toRows(List<Column> columns, QueryData data)
{
if (data == null) {
if (data == null || data.isNull()) {
return NULL_ROWS; // for backward compatibility instead of null
}

verify(!columns.isEmpty(), "columns must be set");
verify(columns != null && !columns.isEmpty(), "Columns must be set when decoding data");
if (data instanceof RawQueryData) {
RawQueryData rawData = (RawQueryData) data;
if (rawData.isNull()) {
Expand All @@ -98,15 +93,12 @@ public ResultRows toRows(QueryData data)
}

if (data instanceof EncodedQueryData) {
verify(decoderFactory != null, "decoderFactory must be set");
// we don't need query-level attributes for now
QueryDataDecoder decoder = decoderFactory.create(columns, DataAttributes.empty());
EncodedQueryData encodedData = (EncodedQueryData) data;
verify(decoder.encoding().equals(encodedData.getEncoding()), "encoding %s is not equal to %s", encodedData.getEncoding(), decoder.encoding());
setEncoding(columns, encodedData.getEncoding());

List<ResultRows> resultRows = encodedData.getSegments()
.stream()
.map(segment -> segmentToRows(decoder, segment))
.map(this::segmentToRows)
.collect(toImmutableList());

return concat(resultRows);
Expand All @@ -115,7 +107,7 @@ public ResultRows toRows(QueryData data)
throw new UnsupportedOperationException("Unsupported data type: " + data.getClass().getName());
}

private ResultRows segmentToRows(QueryDataDecoder decoder, Segment segment)
private ResultRows segmentToRows(Segment segment)
{
if (segment instanceof InlineSegment) {
InlineSegment inlineSegment = (InlineSegment) segment;
Expand All @@ -131,6 +123,7 @@ private ResultRows segmentToRows(QueryDataDecoder decoder, Segment segment)
SpooledSegment spooledSegment = (SpooledSegment) segment;

try {
// The returned rows are lazy which means that decoder is responsible for closing input stream
InputStream stream = loader.load(spooledSegment);
return decoder.decode(stream, spooledSegment.getMetadata());
}
Expand All @@ -144,8 +137,8 @@ private ResultRows segmentToRows(QueryDataDecoder decoder, Segment segment)

public Optional<String> getEncoding()
{
return Optional.ofNullable(decoderFactory)
.map(QueryDataDecoder.Factory::encoding);
return Optional.ofNullable(decoder)
.map(QueryDataDecoder::encoding);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ public QueryStatusInfo currentStatusInfo()
@Nonnull
public ResultRows currentRows()
{
return resultRowsDecoder.toRows(currentData());
checkState(isRunning(), "current position is not valid (cursor past end)");
return resultRowsDecoder.toRows(currentResults.get());
}

@Override
Expand Down Expand Up @@ -465,16 +466,6 @@ private void processResponse(Headers headers, QueryResults results)
setCatalog.set(headers.get(TRINO_HEADERS.responseSetCatalog()));
setSchema.set(headers.get(TRINO_HEADERS.responseSetSchema()));
setPath.set(safeSplitToList(headers.get(TRINO_HEADERS.responseSetPath())));

String responseEncoding = headers.get(TRINO_HEADERS.responseQueryDataEncoding());
if (responseEncoding != null) {
resultRowsDecoder.withEncoding(responseEncoding);
}

if (results.getColumns() != null) {
resultRowsDecoder.withColumns(results.getColumns());
}

String setAuthorizationUser = headers.get(TRINO_HEADERS.responseSetAuthorizationUser());
if (setAuthorizationUser != null) {
this.setAuthorizationUser.set(setAuthorizationUser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ private void assertDataEquals(List<Column> columns, QueryData left, QueryData ri
return;
}

try (ResultRowsDecoder decoder = new ResultRowsDecoder().withColumns(columns)) {
assertThat(decoder.toRows(left))
.containsAll(decoder.toRows(right));
try (ResultRowsDecoder decoder = new ResultRowsDecoder()) {
assertThat(decoder.toRows(columns, left))
.containsAll(decoder.toRows(columns, right));
}
catch (Exception e) {
fail(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,8 @@ private static void assertEquals(QueryData left, QueryData right)

private static Iterable<List<Object>> decodeData(QueryData data)
{
try (ResultRowsDecoder decoder = new ResultRowsDecoder().withColumns(COLUMNS_LIST)) {
if (data instanceof EncodedQueryData encodedQueryData) {
return decoder.withEncoding(encodedQueryData.getEncoding()).toRows(data);
}
return decoder.toRows(data);
try (ResultRowsDecoder decoder = new ResultRowsDecoder()) {
return decoder.toRows(COLUMNS_LIST, data);
}
catch (Exception e) {
return fail(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ public class TestQueryResultsSerialization
{
private static final List<Column> COLUMNS = ImmutableList.of(new Column("_col0", BIGINT, new ClientTypeSignature("bigint")));

private static final ResultRowsDecoder DATA_DECODER = new ResultRowsDecoder()
.withColumns(COLUMNS);

// As close as possible to the server mapper (client mapper differs)
private static final io.airlift.json.JsonCodec<QueryResults> SERVER_CODEC = new JsonCodecFactory(new ObjectMapperProvider()
.withModules(Set.of(new QueryDataJacksonModule())))
Expand Down Expand Up @@ -97,6 +94,7 @@ public void testNullDataSerialization()

@Test
public void testEmptyArraySerialization()
throws Exception
{
testRoundTrip(RawQueryData.of(ImmutableList.of()), "[]");

Expand All @@ -107,20 +105,22 @@ public void testEmptyArraySerialization()

@Test
public void testSerialization()
throws Exception
{
QueryData values = RawQueryData.of(ImmutableList.of(ImmutableList.of(1L), ImmutableList.of(5L)));
testRoundTrip(values, "[[1],[5]]");
}

private void testRoundTrip(QueryData results, String expectedDataRepresentation)
throws Exception
{
assertThat(serialize(results))
.isEqualToIgnoringWhitespace(queryResultsJson(expectedDataRepresentation));

String serialized = serialize(results);
try {
assertThat(DATA_DECODER.toRows(CLIENT_CODEC.fromJson(serialized).getData()))
.containsAll(DATA_DECODER.toRows(results));
try (ResultRowsDecoder decoder = new ResultRowsDecoder()) {
assertThat(decoder.toRows(COLUMNS, CLIENT_CODEC.fromJson(serialized).getData()))
.containsAll(decoder.toRows(COLUMNS, results));
}
catch (JsonProcessingException e) {
throw new UncheckedIOException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ public void testFirstResponseColumns()

QueryResults results = data.orElseThrow();

try (ResultRowsDecoder decoder = new ResultRowsDecoder().withColumns(results.getColumns())) {
assertThat(decoder.toRows(results.getData())).containsOnly(ImmutableList.of("memory"), ImmutableList.of("system"));
try (ResultRowsDecoder decoder = new ResultRowsDecoder()) {
assertThat(decoder.toRows(results)).containsOnly(ImmutableList.of("memory"), ImmutableList.of("system"));
}
}

Expand Down Expand Up @@ -209,8 +209,8 @@ public void testQuery()
.peek(result -> assertThat(result.getError()).isNull())
.peek(results -> {
if (results.getData() != null) {
try (ResultRowsDecoder decoder = new ResultRowsDecoder().withColumns(results.getColumns())) {
data.addAll(decoder.toRows(results.getData()));
try (ResultRowsDecoder decoder = new ResultRowsDecoder()) {
data.addAll(decoder.toRows(results));
}
catch (Exception e) {
throw new RuntimeException(e);
Expand Down

0 comments on commit 3f978dc

Please sign in to comment.