diff --git a/plugin/trino-snowflake/src/main/java/io/trino/plugin/snowflake/SnowflakeClient.java b/plugin/trino-snowflake/src/main/java/io/trino/plugin/snowflake/SnowflakeClient.java index 7b549bc114f4..31fe592cea5e 100644 --- a/plugin/trino-snowflake/src/main/java/io/trino/plugin/snowflake/SnowflakeClient.java +++ b/plugin/trino-snowflake/src/main/java/io/trino/plugin/snowflake/SnowflakeClient.java @@ -14,7 +14,6 @@ package io.trino.plugin.snowflake; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; import io.trino.plugin.base.aggregation.AggregateFunctionRewriter; @@ -79,7 +78,6 @@ import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; -import java.math.RoundingMode; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; @@ -108,14 +106,48 @@ import static io.trino.plugin.jdbc.CaseSensitivity.CASE_INSENSITIVE; import static io.trino.plugin.jdbc.CaseSensitivity.CASE_SENSITIVE; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; +import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; +import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping; +import static io.trino.plugin.jdbc.StandardColumnMappings.booleanWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.decimalColumnMapping; +import static io.trino.plugin.jdbc.StandardColumnMappings.doubleColumnMapping; +import static io.trino.plugin.jdbc.StandardColumnMappings.doubleWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.integerColumnMapping; +import static io.trino.plugin.jdbc.StandardColumnMappings.integerWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.realColumnMapping; +import static io.trino.plugin.jdbc.StandardColumnMappings.realWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.shortDecimalWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.smallintColumnMapping; +import static io.trino.plugin.jdbc.StandardColumnMappings.smallintWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintColumnMapping; +import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.toTrinoTimestamp; +import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryColumnMapping; +import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.varcharWriteFunction; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DecimalType.createDecimalType; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimestampType.createTimestampType; +import static io.trino.spi.type.TimestampWithTimeZoneType.createTimestampWithTimeZoneType; import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_SECOND; import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MILLISECOND; import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND; +import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; import static io.trino.spi.type.VarcharType.createVarcharType; import static java.lang.String.format; import static java.lang.String.join; +import static java.math.RoundingMode.UNNECESSARY; import static java.util.Objects.requireNonNull; public class SnowflakeClient @@ -133,16 +165,6 @@ public class SnowflakeClient private static final TimeZone UTC_TZ = TimeZone.getTimeZone(ZoneId.of("UTC")); private final AggregateFunctionRewriter aggregateFunctionRewriter; - private interface WriteMappingFunction - { - WriteMapping convert(Type type); - } - - private interface ColumnMappingFunction - { - Optional convert(JdbcTypeHandle typeHandle); - } - @Inject public SnowflakeClient( BaseJdbcConfig config, @@ -189,46 +211,56 @@ public Optional toColumnMapping(ConnectorSession session, Connect jdbcTypeName = jdbcTypeName.toLowerCase(Locale.ENGLISH); int type = typeHandle.jdbcType(); - // Mappings for JDBC column types to internal Trino types - final Map standardColumnMappings = ImmutableMap.builder() - .put(Types.BOOLEAN, StandardColumnMappings.booleanColumnMapping()) - .put(Types.TINYINT, StandardColumnMappings.tinyintColumnMapping()) - .put(Types.SMALLINT, StandardColumnMappings.smallintColumnMapping()) - .put(Types.INTEGER, StandardColumnMappings.integerColumnMapping()) - .put(Types.BIGINT, StandardColumnMappings.bigintColumnMapping()) - .put(Types.REAL, StandardColumnMappings.realColumnMapping()) - .put(Types.DOUBLE, StandardColumnMappings.doubleColumnMapping()) - .put(Types.FLOAT, StandardColumnMappings.doubleColumnMapping()) - .put(Types.BINARY, StandardColumnMappings.varbinaryColumnMapping()) - .put(Types.VARBINARY, StandardColumnMappings.varbinaryColumnMapping()) - .put(Types.LONGVARBINARY, StandardColumnMappings.varbinaryColumnMapping()) - .buildOrThrow(); - - ColumnMapping columnMap = standardColumnMappings.get(type); - if (columnMap != null) { - return Optional.of(columnMap); - } - - final Map snowflakeColumnMappings = ImmutableMap.builder() - .put("time", handle -> Optional.of(timeColumnMapping(handle.requiredDecimalDigits()))) - .put("timestampntz", handle -> Optional.of(timestampColumnMapping(handle.requiredDecimalDigits()))) - .put("timestamptz", handle -> Optional.of(timestampTZColumnMapping(handle.requiredDecimalDigits()))) - .put("date", handle -> Optional.of(ColumnMapping.longMapping(DateType.DATE, (resultSet, columnIndex) -> LocalDate.ofEpochDay(resultSet.getLong(columnIndex)).toEpochDay(), snowFlakeDateWriter()))) - .put("varchar", handle -> Optional.of(varcharColumnMapping(handle.requiredColumnSize(), typeHandle.caseSensitivity()))) - .put("number", handle -> { - int decimalDigits = handle.requiredDecimalDigits(); - int precision = handle.requiredColumnSize() + Math.max(-decimalDigits, 0); - if (precision > 38) { - return Optional.empty(); - } - return Optional.of(columnMappingPushdown( - StandardColumnMappings.decimalColumnMapping(DecimalType.createDecimalType(precision, Math.max(decimalDigits, 0)), RoundingMode.UNNECESSARY))); - }) - .buildOrThrow(); - - ColumnMappingFunction columnMappingFunction = snowflakeColumnMappings.get(jdbcTypeName); - if (columnMappingFunction != null) { - return columnMappingFunction.convert(typeHandle); + switch (type) { + case Types.BOOLEAN: + return Optional.of(booleanColumnMapping()); + case Types.TINYINT: + return Optional.of(tinyintColumnMapping()); + case Types.SMALLINT: + return Optional.of(smallintColumnMapping()); + case Types.INTEGER: + return Optional.of(integerColumnMapping()); + case Types.BIGINT: + return Optional.of(bigintColumnMapping()); + case Types.REAL: + return Optional.of(realColumnMapping()); + case Types.FLOAT: + case Types.DOUBLE: + return Optional.of(doubleColumnMapping()); + case Types.NUMERIC: + case Types.DECIMAL: { + int precision = typeHandle.requiredColumnSize(); + int scale = typeHandle.requiredDecimalDigits(); + if (precision > 38) { + break; + } + DecimalType decimalType = createDecimalType(precision, scale); + return Optional.of(decimalColumnMapping(decimalType, UNNECESSARY)); + } + case Types.VARCHAR: + if (jdbcTypeName.equals("varchar")) { + return Optional.of(varcharColumnMapping(typeHandle.requiredColumnSize(), typeHandle.caseSensitivity())); + } + // Some other Snowflake types (ARRAY, VARIANT, GEOMETRY, etc.) are also mapped to Types.VARCHAR, but they're unsupported. + break; + case Types.BINARY: + // Multiple Snowflake types are mapped into Types.BINARY + if (jdbcTypeName.equals("binary")) { + return Optional.of(varbinaryColumnMapping()); + } + // Some other Snowflake types (GEOMETRY in some cases, etc.) are also mapped to Types.BINARY, but they're unsupported. + break; + case Types.VARBINARY: + case Types.LONGVARBINARY: + return Optional.of(varbinaryColumnMapping()); + case Types.DATE: + return Optional.of(ColumnMapping.longMapping(DateType.DATE, ResultSet::getLong, snowFlakeDateWriteFunction())); + case Types.TIME: + return Optional.of(timeColumnMapping(typeHandle.requiredDecimalDigits())); + case Types.TIMESTAMP: + return Optional.of(timestampColumnMapping(typeHandle.requiredDecimalDigits())); + case Types.TIMESTAMP_WITH_TIMEZONE: + return Optional.of(timestampWithTimeZoneColumnMapping(typeHandle.requiredDecimalDigits())); } return Optional.empty(); @@ -237,44 +269,62 @@ public Optional toColumnMapping(ConnectorSession session, Connect @Override public WriteMapping toWriteMapping(ConnectorSession session, Type type) { - Class myClass = type.getClass(); - String simple = myClass.getSimpleName(); - - // Mappings for internal Trino types to JDBC column types - final Map standardWriteMappings = ImmutableMap.builder() - .put("BooleanType", WriteMapping.booleanMapping("boolean", StandardColumnMappings.booleanWriteFunction())) - .put("BigintType", WriteMapping.longMapping("number(19)", StandardColumnMappings.bigintWriteFunction())) - .put("IntegerType", WriteMapping.longMapping("number(10)", StandardColumnMappings.integerWriteFunction())) - .put("SmallintType", WriteMapping.longMapping("number(5)", StandardColumnMappings.smallintWriteFunction())) - .put("TinyintType", WriteMapping.longMapping("number(3)", StandardColumnMappings.tinyintWriteFunction())) - .put("DoubleType", WriteMapping.doubleMapping("double precision", StandardColumnMappings.doubleWriteFunction())) - .put("RealType", WriteMapping.longMapping("real", StandardColumnMappings.realWriteFunction())) - .put("VarbinaryType", WriteMapping.sliceMapping("varbinary", StandardColumnMappings.varbinaryWriteFunction())) - .put("DateType", WriteMapping.longMapping("date", snowFlakeDateWriter())) - .buildOrThrow(); - - WriteMapping writeMapping = standardWriteMappings.get(simple); - if (writeMapping != null) { - return writeMapping; + if (type == BOOLEAN) { + return WriteMapping.booleanMapping("BOOLEAN", booleanWriteFunction()); } - - final Map snowflakeWriteMappings = ImmutableMap.builder() - .put("TimeType", writeType -> WriteMapping.longMapping(format("time(%s)", ((TimeType) writeType).getPrecision()), timeWriteFunction(((TimeType) writeType).getPrecision()))) - .put("ShortTimestampType", SnowflakeClient::snowFlakeTimestampWriter) - .put("ShortTimestampWithTimeZoneType", SnowflakeClient::snowFlakeTimestampWithTZWriter) - .put("LongTimestampType", SnowflakeClient::snowFlakeTimestampWithTZWriter) - .put("LongTimestampWithTimeZoneType", SnowflakeClient::snowFlakeTimestampWithTZWriter) - .put("VarcharType", SnowflakeClient::snowFlakeVarCharWriter) - .put("CharType", SnowflakeClient::snowFlakeCharWriter) - .put("LongDecimalType", SnowflakeClient::snowFlakeDecimalWriter) - .put("ShortDecimalType", SnowflakeClient::snowFlakeDecimalWriter) - .buildOrThrow(); - - WriteMappingFunction writeMappingFunction = snowflakeWriteMappings.get(simple); - if (writeMappingFunction != null) { - return writeMappingFunction.convert(type); + if (type == TINYINT) { + return WriteMapping.longMapping("NUMBER(3, 0)", tinyintWriteFunction()); + } + if (type == SMALLINT) { + return WriteMapping.longMapping("NUMBER(5, 0)", smallintWriteFunction()); + } + if (type == INTEGER) { + return WriteMapping.longMapping("NUMBER(10, 0)", integerWriteFunction()); + } + if (type == BIGINT) { + return WriteMapping.longMapping("NUMBER(19, 0)", bigintWriteFunction()); + } + if (type == REAL) { + return WriteMapping.longMapping("DOUBLE", realWriteFunction()); + } + if (type == DOUBLE) { + return WriteMapping.doubleMapping("DOUBLE", doubleWriteFunction()); + } + if (type instanceof DecimalType decimalType) { + String dataType = format("NUMBER(%s, %s)", decimalType.getPrecision(), decimalType.getScale()); + if (decimalType.isShort()) { + return WriteMapping.longMapping(dataType, shortDecimalWriteFunction(decimalType)); + } + return WriteMapping.objectMapping(dataType, longDecimalWriteFunction(decimalType)); + } + if (type instanceof CharType charType) { + return WriteMapping.sliceMapping("VARCHAR(" + charType.getLength() + ")", charWriteFunction(charType)); + } + if (type instanceof VarcharType varcharType) { + String dataType; + if (varcharType.isUnbounded()) { + dataType = "VARCHAR"; + } + else { + dataType = "VARCHAR(" + varcharType.getBoundedLength() + ")"; + } + return WriteMapping.sliceMapping(dataType, varcharWriteFunction()); + } + if (type == VARBINARY) { + return WriteMapping.sliceMapping("VARBINARY", varbinaryWriteFunction()); + } + if (type == DATE) { + return WriteMapping.longMapping("DATE", snowFlakeDateWriteFunction()); + } + if (type instanceof TimeType timeType) { + return WriteMapping.longMapping(format("TIME(%s)", timeType.getPrecision()), timeWriteFunction(timeType.getPrecision())); + } + if (type instanceof TimestampType timestampType) { + return snowflakeTimestampWriteMapping(timestampType.getPrecision()); + } + if (type instanceof TimestampWithTimeZoneType timestampWithTimeZoneType) { + return snowflakeTimestampWithTimeZoneWriteMapping(timestampWithTimeZoneType.getPrecision()); } - throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); } @@ -376,15 +426,6 @@ public void setColumnType(ConnectorSession session, JdbcTableHandle handle, Jdbc throw new TrinoException(NOT_SUPPORTED, "This connector does not support setting column types"); } - private static ColumnMapping columnMappingPushdown(ColumnMapping mapping) - { - if (mapping.getPredicatePushdownController() == PredicatePushdownController.DISABLE_PUSHDOWN) { - throw new TrinoException(NOT_SUPPORTED, "mapping.getPredicatePushdownController() is DISABLE_PUSHDOWN. Type was " + mapping.getType()); - } - - return ColumnMapping.mapping(mapping.getType(), mapping.getReadFunction(), mapping.getWriteFunction(), PredicatePushdownController.FULL_PUSHDOWN); - } - private static ColumnMapping timeColumnMapping(int precision) { checkArgument(precision <= MAX_SUPPORTED_TEMPORAL_PRECISION, "The max timestamp precision in Snowflake is " + MAX_SUPPORTED_TEMPORAL_PRECISION); @@ -398,28 +439,17 @@ private static ColumnMapping timeColumnMapping(int precision) PredicatePushdownController.FULL_PUSHDOWN); } - private static ColumnMapping timestampTZColumnMapping(int precision) + private static ColumnMapping timestampWithTimeZoneColumnMapping(int precision) { - if (precision <= 3) { - return ColumnMapping.longMapping(TimestampWithTimeZoneType.createTimestampWithTimeZoneType(precision), + if (precision <= TimestampWithTimeZoneType.MAX_SHORT_PRECISION) { + return ColumnMapping.longMapping(createTimestampWithTimeZoneType(precision), (resultSet, columnIndex) -> { ZonedDateTime timestamp = SNOWFLAKE_DATETIME_FORMATTER.parse(resultSet.getString(columnIndex), ZonedDateTime::from); return DateTimeEncoding.packDateTimeWithZone(timestamp.toInstant().toEpochMilli(), timestamp.getZone().getId()); }, - timestampWithTZWriter(), PredicatePushdownController.FULL_PUSHDOWN); - } - else { - return ColumnMapping.objectMapping(TimestampWithTimeZoneType.createTimestampWithTimeZoneType(precision), longTimestampWithTimezoneReadFunction(), longTimestampWithTZWriteFunction()); + shortTimestampWithTimeZoneWriteFunction(), PredicatePushdownController.FULL_PUSHDOWN); } - } - - private static LongWriteFunction timestampWithTZWriter() - { - return (statement, index, encodedTimeWithZone) -> { - Instant timeI = Instant.ofEpochMilli(DateTimeEncoding.unpackMillisUtc(encodedTimeWithZone)); - ZoneId zone = ZoneId.of(DateTimeEncoding.unpackZoneKey(encodedTimeWithZone).getId()); - statement.setString(index, SNOWFLAKE_DATETIME_FORMATTER.format(timeI.atZone(zone))); - }; + return ColumnMapping.objectMapping(createTimestampWithTimeZoneType(precision), longTimestampWithTimezoneReadFunction(), longTimestampWithTimeZoneWriteFunction()); } private static ObjectReadFunction longTimestampWithTimezoneReadFunction() @@ -432,27 +462,12 @@ private static ObjectReadFunction longTimestampWithTimezoneReadFunction() }); } - private static ObjectWriteFunction longTimestampWithTZWriteFunction() - { - return ObjectWriteFunction.of(LongTimestampWithTimeZone.class, (statement, index, value) -> { - long epoMilli = value.getEpochMillis(); - long epoSeconds = Math.floorDiv(epoMilli, Timestamps.MILLISECONDS_PER_SECOND); - long adjNano = (long) Math.floorMod(epoMilli, Timestamps.MILLISECONDS_PER_SECOND) * Timestamps.NANOSECONDS_PER_MILLISECOND + value.getPicosOfMilli() / Timestamps.PICOSECONDS_PER_NANOSECOND; - ZoneId zone = TimeZoneKey.getTimeZoneKey(value.getTimeZoneKey()).getZoneId(); - Instant timeI = Instant.ofEpochSecond(epoSeconds, adjNano); - statement.setString(index, SNOWFLAKE_DATETIME_FORMATTER.format(ZonedDateTime.ofInstant(timeI, zone))); - }); - } - private static ColumnMapping timestampColumnMapping(int precision) { - // <= 6 fits into a long - if (precision <= 6) { - return ColumnMapping.longMapping(TimestampType.createTimestampType(precision), (resultSet, columnIndex) -> StandardColumnMappings.toTrinoTimestamp(TimestampType.createTimestampType(precision), toLocalDateTime(resultSet, columnIndex)), timestampWriteFunction()); + if (precision <= TimestampType.MAX_SHORT_PRECISION) { + return ColumnMapping.longMapping(createTimestampType(precision), (resultSet, columnIndex) -> toTrinoTimestamp(createTimestampType(precision), toLocalDateTime(resultSet, columnIndex)), shortTimestampWriteFunction()); } - - // Too big. Put it in an object - return ColumnMapping.objectMapping(TimestampType.createTimestampType(precision), longTimestampReader(), longTimestampWriter(precision)); + return ColumnMapping.objectMapping(createTimestampType(precision), longTimestampReader(), longTimestampWriteFunction(precision)); } private static LocalDateTime toLocalDateTime(ResultSet resultSet, int columnIndex) @@ -472,7 +487,7 @@ private static ObjectReadFunction longTimestampReader() Timestamp ts = resultSet.getTimestamp(columnIndex, calendar); long epochMillis = ts.getTime(); int nanosInTheSecond = ts.getNanos(); - int nanosInTheMilli = nanosInTheSecond % Timestamps.NANOSECONDS_PER_MILLISECOND; + int nanosInTheMilli = nanosInTheSecond % NANOSECONDS_PER_MILLISECOND; long micro = epochMillis * Timestamps.MICROSECONDS_PER_MILLISECOND + (nanosInTheMilli / Timestamps.NANOSECONDS_PER_MICROSECOND); int picosOfMicro = nanosInTheMilli % 1000 * 1000; return new LongTimestamp(micro, picosOfMicro); @@ -498,96 +513,47 @@ private static ColumnMapping varcharColumnMapping(int varcharLength, Optional { - long epochMilli = value.getEpochMillis(); - long epochSecond = Math.floorDiv(epochMilli, MILLISECONDS_PER_SECOND); - int nanosOfSecond = Math.floorMod(epochMilli, MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND + value.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND; - ZoneId zone = TimeZoneKey.getTimeZoneKey(value.getTimeZoneKey()).getZoneId(); - Instant instant = Instant.ofEpochSecond(epochSecond, nanosOfSecond); - statement.setString(index, SNOWFLAKE_DATETIME_FORMATTER.format(ZonedDateTime.ofInstant(instant, zone))); - }); - } - - private static WriteMapping snowFlakeDecimalWriter(Type type) - { - DecimalType decimalType = (DecimalType) type; - String dataType = format("decimal(%s, %s)", decimalType.getPrecision(), decimalType.getScale()); - - if (decimalType.isShort()) { - return WriteMapping.longMapping(dataType, StandardColumnMappings.shortDecimalWriteFunction(decimalType)); - } - return WriteMapping.objectMapping(dataType, StandardColumnMappings.longDecimalWriteFunction(decimalType)); - } - - private static LongWriteFunction snowFlakeDateWriter() + private static LongWriteFunction snowFlakeDateWriteFunction() { return (statement, index, day) -> statement.setString(index, SNOWFLAKE_DATE_FORMATTER.format(LocalDate.ofEpochDay(day))); } - private static WriteMapping snowFlakeCharWriter(Type type) - { - CharType charType = (CharType) type; - return WriteMapping.sliceMapping("char(" + charType.getLength() + ")", charWriteFunction(charType)); - } - - private static WriteMapping snowFlakeVarCharWriter(Type type) - { - String dataType; - VarcharType varcharType = (VarcharType) type; - - if (varcharType.isUnbounded()) { - dataType = "varchar"; - } - else { - dataType = "varchar(" + varcharType.getBoundedLength() + ")"; - } - return WriteMapping.sliceMapping(dataType, StandardColumnMappings.varcharWriteFunction()); - } - private static SliceWriteFunction charWriteFunction(CharType charType) { return (statement, index, value) -> statement.setString(index, Chars.padSpaces(value, charType).toStringUtf8()); } - private static WriteMapping snowFlakeTimestampWriter(Type type) + private static WriteMapping snowflakeTimestampWriteMapping(int precision) { - TimestampType timestampType = (TimestampType) type; - checkArgument( - timestampType.getPrecision() <= MAX_SUPPORTED_TEMPORAL_PRECISION, - "The max timestamp precision in Snowflake is " + MAX_SUPPORTED_TEMPORAL_PRECISION); - - if (timestampType.isShort()) { - return WriteMapping.longMapping(format("timestamp_ntz(%d)", timestampType.getPrecision()), timestampWriteFunction()); + checkArgument(precision <= MAX_SUPPORTED_TEMPORAL_PRECISION, "The max timestamp precision in Snowflake is " + MAX_SUPPORTED_TEMPORAL_PRECISION); + if (precision <= TimestampType.MAX_SHORT_PRECISION) { + return WriteMapping.longMapping(format("timestamp_ntz(%d)", precision), shortTimestampWriteFunction()); } - return WriteMapping.objectMapping(format("timestamp_ntz(%d)", timestampType.getPrecision()), longTimestampWriter(timestampType.getPrecision())); + return WriteMapping.objectMapping(format("timestamp_ntz(%d)", precision), longTimestampWriteFunction(precision)); } - private static LongWriteFunction timestampWriteFunction() + private static LongWriteFunction shortTimestampWriteFunction() { return (statement, index, value) -> statement.setString(index, StandardColumnMappings.fromTrinoTimestamp(value).toString()); } - private static ObjectWriteFunction longTimestampWriter(int precision) + private static ObjectWriteFunction longTimestampWriteFunction(int precision) { return ObjectWriteFunction.of( LongTimestamp.class, (statement, index, value) -> statement.setString(index, SNOWFLAKE_TIMESTAMP_FORMATTER.format(StandardColumnMappings.fromLongTrinoTimestamp(value, precision)))); } - private static WriteMapping snowFlakeTimestampWithTZWriter(Type type) + private static WriteMapping snowflakeTimestampWithTimeZoneWriteMapping(int precision) { - TimestampWithTimeZoneType timeTZType = (TimestampWithTimeZoneType) type; - - checkArgument(timeTZType.getPrecision() <= MAX_SUPPORTED_TEMPORAL_PRECISION, "Max Snowflake precision is is " + MAX_SUPPORTED_TEMPORAL_PRECISION); - if (timeTZType.isShort()) { - return WriteMapping.longMapping(format("timestamp_tz(%d)", timeTZType.getPrecision()), timestampWithTimezoneWriteFunction()); + checkArgument(precision <= MAX_SUPPORTED_TEMPORAL_PRECISION, "Max Snowflake precision is is " + MAX_SUPPORTED_TEMPORAL_PRECISION); + if (precision <= TimestampWithTimeZoneType.MAX_SHORT_PRECISION) { + return WriteMapping.longMapping(format("timestamp_tz(%d)", precision), shortTimestampWithTimeZoneWriteFunction()); } - return WriteMapping.objectMapping(format("timestamp_tz(%d)", timeTZType.getPrecision()), longTimestampWithTzWriteFunction()); + return WriteMapping.objectMapping(format("timestamp_tz(%d)", precision), longTimestampWithTimeZoneWriteFunction()); } - private static LongWriteFunction timestampWithTimezoneWriteFunction() + private static LongWriteFunction shortTimestampWithTimeZoneWriteFunction() { return (statement, index, encodedTimeWithZone) -> { Instant instant = Instant.ofEpochMilli(DateTimeEncoding.unpackMillisUtc(encodedTimeWithZone)); @@ -595,4 +561,16 @@ private static LongWriteFunction timestampWithTimezoneWriteFunction() statement.setString(index, SNOWFLAKE_DATETIME_FORMATTER.format(instant.atZone(zone))); }; } + + private static ObjectWriteFunction longTimestampWithTimeZoneWriteFunction() + { + return ObjectWriteFunction.of(LongTimestampWithTimeZone.class, (statement, index, value) -> { + long epochMillis = value.getEpochMillis(); + long epochSeconds = Math.floorDiv(epochMillis, MILLISECONDS_PER_SECOND); + long adjustNanoSeconds = (long) Math.floorMod(epochMillis, MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND + value.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND; + ZoneId zone = TimeZoneKey.getTimeZoneKey(value.getTimeZoneKey()).getZoneId(); + Instant instant = Instant.ofEpochSecond(epochSeconds, adjustNanoSeconds); + statement.setString(index, SNOWFLAKE_DATETIME_FORMATTER.format(ZonedDateTime.ofInstant(instant, zone))); + }); + } }