Skip to content

Commit

Permalink
Combine multiple projects into one (#374)
Browse files Browse the repository at this point in the history
Applies both for aggregations and projections:
from i | project a,b | project a becomes
from i| project a

while
from i | stats count() by a, b | project a which yields
Aggregate[count(), a, b][grouping=a,b] becomes Aggregate[a][grouping=a,b]
  • Loading branch information
costin authored Nov 10, 2022
1 parent 134ceee commit 7bc059a
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
import org.elasticsearch.xpack.esql.session.EsqlSession;
import org.elasticsearch.xpack.esql.session.LocalExecutable;
import org.elasticsearch.xpack.esql.session.Result;
import org.elasticsearch.xpack.ql.expression.Alias;
import org.elasticsearch.xpack.ql.expression.Attribute;
import org.elasticsearch.xpack.ql.expression.AttributeMap;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Literal;
import org.elasticsearch.xpack.ql.expression.NamedExpression;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BinaryComparisonSimplification;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanFunctionEqualsElimination;
Expand All @@ -25,13 +28,16 @@
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PushDownAndCombineFilters;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SetAsOptimized;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SimplifyComparisonsArithmetics;
import org.elasticsearch.xpack.ql.plan.logical.Aggregate;
import org.elasticsearch.xpack.ql.plan.logical.Filter;
import org.elasticsearch.xpack.ql.plan.logical.Limit;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.ql.plan.logical.Project;
import org.elasticsearch.xpack.ql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.ql.rule.RuleExecutor;
import org.elasticsearch.xpack.ql.type.DataTypes;

import java.util.ArrayList;
import java.util.List;

import static java.util.Arrays.asList;
Expand All @@ -44,9 +50,9 @@ public LogicalPlan optimize(LogicalPlan verified) {

@Override
protected Iterable<RuleExecutor<LogicalPlan>.Batch> batches() {

Batch operators = new Batch(
"Operator Optimization",
new CombineProjections(),
new ConstantFolding(),
// boolean
new BooleanSimplification(),
Expand All @@ -68,6 +74,73 @@ protected Iterable<RuleExecutor<LogicalPlan>.Batch> batches() {
return asList(operators, local, label);
}

static class CombineProjections extends OptimizerRules.OptimizerRule<UnaryPlan> {

CombineProjections() {
super(OptimizerRules.TransformDirection.UP);
}

@Override
protected LogicalPlan rule(UnaryPlan plan) {
LogicalPlan child = plan.child();

if (plan instanceof Project project) {
if (child instanceof Project p) {
// eliminate lower project but first replace the aliases in the upper one
return new Project(p.source(), p.child(), combineProjections(project.projections(), p.projections()));
}

if (child instanceof Aggregate a) {
return new Aggregate(a.source(), a.child(), a.groupings(), combineProjections(project.projections(), a.aggregates()));
}
}

// Agg with underlying Project (group by on sub-queries)
if (plan instanceof Aggregate a) {
if (child instanceof Project p) {
return new Aggregate(a.source(), p.child(), a.groupings(), combineProjections(a.aggregates(), p.projections()));
}
}
return plan;
}

// normally only the upper projections should survive but since the lower list might have aliases definitions
// that might be reused by the upper one, these need to be replaced.
// for example an alias defined in the lower list might be referred in the upper - without replacing it the alias becomes invalid
private List<NamedExpression> combineProjections(List<? extends NamedExpression> upper, List<? extends NamedExpression> lower) {

// collect aliases in the lower list
AttributeMap.Builder<NamedExpression> aliasesBuilder = AttributeMap.builder();
for (NamedExpression ne : lower) {
if ((ne instanceof Attribute) == false) {
aliasesBuilder.put(ne.toAttribute(), ne);
}
}

AttributeMap<NamedExpression> aliases = aliasesBuilder.build();
List<NamedExpression> replaced = new ArrayList<>();

// replace any matching attribute with a lower alias (if there's a match)
// but clean-up non-top aliases at the end
for (NamedExpression ne : upper) {
NamedExpression replacedExp = (NamedExpression) ne.transformUp(Attribute.class, a -> aliases.resolve(a, a));
replaced.add((NamedExpression) trimNonTopLevelAliases(replacedExp));
}
return replaced;
}

public static Expression trimNonTopLevelAliases(Expression e) {
if (e instanceof Alias a) {
return new Alias(a.source(), a.name(), a.qualifier(), trimAliases(a.child()), a.id());
}
return trimAliases(e);
}

private static Expression trimAliases(Expression e) {
return e.transformDown(Alias.class, Alias::child);
}
}

static class CombineLimits extends OptimizerRules.OptimizerRule<Limit> {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import org.elasticsearch.xpack.esql.session.EmptyExecutable;
import org.elasticsearch.xpack.esql.session.EsqlConfiguration;
import org.elasticsearch.xpack.ql.expression.Literal;
import org.elasticsearch.xpack.ql.plan.QueryPlan;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.ql.tree.Node;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DateUtils;
import org.elasticsearch.xpack.ql.type.DefaultDataTypeRegistry;
Expand Down Expand Up @@ -41,9 +41,9 @@ public static LogicalPlan emptySource() {
return new LocalRelation(Source.EMPTY, new EmptyExecutable(emptyList()));
}

public static <P extends QueryPlan<P>, T extends P> T as(P plan, Class<T> type) {
Assert.assertThat(plan, instanceOf(type));
return type.cast(plan);
public static <P extends Node<P>, T extends P> T as(P node, Class<T> type) {
Assert.assertThat(node, instanceOf(type));
return type.cast(node);
}

public static Map<String, EsField> loadMapping(String name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,104 @@
package org.elasticsearch.xpack.esql.optimizer;

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.analysis.Analyzer;
import org.elasticsearch.xpack.esql.analysis.Verifier;
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
import org.elasticsearch.xpack.ql.expression.Alias;
import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.index.EsIndex;
import org.elasticsearch.xpack.ql.index.IndexResolution;
import org.elasticsearch.xpack.ql.plan.logical.Aggregate;
import org.elasticsearch.xpack.ql.plan.logical.EsRelation;
import org.elasticsearch.xpack.ql.plan.logical.Limit;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.ql.plan.logical.Project;
import org.elasticsearch.xpack.ql.type.EsField;
import org.junit.BeforeClass;

import java.util.Map;

import static org.elasticsearch.xpack.esql.EsqlTestUtils.L;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptySource;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.ql.tree.Source.EMPTY;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;

public class LogicalPlanOptimizerTests extends ESTestCase {

public void testCombineLimits() throws Exception {
private static EsqlParser parser;
private static Analyzer analyzer;
private static LogicalPlanOptimizer logicalOptimizer;
private static Map<String, EsField> mapping;

@BeforeClass
public static void init() {
parser = new EsqlParser();

mapping = loadMapping("mapping-basic.json");
EsIndex test = new EsIndex("test", mapping);
IndexResolution getIndexResult = IndexResolution.valid(test);
logicalOptimizer = new LogicalPlanOptimizer();

analyzer = new Analyzer(getIndexResult, new EsqlFunctionRegistry(), new Verifier(), TEST_CFG);
}

public void testCombineProjections() {
var plan = plan("""
from test
| project emp_no, *name, salary
| project last_name
""");

var project = as(plan, Project.class);
assertThat(Expressions.names(project.projections()), contains("last_name"));
var relation = as(project.child(), EsRelation.class);
}

public void testCombineProjectionWithFilterInBetween() {
var plan = plan("""
from test
| project *name, salary
| where salary > 10
| project last_name
""");

var project = as(plan, Project.class);
assertThat(Expressions.names(project.projections()), contains("last_name"));
}

public void testCombineProjectionWhilePreservingAlias() {
var plan = plan("""
from test
| project x = first_name, salary
| where salary > 10
| project y = x
""");

var project = as(plan, Project.class);
assertThat(Expressions.names(project.projections()), contains("y"));
var p = project.projections().get(0);
var alias = as(p, Alias.class);
assertThat(Expressions.name(alias.child()), containsString("first_name"));
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch-internal/issues/378")
public void testCombineProjectionWithAggregation() {
var plan = plan("""
from test
| stats avg(salary) by last_name, first_name
""");

var agg = as(plan, Aggregate.class);
assertThat(Expressions.names(agg.aggregates()), contains("last_name"));
assertThat(Expressions.names(agg.groupings()), contains("last_name", "first_name"));
}

public void testCombineLimits() {
var limitValues = new int[] { randomIntBetween(10, 99), randomIntBetween(100, 1000) };
var firstLimit = randomBoolean() ? 0 : 1;
var secondLimit = firstLimit == 0 ? 1 : 0;
Expand All @@ -28,7 +117,7 @@ public void testCombineLimits() throws Exception {
);
}

public void testMultipleCombineLimits() throws Exception {
public void testMultipleCombineLimits() {
var numberOfLimits = randomIntBetween(3, 10);
var minimum = randomIntBetween(10, 99);
var limitWithMinimum = randomIntBetween(0, numberOfLimits - 1);
Expand All @@ -40,4 +129,8 @@ public void testMultipleCombineLimits() throws Exception {
}
assertEquals(new Limit(EMPTY, L(minimum), emptySource()), new LogicalPlanOptimizer().optimize(plan));
}

private LogicalPlan plan(String query) {
return logicalOptimizer.optimize(analyzer.analyze(parser.createStatement(query)));
}
}

0 comments on commit 7bc059a

Please sign in to comment.