Skip to content

Commit

Permalink
Correct numeric type mapping of Snowflake Connector
Browse files Browse the repository at this point in the history
For fixed-point numbers:
 - All fixed-point numeric types are decimals (NUMBER in Snowflake).
 - INT , INTEGER , BIGINT , SMALLINT , TINYINT , BYTEINT are synonymous
   with NUMBER(38, 0).

For floating-point numbers:
 - FLOAT, FLOAT4, FLOAT8, DOUBLE, DOUBLE PRECISION, REAL are synonymous
   with each other. They're all 64-bit floating-point numbers.
  • Loading branch information
lxynov authored and ebyhr committed May 17, 2024
1 parent 1379bce commit 3e1e31c
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,11 @@
import static io.trino.plugin.jdbc.StandardColumnMappings.decimalColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.doubleColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.doubleWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.integerColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.integerWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.realColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.realWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.shortDecimalWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.smallintColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.smallintWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.toTrinoTimestamp;
import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryColumnMapping;
Expand Down Expand Up @@ -220,19 +216,13 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
switch (type) {
case Types.BOOLEAN:
return Optional.of(booleanColumnMapping());
case Types.TINYINT:
return Optional.of(tinyintColumnMapping());
case Types.SMALLINT:
return Optional.of(smallintColumnMapping());
case Types.INTEGER:
return Optional.of(integerColumnMapping());
// This is kept for synthetic columns generated by count() aggregation pushdown
case Types.BIGINT:
return Optional.of(bigintColumnMapping());
case Types.REAL:
return Optional.of(realColumnMapping());
case Types.FLOAT:
case Types.DOUBLE:
return Optional.of(doubleColumnMapping());

// In Snowflake all fixed-point numeric types are decimals. It always returns a DECIMAL type when JDBC_TREAT_DECIMAL_AS_INT is set to False.
case Types.NUMERIC:
case Types.DECIMAL: {
int precision = typeHandle.requiredColumnSize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,7 @@ public ConnectionFactory getConnectionFactory(BaseJdbcConfig baseJdbcConfig, Sno
snowflakeConfig.getRole().ifPresent(role -> properties.setProperty("role", role));
snowflakeConfig.getWarehouse().ifPresent(warehouse -> properties.setProperty("warehouse", warehouse));

// Set the expected date/time formatting we expect for our plugin to parse
properties.setProperty("TIMESTAMP_OUTPUT_FORMAT", "YYYY-MM-DD\"T\"HH24:MI:SS.FF9TZH:TZM");
properties.setProperty("TIMESTAMP_NTZ_OUTPUT_FORMAT", "YYYY-MM-DD\"T\"HH24:MI:SS.FF9TZH:TZM");
properties.setProperty("TIMESTAMP_TZ_OUTPUT_FORMAT", "YYYY-MM-DD\"T\"HH24:MI:SS.FF9TZH:TZM");
properties.setProperty("TIMESTAMP_LTZ_OUTPUT_FORMAT", "YYYY-MM-DD\"T\"HH24:MI:SS.FF9TZH:TZM");
properties.setProperty("TIME_OUTPUT_FORMAT", "HH24:MI:SS.FF9");
setOutputProperties(properties);

// Support for Corporate proxies
if (snowflakeConfig.getHttpProxy().isPresent()) {
Expand Down Expand Up @@ -100,4 +95,16 @@ public ConnectionFactory getConnectionFactory(BaseJdbcConfig baseJdbcConfig, Sno
.setOpenTelemetry(openTelemetry)
.build();
}

protected static void setOutputProperties(Properties properties)
{
// Set the expected date/time formatting we expect for our plugin to parse
properties.setProperty("TIMESTAMP_OUTPUT_FORMAT", "YYYY-MM-DD\"T\"HH24:MI:SS.FF9TZH:TZM");
properties.setProperty("TIMESTAMP_NTZ_OUTPUT_FORMAT", "YYYY-MM-DD\"T\"HH24:MI:SS.FF9TZH:TZM");
properties.setProperty("TIMESTAMP_TZ_OUTPUT_FORMAT", "YYYY-MM-DD\"T\"HH24:MI:SS.FF9TZH:TZM");
properties.setProperty("TIMESTAMP_LTZ_OUTPUT_FORMAT", "YYYY-MM-DD\"T\"HH24:MI:SS.FF9TZH:TZM");
properties.setProperty("TIME_OUTPUT_FORMAT", "HH24:MI:SS.FF9");
// Don't treat decimals as bigints as they may overflow
properties.setProperty("JDBC_TREAT_DECIMAL_AS_INT", "FALSE");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import com.google.common.collect.ImmutableMap;
import io.trino.plugin.jdbc.BaseJdbcConnectorTest;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.testing.MaterializedResult;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingConnectorBehavior;
Expand Down Expand Up @@ -130,16 +132,16 @@ protected boolean isColumnNameRejected(Exception exception, String columnName, b
@Override
protected MaterializedResult getDescribeOrdersResult()
{
// Override this test because the type of row "shippriority" should be bigint rather than integer for snowflake case
// Override this test because the type of columns "orderkey", "custkey" and "shippriority" should be decimal rather than integer for snowflake case
return resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR)
.row("orderkey", "bigint", "", "")
.row("custkey", "bigint", "", "")
.row("orderkey", "decimal(19,0)", "", "")
.row("custkey", "decimal(19,0)", "", "")
.row("orderstatus", "varchar(1)", "", "")
.row("totalprice", "double", "", "")
.row("orderdate", "date", "", "")
.row("orderpriority", "varchar(15)", "", "")
.row("clerk", "varchar(15)", "", "")
.row("shippriority", "bigint", "", "")
.row("shippriority", "decimal(10,0)", "", "")
.row("comment", "varchar(79)", "", "")
.build();
}
Expand All @@ -164,17 +166,17 @@ public void testViews()
@Override
public void testShowCreateTable()
{
// Override this test because the type of row "shippriority" should be bigint rather than integer for snowflake case
// Override this test because the type of columns "orderkey", "custkey" and "shippriority" should be decimal rather than integer for snowflake case
assertThat(computeActual("SHOW CREATE TABLE orders").getOnlyValue())
.isEqualTo("CREATE TABLE snowflake.tpch.orders (\n" +
" orderkey bigint,\n" +
" custkey bigint,\n" +
" orderkey decimal(19, 0),\n" +
" custkey decimal(19, 0),\n" +
" orderstatus varchar(1),\n" +
" totalprice double,\n" +
" orderdate date,\n" +
" orderpriority varchar(15),\n" +
" clerk varchar(15),\n" +
" shippriority bigint,\n" +
" shippriority decimal(10, 0),\n" +
" comment varchar(79)\n" +
")");
}
Expand Down Expand Up @@ -366,7 +368,7 @@ public void testInformationSchemaFiltering()
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'tpch' AND table_name = 'orders' LIMIT 1",
"SELECT 'orders' table_name");
assertQuery(
"SELECT table_name FROM information_schema.columns WHERE data_type = 'bigint' AND table_schema = 'tpch' AND table_name = 'nation' and column_name = 'nationkey' LIMIT 1",
"SELECT table_name FROM information_schema.columns WHERE data_type = 'decimal(19,0)' AND table_schema = 'tpch' AND table_name = 'nation' and column_name = 'nationkey' LIMIT 1",
"SELECT 'nation' table_name");
}

Expand All @@ -377,4 +379,28 @@ public void testSelectInformationSchemaColumns()
{
// TODO https://github.com/trinodb/trino/issues/21157 Enable this test after fixing the timeout issue
}

@Test
@Override // Override because for approx_set(nationkey) a ProjectNode is present above the TableScanNode. It's used to project decimals to doubles.
public void testAggregationWithUnsupportedResultType()
{
// TODO array_agg returns array, so it could be supported
assertThat(query("SELECT array_agg(nationkey) FROM nation"))
.skipResultsCorrectnessCheckForPushdown() // array_agg doesn't have a deterministic order of elements in result array
.isNotFullyPushedDown(AggregationNode.class);
// histogram returns map, which is not supported
assertThat(query("SELECT histogram(regionkey) FROM nation")).isNotFullyPushedDown(AggregationNode.class);
// multimap_agg returns multimap, which is not supported
assertThat(query("SELECT multimap_agg(regionkey, nationkey) FROM nation"))
.skipResultsCorrectnessCheckForPushdown() // multimap_agg doesn't have a deterministic order of values for a key
.isNotFullyPushedDown(AggregationNode.class);
// approx_set returns HyperLogLog, which is not supported
assertThat(query("SELECT approx_set(nationkey) FROM nation")).isNotFullyPushedDown(AggregationNode.class, ProjectNode.class);
}

@Override // Override because integers are represented as decimals in Snowflake Connector.
protected String sumDistinctAggregationPushdownExpectedResult()
{
return "VALUES (BIGINT '4', DECIMAL '8')";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR;
import static io.trino.plugin.jdbc.UnsupportedTypeHandling.IGNORE;
import static io.trino.plugin.snowflake.SnowflakeQueryRunner.createSnowflakeQueryRunner;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.DecimalType.createDecimalType;
Expand Down Expand Up @@ -112,21 +111,23 @@ public void testInteger()
private void testInteger(String inputType)
{
SqlDataTypeTest.create()
.addRoundTrip(inputType, "-9223372036854775808", BIGINT, "-9223372036854775808")
.addRoundTrip(inputType, "9223372036854775807", BIGINT, "9223372036854775807")
.addRoundTrip(inputType, "0", BIGINT, "CAST(0 AS BIGINT)")
.addRoundTrip(inputType, "NULL", BIGINT, "CAST(NULL AS BIGINT)")
.addRoundTrip(inputType, "'-9223372036854775808'", createDecimalType(38, 0), "CAST('-9223372036854775808' AS decimal(38, 0))")
.addRoundTrip(inputType, "'9223372036854775807'", createDecimalType(38, 0), "CAST('9223372036854775807' AS decimal(38, 0))")
.addRoundTrip(inputType, "'-99999999999999999999999999999999999999'", createDecimalType(38, 0), "CAST('-99999999999999999999999999999999999999' AS decimal(38, 0))")
.addRoundTrip(inputType, "'99999999999999999999999999999999999999'", createDecimalType(38, 0), "CAST('99999999999999999999999999999999999999' AS decimal(38, 0))")
.addRoundTrip(inputType, "0", createDecimalType(38, 0), "CAST(0 AS decimal(38, 0))")
.addRoundTrip(inputType, "NULL", createDecimalType(38, 0), "CAST(NULL AS decimal(38, 0))")
.execute(getQueryRunner(), snowflakeCreateAndInsert("tpch.integer"));
}

@Test
public void testDecimal()
{
SqlDataTypeTest.create()
.addRoundTrip("decimal(3, 0)", "NULL", BIGINT, "CAST(NULL AS BIGINT)")
.addRoundTrip("decimal(3, 0)", "CAST('193' AS decimal(3, 0))", BIGINT, "CAST('193' AS BIGINT)")
.addRoundTrip("decimal(3, 0)", "CAST('19' AS decimal(3, 0))", BIGINT, "CAST('19' AS BIGINT)")
.addRoundTrip("decimal(3, 0)", "CAST('-193' AS decimal(3, 0))", BIGINT, "CAST('-193' AS BIGINT)")
.addRoundTrip("decimal(3, 0)", "NULL", createDecimalType(3, 0), "CAST(NULL AS decimal(3, 0))")
.addRoundTrip("decimal(3, 0)", "CAST('193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('193' AS decimal(3, 0))")
.addRoundTrip("decimal(3, 0)", "CAST('19' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('19' AS decimal(3, 0))")
.addRoundTrip("decimal(3, 0)", "CAST('-193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('-193' AS decimal(3, 0))")
.addRoundTrip("decimal(3, 1)", "CAST('10.0' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.0' AS decimal(3, 1))")
.addRoundTrip("decimal(3, 1)", "CAST('10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.1' AS decimal(3, 1))")
.addRoundTrip("decimal(3, 1)", "CAST('-10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('-10.1' AS decimal(3, 1))")
Expand All @@ -138,7 +139,8 @@ public void testDecimal()
.addRoundTrip("decimal(24, 4)", "CAST('12345678901234567890.31' AS decimal(24, 4))", createDecimalType(24, 4), "CAST('12345678901234567890.31' AS decimal(24, 4))")
.addRoundTrip("decimal(30, 5)", "CAST('3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('3141592653589793238462643.38327' AS decimal(30, 5))")
.addRoundTrip("decimal(30, 5)", "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))")
.addRoundTrip("decimal(38, 0)", "CAST(NULL AS decimal(38, 0))", BIGINT, "CAST(NULL AS BIGINT)")
.addRoundTrip("decimal(38, 0)", "CAST(NULL AS decimal(38, 0))", createDecimalType(38, 0), "CAST(NULL AS decimal(38, 0))")
.addRoundTrip("decimal(38, 0)", "CAST('99999999999999999999999999999999999999' AS decimal(38, 0))", createDecimalType(38, 0), "CAST('99999999999999999999999999999999999999' AS decimal(38, 0))")
.execute(getQueryRunner(), snowflakeCreateAndInsert("tpch.test_decimal"))
.execute(getQueryRunner(), trinoCreateAsSelect("test_decimal"))
.execute(getQueryRunner(), trinoCreateAndInsert("test_decimal"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.sql.Statement;
import java.util.Properties;

import static io.trino.plugin.snowflake.SnowflakeClientModule.setOutputProperties;
import static io.trino.testing.TestingProperties.requiredNonEmptySystemProperty;

public final class TestingSnowflakeServer
Expand Down Expand Up @@ -60,6 +61,7 @@ private static Properties getProperties()
properties.setProperty("schema", TEST_SCHEMA);
properties.setProperty("warehouse", TEST_WAREHOUSE);
properties.setProperty("role", TEST_ROLE);
setOutputProperties(properties);
return properties;
}
}

0 comments on commit 3e1e31c

Please sign in to comment.