From d102659dce45a56020423f2802478ebd0c3ef341 Mon Sep 17 00:00:00 2001
From: Costin Leau
Date: Tue, 15 Oct 2024 00:35:14 -0700
Subject: [PATCH] ESQL: Introduce per agg filter (#113735)
Add support for aggregation scoped filters that work dynamically on the
data in each group.
| STATS
success = COUNT(*) WHERE 200 <= code AND code < 300,
redirect = COUNT(*) WHERE 300 <= code AND code < 400,
client_err = COUNT(*) WHERE 400 <= code AND code < 500,
server_err = COUNT(*) WHERE 500 <= code AND code < 600,
total_count = COUNT(*)
Implementation wise, the base AggregateFunction has been extended to
allow a filter to be passed on. This is required to incorporate the
filter as part of the aggregate equality/identify which would fail with
the filter as an external component.
As part of the process, the serialization for the existing aggregations
had to be fixed so AggregateFunction implementations so that it
delegates to their parent first.
---
docs/changelog/113735.yaml | 28 +
.../org/elasticsearch/TransportVersions.java | 1 +
.../xpack/esql/core/util/CollectionUtils.java | 15 +
.../src/main/resources/stats.csv-spec | 183 ++
.../esql/src/main/antlr/EsqlBaseLexer.g4 | 1 +
.../esql/src/main/antlr/EsqlBaseParser.g4 | 20 +-
.../xpack/esql/action/EsqlCapabilities.java | 7 +-
.../xpack/esql/analysis/Analyzer.java | 1 +
.../xpack/esql/analysis/Verifier.java | 26 +-
.../function/EsqlFunctionRegistry.java | 74 +-
.../function/aggregate/AggregateFunction.java | 66 +-
.../expression/function/aggregate/Avg.java | 21 +-
.../expression/function/aggregate/Count.java | 19 +-
.../function/aggregate/CountDistinct.java | 34 +-
.../function/aggregate/EnclosedAgg.java | 13 -
.../aggregate/FilteredExpression.java | 95 +
.../function/aggregate/FromPartial.java | 32 +-
.../expression/function/aggregate/Max.java | 17 +-
.../expression/function/aggregate/Median.java | 16 +-
.../aggregate/MedianAbsoluteDeviation.java | 18 +-
.../expression/function/aggregate/Min.java | 17 +-
.../function/aggregate/NumericAggregate.java | 4 +
.../function/aggregate/Percentile.java | 35 +-
.../expression/function/aggregate/Rate.java | 52 +-
.../aggregate/SpatialAggregateFunction.java | 6 +-
.../function/aggregate/SpatialCentroid.java | 14 +-
.../expression/function/aggregate/Sum.java | 16 +-
.../function/aggregate/ToPartial.java | 34 +-
.../expression/function/aggregate/Top.java | 41 +-
.../expression/function/aggregate/Values.java | 17 +-
.../function/aggregate/WeightedAvg.java | 36 +-
.../esql/optimizer/LogicalPlanOptimizer.java | 4 +
.../optimizer/rules/logical/FoldNull.java | 11 +
.../ReplaceStatsAggExpressionWithEval.java | 14 +-
.../logical/SubstituteFilteredExpression.java | 27 +
.../xpack/esql/parser/EsqlBaseLexer.interp | 3 +-
.../xpack/esql/parser/EsqlBaseLexer.java | 1991 +++++++++--------
.../xpack/esql/parser/EsqlBaseParser.interp | 4 +-
.../xpack/esql/parser/EsqlBaseParser.java | 1950 ++++++++--------
.../parser/EsqlBaseParserBaseListener.java | 24 +
.../parser/EsqlBaseParserBaseVisitor.java | 14 +
.../esql/parser/EsqlBaseParserListener.java | 20 +
.../esql/parser/EsqlBaseParserVisitor.java | 12 +
.../xpack/esql/parser/ExpressionBuilder.java | 37 +-
.../xpack/esql/parser/LogicalPlanBuilder.java | 13 +-
.../xpack/esql/plan/logical/Aggregate.java | 1 +
.../AbstractPhysicalOperationProviders.java | 42 +-
.../xpack/esql/planner/AggregateMapper.java | 16 +-
.../elasticsearch/xpack/esql/CsvTests.java | 4 -
.../xpack/esql/analysis/AnalyzerTests.java | 4 +-
.../xpack/esql/analysis/VerifierTests.java | 38 +
.../aggregate/RateSerializationTests.java | 5 +
.../aggregate/TopSerializationTests.java | 5 +
.../optimizer/LogicalPlanOptimizerTests.java | 18 +
.../xpack/esql/parser/ExpressionTests.java | 4 +-
.../esql/parser/StatementParserTests.java | 59 +
.../esql/tree/EsqlNodeSubclassTests.java | 15 +
57 files changed, 3181 insertions(+), 2113 deletions(-)
create mode 100644 docs/changelog/113735.yaml
delete mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/EnclosedAgg.java
create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FilteredExpression.java
create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteFilteredExpression.java
diff --git a/docs/changelog/113735.yaml b/docs/changelog/113735.yaml
new file mode 100644
index 0000000000000..4f6579c7cb9e0
--- /dev/null
+++ b/docs/changelog/113735.yaml
@@ -0,0 +1,28 @@
+pr: 113735
+summary: "ESQL: Introduce per agg filter"
+area: ES|QL
+type: feature
+issues: []
+highlight:
+ title: "ESQL: Introduce per agg filter"
+ body: |-
+ Add support for aggregation scoped filters that work dynamically on the
+ data in each group.
+
+ [source,esql]
+ ----
+ | STATS success = COUNT(*) WHERE 200 <= code AND code < 300,
+ redirect = COUNT(*) WHERE 300 <= code AND code < 400,
+ client_err = COUNT(*) WHERE 400 <= code AND code < 500,
+ server_err = COUNT(*) WHERE 500 <= code AND code < 600,
+ total_count = COUNT(*)
+ ----
+
+ Implementation wise, the base AggregateFunction has been extended to
+ allow a filter to be passed on. This is required to incorporate the
+ filter as part of the aggregate equality/identity which would fail with
+ the filter as an external component.
+ As part of the process, the serialization for the existing aggregations
+ had to be fixed so AggregateFunction implementations so that it
+ delegates to their parent first.
+ notable: true
diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index ab4321edd3f71..3cb4695e867df 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -243,6 +243,7 @@ static TransportVersion def(int id) {
public static final TransportVersion CHUNK_SENTENCE_OVERLAP_SETTING_ADDED = def(8_767_00_0);
public static final TransportVersion OPT_IN_ESQL_CCS_EXECUTION_INFO = def(8_768_00_0);
public static final TransportVersion QUERY_RULE_TEST_API = def(8_769_00_0);
+ public static final TransportVersion ESQL_PER_AGGREGATE_FILTER = def(8_770_00_0);
/*
* STOP! READ THIS FIRST! No, really,
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java
index 48b5fd1605edf..8bfcf4ca5c405 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java
@@ -79,4 +79,19 @@ public static int mapSize(int size) {
}
return (int) (size / 0.75f + 1f);
}
+
+ @SafeVarargs
+ @SuppressWarnings("varargs")
+ public static List nullSafeList(T... entries) {
+ if (entries == null || entries.length == 0) {
+ return emptyList();
+ }
+ List list = new ArrayList<>(entries.length);
+ for (T entry : entries) {
+ if (entry != null) {
+ list.add(entry);
+ }
+ }
+ return list;
+ }
}
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
index 8a2e9b402fbca..496a747fd9c2b 100644
--- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
@@ -2290,3 +2290,186 @@ from employees
m:integer |a:double |x:integer
74999 |48249.0 |0
;
+
+
+statsWithFiltering
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary), max_f = max(salary) where salary < 50000, max_a = max(salary) where salary > 100,
+ min = min(salary), min_f = min(salary) where salary > 50000, min_a = min(salary) where salary > 100
+;
+
+max:integer |max_f:integer |max_a:integer | min:integer | min_f:integer | min_a:integer
+74999 |49818 |74999 | 25324 | 50064 | 25324
+;
+
+statsWithEverythingFiltered
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary), max_a = max(salary) where salary < 100,
+ min = min(salary), min_a = min(salary) where salary > 99999
+;
+
+max:integer |max_a:integer|min:integer | min_a:integer
+74999 |null |25324 | null
+;
+
+statsWithNullFilter
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary), max_a = max(salary) where null,
+ min = min(salary), min_a = min(salary) where to_string(null) == "abc"
+;
+
+max:integer |max_a:integer|min:integer | min_a:integer
+74999 |null |25324 | null
+;
+
+statsWithBasicExpressionFiltered
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary), max_f = max(salary) where salary < 50000,
+ min = min(salary), min_f = min(salary) where salary > 50000,
+ exp_p = max(salary) + 10000 where salary < 50000,
+ exp_m = min(salary) % 10000 where salary > 50000
+;
+
+max:integer |max_f:integer|min:integer | min_f:integer|exp_p:integer | exp_m:integer
+74999 |49818 |25324 | 50064 |59818 | 64
+;
+
+statsWithExpressionOverFilters
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary), max_f = max(salary) where salary < 50000,
+ min = min(salary), min_f = min(salary) where salary > 50000,
+ exp_gt = max(salary) - min(salary) where salary > 50000,
+ exp_lt = max(salary) - min(salary) where salary < 50000
+
+;
+
+max:integer |max_f:integer | min:integer | min_f:integer |exp_gt:integer | exp_lt:integer
+74999 |49818 | 25324 | 50064 |24935 | 24494
+;
+
+
+statsWithExpressionOfExpressionsOverFilters
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary + 1), max_f = max(salary + 2) where salary < 50000,
+ min = min(salary - 1), min_f = min(salary - 2) where salary > 50000,
+ exp_gt = max(salary + 3) - min(salary - 3) where salary > 50000,
+ exp_lt = max(salary + 4) - min(salary - 4) where salary < 50000
+
+;
+
+max:integer |max_f:integer | min:integer | min_f:integer |exp_gt:integer | exp_lt:integer
+75000 |49820 | 25323 | 50062 |24941 | 24502
+;
+
+statsWithSubstitutedExpressionOverFilters
+required_capability: per_agg_filtering
+from employees
+| stats sum = sum(salary), s_l = sum(salary) where salary < 50000, s_u = sum(salary) where salary > 50000,
+ count = count(salary), c_l = count(salary) where salary < 50000, c_u = count(salary) where salary > 50000,
+ avg = round(avg(salary), 2), a_l = round(avg(salary), 2) where salary < 50000, a_u = round(avg(salary),2) where salary > 50000
+;
+
+sum:l |s_l:l | s_u:l | count:l |c_l:l |c_u:l |avg:double |a_l:double | a_u:double
+4824855 |2220951 | 2603904 | 100 |58 |42 |48248.55 |38292.26 | 61997.71
+;
+
+
+statsWithFilterAndGroupBy
+required_capability: per_agg_filtering
+from employees
+| stats m = max(height),
+ m_f = max(height + 1) where gender == "M" OR is_rehired is null
+ BY gender, is_rehired
+| sort gender, is_rehired
+;
+
+m:d |m_f:d |gender:s|is_rehired:bool
+2.1 |null |F |false
+2.1 |null |F |true
+1.85|2.85 |F |null
+2.1 |3.1 |M |false
+2.1 |3.1 |M |true
+2.01|3.01 |M |null
+2.06|null |null |false
+1.97|null |null |true
+1.99|2.99 |null |null
+;
+
+statsWithFilterOnGroupBy
+required_capability: per_agg_filtering
+from employees
+| stats m_f = max(height) where gender == "M" BY gender
+| sort gender
+;
+
+m_f:d |gender:s
+null |F
+2.1 |M
+null |null
+;
+
+statsWithGroupByLiteral
+required_capability: per_agg_filtering
+from employees
+| stats m = max(languages) by salary = 2
+;
+
+m:i |salary:i
+5 |2
+;
+
+
+statsWithFilterOnSameColumn
+required_capability: per_agg_filtering
+from employees
+| stats m = max(languages), m_f = max(languages) where salary > 50000 by salary = 2
+| sort salary
+;
+
+m:i |m_f:i |salary:i
+5 |null |2
+;
+
+# the query is reused below in a multi-stats
+statsWithFilteringAndGrouping
+required_capability: per_agg_filtering
+from employees
+| stats c = count(), c_f = count(languages) where l > 1,
+ m_f = max(height) where salary > 50000
+ by l = languages
+| sort c
+;
+
+c:l |c_f:l |m_f:d |l:i
+10 |0 |2.08 |null
+15 |0 |2.06 |1
+17 |17 |2.1 |3
+18 |18 |1.83 |4
+19 |19 |2.03 |2
+21 |21 |2.1 |5
+;
+
+multiStatsWithFiltering
+required_capability: per_agg_filtering
+from employees
+| stats c = count(), c_f = count(languages) where l > 1,
+ m_f = max(height) where salary > 50000
+ by l = languages
+| stats c2 = count(), c2_f = count() where m_f > 2.06 , m2 = max(l), m2_f = max(l) where l > 1 by c
+| sort c
+;
+
+c2:l |c2_f:l |m2:i |m2_f:i |c:l
+1 |1 |null |null |10
+1 |0 |1 |null |15
+1 |1 |3 |3 |17
+1 |0 |4 |4 |18
+1 |0 |2 |2 |19
+1 |1 |5 |5 |21
+;
diff --git a/x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4 b/x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4
index d6d45097a1d07..b13606befd2a4 100644
--- a/x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4
+++ b/x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4
@@ -209,6 +209,7 @@ SLASH : '/';
PERCENT : '%';
MATCH : 'match';
+NESTED_WHERE : {this.isDevVersion()}? WHERE -> type(WHERE);
NAMED_OR_POSITIONAL_PARAM
: PARAM (LETTER | UNDERSCORE) UNQUOTED_ID_BODY*
diff --git a/x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4 b/x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
index 77568d5527cd1..9a95e0e6726ba 100644
--- a/x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
+++ b/x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
@@ -123,8 +123,7 @@ fields
;
field
- : booleanExpression
- | qualifiedName ASSIGN booleanExpression
+ : (qualifiedName ASSIGN)? booleanExpression
;
fromCommand
@@ -132,8 +131,7 @@ fromCommand
;
indexPattern
- : clusterString COLON indexString
- | indexString
+ : (clusterString COLON)? indexString
;
clusterString
@@ -159,7 +157,7 @@ deprecated_metadata
;
metricsCommand
- : DEV_METRICS indexPattern (COMMA indexPattern)* aggregates=fields? (BY grouping=fields)?
+ : DEV_METRICS indexPattern (COMMA indexPattern)* aggregates=aggFields? (BY grouping=fields)?
;
evalCommand
@@ -167,7 +165,15 @@ evalCommand
;
statsCommand
- : STATS stats=fields? (BY grouping=fields)?
+ : STATS stats=aggFields? (BY grouping=fields)?
+ ;
+
+aggFields
+ : aggField (COMMA aggField)*
+ ;
+
+aggField
+ : field {this.isDevVersion()}? (WHERE booleanExpression)?
;
qualifiedName
@@ -316,5 +322,5 @@ lookupCommand
;
inlinestatsCommand
- : DEV_INLINESTATS stats=fields (BY grouping=fields)?
+ : DEV_INLINESTATS stats=aggFields (BY grouping=fields)?
;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
index 9dc17b020e426..f5baaef4f579d 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
@@ -370,7 +370,12 @@ public enum Cap {
/**
* Fix sorting not allowed on _source and counters.
*/
- SORTING_ON_SOURCE_AND_COUNTERS_FORBIDDEN;
+ SORTING_ON_SOURCE_AND_COUNTERS_FORBIDDEN,
+
+ /**
+ * Allow filter per individual aggregation.
+ */
+ PER_AGG_FILTERING;
private final boolean snapshotOnly;
private final FeatureFlag featureFlag;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
index 90957f55141b9..fe7b945a9b3c1 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
@@ -488,6 +488,7 @@ private LogicalPlan resolveStats(Stats stats, List childrenOutput) {
newAggregates.add(agg);
}
+ // TODO: remove this when Stats interface is removed
stats = changed.get() ? stats.with(stats.child(), groupings, newAggregates) : stats;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java
index dd2b72b4d35d9..ef39220d7ffcc 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java
@@ -30,6 +30,7 @@
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
@@ -308,6 +309,29 @@ private static void checkInvalidNamedExpressionUsage(
Set failures,
int level
) {
+ // unwrap filtered expression
+ if (e instanceof FilteredExpression fe) {
+ e = fe.delegate();
+ // make sure they work on aggregate functions
+ if (e.anyMatch(AggregateFunction.class::isInstance) == false) {
+ Expression filter = fe.filter();
+ failures.add(fail(filter, "WHERE clause allowed only for aggregate functions, none found in [{}]", fe.sourceText()));
+ }
+ // but that the filter doesn't use grouping or aggregate functions
+ fe.filter().forEachDown(c -> {
+ if (c instanceof AggregateFunction af) {
+ failures.add(
+ fail(af, "cannot use aggregate function [{}] in aggregate WHERE clause [{}]", af.sourceText(), fe.sourceText())
+ );
+ }
+ // check the bucketing function against the group
+ else if (c instanceof GroupingFunction gf) {
+ if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
+ failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText()));
+ }
+ }
+ });
+ }
// found an aggregate, constant or a group, bail out
if (e instanceof AggregateFunction af) {
af.field().forEachDown(AggregateFunction.class, f -> {
@@ -319,7 +343,7 @@ private static void checkInvalidNamedExpressionUsage(
} else if (e instanceof GroupingFunction gf) {
// optimizer will later unroll expressions with aggs and non-aggs with a grouping function into an EVAL, but that will no longer
// be verified (by check above in checkAggregate()), so do it explicitly here
- if (groups.stream().anyMatch(ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
+ if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText()));
} else if (level == 0) {
addFailureOnGroupingUsedNakedInAggs(failures, gf, "function");
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
index faf99d6bd65bc..66151275fc2e8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
@@ -259,19 +259,21 @@ private FunctionDefinition[][] functions() {
// grouping functions
new FunctionDefinition[] { def(Bucket.class, Bucket::new, "bucket", "bin"), },
// aggregate functions
+ // since they declare two public constructors - one with filter (for nested where) and one without
+ // use casting to disambiguate between the two
new FunctionDefinition[] {
- def(Avg.class, Avg::new, "avg"),
- def(Count.class, Count::new, "count"),
- def(CountDistinct.class, CountDistinct::new, "count_distinct"),
- def(Max.class, Max::new, "max"),
- def(Median.class, Median::new, "median"),
- def(MedianAbsoluteDeviation.class, MedianAbsoluteDeviation::new, "median_absolute_deviation"),
- def(Min.class, Min::new, "min"),
- def(Percentile.class, Percentile::new, "percentile"),
- def(Sum.class, Sum::new, "sum"),
- def(Top.class, Top::new, "top"),
- def(Values.class, Values::new, "values"),
- def(WeightedAvg.class, WeightedAvg::new, "weighted_avg") },
+ def(Avg.class, uni(Avg::new), "avg"),
+ def(Count.class, uni(Count::new), "count"),
+ def(CountDistinct.class, bi(CountDistinct::new), "count_distinct"),
+ def(Max.class, uni(Max::new), "max"),
+ def(Median.class, uni(Median::new), "median"),
+ def(MedianAbsoluteDeviation.class, uni(MedianAbsoluteDeviation::new), "median_absolute_deviation"),
+ def(Min.class, uni(Min::new), "min"),
+ def(Percentile.class, bi(Percentile::new), "percentile"),
+ def(Sum.class, uni(Sum::new), "sum"),
+ def(Top.class, tri(Top::new), "top"),
+ def(Values.class, uni(Values::new), "values"),
+ def(WeightedAvg.class, bi(WeightedAvg::new), "weighted_avg") },
// math
new FunctionDefinition[] {
def(Abs.class, Abs::new, "abs"),
@@ -482,11 +484,10 @@ public static DataType getTargetType(String[] names) {
}
public static FunctionDescription description(FunctionDefinition def) {
- var constructors = def.clazz().getConstructors();
- if (constructors.length == 0) {
+ Constructor> constructor = constructorFor(def.clazz());
+ if (constructor == null) {
return new FunctionDescription(def.name(), List.of(), null, null, false, false);
}
- Constructor> constructor = constructors[0];
FunctionInfo functionInfo = functionInfo(def);
String functionDescription = functionInfo == null ? "" : functionInfo.description().replace('\n', ' ');
String[] returnType = functionInfo == null ? new String[] { "?" } : removeUnderConstruction(functionInfo.returnType());
@@ -523,14 +524,29 @@ private static String[] removeUnderConstruction(String[] types) {
}
public static FunctionInfo functionInfo(FunctionDefinition def) {
- var constructors = def.clazz().getConstructors();
- if (constructors.length == 0) {
+ Constructor> constructor = constructorFor(def.clazz());
+ if (constructor == null) {
return null;
}
- Constructor> constructor = constructors[0];
return constructor.getAnnotation(FunctionInfo.class);
}
+ private static Constructor> constructorFor(Class extends Function> clazz) {
+ Constructor>[] constructors = clazz.getConstructors();
+ if (constructors.length == 0) {
+ return null;
+ }
+ // when dealing with multiple, pick the constructor exposing the FunctionInfo annotation
+ if (constructors.length > 1) {
+ for (Constructor> constructor : constructors) {
+ if (constructor.getAnnotation(FunctionInfo.class) != null) {
+ return constructor;
+ }
+ }
+ }
+ return constructors[0];
+ }
+
private void buildDataTypesForStringLiteralConversion(FunctionDefinition[]... groupFunctions) {
for (FunctionDefinition[] group : groupFunctions) {
for (FunctionDefinition def : group) {
@@ -913,15 +929,19 @@ protected interface TernaryConfigurationAwareBuilder {
}
//
- // Utility method for extra argument extraction.
+ // Utility functions to help disambiguate the method handle passed in.
+ // They work by providing additional method information to help the compiler know which method to pick.
//
- protected static Boolean asBool(Object[] extras) {
- if (CollectionUtils.isEmpty(extras)) {
- return null;
- }
- if (extras.length != 1 || (extras[0] instanceof Boolean) == false) {
- throw new QlIllegalArgumentException("Invalid number and types of arguments given to function definition");
- }
- return (Boolean) extras[0];
+ private static BiFunction
*/
@Override public T visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx) { return visitChildren(ctx); }
+ /**
+ * {@inheritDoc}
+ *
+ * The default implementation returns the result of calling
+ * {@link #visitChildren} on {@code ctx}.
+ */
+ @Override public T visitAggFields(EsqlBaseParser.AggFieldsContext ctx) { return visitChildren(ctx); }
+ /**
+ * {@inheritDoc}
+ *
+ * The default implementation returns the result of calling
+ * {@link #visitChildren} on {@code ctx}.
+ */
+ @Override public T visitAggField(EsqlBaseParser.AggFieldContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
index c6dcaca736e1f..cf658c4a73141 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
@@ -465,6 +465,26 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
* @param ctx the parse tree
*/
void exitStatsCommand(EsqlBaseParser.StatsCommandContext ctx);
+ /**
+ * Enter a parse tree produced by {@link EsqlBaseParser#aggFields}.
+ * @param ctx the parse tree
+ */
+ void enterAggFields(EsqlBaseParser.AggFieldsContext ctx);
+ /**
+ * Exit a parse tree produced by {@link EsqlBaseParser#aggFields}.
+ * @param ctx the parse tree
+ */
+ void exitAggFields(EsqlBaseParser.AggFieldsContext ctx);
+ /**
+ * Enter a parse tree produced by {@link EsqlBaseParser#aggField}.
+ * @param ctx the parse tree
+ */
+ void enterAggField(EsqlBaseParser.AggFieldContext ctx);
+ /**
+ * Exit a parse tree produced by {@link EsqlBaseParser#aggField}.
+ * @param ctx the parse tree
+ */
+ void exitAggField(EsqlBaseParser.AggFieldContext ctx);
/**
* Enter a parse tree produced by {@link EsqlBaseParser#qualifiedName}.
* @param ctx the parse tree
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
index 310d3dc76dd6d..86c1d1aafc33a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
@@ -284,6 +284,18 @@ public interface EsqlBaseParserVisitor extends ParseTreeVisitor {
* @return the visitor result
*/
T visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx);
+ /**
+ * Visit a parse tree produced by {@link EsqlBaseParser#aggFields}.
+ * @param ctx the parse tree
+ * @return the visitor result
+ */
+ T visitAggFields(EsqlBaseParser.AggFieldsContext ctx);
+ /**
+ * Visit a parse tree produced by {@link EsqlBaseParser#aggField}.
+ * @param ctx the parse tree
+ * @return the visitor result
+ */
+ T visitAggField(EsqlBaseParser.AggFieldContext ctx);
/**
* Visit a parse tree produced by {@link EsqlBaseParser#qualifiedName}.
* @param ctx the parse tree
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
index bcbd28aced939..7ff09c23a1403 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
@@ -26,6 +26,7 @@
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedStar;
+import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.expression.predicate.fulltext.MatchQueryPredicate;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not;
@@ -44,6 +45,7 @@
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.expression.function.FunctionResolutionStrategy;
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
@@ -742,9 +744,12 @@ private NamedExpression enrichFieldName(EsqlBaseParser.QualifiedNamePatternConte
@Override
public Alias visitField(EsqlBaseParser.FieldContext ctx) {
+ return visitField(ctx, source(ctx));
+ }
+
+ private Alias visitField(EsqlBaseParser.FieldContext ctx, Source source) {
UnresolvedAttribute id = visitQualifiedName(ctx.qualifiedName());
Expression value = expression(ctx.booleanExpression());
- var source = source(ctx);
String name = id == null ? source.text() : id.name();
return new Alias(source, name, value);
}
@@ -754,6 +759,36 @@ public List visitFields(EsqlBaseParser.FieldsContext ctx) {
return ctx != null ? visitList(this, ctx.field(), Alias.class) : new ArrayList<>();
}
+ @Override
+ public NamedExpression visitAggField(EsqlBaseParser.AggFieldContext ctx) {
+ Source source = source(ctx);
+ Alias field = visitField(ctx.field(), source);
+ var filterExpression = ctx.booleanExpression();
+
+ if (filterExpression != null) {
+ Expression condition = expression(filterExpression);
+ Expression child = field.child();
+ // basic check as the filter can be specified only on a function (should be an aggregate but we can't determine that yet)
+ if (field.child().anyMatch(Function.class::isInstance)) {
+ field = field.replaceChild(new FilteredExpression(field.source(), child, condition));
+ }
+ // allow condition only per aggregated function
+ else {
+ throw new ParsingException(
+ condition.source(),
+ "WHERE clause allowed only for aggregate functions [{}]",
+ field.sourceText()
+ );
+ }
+ }
+ return field;
+ }
+
+ @Override
+ public List visitAggFields(EsqlBaseParser.AggFieldsContext ctx) {
+ return ctx != null ? visitList(this, ctx.aggField(), Alias.class) : new ArrayList<>();
+ }
+
/**
* Similar to {@link #visitFields(EsqlBaseParser.FieldsContext)} however avoids wrapping the expression
* into an Alias.
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
index c90c3cba4ef24..dc913cd2f14f4 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
@@ -298,13 +298,12 @@ public PlanFactory visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx) {
return input -> new Aggregate(source(ctx), input, Aggregate.AggregateType.STANDARD, stats.groupings, stats.aggregates);
}
- private record Stats(List groupings, List extends NamedExpression> aggregates) {
+ private record Stats(List groupings, List extends NamedExpression> aggregates) {}
- }
-
- private Stats stats(Source source, EsqlBaseParser.FieldsContext groupingsCtx, EsqlBaseParser.FieldsContext aggregatesCtx) {
+ private Stats stats(Source source, EsqlBaseParser.FieldsContext groupingsCtx, EsqlBaseParser.AggFieldsContext aggregatesCtx) {
List groupings = visitGrouping(groupingsCtx);
- List aggregates = new ArrayList<>(visitFields(aggregatesCtx));
+ List aggregates = new ArrayList<>(visitAggFields(aggregatesCtx));
+
if (aggregates.isEmpty() && groupings.isEmpty()) {
throw new ParsingException(source, "At least one aggregation or grouping expression required in [{}]", source.text());
}
@@ -341,9 +340,11 @@ public PlanFactory visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandCont
if (false == EsqlPlugin.INLINESTATS_FEATURE_FLAG.isEnabled()) {
throw new ParsingException(source(ctx), "INLINESTATS command currently requires a snapshot build");
}
- List aggregates = new ArrayList<>(visitFields(ctx.stats));
+ List aggFields = visitAggFields(ctx.stats);
+ List aggregates = new ArrayList<>(aggFields);
List groupings = visitGrouping(ctx.grouping);
aggregates.addAll(groupings);
+ // TODO: add support for filters
return input -> new InlineStats(source(ctx), input, new ArrayList<>(groupings), aggregates);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java
index 8445c8236c45a..3b7240dcd693b 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java
@@ -59,6 +59,7 @@ static AggregateType readType(StreamInput in) throws IOException {
private final AggregateType aggregateType;
private final List groupings;
private final List extends NamedExpression> aggregates;
+
private List lazyOutput;
public Aggregate(
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
index 0e71963e29270..94a9246a56f83 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
@@ -10,10 +10,12 @@
import org.elasticsearch.compute.aggregation.Aggregator;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.AggregatorMode;
+import org.elasticsearch.compute.aggregation.FilteredAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.operator.AggregationOperator;
+import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.HashAggregationOperator.HashAggregationOperatorFactory;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
@@ -24,6 +26,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
+import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
@@ -231,11 +234,14 @@ private void aggregatesToFactory(
boolean grouping,
Consumer consumer
) {
+ // extract filtering channels - and wrap the aggregation with the new evaluator expression only during the init phase
for (NamedExpression ne : aggregates) {
+ // a filter can only appear on aggregate function, not on the grouping columns
+
if (ne instanceof Alias alias) {
var child = alias.child();
if (child instanceof AggregateFunction aggregateFunction) {
- List extends NamedExpression> sourceAttr;
+ List sourceAttr = new ArrayList<>();
if (mode == AggregatorMode.INITIAL) {
// TODO: this needs to be made more reliable - use casting to blow up when dealing with expressions (e+1)
@@ -251,19 +257,22 @@ private void aggregatesToFactory(
);
}
} else {
- sourceAttr = aggregateFunction.inputExpressions().stream().map(e -> {
- Attribute attr = Expressions.attribute(e);
+ // extra dependencies like TS ones (that require a timestamp)
+ for (Expression input : aggregateFunction.references()) {
+ Attribute attr = Expressions.attribute(input);
if (attr == null) {
throw new EsqlIllegalArgumentException(
"Cannot work with target field [{}] for agg [{}]",
- e.sourceText(),
+ input.sourceText(),
aggregateFunction.sourceText()
);
}
- return attr;
- }).toList();
+ sourceAttr.add(attr);
+ }
}
- } else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) {
+ }
+ // coordinator/exchange phase
+ else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) {
if (grouping) {
sourceAttr = aggregateMapper.mapGrouping(aggregateFunction);
} else {
@@ -274,16 +283,27 @@ private void aggregatesToFactory(
}
List inputChannels = sourceAttr.stream().map(attr -> layout.get(attr.id()).channel()).toList();
assert inputChannels.stream().allMatch(i -> i >= 0) : inputChannels;
- if (aggregateFunction instanceof ToAggregator agg) {
- consumer.accept(new AggFunctionSupplierContext(agg.supplier(inputChannels), mode));
- } else {
- throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
+
+ AggregatorFunctionSupplier aggSupplier = supplier(aggregateFunction, inputChannels);
+
+ // apply the filter only in the initial phase - as the rest of the data is already filtered
+ if (aggregateFunction.hasFilter() && mode.isInputPartial() == false) {
+ EvalOperator.ExpressionEvaluator.Factory evalFactory = EvalMapper.toEvaluator(aggregateFunction.filter(), layout);
+ aggSupplier = new FilteredAggregatorFunctionSupplier(aggSupplier, evalFactory);
}
+ consumer.accept(new AggFunctionSupplierContext(aggSupplier, mode));
}
}
}
}
+ private static AggregatorFunctionSupplier supplier(AggregateFunction aggregateFunction, List inputChannels) {
+ if (aggregateFunction instanceof ToAggregator delegate) {
+ return delegate.supplier(inputChannels);
+ }
+ throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
+ }
+
private record GroupSpec(Integer channel, Attribute attribute) {
BlockHash.GroupSpec toHashGroupSpec() {
if (channel == null) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java
index 13ce9ba77cc71..c322135198262 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java
@@ -98,39 +98,39 @@ private record AggDef(Class> aggClazz, String type, String extra, boolean grou
.collect(Collectors.toUnmodifiableMap(aggDef -> aggDef, AggregateMapper::lookupIntermediateState));
/** Cache of aggregates to intermediate expressions. */
- private final HashMap> cache;
+ private final HashMap> cache;
AggregateMapper() {
cache = new HashMap<>();
}
- public List extends NamedExpression> mapNonGrouping(List extends Expression> aggregates) {
+ public List mapNonGrouping(List extends Expression> aggregates) {
return doMapping(aggregates, false);
}
- public List extends NamedExpression> mapNonGrouping(Expression aggregate) {
+ public List mapNonGrouping(Expression aggregate) {
return map(aggregate, false).toList();
}
- public List extends NamedExpression> mapGrouping(List extends Expression> aggregates) {
+ public List mapGrouping(List extends Expression> aggregates) {
return doMapping(aggregates, true);
}
- private List extends NamedExpression> doMapping(List extends Expression> aggregates, boolean grouping) {
+ private List doMapping(List extends Expression> aggregates, boolean grouping) {
AttributeMap attrToExpressions = new AttributeMap<>();
aggregates.stream().flatMap(agg -> map(agg, grouping)).forEach(ne -> attrToExpressions.put(ne.toAttribute(), ne));
return attrToExpressions.values().stream().toList();
}
- public List extends NamedExpression> mapGrouping(Expression aggregate) {
+ public List mapGrouping(Expression aggregate) {
return map(aggregate, true).toList();
}
- private Stream extends NamedExpression> map(Expression aggregate, boolean grouping) {
+ private Stream map(Expression aggregate, boolean grouping) {
return cache.computeIfAbsent(Alias.unwrap(aggregate), aggKey -> computeEntryForAgg(aggKey, grouping)).stream();
}
- private static List extends NamedExpression> computeEntryForAgg(Expression aggregate, boolean grouping) {
+ private static List computeEntryForAgg(Expression aggregate, boolean grouping) {
var aggDef = aggDefOrNull(aggregate, grouping);
if (aggDef != null) {
var is = getNonNull(aggDef);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
index f881c0e1a9bba..ce072e7b0a438 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
@@ -67,10 +67,7 @@
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
-import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext;
-import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.TestLocalPhysicalPlanOptimizer;
-import org.elasticsearch.xpack.esql.optimizer.TestPhysicalPlanOptimizer;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
@@ -167,7 +164,6 @@ public class CsvTests extends ESTestCase {
private final EsqlFunctionRegistry functionRegistry = new EsqlFunctionRegistry();
private final EsqlParser parser = new EsqlParser();
private final Mapper mapper = new Mapper(functionRegistry);
- private final PhysicalPlanOptimizer physicalPlanOptimizer = new TestPhysicalPlanOptimizer(new PhysicalOptimizerContext(configuration));
private ThreadPool threadPool;
private Executor executor;
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
index 6644f9b17055e..d365ee3bb2e51 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
@@ -1219,7 +1219,7 @@ public void testAggsOverGroupingKey() throws Exception {
assertThat(output, hasSize(2));
var aggs = agg.aggregates();
var min = as(Alias.unwrap(aggs.get(0)), Min.class);
- assertThat(min.arguments(), hasSize(1));
+ assertThat(min.arguments(), hasSize(2)); // field + filter
var group = Alias.unwrap(agg.groupings().get(0));
assertEquals(min.arguments().get(0), group);
}
@@ -1241,7 +1241,7 @@ public void testAggsOverGroupingKeyWithAlias() throws Exception {
assertThat(output, hasSize(2));
var aggs = agg.aggregates();
var min = as(Alias.unwrap(aggs.get(0)), Min.class);
- assertThat(min.arguments(), hasSize(1));
+ assertThat(min.arguments(), hasSize(2)); // field + filter
assertEquals(Expressions.attribute(min.arguments().get(0)), Expressions.attribute(agg.groupings().get(0)));
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
index ecf012718eaf8..63f7629f3c720 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
@@ -360,6 +360,40 @@ public void testAggsInsideGrouping() {
);
}
+ public void testAggFilterOnNonAggregates() {
+ assertEquals(
+ "1:36: WHERE clause allowed only for aggregate functions, none found in [emp_no + 1 where languages > 1]",
+ error("from test | stats emp_no + 1 where languages > 1 by emp_no")
+ );
+ assertEquals(
+ "1:53: WHERE clause allowed only for aggregate functions, none found in [abs(emp_no + languages) % 2 WHERE languages > 1]",
+ error("from test | stats abs(emp_no + languages) % 2 WHERE languages > 1 by emp_no, languages")
+ );
+ }
+
+ public void testAggFilterOnBucketingOrAggFunctions() {
+ // query passes when the bucket function is part of the BY clause
+ query("from test | stats max(languages) WHERE bucket(salary, 10) > 1 by bucket(salary, 10)");
+
+ // but fails if it's different
+ assertEquals(
+ "1:40: can only use grouping function [bucket(salary, 10)] part of the BY clause",
+ error("from test | stats max(languages) WHERE bucket(salary, 10) > 1 by emp_no")
+ );
+
+ assertEquals(
+ "1:40: cannot use aggregate function [max(salary)] in aggregate WHERE clause [max(languages) WHERE max(salary) > 1]",
+ error("from test | stats max(languages) WHERE max(salary) > 1 by emp_no")
+ );
+
+ assertEquals(
+ "1:40: cannot use aggregate function [max(salary)] in aggregate WHERE clause [max(languages) WHERE max(salary) + 2 > 1]",
+ error("from test | stats max(languages) WHERE max(salary) + 2 > 1 by emp_no")
+ );
+
+ assertEquals("1:60: Unknown column [m]", error("from test | stats m = max(languages), min(languages) WHERE m + 2 > 1 by emp_no"));
+ }
+
public void testGroupingInsideAggsAsAgg() {
assertEquals(
"1:18: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause",
@@ -1507,6 +1541,10 @@ public void testToDatePeriodToTimeDurationWithInvalidType() {
);
}
+ private void query(String query) {
+ defaultAnalyzer.analyze(parser.createStatement(query));
+ }
+
private String error(String query) {
return error(query, defaultAnalyzer);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java
index 94b2a81b308d7..ea7c480817317 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java
@@ -36,4 +36,9 @@ protected Rate mutateInstance(Rate instance) throws IOException {
}
return new Rate(source, field, timestamp, unit);
}
+
+ @Override
+ protected boolean alwaysEmptySource() {
+ return true;
+ }
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java
index 82bf57d1a194e..e74b26c87c84f 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java
@@ -36,4 +36,9 @@ protected Top mutateInstance(Top instance) throws IOException {
}
return new Top(source, field, limit, order);
}
+
+ @Override
+ protected boolean alwaysEmptySource() {
+ return true;
+ }
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
index c05b5dd165485..8d7c1997f78e3 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
@@ -537,6 +537,24 @@ public void testCombineProjectionWithDuplicateAggregation() {
assertThat(Expressions.names(agg.groupings()), contains("last_name", "first_name"));
}
+ /**
+ * Limit[1000[INTEGER]]
+ * \_Aggregate[STANDARD,[],[SUM(salary{f}#12,true[BOOLEAN]) AS sum(salary), SUM(salary{f}#12,last_name{f}#11 == [44 6f 65][KEYW
+ * ORD]) AS sum(salary) WheRe last_name == "Doe"]]
+ * \_EsRelation[test][_meta_field{f}#13, emp_no{f}#7, first_name{f}#8, ge..]
+ */
+ public void testStatsWithFilteringDefaultAliasing() {
+ var plan = plan("""
+ from test
+ | stats sum(salary), sum(salary) WheRe last_name == "Doe"
+ """);
+
+ var limit = as(plan, Limit.class);
+ var agg = as(limit.child(), Aggregate.class);
+ assertThat(agg.aggregates(), hasSize(2));
+ assertThat(Expressions.names(agg.aggregates()), contains("sum(salary)", "sum(salary) WheRe last_name == \"Doe\""));
+ }
+
public void testQlComparisonOptimizationsApply() {
var plan = plan("""
from test
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java
index 80a2d49d0d94a..67b4dd71260aa 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java
@@ -208,7 +208,7 @@ public void testParenthesizedExpression() {
}
public void testCommandNamesAsIdentifiers() {
- Expression expr = whereExpression("from and where");
+ Expression expr = whereExpression("from and limit");
assertThat(expr, instanceOf(And.class));
And and = (And) expr;
@@ -216,7 +216,7 @@ public void testCommandNamesAsIdentifiers() {
assertThat(((UnresolvedAttribute) and.left()).name(), equalTo("from"));
assertThat(and.right(), instanceOf(UnresolvedAttribute.class));
- assertThat(((UnresolvedAttribute) and.right()).name(), equalTo("where"));
+ assertThat(((UnresolvedAttribute) and.right()).name(), equalTo("limit"));
}
public void testIdentifiersCaseSensitive() {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
index 53621a79aedac..c797f426d2ae5 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
@@ -20,14 +20,18 @@
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not;
+import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.Order;
import org.elasticsearch.xpack.esql.expression.UnresolvedNamePattern;
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
+import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
+import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
@@ -321,6 +325,61 @@ public void testAggsWithGroupKeyAsAgg() throws Exception {
}
}
+ public void testStatsWithGroupKeyAndAggFilter() throws Exception {
+ var a = attribute("a");
+ var f = new UnresolvedFunction(EMPTY, "min", DEFAULT, List.of(a));
+ var filter = new Alias(EMPTY, "min(a) where a > 1", new FilteredExpression(EMPTY, f, new GreaterThan(EMPTY, a, integer(1))));
+ assertEquals(
+ new Aggregate(EMPTY, PROCESSING_CMD_INPUT, Aggregate.AggregateType.STANDARD, List.of(a), List.of(filter, a)),
+ processingCommand("stats min(a) where a > 1 by a")
+ );
+ }
+
+ public void testStatsWithGroupKeyAndMixedAggAndFilter() throws Exception {
+ var a = attribute("a");
+ var min = new UnresolvedFunction(EMPTY, "min", DEFAULT, List.of(a));
+ var max = new UnresolvedFunction(EMPTY, "max", DEFAULT, List.of(a));
+ var avg = new UnresolvedFunction(EMPTY, "avg", DEFAULT, List.of(a));
+ var min_alias = new Alias(EMPTY, "min", min);
+
+ var max_filter_ex = new Or(
+ EMPTY,
+ new GreaterThan(EMPTY, new Mod(EMPTY, a, integer(3)), integer(10)),
+ new GreaterThan(EMPTY, new Div(EMPTY, a, integer(2)), integer(100))
+ );
+ var max_filter = new Alias(EMPTY, "max", new FilteredExpression(EMPTY, max, max_filter_ex));
+
+ var avg_filter_ex = new GreaterThan(EMPTY, new Div(EMPTY, a, integer(2)), integer(100));
+ var avg_filter = new Alias(EMPTY, "avg", new FilteredExpression(EMPTY, avg, avg_filter_ex));
+
+ assertEquals(
+ new Aggregate(
+ EMPTY,
+ PROCESSING_CMD_INPUT,
+ Aggregate.AggregateType.STANDARD,
+ List.of(a),
+ List.of(min_alias, max_filter, avg_filter, a)
+ ),
+ processingCommand("""
+ stats
+ min = min(a),
+ max = max(a) WHERE (a % 3 > 10 OR a / 2 > 100),
+ avg = avg(a) WHERE a / 2 > 100
+ BY a
+ """)
+ );
+ }
+
+ public void testStatsWithoutGroupKeyMixedAggAndFilter() throws Exception {
+ var a = attribute("a");
+ var f = new UnresolvedFunction(EMPTY, "min", DEFAULT, List.of(a));
+ var filter = new Alias(EMPTY, "min(a) where a > 1", new FilteredExpression(EMPTY, f, new GreaterThan(EMPTY, a, integer(1))));
+ assertEquals(
+ new Aggregate(EMPTY, PROCESSING_CMD_INPUT, Aggregate.AggregateType.STANDARD, List.of(), List.of(filter)),
+ processingCommand("stats min(a) where a > 1")
+ );
+ }
+
public void testInlineStatsWithGroups() {
var query = "inlinestats b = min(a) by c, d.e";
if (Build.current().isSnapshot() == false) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java
index d186b4c199d77..7075c9fe58d63 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java
@@ -21,6 +21,8 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
+import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttributeTests;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedNamedExpression;
@@ -164,6 +166,16 @@ public void testInfoParameters() throws Exception {
* in the parameters and not included.
*/
expectedCount -= 1;
+
+ // special exceptions with private constructors
+ if (MetadataAttribute.class.equals(subclass) || ReferenceAttribute.class.equals(subclass)) {
+ expectedCount++;
+ }
+
+ if (FieldAttribute.class.equals(subclass)) {
+ expectedCount += 2;
+ }
+
assertEquals(expectedCount, info(node).properties().size());
}
@@ -174,6 +186,9 @@ public void testInfoParameters() throws Exception {
* implementations in the process.
*/
public void testTransform() throws Exception {
+ if (FieldAttribute.class.equals(subclass)) {
+ assumeTrue("FieldAttribute private constructor", false);
+ }
Constructor ctor = longestCtor(subclass);
Object[] nodeCtorArgs = ctorArgs(ctor);
T node = ctor.newInstance(nodeCtorArgs);