From 2a79bd0666eded013d6287c751414a58f0366bcf Mon Sep 17 00:00:00 2001
From: Yuya Ebihara <ebyhry@gmail.com>
Date: Thu, 20 Jul 2023 15:03:46 +0900
Subject: [PATCH] Support writing timestamp tz type on partitioned column in
 Delta

Co-Authored-By: alberic <cnuliuweiren@gmail.com>
---
 .../deltalake/TestDeltaLakeConnectorTest.java | 40 +++++++++++++++++
 .../plugin/hive/util/HiveWriteUtils.java      | 14 ++++++
 ...eltaLakeDatabricksInsertCompatibility.java | 44 +++++++++++++++++++
 3 files changed, 98 insertions(+)

diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java
index 3f97d1f0dce0..b2b26af59136 100644
--- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java
+++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeConnectorTest.java
@@ -424,6 +424,46 @@ public void testTimestampPredicatePushdown(String value)
                 results -> {});
     }
 
+    @Test
+    public void testTimestampWithTimeZonePartition()
+    {
+        String tableName = "test_timestamp_tz_partition_" + randomNameSuffix();
+
+        assertUpdate("DROP TABLE IF EXISTS " + tableName);
+        assertUpdate("CREATE TABLE " + tableName + "(id INT, part TIMESTAMP WITH TIME ZONE) WITH (partitioned_by = ARRAY['part'])");
+        assertUpdate(
+                "INSERT INTO " + tableName + " VALUES " +
+                        "(1, NULL)," +
+                        "(2, TIMESTAMP '0001-01-01 00:00:00.000 UTC')," +
+                        "(3, TIMESTAMP '2023-07-20 01:02:03.9999 -01:00')," +
+                        "(4, TIMESTAMP '9999-12-31 23:59:59.999 UTC')",
+                4);
+
+        assertThat(query("SELECT * FROM " + tableName))
+                .matches("VALUES " +
+                        "(1, NULL)," +
+                        "(2, TIMESTAMP '0001-01-01 00:00:00.000 UTC')," +
+                        "(3, TIMESTAMP '2023-07-20 02:02:04.000 UTC')," +
+                        "(4, TIMESTAMP '9999-12-31 23:59:59.999 UTC')");
+        assertQuery(
+                "SHOW STATS FOR " + tableName,
+                "VALUES " +
+                        "('id', null, 4.0, 0.0, null, 1, 4)," +
+                        "('part', null, 3.0, 0.25, null, null, null)," +
+                        "(null, null, null, null, 4.0, null, null)");
+
+        assertThat((String) computeScalar("SELECT \"$path\" FROM " + tableName + " WHERE id = 1"))
+                .contains("/part=__HIVE_DEFAULT_PARTITION__/");
+        assertThat((String) computeScalar("SELECT \"$path\" FROM " + tableName + " WHERE id = 2"))
+                .contains("/part=0001-01-01 00%3A00%3A00/");
+        assertThat((String) computeScalar("SELECT \"$path\" FROM " + tableName + " WHERE id = 3"))
+                .contains("/part=2023-07-20 02%3A02%3A04/");
+        assertThat((String) computeScalar("SELECT \"$path\" FROM " + tableName + " WHERE id = 4"))
+                .contains("/part=9999-12-31 23%3A59%3A59.999/");
+
+        assertUpdate("DROP TABLE " + tableName);
+    }
+
     @DataProvider
     public Object[][] timestampValues()
     {
diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveWriteUtils.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveWriteUtils.java
index 89152c58c3f4..c05d7091470b 100644
--- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveWriteUtils.java
+++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveWriteUtils.java
@@ -52,6 +52,7 @@
 import io.trino.spi.type.MapType;
 import io.trino.spi.type.RowType;
 import io.trino.spi.type.TimestampType;
+import io.trino.spi.type.TimestampWithTimeZoneType;
 import io.trino.spi.type.Type;
 import io.trino.spi.type.VarcharType;
 import org.apache.hadoop.conf.Configuration;
@@ -86,6 +87,7 @@
 import java.util.Optional;
 import java.util.Properties;
 
+import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Verify.verify;
 import static com.google.common.collect.ImmutableList.toImmutableList;
 import static com.google.common.io.BaseEncoding.base16;
@@ -113,11 +115,13 @@
 import static io.trino.spi.type.BigintType.BIGINT;
 import static io.trino.spi.type.BooleanType.BOOLEAN;
 import static io.trino.spi.type.Chars.padSpaces;
+import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc;
 import static io.trino.spi.type.DateType.DATE;
 import static io.trino.spi.type.DoubleType.DOUBLE;
 import static io.trino.spi.type.IntegerType.INTEGER;
 import static io.trino.spi.type.RealType.REAL;
 import static io.trino.spi.type.SmallintType.SMALLINT;
+import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
 import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND;
 import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND;
 import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_SECOND;
@@ -334,6 +338,10 @@ public static Object getField(DateTimeZone localZone, Type type, Block block, in
         if (type instanceof TimestampType timestampType) {
             return getHiveTimestamp(localZone, timestampType, block, position);
         }
+        if (type instanceof TimestampWithTimeZoneType) {
+            checkArgument(type.equals(TIMESTAMP_TZ_MILLIS));
+            return getHiveTimestampTz(block, position);
+        }
         if (type instanceof DecimalType decimalType) {
             return getHiveDecimal(decimalType, block, position);
         }
@@ -780,4 +788,10 @@ private static Timestamp getHiveTimestamp(DateTimeZone localZone, TimestampType
         int nanosOfSecond = microsOfSecond * NANOSECONDS_PER_MICROSECOND + nanosOfMicro;
         return Timestamp.ofEpochSecond(epochSeconds, nanosOfSecond);
     }
+
+    private static Timestamp getHiveTimestampTz(Block block, int position)
+    {
+        long epochMillis = unpackMillisUtc(block.getLong(position, 0));
+        return Timestamp.ofEpochMilli(epochMillis);
+    }
 }
diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksInsertCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksInsertCompatibility.java
index 63eec996c0f2..7cc363362bc0 100644
--- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksInsertCompatibility.java
+++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeDatabricksInsertCompatibility.java
@@ -130,6 +130,50 @@ public void testPartitionedInsertCompatibility()
         }
     }
 
+    @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_OSS, PROFILE_SPECIFIC_TESTS})
+    @Flaky(issue = DATABRICKS_COMMUNICATION_FAILURE_ISSUE, match = DATABRICKS_COMMUNICATION_FAILURE_MATCH)
+    public void testTimestampWithTimeZonePartitionedInsertCompatibility()
+    {
+        String tableName = "test_dl_timestamp_tz_partitioned_insert_" + randomNameSuffix();
+
+        onTrino().executeQuery("" +
+                "CREATE TABLE delta.default." + tableName +
+                "(id INT, part TIMESTAMP WITH TIME ZONE)" +
+                "WITH (partitioned_by = ARRAY['part'], location = 's3://" + bucketName + "/databricks-compatibility-test-" + tableName + "')");
+        try {
+            onDelta().executeQuery("INSERT INTO default." + tableName + " VALUES" +
+                    "(1, TIMESTAMP '0001-01-01 00:00:00.000 UTC')," +
+                    "(2, TIMESTAMP '2023-01-02 01:02:03.999 +01:00')");
+            onTrino().executeQuery("INSERT INTO delta.default." + tableName + " VALUES" +
+                    "(3, TIMESTAMP '2023-03-04 01:02:03.999 -01:00')," +
+                    "(4, TIMESTAMP '9999-12-31 23:59:59.999 UTC')");
+
+            List<Row> expectedRows = ImmutableList.<Row>builder()
+                    .add(row(1, "0001-01-01 00:00:00.000"))
+                    .add(row(2, "2023-01-02 00:02:03.999"))
+                    .add(row(3, "2023-03-04 02:02:03.999"))
+                    .add(row(4, "9999-12-31 23:59:59.999"))
+                    .build();
+
+            assertThat(onDelta().executeQuery("SELECT id, date_format(part, \"yyyy-MM-dd HH:mm:ss.SSS\") FROM default." + tableName))
+                    .containsOnly(expectedRows);
+            assertThat(onTrino().executeQuery("SELECT id, format_datetime(part, 'yyyy-MM-dd HH:mm:ss.SSS') FROM delta.default." + tableName))
+                    .containsOnly(expectedRows);
+
+            assertThat((String) onTrino().executeQuery("SELECT \"$path\" FROM delta.default." + tableName + " WHERE id = 1").getOnlyValue())
+                    .contains("/part=0001-01-01 00%3A00%3A00/");
+            assertThat((String) onTrino().executeQuery("SELECT \"$path\" FROM delta.default." + tableName + " WHERE id = 2").getOnlyValue())
+                    .contains("/part=2023-01-02 00%3A02%3A03.999/");
+            assertThat((String) onTrino().executeQuery("SELECT \"$path\" FROM delta.default." + tableName + " WHERE id = 3").getOnlyValue())
+                    .contains("/part=2023-03-04 02%3A02%3A03.999/");
+            assertThat((String) onTrino().executeQuery("SELECT \"$path\" FROM delta.default." + tableName + " WHERE id = 4").getOnlyValue())
+                    .contains("/part=9999-12-31 23%3A59%3A59.999/");
+        }
+        finally {
+            onTrino().executeQuery("DROP TABLE delta.default." + tableName);
+        }
+    }
+
     @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_OSS, PROFILE_SPECIFIC_TESTS})
     @Flaky(issue = DATABRICKS_COMMUNICATION_FAILURE_ISSUE, match = DATABRICKS_COMMUNICATION_FAILURE_MATCH)
     public void testTrinoPartitionedDifferentOrderInsertCompatibility()