Skip to content

Commit

Permalink
Support distinct count aggregation (#167)
Browse files Browse the repository at this point in the history
* Support construct AggregationResponseParser during Aggregator build stage (#108)

* Support construct AggregationResponseParser during Aggregator build stage

* modify the doc

Signed-off-by: penghuo <[email protected]>

* support distinct count aggregation

Signed-off-by: chloe-zh <[email protected]>

* fixed tests

Signed-off-by: chloe-zh <[email protected]>

* Merge remote-tracking branch 'upstream/develop' into issue/#100

Signed-off-by: chloe-zh <[email protected]>

# Conflicts:
#	opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java

* update

Signed-off-by: chloe-zh <[email protected]>

* updated user doc

Signed-off-by: chloe-zh <[email protected]>

* Update: support only count for distinct aggregations

Signed-off-by: chloe-zh <[email protected]>

* Update doc; removed distinct start

Signed-off-by: chloe-zh <[email protected]>

* Removed unnecessary methods

Signed-off-by: chloe-zh <[email protected]>

* update

Signed-off-by: chloe-zh <[email protected]>

* Impl stddev and variance function in SQL and PPL (#115)

* impl variance frontend and backend

* Support construct AggregationResponseParser during Aggregator build stage

* add var and varp for PPL

Signed-off-by: penghuo <[email protected]>

* add UT

Signed-off-by: penghuo <[email protected]>

* fix UT

Signed-off-by: penghuo <[email protected]>

* fix doc format

Signed-off-by: penghuo <[email protected]>

* fix doc format

Signed-off-by: penghuo <[email protected]>

* fix the doc

Signed-off-by: penghuo <[email protected]>

* add stddev_samp and stddev_pop

Signed-off-by: penghuo <[email protected]>

* fix UT coverage

* address comments

Signed-off-by: penghuo <[email protected]>

* Fix the aggregation filter missing in named aggregators (#123)

* Take the condition expression as property to the named aggregator when wrapping the delegated aggregator

Signed-off-by: chloe-zh <[email protected]>

* update

Signed-off-by: chloe-zh <[email protected]>

* Added test case where filtered agg is not pushed down

Signed-off-by: chloe-zh <[email protected]>

* update

Signed-off-by: chloe-zh <[email protected]>

* update

Signed-off-by: chloe-zh <[email protected]>

* update

Signed-off-by: chloe-zh <[email protected]>

* modified comparison test

Signed-off-by: chloe-zh <[email protected]>

* removed a comparison test and added it to aggregationIT

Signed-off-by: chloe-zh <[email protected]>

* added ppl IT test cases; added window function test cases

Signed-off-by: chloe-zh <[email protected]>

* moved distinct window function test cases to WindowsIT

Signed-off-by: chloe-zh <[email protected]>

* added ut

Signed-off-by: chloe-zh <[email protected]>

* update

Signed-off-by: chloe-zh <[email protected]>

* update

Signed-off-by: chloe-zh <[email protected]>

* addressed comments

Signed-off-by: chloe-zh <[email protected]>

* added test cases to meet the coverage requirement

Signed-off-by: chloe-zh <[email protected]>

* added test cases for distinct count map and array types

Signed-off-by: chloe-zh <[email protected]>

Co-authored-by: Peng Huo <[email protected]>
  • Loading branch information
chloe-zh and penghuo authored Jul 29, 2021
1 parent bcdd3f5 commit b3dfc49
Show file tree
Hide file tree
Showing 29 changed files with 428 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
Expression arg = node.getField().accept(this, context);
Aggregator aggregator = (Aggregator) repository.compile(
builtinFunctionName.get().getName(), Collections.singletonList(arg));
if (node.getCondition() != null) {
aggregator.condition(analyze(node.getCondition(), context));
aggregator.distinct(node.getDistinct());
if (node.condition() != null) {
aggregator.condition(analyze(node.condition(), context));
}
return aggregator;
} else {
Expand Down
11 changes: 10 additions & 1 deletion core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,16 @@ public static UnresolvedExpression aggregate(

public static UnresolvedExpression filteredAggregate(
String func, UnresolvedExpression field, UnresolvedExpression condition) {
return new AggregateFunction(func, field, condition);
return new AggregateFunction(func, field).condition(condition);
}

public static UnresolvedExpression distinctAggregate(String func, UnresolvedExpression field) {
return new AggregateFunction(func, field, true);
}

public static UnresolvedExpression filteredDistinctCount(
String func, UnresolvedExpression field, UnresolvedExpression condition) {
return new AggregateFunction(func, field, true).condition(condition);
}

public static Function function(String funcName, UnresolvedExpression... funcArgs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@

import java.util.Collections;
import java.util.List;
import javax.annotation.Nullable;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.common.utils.StringUtils;

Expand All @@ -45,7 +48,10 @@ public class AggregateFunction extends UnresolvedExpression {
private final String funcName;
private final UnresolvedExpression field;
private final List<UnresolvedExpression> argList;
@Setter
@Accessors(fluent = true)
private UnresolvedExpression condition;
private Boolean distinct = false;

/**
* Constructor.
Expand All @@ -62,14 +68,13 @@ public AggregateFunction(String funcName, UnresolvedExpression field) {
* Constructor.
* @param funcName function name.
* @param field {@link UnresolvedExpression}.
* @param condition condition in aggregation filter.
* @param distinct whether distinct field is specified or not.
*/
public AggregateFunction(String funcName, UnresolvedExpression field,
UnresolvedExpression condition) {
public AggregateFunction(String funcName, UnresolvedExpression field, Boolean distinct) {
this.funcName = funcName;
this.field = field;
this.argList = Collections.emptyList();
this.condition = condition;
this.distinct = distinct;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ public static ExprValue fromObjectValue(Object o, ExprCoreType type) {
}
}

public static Byte getByteValue(ExprValue exprValue) {
return exprValue.byteValue();
}

public static Short getShortValue(ExprValue exprValue) {
return exprValue.shortValue();
}

public static Integer getIntegerValue(ExprValue exprValue) {
return exprValue.integerValue();
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,10 @@ public Aggregator count(Expression... expressions) {
return aggregate(BuiltinFunctionName.COUNT, expressions);
}

public Aggregator distinctCount(Expression... expressions) {
return count(expressions).distinct(true);
}

public Aggregator varSamp(Expression... expressions) {
return aggregate(BuiltinFunctionName.VARSAMP, expressions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ public abstract class Aggregator<S extends AggregationState>
@Getter
@Accessors(fluent = true)
protected Expression condition;
@Setter
@Getter
@Accessors(fluent = true)
protected Boolean distinct = false;

/**
* Create an {@link AggregationState} which will be used for aggregation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

import static org.opensearch.sql.utils.ExpressionUtils.format;

import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
Expand All @@ -45,33 +47,51 @@ public CountAggregator(List<Expression> arguments, ExprCoreType returnType) {

@Override
public CountAggregator.CountState create() {
return new CountState();
return distinct ? new DistinctCountState() : new CountState();
}

@Override
protected CountState iterate(ExprValue value, CountState state) {
state.count++;
state.count(value);
return state;
}

@Override
public String toString() {
return String.format(Locale.ROOT, "count(%s)", format(getArguments()));
return distinct
? String.format(Locale.ROOT, "count(distinct %s)", format(getArguments()))
: String.format(Locale.ROOT, "count(%s)", format(getArguments()));
}

/**
* Count State.
*/
protected static class CountState implements AggregationState {
private int count;
protected int count;

CountState() {
this.count = 0;
}

public void count(ExprValue value) {
count++;
}

@Override
public ExprValue result() {
return ExprValueUtils.integerValue(count);
}
}

protected static class DistinctCountState extends CountState {
private final Set<ExprValue> distinctValues = new HashSet<>();

@Override
public void count(ExprValue value) {
if (!distinctValues.contains(value)) {
distinctValues.add(value);
count++;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ public class NamedAggregator extends Aggregator<AggregationState> {

/**
* NamedAggregator.
* The aggregator properties {@link #condition} is inherited by named aggregator
* to avoid errors introduced by the property inconsistency.
* The aggregator properties {@link #condition} and {@link #distinct}
* are inherited by named aggregator to avoid errors introduced by the property inconsistency.
*
* @param name name
* @param delegated delegated
Expand All @@ -67,6 +67,7 @@ public NamedAggregator(
this.name = name;
this.delegated = delegated;
this.condition = delegated.condition;
this.distinct = delegated.distinct;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,24 @@ public void variance_mapto_varPop() {
);
}

@Test
public void distinct_count() {
assertAnalyzeEqual(
dsl.distinctCount(DSL.ref("integer_value", INTEGER)),
AstDSL.distinctAggregate("count", qualifiedName("integer_value"))
);
}

@Test
public void filtered_distinct_count() {
assertAnalyzeEqual(
dsl.distinctCount(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))),
AstDSL.filteredDistinctCount("count", qualifiedName("integer_value"), function(
">", qualifiedName("integer_value"), intLiteral(1)))
);
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ public class ExprValueUtilsTest {
Lists.newArrayList(Iterables.concat(numberValues, nonNumberValues));

private static List<Function<ExprValue, Object>> numberValueExtractor = Arrays.asList(
ExprValue::byteValue,
ExprValue::shortValue,
ExprValueUtils::getByteValue,
ExprValueUtils::getShortValue,
ExprValueUtils::getIntegerValue,
ExprValueUtils::getLongValue,
ExprValueUtils::getFloatValue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,29 @@ public class AggregationTest extends ExpressionTestBase {
"timestamp_value",
"2040-01-01 07:00:00")));

protected static List<ExprValue> tuples_with_duplicates =
Arrays.asList(
ExprValueUtils.tupleValue(ImmutableMap.of(
"integer_value", 1,
"double_value", 4d,
"struct_value", ImmutableMap.of("str", 1),
"array_value", ImmutableList.of(1))),
ExprValueUtils.tupleValue(ImmutableMap.of(
"integer_value", 1,
"double_value", 3d,
"struct_value", ImmutableMap.of("str", 1),
"array_value", ImmutableList.of(1))),
ExprValueUtils.tupleValue(ImmutableMap.of(
"integer_value", 2,
"double_value", 2d,
"struct_value", ImmutableMap.of("str", 2),
"array_value", ImmutableList.of(2))),
ExprValueUtils.tupleValue(ImmutableMap.of(
"integer_value", 3,
"double_value", 1d,
"struct_value", ImmutableMap.of("str1", 1),
"array_value", ImmutableList.of(1, 2))));

protected static List<ExprValue> tuples_with_null_and_missing =
Arrays.asList(
ExprValueUtils.tupleValue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,35 @@ public void filtered_count() {
assertEquals(3, result.value());
}

@Test
public void distinct_count() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER)),
tuples_with_duplicates);
assertEquals(3, result.value());
}

@Test
public void filtered_distinct_count() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("double_value", DOUBLE), DSL.literal(1d))),
tuples_with_duplicates);
assertEquals(2, result.value());
}

@Test
public void distinct_count_map() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("struct_value", STRUCT)),
tuples_with_duplicates);
assertEquals(3, result.value());
}

@Test
public void distinct_count_array() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("array_value", ARRAY)),
tuples_with_duplicates);
assertEquals(3, result.value());
}

@Test
public void count_with_missing() {
ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)),
Expand Down Expand Up @@ -166,6 +195,9 @@ public void valueOf() {
public void test_to_string() {
Aggregator countAggregator = dsl.count(DSL.ref("integer_value", INTEGER));
assertEquals("count(integer_value)", countAggregator.toString());

countAggregator = dsl.distinctCount(DSL.ref("integer_value", INTEGER));
assertEquals("count(distinct integer_value)", countAggregator.toString());
}

@Test
Expand Down
26 changes: 26 additions & 0 deletions docs/user/dql/aggregations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,19 @@ Example::
| 2.8613807855648994 |
+--------------------+

DISTINCT COUNT Aggregation
--------------------------

To get the count of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the count aggregation. Example::

os> SELECT COUNT(DISTINCT gender), COUNT(gender) FROM accounts;
fetched rows / total rows = 1/1
+--------------------------+-----------------+
| COUNT(DISTINCT gender) | COUNT(gender) |
|--------------------------+-----------------|
| 2 | 4 |
+--------------------------+-----------------+

HAVING Clause
=============

Expand Down Expand Up @@ -456,3 +469,16 @@ The ``FILTER`` clause can be used in aggregation functions without GROUP BY as w
| 4 | 1 |
+--------------+------------+

Distinct count aggregate with FILTER
------------------------------------

The ``FILTER`` clause is also used in distinct count to do the filtering before count the distinct values of specific field. For example::

os> SELECT COUNT(DISTINCT firstname) FILTER(WHERE age > 30) AS distinct_count FROM accounts
fetched rows / total rows = 1/1
+------------------+
| distinct_count |
|------------------|
| 3 |
+------------------+

15 changes: 15 additions & 0 deletions docs/user/ppl/cmd/stats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,18 @@ PPL query::
| 36 | 32 | M |
+------------+------------+----------+

Example 7: Calculate the distinct count of a field
==================================================

To get the count of distinct values of a field, you can use ``DISTINCT_COUNT`` (or ``DC``) function instead of ``COUNT``. The example calculates both the count and the distinct count of gender field of all the accounts.

PPL query::

os> source=accounts | stats count(gender), distinct_count(gender);
fetched rows / total rows = 1/1
+-----------------+--------------------------+
| count(gender) | distinct_count(gender) |
|-----------------+--------------------------|
| 4 | 2 |
+-----------------+--------------------------+

Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ public void testStatsCountAll() throws IOException {
verifyDataRows(response, rows(1000));
}

@Test
public void testStatsDistinctCount() throws IOException {
JSONObject response =
executeQuery(String.format("source=%s | stats distinct_count(gender)", TEST_INDEX_ACCOUNT));
verifySchema(response, schema("distinct_count(gender)", null, "integer"));
verifyDataRows(response, rows(2));

response =
executeQuery(String.format("source=%s | stats dc(age)", TEST_INDEX_ACCOUNT));
verifySchema(response, schema("dc(age)", null, "integer"));
verifyDataRows(response, rows(21));
}

@Test
public void testStatsMin() throws IOException {
JSONObject response = executeQuery(String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ protected void init() throws Exception {
}

@Test
void filteredAggregateWithSubquery() throws IOException {
void filteredAggregatePushedDown() throws IOException {
JSONObject response = executeQuery(
"SELECT COUNT(*) FILTER(WHERE age > 35) FROM " + TEST_INDEX_BANK);
verifySchema(response, schema("COUNT(*)", null, "integer"));
verifyDataRows(response, rows(3));
}

@Test
void filteredAggregateNotPushedDown() throws IOException {
JSONObject response = executeQuery(
"SELECT COUNT(*) FILTER(WHERE age > 35) FROM (SELECT * FROM " + TEST_INDEX_BANK
+ ") AS a");
Expand Down
Loading

0 comments on commit b3dfc49

Please sign in to comment.