Skip to content

Commit

Permalink
SQL: Fix ORDER BY on aggregates and GROUPed BY fields
Browse files Browse the repository at this point in the history
Previously, in the in-memory sorting module
`LocalAggregationSorterListener` only the aggregate functions where used
(grabbed by the `sortingColumns`). As a consequence, if the ORDER BY
was also using columns of the GROUP BY clause, (especially in the case
of higher priority - before the aggregate functions) wrong results were
produced. E.g.:
```
SELECT gender, MAX(salary) AS max FROM test_emp
GROUP BY gender
ORDER BY gender, max
```

Add all columns of the ORDER BY to the `sortingColumns` so that the
`LocalAggregationSorterListener` can use the correct comparators in
the underlying PriorityQueue used to implement the in-memory sorting.

Fixes: elastic#50355
  • Loading branch information
matriv committed Feb 4, 2020
1 parent 12b24bf commit b49a23c
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 74 deletions.
27 changes: 24 additions & 3 deletions x-pack/plugin/sql/qa/src/main/resources/agg-ordering.sql-spec
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ aggNotSpecifiedInTheAggregateAndGroupWithHavingWithLimitAndDirection
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY MAX(salary) ASC, c DESC LIMIT 5;

groupAndAggNotSpecifiedInTheAggregateWithHaving
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY gender, MAX(salary);
SELECT gender, MIN(salary) AS min, COUNT(*) AS c FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY gender NULLS FIRST, MAX(salary);

multipleAggsThatGetRewrittenWithAliasOnAMediumGroupBy
SELECT languages, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY languages ORDER BY max;
Expand Down Expand Up @@ -136,5 +136,26 @@ SELECT gender AS g, first_name AS f, last_name AS l FROM test_emp GROUP BY f, ge
multipleGroupingsAndOrderingByGroups_8
SELECT gender AS g, first_name, last_name FROM test_emp GROUP BY g, last_name, first_name ORDER BY gender ASC, first_name DESC, last_name ASC;

multipleGroupingsAndOrderingByGroupsWithFunctions
SELECT first_name f, last_name l, gender g, CONCAT(first_name, last_name) c FROM test_emp GROUP BY gender, l, f, c ORDER BY gender, c DESC, first_name, last_name ASC;
multipleGroupingsAndOrderingByGroupsAndAggs_1
SELECT gender, MIN(salary) AS min, COUNT(*) AS c, MAX(salary) AS max FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY gender ASC NULLS FIRST, MAX(salary) DESC;

multipleGroupingsAndOrderingByGroupsAndAggs_2
SELECT gender, MIN(salary) AS min, COUNT(*) AS c, MAX(salary) AS max FROM test_emp GROUP BY gender HAVING c > 1 ORDER BY gender DESC NULLS LAST, MAX(salary) ASC;

multipleGroupingsAndOrderingByGroupsWithFunctions_1
SELECT first_name f, last_name l, gender g, CONCAT(first_name, last_name) c FROM test_emp GROUP BY gender, l, f, c ORDER BY gender NULLS FIRST, c DESC, first_name, last_name ASC;

multipleGroupingsAndOrderingByGroupsWithFunctions_2
SELECT first_name f, last_name l, gender g, CONCAT(first_name, last_name) c FROM test_emp GROUP BY gender, l, f, c ORDER BY c DESC, gender DESC NULLS LAST, first_name, last_name ASC;

multipleGroupingsAndOrderingByGroupsAndAggregatesWithFunctions_1
SELECT CONCAT('foo', gender) g, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY g ORDER BY 1 NULLS FIRST, 2, 3;

multipleGroupingsAndOrderingByGroupsAndAggregatesWithFunctions_2
SELECT CONCAT('foo', gender) g, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY g ORDER BY 1 DESC NULLS LAST, 2, 3;

multipleGroupingsAndOrderingByGroupsAndAggregatesWithFunctions_3
SELECT CONCAT('foo', gender) g, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY g ORDER BY 2, 1 NULLS FIRST, 3;

multipleGroupingsAndOrderingByGroupsAndAggregatesWithFunctions_4
SELECT CONCAT('foo', gender) g, MAX(salary) AS max, MIN(salary) AS min FROM test_emp GROUP BY g ORDER BY 3 DESC, 1 NULLS FIRST, 2;
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.elasticsearch.xpack.sql.querydsl.container.GlobalCountRef;
import org.elasticsearch.xpack.sql.querydsl.container.GroupByRef;
import org.elasticsearch.xpack.sql.querydsl.container.GroupByRef.Property;
import org.elasticsearch.xpack.sql.querydsl.container.GroupingFunctionSort;
import org.elasticsearch.xpack.sql.querydsl.container.MetricAggRef;
import org.elasticsearch.xpack.sql.querydsl.container.PivotColumnRef;
import org.elasticsearch.xpack.sql.querydsl.container.QueryContainer;
Expand Down Expand Up @@ -682,37 +683,34 @@ protected PhysicalPlan rule(OrderExec plan) {

// TODO: might need to validate whether the target field or group actually exist
if (group != null && group != Aggs.IMPLICIT_GROUP_KEY) {
// check whether the lookup matches a group
if (group.id().equals(lookup)) {
qContainer = qContainer.updateGroup(group.with(direction));
}
// else it's a leafAgg
else {
qContainer = qContainer.updateGroup(group.with(direction));
}
qContainer = qContainer.updateGroup(group.with(direction));
}

// field
if (orderExpression instanceof FieldAttribute) {
qContainer = qContainer.addSort(new AttributeSort((FieldAttribute) orderExpression, direction, missing));
}
// scalar functions typically require script ordering
else if (orderExpression instanceof ScalarFunction) {
ScalarFunction sf = (ScalarFunction) orderExpression;
// nope, use scripted sorting
qContainer = qContainer.addSort(new ScriptSort(Expressions.id(sf), sf.asScript(), direction, missing));
}
// histogram
else if (orderExpression instanceof Histogram) {
qContainer = qContainer.addSort(new GroupingFunctionSort(Expressions.id(orderExpression), direction, missing));
}
// score
else if (orderExpression instanceof Score) {
qContainer = qContainer.addSort(new ScoreSort(Expressions.id(orderExpression), direction, missing));
}
// agg function
else if (orderExpression instanceof AggregateFunction) {
qContainer = qContainer.addSort(new AggregateSort((AggregateFunction) orderExpression, direction, missing));
}
// unknown
else {
// scalar functions typically require script ordering
if (orderExpression instanceof ScalarFunction) {
ScalarFunction sf = (ScalarFunction) orderExpression;
// nope, use scripted sorting
qContainer = qContainer.addSort(new ScriptSort(sf.asScript(), direction, missing));
}
// score
else if (orderExpression instanceof Score) {
qContainer = qContainer.addSort(new ScoreSort(direction, missing));
}
// field
else if (orderExpression instanceof FieldAttribute) {
qContainer = qContainer.addSort(new AttributeSort((FieldAttribute) orderExpression, direction, missing));
}
// agg function
else if (orderExpression instanceof AggregateFunction) {
qContainer = qContainer.addSort(new AggregateSort((AggregateFunction) orderExpression, direction, missing));
} else {
// unknown
throw new SqlIllegalArgumentException("unsupported sorting expression {}", orderExpression);
}
throw new SqlIllegalArgumentException("unsupported sorting expression {}", orderExpression);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

package org.elasticsearch.xpack.sql.querydsl.container;

import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;

import java.util.Objects;
Expand All @@ -23,6 +24,11 @@ public AggregateFunction agg() {
return agg;
}

@Override
public String id() {
return Expressions.id(agg);
}

@Override
public int hashCode() {
return Objects.hash(agg, direction(), missing());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.elasticsearch.xpack.sql.querydsl.container;

import org.elasticsearch.xpack.ql.expression.Attribute;
import org.elasticsearch.xpack.ql.expression.Expressions;

import java.util.Objects;

Expand All @@ -22,6 +23,11 @@ public Attribute attribute() {
return attribute;
}

@Override
public String id() {
return Expressions.id(attribute);
}

@Override
public int hashCode() {
return Objects.hash(attribute, direction(), missing());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.querydsl.container;

import java.util.Objects;

public class GroupingFunctionSort extends Sort {

private final String id;

public GroupingFunctionSort(String id, Direction direction, Missing missing) {
super(direction, missing);
this.id = id;
}

@Override
public String id() {
return id;
}

@Override
public int hashCode() {
return Objects.hash(direction(), missing(), id);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}

if (obj == null || getClass() != obj.getClass()) {
return false;
}

GroupingFunctionSort other = (GroupingFunctionSort) obj;
return Objects.equals(direction(), other.direction())
&& Objects.equals(missing(), other.missing())
&& Objects.equals(id(), other.id());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.FieldAttribute;
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.ql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.ql.expression.gen.pipeline.ConstantInput;
import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe;
Expand Down Expand Up @@ -134,45 +133,39 @@ public List<Tuple<Integer, Comparator>> sortingColumns() {
return emptyList();
}

List<Tuple<Integer, Comparator>> sortingColumns = new ArrayList<>(sort.size());

boolean aggSort = false;
for (Sort s : sort) {
Tuple<Integer, Comparator> tuple = new Tuple<>(Integer.valueOf(-1), null);

if (s instanceof AggregateSort) {
AggregateSort as = (AggregateSort) s;
// find the relevant column of each aggregate function
AggregateFunction af = as.agg();

aggSort = true;
int atIndex = -1;
String id = Expressions.id(af);

for (int i = 0; i < fields.size(); i++) {
Tuple<FieldExtraction, String> field = fields.get(i);
if (field.v2().equals(id)) {
atIndex = i;
break;
}
}
if (atIndex == -1) {
throw new SqlIllegalArgumentException("Cannot find backing column for ordering aggregation [{}]", s);
}
// assemble a comparator for it
Comparator comp = s.direction() == Sort.Direction.ASC ? Comparator.naturalOrder() : Comparator.reverseOrder();
comp = s.missing() == Sort.Missing.FIRST ? Comparator.nullsFirst(comp) : Comparator.nullsLast(comp);

tuple = new Tuple<>(Integer.valueOf(atIndex), comp);
customSort = Boolean.TRUE;
}
sortingColumns.add(tuple);
}


// If no custom sort is used break early
if (customSort == null) {
customSort = Boolean.valueOf(aggSort);
customSort = Boolean.FALSE;
return emptyList();
}

return aggSort ? sortingColumns : emptyList();
List<Tuple<Integer, Comparator>> sortingColumns = new ArrayList<>(sort.size());
for (Sort s: sort) {
int atIndex = -1;
for (int i = 0; i < fields.size(); i++) {
Tuple<FieldExtraction, String> field = fields.get(i);
if (field.v2().equals(s.id())) {
atIndex = i;
break;
}
}
if (atIndex==-1) {
throw new SqlIllegalArgumentException("Cannot find backing column for ordering aggregation [{}]", s);
}
// assemble a comparator for it
Comparator comp = s.direction()==Sort.Direction.ASC ? Comparator.naturalOrder():Comparator.reverseOrder();
comp = s.missing()==Sort.Missing.FIRST ? Comparator.nullsFirst(comp):Comparator.nullsLast(comp);

sortingColumns.add(new Tuple<>(Integer.valueOf(atIndex), comp));
}

return sortingColumns;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,22 @@
import java.util.Objects;

public class ScoreSort extends Sort {
public ScoreSort(Direction direction, Missing missing) {

final String id;

public ScoreSort(String id, Direction direction, Missing missing) {
super(direction, missing);
this.id = id;
}

@Override
public String id() {
return id;
}

@Override
public int hashCode() {
return Objects.hash(direction(), missing());
return Objects.hash(direction(), missing(), id());
}

@Override
Expand All @@ -29,6 +38,7 @@ public boolean equals(Object obj) {

ScriptSort other = (ScriptSort) obj;
return Objects.equals(direction(), other.direction())
&& Objects.equals(missing(), other.missing());
&& Objects.equals(missing(), other.missing())
&& Objects.equals(id(), other.id());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,27 @@

public class ScriptSort extends Sort {

private final String id;
private final ScriptTemplate script;

public ScriptSort(ScriptTemplate script, Direction direction, Missing missing) {
public ScriptSort(String id, ScriptTemplate script, Direction direction, Missing missing) {
super(direction, missing);
this.id = id;
this.script = Scripts.nullSafeSort(script);
}

@Override
public String id() {
return id;
}

public ScriptTemplate script() {
return script;
}

@Override
public int hashCode() {
return Objects.hash(direction(), missing(), script);
return Objects.hash(direction(), missing(), id(), script());
}

@Override
Expand All @@ -37,10 +44,11 @@ public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}

ScriptSort other = (ScriptSort) obj;
return Objects.equals(direction(), other.direction())
&& Objects.equals(missing(), other.missing())
&& Objects.equals(script, other.script);
&& Objects.equals(id(), other.id())
&& Objects.equals(script(), other.script());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ protected Sort(Direction direction, Missing nulls) {
this.missing = nulls;
}

public abstract String id();

public Direction direction() {
return direction;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public void testSelectScoreForcesTrackingScore() {

public void testSortScoreSpecified() {
QueryContainer container = new QueryContainer()
.addSort(new ScoreSort(Direction.DESC, null));
.addSort(new ScoreSort("id", Direction.DESC, null));
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10));
assertEquals(singletonList(scoreSort()), sourceBuilder.sorts());
}
Expand Down Expand Up @@ -137,4 +137,4 @@ public void testNoSortIfAgg() {
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10));
assertNull(sourceBuilder.sorts());
}
}
}

0 comments on commit b49a23c

Please sign in to comment.