Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client cleanup / improvements #23569

Merged
merged 7 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ public CompressedQueryDataDecoder(QueryDataDecoder delegate)
this.delegate = requireNonNull(delegate, "delegate is null");
}

abstract InputStream decompress(InputStream inputStream, int uncompressedSize)
abstract InputStream decompress(InputStream inputStream, int expectedDecompressedSize)
throws IOException;

@Override
public Iterable<List<Object>> decode(InputStream stream, DataAttributes metadata)
throws IOException
{
Optional<Integer> uncompressedSize = metadata.getOptional(DataAttribute.UNCOMPRESSED_SIZE, Integer.class);
if (uncompressedSize.isPresent()) {
return delegate.decode(decompress(stream, uncompressedSize.get()), metadata);
Optional<Integer> expectedDecompressedSize = metadata.getOptional(DataAttribute.UNCOMPRESSED_SIZE, Integer.class);
if (expectedDecompressedSize.isPresent()) {
return delegate.decode(decompress(stream, expectedDecompressedSize.get()), metadata);
}
// Data not compressed - below threshold
return delegate.decode(stream, metadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
*/
package io.trino.client.spooling.encoding;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.trino.client.Column;
Expand Down Expand Up @@ -105,29 +103,4 @@ public String encoding()
return super.encoding() + "+lz4";
}
}

public static class JsonSchema
{
private final int[] offsets;
private final int step;

@JsonCreator
public JsonSchema(int[] offsets, int step)
{
this.offsets = offsets;
this.step = step;
}

@JsonProperty("offsets")
public int[] getOffsets()
{
return offsets;
}

@JsonProperty("step")
public int getStep()
{
return step;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ public Lz4QueryDataDecoder(QueryDataDecoder delegate)
}

@Override
InputStream decompress(InputStream stream, int uncompressedSize)
InputStream decompress(InputStream stream, int expectedDecompressedSize)
throws IOException
{
Lz4Decompressor decompressor = new Lz4Decompressor();
byte[] bytes = ByteStreams.toByteArray(stream);
byte[] output = new byte[uncompressedSize];
byte[] output = new byte[expectedDecompressedSize];

int decompressedSize = decompressor.decompress(bytes, 0, bytes.length, output, 0, output.length);
if (decompressedSize != uncompressedSize) {
throw new IOException(format("Decompressed size does not match expected segment size, expected %d, got %d", decompressedSize, uncompressedSize));
if (decompressedSize != expectedDecompressedSize) {
throw new IOException(format("Decompressed size does not match expected segment size, expected %d, got %d", decompressedSize, expectedDecompressedSize));
}
return new ByteArrayInputStream(output);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
*/
package io.trino.client.spooling.encoding;

import io.airlift.compress.zstd.ZstdInputStream;
import com.google.common.io.ByteStreams;
import io.airlift.compress.zstd.ZstdDecompressor;
import io.trino.client.QueryDataDecoder;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;

import static java.lang.String.format;

public class ZstdQueryDataDecoder
extends CompressedQueryDataDecoder
{
Expand All @@ -27,9 +32,18 @@ public ZstdQueryDataDecoder(QueryDataDecoder delegate)
}

@Override
InputStream decompress(InputStream inputStream, int uncompressedSize)
InputStream decompress(InputStream stream, int expectedDecompressedSize)
throws IOException
{
return new ZstdInputStream(inputStream);
ZstdDecompressor decompressor = new ZstdDecompressor();
byte[] bytes = ByteStreams.toByteArray(stream);
byte[] output = new byte[expectedDecompressedSize];

int decompressedSize = decompressor.decompress(bytes, 0, bytes.length, output, 0, output.length);
if (decompressedSize != expectedDecompressedSize) {
throw new IOException(format("Decompressed size does not match expected segment size, expected %d, got %d", decompressedSize, expectedDecompressedSize));
}
return new ByteArrayInputStream(output);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ abstract class AbstractTrinoResultSet

private static final int MAX_DATETIME_PRECISION = 12;

private static final DateTimeZone CURRENT_TIME_ZONE = DateTimeZone.forID(ZoneId.systemDefault().getId());
private static final TimeZone CURRENT_JAVA_TIME_ZONE = TimeZone.getTimeZone(ZoneId.of(CURRENT_TIME_ZONE.getID()));

private static final int MILLISECONDS_PER_SECOND = 1000;
private static final int MILLISECONDS_PER_MINUTE = 60 * MILLISECONDS_PER_SECOND;
private static final long NANOSECONDS_PER_SECOND = 1_000_000_000;
Expand Down Expand Up @@ -150,8 +153,8 @@ abstract class AbstractTrinoResultSet
TypeConversions.builder()
.add("decimal", String.class, BigDecimal.class, AbstractTrinoResultSet::parseBigDecimal)
.add("varbinary", byte[].class, String.class, value -> "0x" + BaseEncoding.base16().encode(value))
.add("date", String.class, Date.class, string -> parseDate(string, DateTimeZone.forID(ZoneId.systemDefault().getId())))
.add("date", String.class, java.time.LocalDate.class, string -> parseDate(string, DateTimeZone.forID(ZoneId.systemDefault().getId())).toLocalDate())
.add("date", String.class, Date.class, string -> parseDate(string, CURRENT_TIME_ZONE, CURRENT_JAVA_TIME_ZONE))
.add("date", String.class, java.time.LocalDate.class, string -> parseDate(string, CURRENT_TIME_ZONE, CURRENT_JAVA_TIME_ZONE).toLocalDate())
.add("time", String.class, Time.class, string -> parseTime(string, ZoneId.systemDefault()))
.add("time with time zone", String.class, Time.class, AbstractTrinoResultSet::parseTimeWithTimeZone)
.add("timestamp", String.class, Timestamp.class, string -> parseTimestampAsSqlTimestamp(string, ZoneId.systemDefault()))
Expand Down Expand Up @@ -179,8 +182,6 @@ abstract class AbstractTrinoResultSet
return result;
})
.build();

private final DateTimeZone resultTimeZone;
protected final Iterator<List<Object>> results;
private final Map<String, Integer> fieldMap;
private final List<ColumnInfo> columnInfoList;
Expand All @@ -193,8 +194,6 @@ abstract class AbstractTrinoResultSet
AbstractTrinoResultSet(Optional<Statement> statement, List<Column> columns, Iterator<List<Object>> results)
{
this.statement = requireNonNull(statement, "statement is null");
this.resultTimeZone = DateTimeZone.forID(ZoneId.systemDefault().getId());

requireNonNull(columns, "columns is null");
this.fieldMap = getFieldMap(columns);
this.columnInfoList = getColumnInfo(columns);
Expand Down Expand Up @@ -333,10 +332,10 @@ public byte[] getBytes(int columnIndex)
public Date getDate(int columnIndex)
throws SQLException
{
return getDate(columnIndex, resultTimeZone);
return getDate(columnIndex, CURRENT_TIME_ZONE, CURRENT_JAVA_TIME_ZONE);
}

private Date getDate(int columnIndex, DateTimeZone localTimeZone)
private Date getDate(int columnIndex, DateTimeZone localTimeZone, TimeZone localJavaTimeZone)
throws SQLException
{
Object value = column(columnIndex);
Expand All @@ -345,16 +344,16 @@ private Date getDate(int columnIndex, DateTimeZone localTimeZone)
}

try {
return parseDate(String.valueOf(value), localTimeZone);
return parseDate(String.valueOf(value), localTimeZone, localJavaTimeZone);
}
catch (IllegalArgumentException e) {
throw new SQLException("Expected value to be a date but is: " + value, e);
}
}

private static Date parseDate(String value, DateTimeZone localTimeZone)
private static Date parseDate(String value, DateTimeZone localTimeZone, TimeZone localJavaTimeZone)
{
LocalDate localDate = DATE_FORMATTER.parseLocalDate(String.valueOf(value));
LocalDate localDate = DATE_FORMATTER.parseLocalDate(value);
long millis = localDate.toDateTimeAtStartOfDay(localTimeZone).getMillis();
if (millis >= START_OF_MODERN_ERA_SECONDS * MILLISECONDS_PER_SECOND) {
return new Date(millis);
Expand All @@ -367,9 +366,8 @@ private static Date parseDate(String value, DateTimeZone localTimeZone)
// expensive GregorianCalendar; note that Joda also has a chronology that works for
// older dates, but it uses a slightly different algorithm and yields results that
// are not compatible with java.sql.Date.
LocalDate preGregorianDate = DATE_FORMATTER.parseLocalDate(String.valueOf(value));
Calendar calendar = new GregorianCalendar(preGregorianDate.getYear(), preGregorianDate.getMonthOfYear() - 1, preGregorianDate.getDayOfMonth());
calendar.setTimeZone(TimeZone.getTimeZone(ZoneId.of(localTimeZone.getID())));
Calendar calendar = new GregorianCalendar(localDate.getYear(), localDate.getMonthOfYear() - 1, localDate.getDayOfMonth());
calendar.setTimeZone(localJavaTimeZone);

return new Date(calendar.getTimeInMillis());
}
Expand All @@ -378,7 +376,7 @@ private static Date parseDate(String value, DateTimeZone localTimeZone)
public Time getTime(int columnIndex)
throws SQLException
{
return getTime(columnIndex, resultTimeZone);
return getTime(columnIndex, CURRENT_TIME_ZONE);
}

private Time getTime(int columnIndex, DateTimeZone localTimeZone)
Expand Down Expand Up @@ -415,7 +413,7 @@ private Time getTime(int columnIndex, DateTimeZone localTimeZone)
public Timestamp getTimestamp(int columnIndex)
throws SQLException
{
return getTimestamp(columnIndex, resultTimeZone);
return getTimestamp(columnIndex, CURRENT_TIME_ZONE);
}

private Timestamp getTimestamp(int columnIndex, DateTimeZone localTimeZone)
Expand Down Expand Up @@ -1351,7 +1349,7 @@ public Array getArray(String columnLabel)
public Date getDate(int columnIndex, Calendar cal)
throws SQLException
{
return getDate(columnIndex, DateTimeZone.forTimeZone(cal.getTimeZone()));
return getDate(columnIndex, DateTimeZone.forTimeZone(cal.getTimeZone()), cal.getTimeZone());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
public class Lz4QueryDataEncoder
extends CompressedQueryDataEncoder
{
private static final int COMPRESSION_THRESHOLD = 2048;
private static final int COMPRESSION_THRESHOLD = 8192;

public Lz4QueryDataEncoder(QueryDataEncoder delegate)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
public class ZstdQueryDataEncoder
extends CompressedQueryDataEncoder
{
private static final int COMPRESSION_THRESHOLD = 2048;
private static final int COMPRESSION_THRESHOLD = 8192;

public ZstdQueryDataEncoder(QueryDataEncoder delegate)
{
Expand Down
Loading