Skip to content

Commit

Permalink
Disable incorrect character pushdown in PostgreSQL
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi committed Mar 4, 2021
1 parent 978405b commit 514272e
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.trino.plugin.jdbc.LongWriteFunction;
import io.trino.plugin.jdbc.ObjectReadFunction;
import io.trino.plugin.jdbc.ObjectWriteFunction;
import io.trino.plugin.jdbc.PredicatePushdownController;
import io.trino.plugin.jdbc.ReadFunction;
import io.trino.plugin.jdbc.SliceReadFunction;
import io.trino.plugin.jdbc.SliceWriteFunction;
Expand Down Expand Up @@ -122,16 +123,16 @@
import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRoundingMode;
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR;
import static io.trino.plugin.jdbc.PredicatePushdownController.DISABLE_PUSHDOWN;
import static io.trino.plugin.jdbc.PredicatePushdownController.FULL_PUSHDOWN;
import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.booleanWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.charReadFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.charWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.dateColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.dateWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.decimalColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.defaultCharColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.defaultVarcharColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.doubleColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.doubleWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.fromTrinoTimestamp;
Expand All @@ -147,6 +148,7 @@
import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryColumnMapping;
import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryWriteFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.varcharReadFunction;
import static io.trino.plugin.jdbc.StandardColumnMappings.varcharWriteFunction;
import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.getUnsupportedTypeHandling;
import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR;
Expand All @@ -164,6 +166,7 @@
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.CharType.createCharType;
import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone;
import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc;
import static io.trino.spi.type.DateType.DATE;
Expand All @@ -188,6 +191,8 @@
import static io.trino.spi.type.TypeSignature.mapType;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.spi.type.VarcharType.createUnboundedVarcharType;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.String.format;
Expand Down Expand Up @@ -216,6 +221,22 @@ public class PostgreSqlClient
private final List<String> tableTypes;
private final AggregateFunctionRewriter aggregateFunctionRewriter;

private static final PredicatePushdownController POSTGRESQL_CHARACTER_PUSHDOWN = (session, domain) -> {
checkArgument(
domain.getType() instanceof VarcharType || domain.getType() instanceof CharType,
"This PredicatePushdownController can be used only for chars and varchars");

if (domain.isOnlyNull() ||
// PostgreSQL is case sensitive by default
domain.getValues().isDiscreteSet()) {
return FULL_PUSHDOWN.apply(session, domain);
}

// PostgreSQL by default orders lowercase letters before uppercase, which is different from Trino
// TODO We could still push the predicates down if we could inject a PostgreSQL-specific syntax for selecting a collation for given comparison.
return DISABLE_PUSHDOWN.apply(session, domain);
};

@Inject
public PostgreSqlClient(
BaseJdbcConfig config,
Expand Down Expand Up @@ -455,14 +476,14 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
}

case Types.CHAR:
return Optional.of(defaultCharColumnMapping(typeHandle.getRequiredColumnSize(), true));
return Optional.of(charColumnMapping(typeHandle.getRequiredColumnSize()));

case Types.VARCHAR:
if (!jdbcTypeName.equals("varchar")) {
// This can be e.g. an ENUM
return Optional.of(typedVarcharColumnMapping(jdbcTypeName));
}
return Optional.of(defaultVarcharColumnMapping(typeHandle.getRequiredColumnSize(), true));
return Optional.of(varcharColumnMapping(typeHandle.getRequiredColumnSize()));

case Types.BINARY:
return Optional.of(varbinaryColumnMapping());
Expand Down Expand Up @@ -689,6 +710,31 @@ public boolean isLimitGuaranteed(ConnectorSession session)
return true;
}

private static ColumnMapping charColumnMapping(int charLength)
{
if (charLength > CharType.MAX_LENGTH) {
return varcharColumnMapping(charLength);
}
CharType charType = createCharType(charLength);
return ColumnMapping.sliceMapping(
charType,
charReadFunction(charType),
charWriteFunction(),
POSTGRESQL_CHARACTER_PUSHDOWN);
}

private static ColumnMapping varcharColumnMapping(int varcharLength)
{
VarcharType varcharType = varcharLength <= VarcharType.MAX_LENGTH
? createVarcharType(varcharLength)
: createUnboundedVarcharType();
return ColumnMapping.sliceMapping(
varcharType,
varcharReadFunction(varcharType),
varcharWriteFunction(),
POSTGRESQL_CHARACTER_PUSHDOWN);
}

private static ColumnMapping timeColumnMapping(int precision)
{
verify(precision <= 6, "Unsupported precision: %s", precision); // PostgreSQL limit but also assumption within this method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import io.trino.testing.sql.JdbcSqlExecutor;
import io.trino.testing.sql.TestTable;
import org.intellij.lang.annotations.Language;
import org.testng.SkipException;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -245,7 +244,7 @@ public void testPredicatePushdown()
// varchar range
assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE name BETWEEN 'POLAND' AND 'RPA'"))
.matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))")
.isFullyPushedDown();
.isNotFullyPushedDown(FilterNode.class);

// varchar different case
assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE name = 'romania'"))
Expand Down Expand Up @@ -475,7 +474,7 @@ public void testAggregationPushdown()

// GROUP BY and WHERE on varchar column
// GROUP BY and WHERE on "other" (not aggregation key, not aggregation input)
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isFullyPushedDown();
assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isNotFullyPushedDown(FilterNode.class);

// GROUP BY above WHERE and LIMIT
assertThat(query("" +
Expand Down Expand Up @@ -750,7 +749,7 @@ public void testLimitPushdown()
assertThat(query("SELECT name FROM nation WHERE regionkey = 3 LIMIT 5")).isFullyPushedDown();

// with filter over varchar column
assertThat(query("SELECT name FROM nation WHERE name < 'EEE' LIMIT 5")).isFullyPushedDown();
assertThat(query("SELECT name FROM nation WHERE name < 'EEE' LIMIT 5")).isNotFullyPushedDown(FilterNode.class);

// with aggregation
assertThat(query("SELECT max(regionkey) FROM nation LIMIT 5")).isFullyPushedDown(); // global aggregation, LIMIT removed
Expand All @@ -759,7 +758,7 @@ public void testLimitPushdown()

// with filter and aggregation
assertThat(query("SELECT regionkey, count(*) FROM nation WHERE nationkey < 5 GROUP BY regionkey LIMIT 3")).isFullyPushedDown();
assertThat(query("SELECT regionkey, count(*) FROM nation WHERE name < 'EGYPT' GROUP BY regionkey LIMIT 3")).isFullyPushedDown();
assertThat(query("SELECT regionkey, count(*) FROM nation WHERE name < 'EGYPT' GROUP BY regionkey LIMIT 3")).isNotFullyPushedDown(FilterNode.class);

// with TopN
assertThat(query("SELECT * FROM (SELECT regionkey FROM nation ORDER BY name ASC LIMIT 10) LIMIT 5")).isFullyPushedDown();
Expand Down Expand Up @@ -811,13 +810,6 @@ public void testTimestampColumnAndTimestampWithTimeZoneConstant()
}
}

@Override
public void testCaseSensitiveDataMapping(DataMappingTestSetup dataMappingTestSetup)
{
// TODO - https://github.com/trinodb/trino/issues/3645
throw new SkipException("PostgreSQL has different collation than Trino");
}

private String getLongInClause(int start, int length)
{
String longValues = range(start, start + length)
Expand Down

0 comments on commit 514272e

Please sign in to comment.