Skip to content

Commit

Permalink
Move collation-aware pushdown to query builder
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Feb 23, 2022
1 parent 9c3c6a7 commit 8c00d95
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,10 @@ protected String toPredicate(JdbcClient client, ConnectorSession session, Connec
else {
List<String> rangeConjuncts = new ArrayList<>();
if (!range.isLowUnbounded()) {
rangeConjuncts.add(toPredicate(client, column, jdbcType, type, writeFunction, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), accumulator));
rangeConjuncts.add(toPredicate(client, session, column, jdbcType, type, writeFunction, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), accumulator));
}
if (!range.isHighUnbounded()) {
rangeConjuncts.add(toPredicate(client, column, jdbcType, type, writeFunction, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), accumulator));
rangeConjuncts.add(toPredicate(client, session, column, jdbcType, type, writeFunction, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), accumulator));
}
// If rangeConjuncts is null, then the range was ALL, which should already have been checked for
checkState(!rangeConjuncts.isEmpty());
Expand All @@ -354,7 +354,7 @@ protected String toPredicate(JdbcClient client, ConnectorSession session, Connec

// Add back all of the possible single values either as an equality or an IN predicate
if (singleValues.size() == 1) {
disjuncts.add(toPredicate(client, column, jdbcType, type, writeFunction, "=", getOnlyElement(singleValues), accumulator));
disjuncts.add(toPredicate(client, session, column, jdbcType, type, writeFunction, "=", getOnlyElement(singleValues), accumulator));
}
else if (singleValues.size() > 1) {
for (Object value : singleValues) {
Expand All @@ -371,7 +371,7 @@ else if (singleValues.size() > 1) {
return "(" + Joiner.on(" OR ").join(disjuncts) + ")";
}

protected String toPredicate(JdbcClient client, JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer<QueryParameter> accumulator)
protected String toPredicate(JdbcClient client, ConnectorSession session, JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer<QueryParameter> accumulator)
{
accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value)));
return format("%s %s %s", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcJoinCondition;
import io.trino.spi.type.CharType;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.plugin.jdbc.WriteFunction;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;

import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.Stream;

import static io.trino.plugin.postgresql.PostgreSqlClient.isCollatable;
import static io.trino.plugin.postgresql.PostgreSqlSessionProperties.isEnableStringPushdownWithCollate;
import static java.lang.String.format;

public class CollationAwareQueryBuilder
Expand All @@ -31,11 +37,10 @@ public class CollationAwareQueryBuilder
@Override
protected String formatJoinCondition(JdbcClient client, String leftRelationAlias, String rightRelationAlias, JdbcJoinCondition condition)
{
boolean needsCollation = Stream.of(condition.getLeftColumn(), condition.getRightColumn())
.map(JdbcColumnHandle::getColumnType)
.anyMatch(CollationAwareQueryBuilder::isCharType);
boolean isCollatable = Stream.of(condition.getLeftColumn(), condition.getRightColumn())
.anyMatch(PostgreSqlClient::isCollatable);

if (needsCollation) {
if (isCollatable) {
return format(
"%s.%s COLLATE \"C\" %s %s.%s COLLATE \"C\"",
leftRelationAlias,
Expand All @@ -48,8 +53,14 @@ protected String formatJoinCondition(JdbcClient client, String leftRelationAlias
return super.formatJoinCondition(client, leftRelationAlias, rightRelationAlias, condition);
}

private static boolean isCharType(Type type)
@Override
protected String toPredicate(JdbcClient client, ConnectorSession session, JdbcColumnHandle column, JdbcTypeHandle jdbcType, Type type, WriteFunction writeFunction, String operator, Object value, Consumer<QueryParameter> accumulator)
{
return type instanceof CharType || type instanceof VarcharType;
if (isCollatable(column) && isEnableStringPushdownWithCollate(session)) {
accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value)));
return format("%s %s %s COLLATE \"C\"", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression());
}

return super.toPredicate(client, session, column, jdbcType, type, writeFunction, operator, value, accumulator);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -236,22 +236,13 @@ public class PostgreSqlClient
private final List<String> tableTypes;
private final AggregateFunctionRewriter<JdbcExpression> aggregateFunctionRewriter;

private static final PredicatePushdownController POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE = (session, domain) -> {
checkArgument(
domain.getType() instanceof VarcharType || domain.getType() instanceof CharType,
"This PredicatePushdownController can be used only for chars and varchars");

private static final PredicatePushdownController POSTGRESQL_STRING_COLLATION_AWARE_PUSHDOWN = (session, domain) -> {
if (domain.isOnlyNull()) {
return FULL_PUSHDOWN.apply(session, domain);
}

// PostgreSQL is case sensitive by default
// PostgreSQL by default orders lowercase letters before uppercase, which is different from Trino
// TODO We could still push the predicates down if we could inject a PostgreSQL-specific syntax for selecting a collation for given comparison.
if (!domain.getValues().isDiscreteSet()) {
// Push down of range predicate for varchar/char types could lead to incorrect results
// due to different sort ordering of lowercase and uppercase letters in PostgreSQL
return DISABLE_PUSHDOWN.apply(session, domain);
if (isEnableStringPushdownWithCollate(session)) {
return FULL_PUSHDOWN.apply(session, domain);
}

Domain simplifiedDomain = domain.simplify(getDomainCompactionThreshold(session));
Expand Down Expand Up @@ -509,22 +500,13 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
}

case Types.CHAR:
if (isEnableStringPushdownWithCollate(session)) {
return Optional.of(charColumnMappingWithCollate(typeHandle.getRequiredColumnSize()));
}
return Optional.of(charColumnMapping(typeHandle.getRequiredColumnSize()));

case Types.VARCHAR:
if (!jdbcTypeName.equals("varchar")) {
// This can be e.g. an ENUM
if (isCollatable(jdbcTypeName) && isEnableStringPushdownWithCollate(session)) {
return Optional.of(typedVarcharColumnMappingWithCollate(jdbcTypeName));
}
return Optional.of(typedVarcharColumnMapping(jdbcTypeName));
}
if (isCollatable(jdbcTypeName) && isEnableStringPushdownWithCollate(session)) {
return Optional.of(varcharColumnMappingWithCollate(typeHandle.getRequiredColumnSize()));
}
return Optional.of(varcharColumnMapping(typeHandle.getRequiredColumnSize()));

case Types.BINARY:
Expand Down Expand Up @@ -759,7 +741,7 @@ protected Optional<TopNFunction> topNFunction()
});
}

private boolean isCollatable(JdbcColumnHandle column)
protected static boolean isCollatable(JdbcColumnHandle column)
{
if (column.getColumnType() instanceof CharType || column.getColumnType() instanceof VarcharType) {
String jdbcTypeName = column.getJdbcTypeHandle().getJdbcTypeName()
Expand All @@ -771,7 +753,7 @@ private boolean isCollatable(JdbcColumnHandle column)
return false;
}

private boolean isCollatable(String jdbcTypeName)
private static boolean isCollatable(String jdbcTypeName)
{
// Only char (internally named bpchar)/varchar/text are the built-in collatable types
return "bpchar".equals(jdbcTypeName) || "varchar".equals(jdbcTypeName) || "text".equals(jdbcTypeName);
Expand Down Expand Up @@ -852,20 +834,7 @@ private static ColumnMapping charColumnMapping(int charLength)
charType,
charReadFunction(charType),
charWriteFunction(),
POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE);
}

private static ColumnMapping charColumnMappingWithCollate(int charLength)
{
if (charLength > CharType.MAX_LENGTH) {
return varcharColumnMappingWithCollate(charLength);
}
CharType charType = createCharType(charLength);
return ColumnMapping.sliceMapping(
charType,
charReadFunction(charType),
stringWriteFunctionWithCollate(),
FULL_PUSHDOWN);
POSTGRESQL_STRING_COLLATION_AWARE_PUSHDOWN);
}

private static ColumnMapping varcharColumnMapping(int varcharLength)
Expand All @@ -877,38 +846,7 @@ private static ColumnMapping varcharColumnMapping(int varcharLength)
varcharType,
varcharReadFunction(varcharType),
varcharWriteFunction(),
POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE);
}

private static ColumnMapping varcharColumnMappingWithCollate(int varcharLength)
{
VarcharType varcharType = varcharLength <= VarcharType.MAX_LENGTH
? createVarcharType(varcharLength)
: createUnboundedVarcharType();
return ColumnMapping.sliceMapping(
varcharType,
varcharReadFunction(varcharType),
stringWriteFunctionWithCollate(),
FULL_PUSHDOWN);
}

private static SliceWriteFunction stringWriteFunctionWithCollate()
{
return new SliceWriteFunction()
{
@Override
public String getBindExpression()
{
return "? COLLATE \"C\"";
}

@Override
public void set(PreparedStatement statement, int index, Slice value)
throws SQLException
{
statement.setString(index, value.toStringUtf8());
}
};
POSTGRESQL_STRING_COLLATION_AWARE_PUSHDOWN);
}

private static ColumnMapping timeColumnMapping(int precision)
Expand Down Expand Up @@ -1225,16 +1163,7 @@ private static ColumnMapping typedVarcharColumnMapping(String jdbcTypeName)
VARCHAR,
(resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex)),
typedVarcharWriteFunction(jdbcTypeName),
POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE);
}

private static ColumnMapping typedVarcharColumnMappingWithCollate(String jdbcTypeName)
{
return ColumnMapping.sliceMapping(
VARCHAR,
(resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex)),
typedVarcharWriteFunctionWithCollate(jdbcTypeName),
FULL_PUSHDOWN);
POSTGRESQL_STRING_COLLATION_AWARE_PUSHDOWN);
}

private static SliceWriteFunction typedVarcharWriteFunction(String jdbcTypeName)
Expand All @@ -1258,28 +1187,6 @@ public void set(PreparedStatement statement, int index, Slice value)
};
}

private static SliceWriteFunction typedVarcharWriteFunctionWithCollate(String jdbcTypeName)
{
String collation = "COLLATE \"C\"";
String bindExpression = format("CAST(? AS %s) %s", requireNonNull(jdbcTypeName, "jdbcTypeName is null"), collation);

return new SliceWriteFunction()
{
@Override
public String getBindExpression()
{
return bindExpression;
}

@Override
public void set(PreparedStatement statement, int index, Slice value)
throws SQLException
{
statement.setString(index, value.toStringUtf8());
}
};
}

private static ColumnMapping moneyColumnMapping()
{
/*
Expand Down

0 comments on commit 8c00d95

Please sign in to comment.