diff --git a/docs/reference/sql/functions/operators.asciidoc b/docs/reference/sql/functions/operators.asciidoc index 9c90d12320ed0..aae9d47ec7e0d 100644 --- a/docs/reference/sql/functions/operators.asciidoc +++ b/docs/reference/sql/functions/operators.asciidoc @@ -3,7 +3,7 @@ [[sql-operators]] === Comparison Operators -Boolean operator for comparing one or two expressions. +Boolean operator for comparing against one or multiple expressions. * Equality (`=`) @@ -40,6 +40,13 @@ include-tagged::{sql-specs}/filter.sql-spec[whereBetween] include-tagged::{sql-specs}/filter.sql-spec[whereIsNotNullAndIsNull] -------------------------------------------------- +* `IN (, , ...)` + +["source","sql",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{sql-specs}/filter.sql-spec[whereWithInAndMultipleValues] +-------------------------------------------------- + [[sql-operators-logical]] === Logical Operators diff --git a/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/type/DataType.java b/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/type/DataType.java index 4087a81a424fa..1c08c6e1c9fa1 100644 --- a/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/type/DataType.java +++ b/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/type/DataType.java @@ -225,4 +225,16 @@ public static DataType fromODBCType(String odbcType) { public static DataType fromEsType(String esType) { return DataType.valueOf(esType.toUpperCase(Locale.ROOT)); } + + public boolean isCompatibleWith(DataType other) { + if (this == other) { + return true; + } else if (isString() && other.isString()) { + return true; + } else if (isNumeric() && other.isNumeric()) { + return true; + } else { + return false; + } + } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java index 4915a25a55bc7..e5ab3ce082b71 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.sql.expression.function.Functions; import org.elasticsearch.xpack.sql.expression.function.Score; import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction; +import org.elasticsearch.xpack.sql.expression.predicate.In; import org.elasticsearch.xpack.sql.plan.logical.Aggregate; import org.elasticsearch.xpack.sql.plan.logical.Distinct; import org.elasticsearch.xpack.sql.plan.logical.Filter; @@ -40,7 +41,9 @@ import static java.lang.String.format; -abstract class Verifier { +final class Verifier { + + private Verifier() {} static class Failure { private final Node source; @@ -188,6 +191,8 @@ static Collection verify(LogicalPlan plan) { Set localFailures = new LinkedHashSet<>(); + validateInExpression(p, localFailures); + if (!groupingFailures.contains(p)) { checkGroupBy(p, localFailures, resolvedFunctions, groupingFailures); } @@ -488,4 +493,19 @@ private static void checkNestedUsedInGroupByOrHaving(LogicalPlan p, Set fail(nested.get(0), "HAVING isn't (yet) compatible with nested fields " + new AttributeSet(nested).names())); } } -} \ No newline at end of file + + private static void validateInExpression(LogicalPlan p, Set localFailures) { + p.forEachExpressions(e -> + e.forEachUp((In in) -> { + DataType dt = in.value().dataType(); + for (Expression value : in.list()) { + if (!in.value().dataType().isCompatibleWith(value.dataType())) { + localFailures.add(fail(value, "expected data type [%s], value provided is of type [%s]", + dt, value.dataType())); + return; + } + } + }, + In.class)); + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expressions.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expressions.java index c95e08f087dda..e9a37240be091 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expressions.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expressions.java @@ -67,6 +67,15 @@ public static boolean nullable(List exps) { return true; } + public static boolean foldable(List exps) { + for (Expression exp : exps) { + if (!exp.foldable()) { + return false; + } + } + return true; + } + public static AttributeSet references(List exps) { if (exps.isEmpty()) { return AttributeSet.EMPTY; diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/pipeline/Pipe.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/pipeline/Pipe.java index 4d1604ff535d3..5c96d2c9244ab 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/pipeline/Pipe.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/pipeline/Pipe.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.sql.expression.gen.pipeline; +import org.elasticsearch.xpack.sql.capabilities.Resolvable; import org.elasticsearch.xpack.sql.execution.search.FieldExtraction; import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Expression; @@ -24,7 +25,7 @@ * Is an {@code Add} operator with left {@code ABS} over an aggregate (MAX), and * right being a {@code CAST} function. */ -public abstract class Pipe extends Node implements FieldExtraction { +public abstract class Pipe extends Node implements FieldExtraction, Resolvable { private final Expression expression; @@ -37,8 +38,6 @@ public Expression expression() { return expression; } - public abstract boolean resolved(); - public abstract Processor asProcessor(); /** diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/In.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/In.java index fb04f6d438a91..a820833d1a013 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/In.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/In.java @@ -5,43 +5,55 @@ */ package org.elasticsearch.xpack.sql.expression.predicate; -import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Expression; +import org.elasticsearch.xpack.sql.expression.Expressions; import org.elasticsearch.xpack.sql.expression.NamedExpression; +import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute; +import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; +import org.elasticsearch.xpack.sql.expression.gen.script.Params; +import org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; +import org.elasticsearch.xpack.sql.expression.gen.script.ScriptWeaver; +import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Comparisons; +import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.InPipe; import org.elasticsearch.xpack.sql.tree.Location; import org.elasticsearch.xpack.sql.tree.NodeInfo; import org.elasticsearch.xpack.sql.type.DataType; import org.elasticsearch.xpack.sql.util.CollectionUtils; +import java.util.ArrayList; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Locale; import java.util.Objects; +import java.util.StringJoiner; +import java.util.stream.Collectors; -public class In extends NamedExpression { +import static java.lang.String.format; +import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder; + +public class In extends NamedExpression implements ScriptWeaver { private final Expression value; private final List list; - private final boolean nullable, foldable; + private Attribute lazyAttribute; public In(Location location, Expression value, List list) { super(location, null, CollectionUtils.combine(list, value), null); this.value = value; - this.list = list; - - this.nullable = children().stream().anyMatch(Expression::nullable); - this.foldable = children().stream().allMatch(Expression::foldable); + this.list = list.stream().distinct().collect(Collectors.toList()); } @Override protected NodeInfo info() { - return NodeInfo.create(this, In::new, value(), list()); + return NodeInfo.create(this, In::new, value, list); } @Override public Expression replaceChildren(List newChildren) { - if (newChildren.size() < 1) { - throw new IllegalArgumentException("expected one or more children but received [" + newChildren.size() + "]"); + if (newChildren.size() < 2) { + throw new IllegalArgumentException("expected at least [2] children but received [" + newChildren.size() + "]"); } return new In(location(), newChildren.get(newChildren.size() - 1), newChildren.subList(0, newChildren.size() - 1)); } @@ -61,22 +73,75 @@ public DataType dataType() { @Override public boolean nullable() { - return nullable; + return Expressions.nullable(children()); } @Override public boolean foldable() { - return foldable; + return Expressions.foldable(children()); + } + + @Override + public Object fold() { + Object foldedLeftValue = value.fold(); + + for (Expression rightValue : list) { + Boolean compResult = Comparisons.eq(foldedLeftValue, rightValue.fold()); + if (compResult != null && compResult) { + return true; + } + } + return false; + } + + @Override + public String name() { + StringJoiner sj = new StringJoiner(", ", " IN(", ")"); + list.forEach(e -> sj.add(Expressions.name(e))); + return Expressions.name(value) + sj.toString(); } @Override public Attribute toAttribute() { - throw new SqlIllegalArgumentException("not implemented yet"); + if (lazyAttribute == null) { + lazyAttribute = new ScalarFunctionAttribute(location(), name(), dataType(), null, + false, id(), false, "IN", asScript(), null, asPipe()); + } + return lazyAttribute; } @Override public ScriptTemplate asScript() { - throw new SqlIllegalArgumentException("not implemented yet"); + StringJoiner sj = new StringJoiner(" || "); + ScriptTemplate leftScript = asScript(value); + List rightParams = new ArrayList<>(); + String scriptPrefix = leftScript + "=="; + LinkedHashSet values = list.stream().map(Expression::fold).collect(Collectors.toCollection(LinkedHashSet::new)); + for (Object valueFromList : values) { + if (valueFromList instanceof Expression) { + ScriptTemplate rightScript = asScript((Expression) valueFromList); + sj.add(scriptPrefix + rightScript.template()); + rightParams.add(rightScript.params()); + } else { + if (valueFromList instanceof String) { + sj.add(scriptPrefix + '"' + valueFromList + '"'); + } else { + sj.add(scriptPrefix + valueFromList.toString()); + } + } + } + + ParamsBuilder paramsBuilder = paramsBuilder().script(leftScript.params()); + for (Params p : rightParams) { + paramsBuilder = paramsBuilder.script(p); + } + + return new ScriptTemplate(format(Locale.ROOT, "%s", sj.toString()), paramsBuilder.build(), dataType()); + } + + @Override + protected Pipe makePipe() { + return new InPipe(location(), this, children().stream().map(Expressions::pipe).collect(Collectors.toList())); } @Override @@ -97,4 +162,4 @@ public boolean equals(Object obj) { return Objects.equals(value, other.value) && Objects.equals(list, other.list); } -} \ No newline at end of file +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/Comparisons.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/Comparisons.java index cdd293cb1afb1..79d3f2b318b59 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/Comparisons.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/Comparisons.java @@ -5,12 +5,16 @@ */ package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison; +import java.util.Set; + /** * Comparison utilities. */ -abstract class Comparisons { +public final class Comparisons { + + private Comparisons() {} - static Boolean eq(Object l, Object r) { + public static Boolean eq(Object l, Object r) { Integer i = compare(l, r); return i == null ? null : i.intValue() == 0; } @@ -35,6 +39,10 @@ static Boolean gte(Object l, Object r) { return i == null ? null : i.intValue() >= 0; } + static Boolean in(Object l, Set r) { + return r.contains(l); + } + /** * Compares two expression arguments (typically Numbers), if possible. * Otherwise returns null (the arguments are not comparable or at least @@ -73,4 +81,4 @@ private static Integer compare(Number l, Number r) { return Integer.valueOf(Integer.compare(l.intValue(), r.intValue())); } -} \ No newline at end of file +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/InPipe.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/InPipe.java new file mode 100644 index 0000000000000..4ae72b4b49e7a --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/InPipe.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison; + +import org.elasticsearch.xpack.sql.capabilities.Resolvables; +import org.elasticsearch.xpack.sql.execution.search.FieldExtraction; +import org.elasticsearch.xpack.sql.execution.search.SqlSourceBuilder; +import org.elasticsearch.xpack.sql.expression.Expression; +import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; +import org.elasticsearch.xpack.sql.tree.Location; +import org.elasticsearch.xpack.sql.tree.NodeInfo; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +public class InPipe extends Pipe { + + private List pipes; + + public InPipe(Location location, Expression expression, List pipes) { + super(location, expression, pipes); + this.pipes = pipes; + } + + @Override + public final Pipe replaceChildren(List newChildren) { + if (newChildren.size() < 2) { + throw new IllegalArgumentException("expected at least [2] children but received [" + newChildren.size() + "]"); + } + return new InPipe(location(), expression(), newChildren); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, InPipe::new, expression(), pipes); + } + + @Override + public boolean supportedByAggsOnlyQuery() { + return pipes.stream().allMatch(FieldExtraction::supportedByAggsOnlyQuery); + } + + @Override + public final Pipe resolveAttributes(AttributeResolver resolver) { + List newPipes = new ArrayList<>(pipes.size()); + for (Pipe p : pipes) { + newPipes.add(p.resolveAttributes(resolver)); + } + return replaceChildren(newPipes); + } + + @Override + public boolean resolved() { + return Resolvables.resolved(pipes); + } + + @Override + public final void collectFields(SqlSourceBuilder sourceBuilder) { + pipes.forEach(p -> p.collectFields(sourceBuilder)); + } + + @Override + public int hashCode() { + return Objects.hash(pipes); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + InPipe other = (InPipe) obj; + return Objects.equals(pipes, other.pipes); + } + + @Override + public InProcessor asProcessor() { + return new InProcessor(pipes.stream().map(Pipe::asProcessor).collect(Collectors.toList())); + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/InProcessor.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/InProcessor.java new file mode 100644 index 0000000000000..5ebf8870965b5 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/InProcessor.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.sql.expression.gen.processor.Processor; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class InProcessor implements Processor { + + public static final String NAME = "in"; + + private final List processsors; + + public InProcessor(List processors) { + this.processsors = processors; + } + + public InProcessor(StreamInput in) throws IOException { + processsors = in.readNamedWriteableList(Processor.class); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public final void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteableList(processsors); + } + + @Override + public Object process(Object input) { + Object leftValue = processsors.get(processsors.size() - 1).process(input); + + for (int i = 0; i < processsors.size() - 1; i++) { + Boolean compResult = Comparisons.eq(leftValue, processsors.get(i).process(input)); + if (compResult != null && compResult) { + return true; + } + } + return false; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InProcessor that = (InProcessor) o; + return Objects.equals(processsors, that.processsors); + } + + @Override + public int hashCode() { + return Objects.hash(processsors); + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java index 18ba4ff41b702..8443358a12cb2 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java @@ -1892,4 +1892,4 @@ protected LogicalPlan rule(LogicalPlan plan) { enum TransformDirection { UP, DOWN }; -} \ No newline at end of file +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryFolder.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryFolder.java index 8616205b00359..1d61cb1be46a9 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryFolder.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryFolder.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; import org.elasticsearch.xpack.sql.expression.gen.pipeline.UnaryPipe; import org.elasticsearch.xpack.sql.expression.gen.processor.Processor; +import org.elasticsearch.xpack.sql.expression.predicate.In; import org.elasticsearch.xpack.sql.plan.physical.AggregateExec; import org.elasticsearch.xpack.sql.plan.physical.EsQueryExec; import org.elasticsearch.xpack.sql.plan.physical.FilterExec; @@ -138,6 +139,9 @@ protected PhysicalPlan rule(ProjectExec project) { if (pj instanceof ScalarFunction) { ScalarFunction f = (ScalarFunction) pj; processors.put(f.toAttribute(), Expressions.pipe(f)); + } else if (pj instanceof In) { + In in = (In) pj; + processors.put(in.toAttribute(), Expressions.pipe(in)); } } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java index 806944e3a7905..453660f07da8a 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java @@ -32,6 +32,7 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeFunction; import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeHistogramFunction; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; +import org.elasticsearch.xpack.sql.expression.predicate.In; import org.elasticsearch.xpack.sql.expression.predicate.IsNotNull; import org.elasticsearch.xpack.sql.expression.predicate.Range; import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate; @@ -80,6 +81,7 @@ import org.elasticsearch.xpack.sql.querydsl.query.RegexQuery; import org.elasticsearch.xpack.sql.querydsl.query.ScriptQuery; import org.elasticsearch.xpack.sql.querydsl.query.TermQuery; +import org.elasticsearch.xpack.sql.querydsl.query.TermsQuery; import org.elasticsearch.xpack.sql.querydsl.query.WildcardQuery; import org.elasticsearch.xpack.sql.tree.Location; import org.elasticsearch.xpack.sql.util.Check; @@ -90,16 +92,20 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import static java.util.Collections.singletonList; import static org.elasticsearch.xpack.sql.expression.Foldables.doubleValuesOf; import static org.elasticsearch.xpack.sql.expression.Foldables.stringValueOf; import static org.elasticsearch.xpack.sql.expression.Foldables.valueOf; -abstract class QueryTranslator { +final class QueryTranslator { - static final List> QUERY_TRANSLATORS = Arrays.asList( + private QueryTranslator(){} + + private static final List> QUERY_TRANSLATORS = Arrays.asList( new BinaryComparisons(), + new InComparisons(), new Ranges(), new BinaryLogic(), new Nots(), @@ -110,7 +116,7 @@ abstract class QueryTranslator { new MultiMatches() ); - static final List> AGG_TRANSLATORS = Arrays.asList( + private static final List> AGG_TRANSLATORS = Arrays.asList( new Maxes(), new Mins(), new Avgs(), @@ -235,7 +241,7 @@ static GroupingContext groupBy(List groupings) { } aggId = ne.id().toString(); - GroupByKey key = null; + GroupByKey key; // handle functions differently if (exp instanceof Function) { @@ -281,7 +287,7 @@ static QueryTranslation and(Location loc, QueryTranslation left, QueryTranslatio newQ = and(loc, left.query, right.query); } - AggFilter aggFilter = null; + AggFilter aggFilter; if (left.aggFilter == null) { aggFilter = right.aggFilter; @@ -533,7 +539,7 @@ protected QueryTranslation asQuery(BinaryComparison bc, boolean onAggs) { // if the code gets here it's a bug // else { - throw new UnsupportedOperationException("No idea how to translate " + bc.left()); + throw new SqlIllegalArgumentException("No idea how to translate " + bc.left()); } } @@ -572,6 +578,55 @@ private static Query translateQuery(BinaryComparison bc) { } } + // assume the Optimizer properly orders the predicates to ease the translation + static class InComparisons extends ExpressionTranslator { + + @Override + protected QueryTranslation asQuery(In in, boolean onAggs) { + Optional firstNotFoldable = in.list().stream().filter(expression -> !expression.foldable()).findFirst(); + + if (firstNotFoldable.isPresent()) { + throw new SqlIllegalArgumentException( + "Line {}:{}: Comparisons against variables are not (currently) supported; offender [{}] in [{}]", + firstNotFoldable.get().location().getLineNumber(), + firstNotFoldable.get().location().getColumnNumber(), + Expressions.name(firstNotFoldable.get()), + in.name()); + } + + if (in.value() instanceof NamedExpression) { + NamedExpression ne = (NamedExpression) in.value(); + + Query query = null; + AggFilter aggFilter = null; + + Attribute at = ne.toAttribute(); + // + // Agg context means HAVING -> PipelineAggs + // + ScriptTemplate script = in.asScript(); + if (onAggs) { + aggFilter = new AggFilter(at.id().toString(), script); + } + else { + // query directly on the field + if (at instanceof FieldAttribute) { + query = wrapIfNested(new TermsQuery(in.location(), ne.name(), in.list()), ne); + } else { + query = new ScriptQuery(at.location(), script); + } + } + return new QueryTranslation(query, aggFilter); + } + // + // if the code gets here it's a bug + // + else { + throw new SqlIllegalArgumentException("No idea how to translate " + in.value()); + } + } + } + static class Ranges extends ExpressionTranslator { @Override @@ -759,4 +814,4 @@ protected static Query wrapIfNested(Query query, Expression exp) { return query; } } -} \ No newline at end of file +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/query/TermsQuery.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/query/TermsQuery.java new file mode 100644 index 0000000000000..412df4e8ca682 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/query/TermsQuery.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.sql.querydsl.query; + +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.sql.expression.Expression; +import org.elasticsearch.xpack.sql.expression.Foldables; +import org.elasticsearch.xpack.sql.tree.Location; + +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.index.query.QueryBuilders.termsQuery; + +public class TermsQuery extends LeafQuery { + + private final String term; + private final LinkedHashSet values; + + public TermsQuery(Location location, String term, List values) { + super(location); + this.term = term; + this.values = new LinkedHashSet<>(Foldables.valuesOf(values, values.get(0).dataType())); + } + + @Override + public QueryBuilder asBuilder() { + return termsQuery(term, values); + } + + @Override + public int hashCode() { + return Objects.hash(term, values); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + TermsQuery other = (TermsQuery) obj; + return Objects.equals(term, other.term) + && Objects.equals(values, other.values); + } + + @Override + protected String innerToString() { + return term + ":" + values; + } +} diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java index 05e88cfb66b39..c193dcfd5461f 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java @@ -174,4 +174,44 @@ public void testHavingOnScalar() { assertEquals("1:42: Cannot filter HAVING on non-aggregate [int]; consider using WHERE instead", verify("SELECT int FROM test GROUP BY int HAVING 2 < ABS(int)")); } + + public void testInWithDifferentDataTypes_SelectClause() { + assertEquals("1:17: expected data type [INTEGER], value provided is of type [KEYWORD]", + verify("SELECT 1 IN (2, '3', 4)")); + } + + public void testInNestedWithDifferentDataTypes_SelectClause() { + assertEquals("1:27: expected data type [INTEGER], value provided is of type [KEYWORD]", + verify("SELECT 1 = 1 OR 1 IN (2, '3', 4)")); + } + + public void testInWithDifferentDataTypesFromLeftValue_SelectClause() { + assertEquals("1:14: expected data type [INTEGER], value provided is of type [KEYWORD]", + verify("SELECT 1 IN ('foo', 'bar')")); + } + + public void testInNestedWithDifferentDataTypesFromLeftValue_SelectClause() { + assertEquals("1:29: expected data type [KEYWORD], value provided is of type [INTEGER]", + verify("SELECT 1 = 1 OR 'foo' IN (2, 3)")); + } + + public void testInWithDifferentDataTypes_WhereClause() { + assertEquals("1:49: expected data type [TEXT], value provided is of type [INTEGER]", + verify("SELECT * FROM test WHERE text IN ('foo', 'bar', 4)")); + } + + public void testInNestedWithDifferentDataTypes_WhereClause() { + assertEquals("1:60: expected data type [TEXT], value provided is of type [INTEGER]", + verify("SELECT * FROM test WHERE int = 1 OR text IN ('foo', 'bar', 2)")); + } + + public void testInWithDifferentDataTypesFromLeftValue_WhereClause() { + assertEquals("1:35: expected data type [TEXT], value provided is of type [INTEGER]", + verify("SELECT * FROM test WHERE text IN (1, 2)")); + } + + public void testInNestedWithDifferentDataTypesFromLeftValue_WhereClause() { + assertEquals("1:46: expected data type [TEXT], value provided is of type [INTEGER]", + verify("SELECT * FROM test WHERE int = 1 OR text IN (1, 2)")); + } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/InProcessorTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/InProcessorTests.java new file mode 100644 index 0000000000000..3e71ac90f8127 --- /dev/null +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/InProcessorTests.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.sql.expression.predicate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable.Reader; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.sql.expression.Literal; +import org.elasticsearch.xpack.sql.expression.function.scalar.Processors; +import org.elasticsearch.xpack.sql.expression.gen.processor.ConstantProcessor; +import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.InProcessor; + +import java.util.Arrays; + +import static org.elasticsearch.xpack.sql.tree.Location.EMPTY; + +public class InProcessorTests extends AbstractWireSerializingTestCase { + + private static final Literal ONE = L(1); + private static final Literal TWO = L(2); + private static final Literal THREE = L(3); + + public static InProcessor randomProcessor() { + return new InProcessor(Arrays.asList(new ConstantProcessor(randomLong()), new ConstantProcessor(randomLong()))); + } + + @Override + protected InProcessor createTestInstance() { + return randomProcessor(); + } + + @Override + protected Reader instanceReader() { + return InProcessor::new; + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(Processors.getNamedWriteables()); + } + + public void testEq() { + assertEquals(true, new In(EMPTY, TWO, Arrays.asList(ONE, TWO, THREE)).makePipe().asProcessor().process(null)); + assertEquals(false, new In(EMPTY, THREE, Arrays.asList(ONE, TWO)).makePipe().asProcessor().process(null)); + } + + private static Literal L(Object value) { + return Literal.of(EMPTY, value); + } +} diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java index 608be8ab86f49..acd0378ee010f 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.Expressions; import org.elasticsearch.xpack.sql.expression.FieldAttribute; +import org.elasticsearch.xpack.sql.expression.Foldables; import org.elasticsearch.xpack.sql.expression.Literal; import org.elasticsearch.xpack.sql.expression.NamedExpression; import org.elasticsearch.xpack.sql.expression.Order; @@ -30,6 +31,7 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.math.E; import org.elasticsearch.xpack.sql.expression.function.scalar.math.Floor; import org.elasticsearch.xpack.sql.expression.predicate.BinaryOperator; +import org.elasticsearch.xpack.sql.expression.predicate.In; import org.elasticsearch.xpack.sql.expression.predicate.IsNotNull; import org.elasticsearch.xpack.sql.expression.predicate.Range; import org.elasticsearch.xpack.sql.expression.predicate.logical.And; @@ -81,6 +83,7 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; import static org.elasticsearch.xpack.sql.tree.Location.EMPTY; +import static org.hamcrest.Matchers.contains; public class OptimizerTests extends ESTestCase { @@ -147,6 +150,11 @@ private static Literal L(Object value) { return Literal.of(EMPTY, value); } + private static FieldAttribute getFieldAttribute() { + return new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + } + + public void testPruneSubqueryAliases() { ShowTables s = new ShowTables(EMPTY, null, null); SubQueryAlias plan = new SubQueryAlias(EMPTY, s, "show"); @@ -298,6 +306,23 @@ public void testConstantFoldingDatetime() { new WeekOfYear(EMPTY, new Literal(EMPTY, null, DataType.NULL), UTC))); } + public void testConstantFoldingIn() { + In in = new In(EMPTY, ONE, + Arrays.asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE)))); + Literal result= (Literal) new ConstantFolding().rule(in); + assertEquals(true, result.value()); + } + + public void testConstantFoldingIn_LeftValueNotFoldable() { + Project p = new Project(EMPTY, FROM(), Collections.singletonList( + new In(EMPTY, getFieldAttribute(), + Arrays.asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE)))))); + p = (Project) new ConstantFolding().apply(p); + assertEquals(1, p.projections().size()); + In in = (In) p.projections().get(0); + assertThat(Foldables.valuesOf(in.list(), DataType.INTEGER), contains(1 ,2 ,3 ,4)); + } + public void testArithmeticFolding() { assertEquals(10, foldOperator(new Add(EMPTY, L(7), THREE))); assertEquals(4, foldOperator(new Sub(EMPTY, L(7), THREE))); @@ -389,7 +414,7 @@ public void testBoolCommonFactorExtraction() { // 6 < a <= 5 -> FALSE public void testFoldExcludingRangeToFalse() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r = new Range(EMPTY, fa, SIX, false, FIVE, true); assertTrue(r.foldable()); @@ -398,7 +423,7 @@ public void testFoldExcludingRangeToFalse() { // 6 < a <= 5.5 -> FALSE public void testFoldExcludingRangeWithDifferentTypesToFalse() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r = new Range(EMPTY, fa, SIX, false, L(5.5d), true); assertTrue(r.foldable()); @@ -408,7 +433,7 @@ public void testFoldExcludingRangeWithDifferentTypesToFalse() { // Conjunction public void testCombineBinaryComparisonsNotComparable() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); LessThanOrEqual lte = new LessThanOrEqual(EMPTY, fa, SIX); LessThan lt = new LessThan(EMPTY, fa, Literal.FALSE); @@ -420,7 +445,7 @@ public void testCombineBinaryComparisonsNotComparable() { // a <= 6 AND a < 5 -> a < 5 public void testCombineBinaryComparisonsUpper() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); LessThanOrEqual lte = new LessThanOrEqual(EMPTY, fa, SIX); LessThan lt = new LessThan(EMPTY, fa, FIVE); @@ -434,7 +459,7 @@ public void testCombineBinaryComparisonsUpper() { // 6 <= a AND 5 < a -> 6 <= a public void testCombineBinaryComparisonsLower() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); GreaterThanOrEqual gte = new GreaterThanOrEqual(EMPTY, fa, SIX); GreaterThan gt = new GreaterThan(EMPTY, fa, FIVE); @@ -448,7 +473,7 @@ public void testCombineBinaryComparisonsLower() { // 5 <= a AND 5 < a -> 5 < a public void testCombineBinaryComparisonsInclude() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); GreaterThanOrEqual gte = new GreaterThanOrEqual(EMPTY, fa, FIVE); GreaterThan gt = new GreaterThan(EMPTY, fa, FIVE); @@ -462,7 +487,7 @@ public void testCombineBinaryComparisonsInclude() { // 3 <= a AND 4 < a AND a <= 7 AND a < 6 -> 4 < a < 6 public void testCombineMultipleBinaryComparisons() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); GreaterThanOrEqual gte = new GreaterThanOrEqual(EMPTY, fa, THREE); GreaterThan gt = new GreaterThan(EMPTY, fa, FOUR); LessThanOrEqual lte = new LessThanOrEqual(EMPTY, fa, L(7)); @@ -481,7 +506,7 @@ public void testCombineMultipleBinaryComparisons() { // 3 <= a AND TRUE AND 4 < a AND a != 5 AND a <= 7 -> 4 < a <= 7 AND a != 5 AND TRUE public void testCombineMixedMultipleBinaryComparisons() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); GreaterThanOrEqual gte = new GreaterThanOrEqual(EMPTY, fa, THREE); GreaterThan gt = new GreaterThan(EMPTY, fa, FOUR); LessThanOrEqual lte = new LessThanOrEqual(EMPTY, fa, L(7)); @@ -503,7 +528,7 @@ public void testCombineMixedMultipleBinaryComparisons() { // 1 <= a AND a < 5 -> 1 <= a < 5 public void testCombineComparisonsIntoRange() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); GreaterThanOrEqual gte = new GreaterThanOrEqual(EMPTY, fa, ONE); LessThan lt = new LessThan(EMPTY, fa, FIVE); @@ -520,7 +545,7 @@ public void testCombineComparisonsIntoRange() { // a != NULL AND a > 1 AND a < 5 AND a == 10 -> (a != NULL AND a == 10) AND 1 <= a < 5 public void testCombineUnbalancedComparisonsMixedWithEqualsIntoRange() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); IsNotNull isn = new IsNotNull(EMPTY, fa); GreaterThanOrEqual gte = new GreaterThanOrEqual(EMPTY, fa, ONE); @@ -544,7 +569,7 @@ public void testCombineUnbalancedComparisonsMixedWithEqualsIntoRange() { // (2 < a < 3) AND (1 < a < 4) -> (2 < a < 3) public void testCombineBinaryComparisonsConjunctionOfIncludedRange() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r1 = new Range(EMPTY, fa, TWO, false, THREE, false); Range r2 = new Range(EMPTY, fa, ONE, false, FOUR, false); @@ -558,7 +583,7 @@ public void testCombineBinaryComparisonsConjunctionOfIncludedRange() { // (2 < a < 3) AND a < 2 -> 2 < a < 2 public void testCombineBinaryComparisonsConjunctionOfNonOverlappingBoundaries() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r1 = new Range(EMPTY, fa, TWO, false, THREE, false); Range r2 = new Range(EMPTY, fa, ONE, false, TWO, false); @@ -578,7 +603,7 @@ public void testCombineBinaryComparisonsConjunctionOfNonOverlappingBoundaries() // (2 < a < 3) AND (2 < a <= 3) -> 2 < a < 3 public void testCombineBinaryComparisonsConjunctionOfUpperEqualsOverlappingBoundaries() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r1 = new Range(EMPTY, fa, TWO, false, THREE, false); Range r2 = new Range(EMPTY, fa, TWO, false, THREE, true); @@ -592,7 +617,7 @@ public void testCombineBinaryComparisonsConjunctionOfUpperEqualsOverlappingBound // (2 < a < 3) AND (1 < a < 3) -> 2 < a < 3 public void testCombineBinaryComparisonsConjunctionOverlappingUpperBoundary() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r2 = new Range(EMPTY, fa, TWO, false, THREE, false); Range r1 = new Range(EMPTY, fa, ONE, false, THREE, false); @@ -606,7 +631,7 @@ public void testCombineBinaryComparisonsConjunctionOverlappingUpperBoundary() { // (2 < a <= 3) AND (1 < a < 3) -> 2 < a < 3 public void testCombineBinaryComparisonsConjunctionWithDifferentUpperLimitInclusion() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r1 = new Range(EMPTY, fa, ONE, false, THREE, false); Range r2 = new Range(EMPTY, fa, TWO, false, THREE, true); @@ -625,7 +650,7 @@ public void testCombineBinaryComparisonsConjunctionWithDifferentUpperLimitInclus // (0 < a <= 1) AND (0 <= a < 2) -> 0 < a <= 1 public void testRangesOverlappingConjunctionNoLowerBoundary() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r1 = new Range(EMPTY, fa, L(0), false, ONE, true); Range r2 = new Range(EMPTY, fa, L(0), true, TWO, false); @@ -640,7 +665,7 @@ public void testRangesOverlappingConjunctionNoLowerBoundary() { // Disjunction public void testCombineBinaryComparisonsDisjunctionNotComparable() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); GreaterThan gt1 = new GreaterThan(EMPTY, fa, ONE); GreaterThan gt2 = new GreaterThan(EMPTY, fa, Literal.FALSE); @@ -655,7 +680,7 @@ public void testCombineBinaryComparisonsDisjunctionNotComparable() { // 2 < a OR 1 < a OR 3 < a -> 1 < a public void testCombineBinaryComparisonsDisjunctionLowerBound() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); GreaterThan gt1 = new GreaterThan(EMPTY, fa, ONE); GreaterThan gt2 = new GreaterThan(EMPTY, fa, TWO); @@ -673,7 +698,7 @@ public void testCombineBinaryComparisonsDisjunctionLowerBound() { // 2 < a OR 1 < a OR 3 <= a -> 1 < a public void testCombineBinaryComparisonsDisjunctionIncludeLowerBounds() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); GreaterThan gt1 = new GreaterThan(EMPTY, fa, ONE); GreaterThan gt2 = new GreaterThan(EMPTY, fa, TWO); @@ -691,7 +716,7 @@ public void testCombineBinaryComparisonsDisjunctionIncludeLowerBounds() { // a < 1 OR a < 2 OR a < 3 -> a < 3 public void testCombineBinaryComparisonsDisjunctionUpperBound() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); LessThan lt1 = new LessThan(EMPTY, fa, ONE); LessThan lt2 = new LessThan(EMPTY, fa, TWO); @@ -709,7 +734,7 @@ public void testCombineBinaryComparisonsDisjunctionUpperBound() { // a < 2 OR a <= 2 OR a < 1 -> a <= 2 public void testCombineBinaryComparisonsDisjunctionIncludeUpperBounds() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); LessThan lt1 = new LessThan(EMPTY, fa, ONE); LessThan lt2 = new LessThan(EMPTY, fa, TWO); @@ -727,7 +752,7 @@ public void testCombineBinaryComparisonsDisjunctionIncludeUpperBounds() { // a < 2 OR 3 < a OR a < 1 OR 4 < a -> a < 2 OR 3 < a public void testCombineBinaryComparisonsDisjunctionOfLowerAndUpperBounds() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); LessThan lt1 = new LessThan(EMPTY, fa, ONE); LessThan lt2 = new LessThan(EMPTY, fa, TWO); @@ -753,7 +778,7 @@ public void testCombineBinaryComparisonsDisjunctionOfLowerAndUpperBounds() { // (2 < a < 3) OR (1 < a < 4) -> (1 < a < 4) public void testCombineBinaryComparisonsDisjunctionOfIncludedRangeNotComparable() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r1 = new Range(EMPTY, fa, TWO, false, THREE, false); Range r2 = new Range(EMPTY, fa, ONE, false, Literal.FALSE, false); @@ -765,10 +790,9 @@ public void testCombineBinaryComparisonsDisjunctionOfIncludedRangeNotComparable( assertEquals(or, exp); } - // (2 < a < 3) OR (1 < a < 4) -> (1 < a < 4) public void testCombineBinaryComparisonsDisjunctionOfIncludedRange() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r1 = new Range(EMPTY, fa, TWO, false, THREE, false); @@ -789,7 +813,7 @@ public void testCombineBinaryComparisonsDisjunctionOfIncludedRange() { // (2 < a < 3) OR (1 < a < 2) -> same public void testCombineBinaryComparisonsDisjunctionOfNonOverlappingBoundaries() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r1 = new Range(EMPTY, fa, TWO, false, THREE, false); Range r2 = new Range(EMPTY, fa, ONE, false, TWO, false); @@ -803,7 +827,7 @@ public void testCombineBinaryComparisonsDisjunctionOfNonOverlappingBoundaries() // (2 < a < 3) OR (2 < a <= 3) -> 2 < a <= 3 public void testCombineBinaryComparisonsDisjunctionOfUpperEqualsOverlappingBoundaries() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r1 = new Range(EMPTY, fa, TWO, false, THREE, false); Range r2 = new Range(EMPTY, fa, TWO, false, THREE, true); @@ -817,7 +841,7 @@ public void testCombineBinaryComparisonsDisjunctionOfUpperEqualsOverlappingBound // (2 < a < 3) OR (1 < a < 3) -> 1 < a < 3 public void testCombineBinaryComparisonsOverlappingUpperBoundary() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r2 = new Range(EMPTY, fa, TWO, false, THREE, false); Range r1 = new Range(EMPTY, fa, ONE, false, THREE, false); @@ -831,7 +855,7 @@ public void testCombineBinaryComparisonsOverlappingUpperBoundary() { // (2 < a <= 3) OR (1 < a < 3) -> same (the <= prevents the ranges from being combined) public void testCombineBinaryComparisonsWithDifferentUpperLimitInclusion() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r1 = new Range(EMPTY, fa, ONE, false, THREE, false); Range r2 = new Range(EMPTY, fa, TWO, false, THREE, true); @@ -845,7 +869,7 @@ public void testCombineBinaryComparisonsWithDifferentUpperLimitInclusion() { // (0 < a <= 1) OR (0 < a < 2) -> 0 < a < 2 public void testRangesOverlappingNoLowerBoundary() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Range r2 = new Range(EMPTY, fa, L(0), false, TWO, false); Range r1 = new Range(EMPTY, fa, L(0), false, ONE, true); @@ -861,7 +885,7 @@ public void testRangesOverlappingNoLowerBoundary() { // a == 1 AND a == 2 -> FALSE public void testDualEqualsConjunction() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Equals eq1 = new Equals(EMPTY, fa, ONE); Equals eq2 = new Equals(EMPTY, fa, TWO); @@ -872,7 +896,7 @@ public void testDualEqualsConjunction() { // 1 <= a < 10 AND a == 1 -> a == 1 public void testEliminateRangeByEqualsInInterval() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Equals eq1 = new Equals(EMPTY, fa, ONE); Range r = new Range(EMPTY, fa, ONE, true, L(10), false); @@ -883,7 +907,7 @@ public void testEliminateRangeByEqualsInInterval() { // 1 < a < 10 AND a == 10 -> FALSE public void testEliminateRangeByEqualsOutsideInterval() { - FieldAttribute fa = new FieldAttribute(EMPTY, "a", new EsField("af", DataType.INTEGER, emptyMap(), true)); + FieldAttribute fa = getFieldAttribute(); Equals eq1 = new Equals(EMPTY, fa, L(10)); Range r = new Range(EMPTY, fa, ONE, false, L(10), false); @@ -891,4 +915,4 @@ public void testEliminateRangeByEqualsOutsideInterval() { Expression exp = rule.rule(new And(EMPTY, eq1, r)); assertEquals(Literal.FALSE, rule.rule(exp)); } -} \ No newline at end of file +} diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java index 71f4dab679c99..8d5db634ff073 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java @@ -5,7 +5,7 @@ */ package org.elasticsearch.xpack.sql.planner; -import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.AbstractBuilderTestCase; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer; import org.elasticsearch.xpack.sql.analysis.index.EsIndex; @@ -20,30 +20,40 @@ import org.elasticsearch.xpack.sql.planner.QueryTranslator.QueryTranslation; import org.elasticsearch.xpack.sql.querydsl.query.Query; import org.elasticsearch.xpack.sql.querydsl.query.RangeQuery; +import org.elasticsearch.xpack.sql.querydsl.query.ScriptQuery; import org.elasticsearch.xpack.sql.querydsl.query.TermQuery; +import org.elasticsearch.xpack.sql.querydsl.query.TermsQuery; import org.elasticsearch.xpack.sql.type.EsField; import org.elasticsearch.xpack.sql.type.TypesTests; import org.joda.time.DateTime; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import java.io.IOException; import java.util.Map; import java.util.TimeZone; -public class QueryTranslatorTests extends ESTestCase { +import static org.hamcrest.core.StringStartsWith.startsWith; - private SqlParser parser; - private IndexResolution getIndexResult; - private FunctionRegistry functionRegistry; - private Analyzer analyzer; - - public QueryTranslatorTests() { +public class QueryTranslatorTests extends AbstractBuilderTestCase { + + private static SqlParser parser; + private static Analyzer analyzer; + + @BeforeClass + public static void init() { parser = new SqlParser(); - functionRegistry = new FunctionRegistry(); Map mapping = TypesTests.loadMapping("mapping-multi-field-variation.json"); - EsIndex test = new EsIndex("test", mapping); - getIndexResult = IndexResolution.valid(test); - analyzer = new Analyzer(functionRegistry, getIndexResult, TimeZone.getTimeZone("UTC")); + IndexResolution getIndexResult = IndexResolution.valid(test); + analyzer = new Analyzer(new FunctionRegistry(), getIndexResult, TimeZone.getTimeZone("UTC")); + } + + @AfterClass + public static void destroy() { + parser = null; + analyzer = null; } private LogicalPlan plan(String sql) { @@ -149,4 +159,41 @@ public void testLikeConstructsNotSupported() { SqlIllegalArgumentException ex = expectThrows(SqlIllegalArgumentException.class, () -> QueryTranslator.toQuery(condition, false)); assertEquals("Scalar function (LTRIM(keyword)) not allowed (yet) as arguments for LIKE", ex.getMessage()); } -} \ No newline at end of file + + public void testTranslateInExpression_WhereClause() throws IOException { + LogicalPlan p = plan("SELECT * FROM test WHERE keyword IN ('foo', 'bar', 'lala', 'foo', concat('la', 'la'))"); + assertTrue(p instanceof Project); + assertTrue(p.children().get(0) instanceof Filter); + Expression condition = ((Filter) p.children().get(0)).condition(); + assertFalse(condition.foldable()); + QueryTranslation translation = QueryTranslator.toQuery(condition, false); + Query query = translation.query; + assertTrue(query instanceof TermsQuery); + TermsQuery tq = (TermsQuery) query; + assertEquals("keyword:(bar foo lala)", tq.asBuilder().toQuery(createShardContext()).toString()); + } + + public void testTranslateInExpressionInvalidValues_WhereClause() { + LogicalPlan p = plan("SELECT * FROM test WHERE keyword IN ('foo', 'bar', keyword)"); + assertTrue(p instanceof Project); + assertTrue(p.children().get(0) instanceof Filter); + Expression condition = ((Filter) p.children().get(0)).condition(); + assertFalse(condition.foldable()); + SqlIllegalArgumentException ex = expectThrows(SqlIllegalArgumentException.class, () -> QueryTranslator.toQuery(condition, false)); + assertEquals("Line 1:52: Comparisons against variables are not (currently) supported; " + + "offender [keyword] in [keyword IN(foo, bar, keyword)]", ex.getMessage()); + } + + public void testTranslateInExpression_HavingClause_Painless() { + LogicalPlan p = plan("SELECT keyword, max(int) FROM test GROUP BY keyword HAVING max(int) in (10, 20, 30 - 10)"); + assertTrue(p instanceof Project); + assertTrue(p.children().get(0) instanceof Filter); + Expression condition = ((Filter) p.children().get(0)).condition(); + assertFalse(condition.foldable()); + QueryTranslation translation = QueryTranslator.toQuery(condition, false); + assertTrue(translation.query instanceof ScriptQuery); + ScriptQuery sq = (ScriptQuery) translation.query; + assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20)", sq.script().toString()); + assertThat(sq.script().params().toString(), startsWith("[{a=MAX(int){a->")); + } +} diff --git a/x-pack/qa/sql/src/main/resources/agg.sql-spec b/x-pack/qa/sql/src/main/resources/agg.sql-spec index daf97ebd78788..2c6248059f5fb 100644 --- a/x-pack/qa/sql/src/main/resources/agg.sql-spec +++ b/x-pack/qa/sql/src/main/resources/agg.sql-spec @@ -426,6 +426,11 @@ SELECT MIN(emp_no) AS a, 1 + MIN(emp_no) AS b, ABS(MIN(emp_no)) AS c FROM test_e aggRepeatFunctionBetweenSelectAndHaving SELECT gender, COUNT(DISTINCT languages) AS c FROM test_emp GROUP BY gender HAVING count(DISTINCT languages) > 0 ORDER BY gender; +// filter with IN +aggMultiWithHavingUsingIn +SELECT MIN(salary) min, MAX(salary) max, gender g, COUNT(*) c FROM "test_emp" WHERE languages > 0 GROUP BY g HAVING max IN(74999, 74600) ORDER BY gender; +aggMultiGroupByMultiWithHavingUsingIn +SELECT MIN(salary) min, MAX(salary) max, gender g, languages l, COUNT(*) c FROM "test_emp" WHERE languages > 0 GROUP BY g, languages HAVING max IN (74500, 74600) ORDER BY gender, languages; // @@ -444,4 +449,4 @@ SELECT hire_date HD, COUNT(*) c FROM test_emp GROUP BY hire_date ORDER BY hire_d selectHireDateGroupByHireDate SELECT hire_date HD, COUNT(*) c FROM test_emp GROUP BY hire_date ORDER BY hire_date DESC; selectSalaryGroupBySalary -SELECT salary, COUNT(*) c FROM test_emp GROUP BY salary ORDER BY salary DESC; \ No newline at end of file +SELECT salary, COUNT(*) c FROM test_emp GROUP BY salary ORDER BY salary DESC; diff --git a/x-pack/qa/sql/src/main/resources/filter.sql-spec b/x-pack/qa/sql/src/main/resources/filter.sql-spec index 5112fbc15511d..1a564ecb9ad82 100644 --- a/x-pack/qa/sql/src/main/resources/filter.sql-spec +++ b/x-pack/qa/sql/src/main/resources/filter.sql-spec @@ -78,3 +78,21 @@ SELECT last_name l FROM "test_emp" WHERE emp_no BETWEEN 9990 AND 10003 ORDER BY // end::whereBetween whereNotBetween SELECT last_name l FROM "test_emp" WHERE emp_no NOT BETWEEN 10010 AND 10020 ORDER BY emp_no LIMIT 5; + +// +// IN expression +// +whereWithInAndOneValue +SELECT last_name l FROM "test_emp" WHERE emp_no IN (10001); +whereWithInAndMultipleValues +// tag::whereWithInAndMultipleValues +SELECT last_name l FROM "test_emp" WHERE emp_no IN (10000, 10001, 10002, 999) ORDER BY emp_no LIMIT 5; +// end::whereWithInAndMultipleValues + +whereWithInAndOneValueWithNegation +SELECT last_name l FROM "test_emp" WHERE emp_no NOT IN (10001) ORDER BY emp_no LIMIT 5; +whereWithInAndMultipleValuesAndNegation +SELECT last_name l FROM "test_emp" WHERE emp_no NOT IN (10000, 10001, 10002, 999) ORDER BY emp_no LIMIT 5; + +whereWithInAndComplexFunctions +SELECT last_name l FROM "test_emp" WHERE emp_no NOT IN (10000, abs(2 - 10003), 10002, 999) AND lcase(first_name) IN ('sumant', 'mary', 'patricio', 'No''Match') ORDER BY emp_no LIMIT 5; diff --git a/x-pack/qa/sql/src/main/resources/select.csv-spec b/x-pack/qa/sql/src/main/resources/select.csv-spec new file mode 100644 index 0000000000000..b3888abd47bf3 --- /dev/null +++ b/x-pack/qa/sql/src/main/resources/select.csv-spec @@ -0,0 +1,67 @@ +// SELECT with IN +inWithLiterals +SELECT 1 IN (1, 2, 3), 1 IN (2, 3); + + 1 IN (1, 2, 3) | 1 IN (2, 3) +-----------------+------------- +true |false +; + +inWithLiteralsAndFunctions +SELECT 1 IN (2 - 1, 2, 3), abs(-1) IN (2, 3, abs(4 - 5)); + + 1 IN (1, 2, 3) | 1 IN (2, 3) +-----------------+------------- +true |false +; + + +inWithLiteralsAndNegation +SELECT NOT 1 IN (1, 1 + 1, 3), NOT 1 IN (2, 3); + + 1 IN (1, 2, 3) | 1 IN (2, 3) +-----------------+------------- +false |true +; + + +// +// SELECT with IN and table columns +// +inWithTableColumn +SELECT emp_no IN (10000, 10001, 10002) FROM test_emp ORDER BY 1; + + emp_no +------- +10001 +10002 +; + +inWithTableColumnAndFunction +SELECT emp_no IN (10000, 10000 + 1, abs(-10000 - 2)) FROM test_emp; + + emp_no +------- +10001 +10002 +; + +inWithTableColumnAndNegation +SELECT emp_no NOT IN (10000, 10000 + 1, 10002) FROM test_emp ORDER BY 1 LIMIT 3; + + emp_no +------- +10003 +10004 +10005 +; + +inWithTableColumnAndComplexFunctions +SELECT 1 IN (1, abs(2 - 4), 3) OR emp_no NOT IN (10000, 10000 + 1, 10002) FROM test_emp ORDER BY 1 LIMIT 3; + + emp_no +------- +10003 +10004 +10005 +; \ No newline at end of file