Skip to content

Commit

Permalink
Add MySQL case sensitive collation varchar LIKE push down support
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad-lyutenko committed Aug 1, 2023
1 parent 380484d commit 7a2300b
Show file tree
Hide file tree
Showing 7 changed files with 566 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.plugin.jdbc.CaseSensitivity;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.VarcharType;

import java.util.Optional;

import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type;
import static io.trino.plugin.jdbc.CaseSensitivity.CASE_INSENSITIVE;
import static io.trino.spi.expression.StandardFunctions.LIKE_FUNCTION_NAME;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static java.lang.String.format;

public class RewriteLikeEscapeWithCaseSensitivity
implements ConnectorExpressionRule<Call, ParameterizedExpression>
{
private static final Capture<ConnectorExpression> LIKE_VALUE = newCapture();
private static final Capture<ConnectorExpression> LIKE_PATTERN = newCapture();
private static final Capture<ConnectorExpression> ESCAPE_PATTERN = newCapture();
private static final Pattern<Call> PATTERN = call()
.with(functionName().equalTo(LIKE_FUNCTION_NAME))
.with(type().equalTo(BOOLEAN))
.with(argumentCount().equalTo(3))
.with(argument(0).matching(expression().capturedAs(LIKE_VALUE).with(type().matching(VarcharType.class::isInstance))))
.with(argument(1).matching(expression().capturedAs(LIKE_PATTERN).with(type().matching(VarcharType.class::isInstance))))
.with(argument(2).matching(expression().capturedAs(ESCAPE_PATTERN).with(type().matching(VarcharType.class::isInstance))));

@Override
public Pattern<Call> getPattern()
{
return PATTERN;
}

@Override
public Optional<ParameterizedExpression> rewrite(Call expression, Captures captures, RewriteContext<ParameterizedExpression> context)
{
ConnectorExpression capturedValue = captures.get(LIKE_VALUE);
if (capturedValue instanceof Variable variable) {
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(variable.getName());
Optional<CaseSensitivity> caseSensitivity = columnHandle.getJdbcTypeHandle().getCaseSensitivity();
if (caseSensitivity.orElse(CASE_INSENSITIVE) == CASE_INSENSITIVE) {
return Optional.empty();
}
}
Optional<ParameterizedExpression> value = context.defaultRewrite(capturedValue);
if (value.isEmpty()) {
return Optional.empty();
}

ImmutableList.Builder<QueryParameter> parameters = ImmutableList.builder();
parameters.addAll(value.get().parameters());
Optional<ParameterizedExpression> pattern = context.defaultRewrite(captures.get(LIKE_PATTERN));
if (pattern.isEmpty()) {
return Optional.empty();
}
parameters.addAll(pattern.get().parameters());

Optional<ParameterizedExpression> escape = context.defaultRewrite(captures.get(ESCAPE_PATTERN));
if (escape.isEmpty()) {
return Optional.empty();
}
parameters.addAll(escape.get().parameters());
return Optional.of(new ParameterizedExpression(format("%s LIKE %s ESCAPE %s", value.get().expression(), pattern.get().expression(), escape.get().expression()), parameters.build()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.plugin.jdbc.CaseSensitivity;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.VarcharType;

import java.util.Optional;

import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type;
import static io.trino.plugin.jdbc.CaseSensitivity.CASE_INSENSITIVE;
import static io.trino.spi.expression.StandardFunctions.LIKE_FUNCTION_NAME;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static java.lang.String.format;

public class RewriteLikeWithCaseSensitivity
implements ConnectorExpressionRule<Call, ParameterizedExpression>
{
private static final Capture<ConnectorExpression> LIKE_VALUE = newCapture();
private static final Capture<ConnectorExpression> LIKE_PATTERN = newCapture();
private static final Pattern<Call> PATTERN = call()
.with(functionName().equalTo(LIKE_FUNCTION_NAME))
.with(type().equalTo(BOOLEAN))
.with(argumentCount().equalTo(2))
.with(argument(0).matching(expression().capturedAs(LIKE_VALUE).with(type().matching(VarcharType.class::isInstance))))
.with(argument(1).matching(expression().capturedAs(LIKE_PATTERN).with(type().matching(VarcharType.class::isInstance))));

@Override
public Pattern<Call> getPattern()
{
return PATTERN;
}

@Override
public Optional<ParameterizedExpression> rewrite(Call expression, Captures captures, RewriteContext<ParameterizedExpression> context)
{
ConnectorExpression capturedValue = captures.get(LIKE_VALUE);
if (capturedValue instanceof Variable variable) {
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(variable.getName());
Optional<CaseSensitivity> caseSensitivity = columnHandle.getJdbcTypeHandle().getCaseSensitivity();
if (caseSensitivity.orElse(CASE_INSENSITIVE) == CASE_INSENSITIVE) {
return Optional.empty();
}
}
Optional<ParameterizedExpression> value = context.defaultRewrite(capturedValue);
if (value.isEmpty()) {
return Optional.empty();
}

ImmutableList.Builder<QueryParameter> parameters = ImmutableList.builder();
parameters.addAll(value.get().parameters());
Optional<ParameterizedExpression> pattern = context.defaultRewrite(captures.get(LIKE_PATTERN));
if (pattern.isEmpty()) {
return Optional.empty();
}
parameters.addAll(pattern.get().parameters());
return Optional.of(new ParameterizedExpression(format("%s LIKE %s", value.get().expression(), pattern.get().expression()), parameters.build()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Match;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;

import java.sql.Types;
import java.util.Map;
import java.util.Optional;

import static com.google.common.collect.MoreCollectors.toOptional;
import static io.trino.plugin.jdbc.CaseSensitivity.CASE_SENSITIVE;
import static io.trino.plugin.jdbc.TestingJdbcTypeHandle.JDBC_BIGINT;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static org.assertj.core.api.Assertions.assertThat;

public abstract class BaseTestRewriteLikeWithCaseSensitivity
{
protected abstract ConnectorExpressionRule<Call, ParameterizedExpression> getRewrite();

protected Optional<ParameterizedExpression> apply(Call expression)
{
Optional<Match> match = getRewrite().getPattern().match(expression).collect(toOptional());
if (match.isEmpty()) {
return Optional.empty();
}
return getRewrite().rewrite(expression, match.get().captures(), new ConnectorExpressionRule.RewriteContext<>()
{
@Override
public Map<String, ColumnHandle> getAssignments()
{
return Map.of("case_insensitive_value", new JdbcColumnHandle("case_insensitive_value", JDBC_BIGINT, VARCHAR),
"case_sensitive_value", new JdbcColumnHandle("case_sensitive_value", new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(10), Optional.empty(), Optional.empty(), Optional.of(CASE_SENSITIVE)), VARCHAR));
}

@Override
public ConnectorSession getSession()
{
throw new UnsupportedOperationException();
}

@Override
public Optional<ParameterizedExpression> defaultRewrite(ConnectorExpression expression)
{
if (expression instanceof Variable) {
String name = ((Variable) expression).getName();
return Optional.of(new ParameterizedExpression("\"" + name.replace("\"", "\"\"") + "\"", ImmutableList.of(new QueryParameter(expression.getType(), Optional.of(name)))));
}
return Optional.empty();
}
});
}

protected void assertNoRewrite(Call expression)
{
Optional<ParameterizedExpression> rewritten = apply(expression);
assertThat(rewritten).isEmpty();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.Variable;
import org.testng.annotations.Test;

import java.util.List;
import java.util.Optional;

import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static org.assertj.core.api.Assertions.assertThat;

public class TestRewriteLikeEscapeWithCaseSensitivity
extends BaseTestRewriteLikeWithCaseSensitivity
{
private final RewriteLikeEscapeWithCaseSensitivity rewrite = new RewriteLikeEscapeWithCaseSensitivity();

@Override
protected ConnectorExpressionRule<Call, ParameterizedExpression> getRewrite()
{
return rewrite;
}

@Test
public void testRewriteLikeEscapeCallInvalidNumberOfArguments()
{
Call expression = new Call(
BOOLEAN,
new FunctionName("$like"),
List.of(new Variable("case_sensitive_value", VARCHAR)));

assertNoRewrite(expression);
}

@Test
public void testRewriteLikeEscapeCallInvalidTypeValue()
{
Call expression = new Call(
BOOLEAN,
new FunctionName("$like"),
List.of(
new Variable("case_sensitive_value", BIGINT),
new Variable("pattern", VARCHAR),
new Variable("escape", VARCHAR)));

assertNoRewrite(expression);
}

@Test
public void testRewriteLikeEscapeCallInvalidTypePattern()
{
Call expression = new Call(
BOOLEAN,
new FunctionName("$like"),
List.of(
new Variable("case_sensitive_value", VARCHAR),
new Variable("pattern", BIGINT),
new Variable("escape", VARCHAR)));

assertNoRewrite(expression);
}

@Test
public void testRewriteLikeEscapeCallInvalidTypeEscape()
{
Call expression = new Call(
BOOLEAN,
new FunctionName("$like"),
List.of(
new Variable("case_sensitive_value", VARCHAR),
new Variable("pattern", VARCHAR),
new Variable("escape", BIGINT)));

assertNoRewrite(expression);
}

@Test
public void testRewriteLikeEscapeCallOnCaseInsensitiveValue()
{
Call expression = new Call(
BOOLEAN,
new FunctionName("$like"),
List.of(
new Variable("case_insensitive_value", VARCHAR),
new Variable("pattern", VARCHAR),
new Variable("escape", VARCHAR)));

assertNoRewrite(expression);
}

@Test
public void testRewriteLikeEscapeCallOnCaseSensitiveValue()
{
Call expression = new Call(
BOOLEAN,
new FunctionName("$like"),
List.of(
new Variable("case_sensitive_value", VARCHAR),
new Variable("pattern", VARCHAR),
new Variable("escape", VARCHAR)));

ParameterizedExpression rewritten = apply(expression).orElseThrow();
assertThat(rewritten.expression()).isEqualTo("\"case_sensitive_value\" LIKE \"pattern\" ESCAPE \"escape\"");
assertThat(rewritten.parameters()).isEqualTo(List.of(
new QueryParameter(VARCHAR, Optional.of("case_sensitive_value")),
new QueryParameter(VARCHAR, Optional.of("pattern")),
new QueryParameter(VARCHAR, Optional.of("escape"))));
}
}
Loading

0 comments on commit 7a2300b

Please sign in to comment.