From 8c00d95e4e5af99a0725462f1933ab8028c45d6b Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Wed, 5 Jan 2022 12:59:02 +0100 Subject: [PATCH] Move collation-aware pushdown to query builder --- .../plugin/jdbc/DefaultQueryBuilder.java | 8 +- .../CollationAwareQueryBuilder.java | 27 +++-- .../plugin/postgresql/PostgreSqlClient.java | 109 ++---------------- 3 files changed, 31 insertions(+), 113 deletions(-) diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java index 704f6e1ebfe9..47befb647f06 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java @@ -336,10 +336,10 @@ protected String toPredicate(JdbcClient client, ConnectorSession session, Connec else { List rangeConjuncts = new ArrayList<>(); if (!range.isLowUnbounded()) { - rangeConjuncts.add(toPredicate(client, column, jdbcType, type, writeFunction, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), accumulator)); + rangeConjuncts.add(toPredicate(client, session, column, jdbcType, type, writeFunction, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), accumulator)); } if (!range.isHighUnbounded()) { - rangeConjuncts.add(toPredicate(client, column, jdbcType, type, writeFunction, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), accumulator)); + rangeConjuncts.add(toPredicate(client, session, column, jdbcType, type, writeFunction, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), accumulator)); } // If rangeConjuncts is null, then the range was ALL, which should already have been checked for checkState(!rangeConjuncts.isEmpty()); @@ -354,7 +354,7 @@ protected String toPredicate(JdbcClient client, ConnectorSession session, Connec // Add back all of the possible single values either as an equality or an IN predicate if (singleValues.size() == 1) { - disjuncts.add(toPredicate(client, column, jdbcType, type, writeFunction, "=", getOnlyElement(singleValues), accumulator)); + disjuncts.add(toPredicate(client, session, column, jdbcType, type, writeFunction, "=", getOnlyElement(singleValues), accumulator)); } else if (singleValues.size() > 1) { for (Object value : singleValues) { @@ -371,7 +371,7 @@ else if (singleValues.size() > 1) { return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; } - protected String toPredicate(JdbcClient client, JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer accumulator) + protected String toPredicate(JdbcClient client, ConnectorSession session, JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer accumulator) { accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value))); return format("%s %s %s", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression()); diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java index 52cbfb9604cb..118d02dbb8a2 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java @@ -17,12 +17,18 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcJoinCondition; -import io.trino.spi.type.CharType; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.WriteFunction; +import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.Type; -import io.trino.spi.type.VarcharType; +import java.util.Optional; +import java.util.function.Consumer; import java.util.stream.Stream; +import static io.trino.plugin.postgresql.PostgreSqlClient.isCollatable; +import static io.trino.plugin.postgresql.PostgreSqlSessionProperties.isEnableStringPushdownWithCollate; import static java.lang.String.format; public class CollationAwareQueryBuilder @@ -31,11 +37,10 @@ public class CollationAwareQueryBuilder @Override protected String formatJoinCondition(JdbcClient client, String leftRelationAlias, String rightRelationAlias, JdbcJoinCondition condition) { - boolean needsCollation = Stream.of(condition.getLeftColumn(), condition.getRightColumn()) - .map(JdbcColumnHandle::getColumnType) - .anyMatch(CollationAwareQueryBuilder::isCharType); + boolean isCollatable = Stream.of(condition.getLeftColumn(), condition.getRightColumn()) + .anyMatch(PostgreSqlClient::isCollatable); - if (needsCollation) { + if (isCollatable) { return format( "%s.%s COLLATE \"C\" %s %s.%s COLLATE \"C\"", leftRelationAlias, @@ -48,8 +53,14 @@ protected String formatJoinCondition(JdbcClient client, String leftRelationAlias return super.formatJoinCondition(client, leftRelationAlias, rightRelationAlias, condition); } - private static boolean isCharType(Type type) + @Override + protected String toPredicate(JdbcClient client, ConnectorSession session, JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer accumulator) { - return type instanceof CharType || type instanceof VarcharType; + if (isCollatable(column) && isEnableStringPushdownWithCollate(session)) { + accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value))); + return format("%s %s %s COLLATE \"C\"", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression()); + } + + return super.toPredicate(client, session, column, jdbcType, type, writeFunction, operator, value, accumulator); } } diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 191bd672e2a9..e2299ba31e26 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -236,22 +236,13 @@ public class PostgreSqlClient private final List tableTypes; private final AggregateFunctionRewriter aggregateFunctionRewriter; - private static final PredicatePushdownController POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE = (session, domain) -> { - checkArgument( - domain.getType() instanceof VarcharType || domain.getType() instanceof CharType, - "This PredicatePushdownController can be used only for chars and varchars"); - + private static final PredicatePushdownController POSTGRESQL_STRING_COLLATION_AWARE_PUSHDOWN = (session, domain) -> { if (domain.isOnlyNull()) { return FULL_PUSHDOWN.apply(session, domain); } - // PostgreSQL is case sensitive by default - // PostgreSQL by default orders lowercase letters before uppercase, which is different from Trino - // TODO We could still push the predicates down if we could inject a PostgreSQL-specific syntax for selecting a collation for given comparison. - if (!domain.getValues().isDiscreteSet()) { - // Push down of range predicate for varchar/char types could lead to incorrect results - // due to different sort ordering of lowercase and uppercase letters in PostgreSQL - return DISABLE_PUSHDOWN.apply(session, domain); + if (isEnableStringPushdownWithCollate(session)) { + return FULL_PUSHDOWN.apply(session, domain); } Domain simplifiedDomain = domain.simplify(getDomainCompactionThreshold(session)); @@ -509,22 +500,13 @@ public Optional toColumnMapping(ConnectorSession session, Connect } case Types.CHAR: - if (isEnableStringPushdownWithCollate(session)) { - return Optional.of(charColumnMappingWithCollate(typeHandle.getRequiredColumnSize())); - } return Optional.of(charColumnMapping(typeHandle.getRequiredColumnSize())); case Types.VARCHAR: if (!jdbcTypeName.equals("varchar")) { // This can be e.g. an ENUM - if (isCollatable(jdbcTypeName) && isEnableStringPushdownWithCollate(session)) { - return Optional.of(typedVarcharColumnMappingWithCollate(jdbcTypeName)); - } return Optional.of(typedVarcharColumnMapping(jdbcTypeName)); } - if (isCollatable(jdbcTypeName) && isEnableStringPushdownWithCollate(session)) { - return Optional.of(varcharColumnMappingWithCollate(typeHandle.getRequiredColumnSize())); - } return Optional.of(varcharColumnMapping(typeHandle.getRequiredColumnSize())); case Types.BINARY: @@ -759,7 +741,7 @@ protected Optional topNFunction() }); } - private boolean isCollatable(JdbcColumnHandle column) + protected static boolean isCollatable(JdbcColumnHandle column) { if (column.getColumnType() instanceof CharType || column.getColumnType() instanceof VarcharType) { String jdbcTypeName = column.getJdbcTypeHandle().getJdbcTypeName() @@ -771,7 +753,7 @@ private boolean isCollatable(JdbcColumnHandle column) return false; } - private boolean isCollatable(String jdbcTypeName) + private static boolean isCollatable(String jdbcTypeName) { // Only char (internally named bpchar)/varchar/text are the built-in collatable types return "bpchar".equals(jdbcTypeName) || "varchar".equals(jdbcTypeName) || "text".equals(jdbcTypeName); @@ -852,20 +834,7 @@ private static ColumnMapping charColumnMapping(int charLength) charType, charReadFunction(charType), charWriteFunction(), - POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE); - } - - private static ColumnMapping charColumnMappingWithCollate(int charLength) - { - if (charLength > CharType.MAX_LENGTH) { - return varcharColumnMappingWithCollate(charLength); - } - CharType charType = createCharType(charLength); - return ColumnMapping.sliceMapping( - charType, - charReadFunction(charType), - stringWriteFunctionWithCollate(), - FULL_PUSHDOWN); + POSTGRESQL_STRING_COLLATION_AWARE_PUSHDOWN); } private static ColumnMapping varcharColumnMapping(int varcharLength) @@ -877,38 +846,7 @@ private static ColumnMapping varcharColumnMapping(int varcharLength) varcharType, varcharReadFunction(varcharType), varcharWriteFunction(), - POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE); - } - - private static ColumnMapping varcharColumnMappingWithCollate(int varcharLength) - { - VarcharType varcharType = varcharLength <= VarcharType.MAX_LENGTH - ? createVarcharType(varcharLength) - : createUnboundedVarcharType(); - return ColumnMapping.sliceMapping( - varcharType, - varcharReadFunction(varcharType), - stringWriteFunctionWithCollate(), - FULL_PUSHDOWN); - } - - private static SliceWriteFunction stringWriteFunctionWithCollate() - { - return new SliceWriteFunction() - { - @Override - public String getBindExpression() - { - return "? COLLATE \"C\""; - } - - @Override - public void set(PreparedStatement statement, int index, Slice value) - throws SQLException - { - statement.setString(index, value.toStringUtf8()); - } - }; + POSTGRESQL_STRING_COLLATION_AWARE_PUSHDOWN); } private static ColumnMapping timeColumnMapping(int precision) @@ -1225,16 +1163,7 @@ private static ColumnMapping typedVarcharColumnMapping(String jdbcTypeName) VARCHAR, (resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex)), typedVarcharWriteFunction(jdbcTypeName), - POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE); - } - - private static ColumnMapping typedVarcharColumnMappingWithCollate(String jdbcTypeName) - { - return ColumnMapping.sliceMapping( - VARCHAR, - (resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex)), - typedVarcharWriteFunctionWithCollate(jdbcTypeName), - FULL_PUSHDOWN); + POSTGRESQL_STRING_COLLATION_AWARE_PUSHDOWN); } private static SliceWriteFunction typedVarcharWriteFunction(String jdbcTypeName) @@ -1258,28 +1187,6 @@ public void set(PreparedStatement statement, int index, Slice value) }; } - private static SliceWriteFunction typedVarcharWriteFunctionWithCollate(String jdbcTypeName) - { - String collation = "COLLATE \"C\""; - String bindExpression = format("CAST(? AS %s) %s", requireNonNull(jdbcTypeName, "jdbcTypeName is null"), collation); - - return new SliceWriteFunction() - { - @Override - public String getBindExpression() - { - return bindExpression; - } - - @Override - public void set(PreparedStatement statement, int index, Slice value) - throws SQLException - { - statement.setString(index, value.toStringUtf8()); - } - }; - } - private static ColumnMapping moneyColumnMapping() { /*