Skip to content

Commit

Permalink
Core, Spark: Add Aggregate expressions (apache#5961)
Browse files Browse the repository at this point in the history
  • Loading branch information
huaxingao authored Oct 25, 2022
1 parent e1070c6 commit 8271791
Show file tree
Hide file tree
Showing 12 changed files with 503 additions and 24 deletions.
58 changes: 58 additions & 0 deletions api/src/main/java/org/apache/iceberg/expressions/Aggregate.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.expressions;

/**
* The aggregate functions that can be pushed and evaluated in Iceberg. Currently only three
* aggregate functions Max, Min and Count are supported.
*/
public abstract class Aggregate<C extends Term> implements Expression {
private final Operation op;
private final C term;

Aggregate(Operation op, C term) {
this.op = op;
this.term = term;
}

@Override
public Operation op() {
return op;
}

public C term() {
return term;
}

@Override
public String toString() {
switch (op()) {
case COUNT:
return "count(" + term() + ")";
case COUNT_STAR:
return "count(*)";
case MAX:
return "max(" + term() + ")";
case MIN:
return "min(" + term() + ")";
default:
throw new UnsupportedOperationException("Invalid aggregate: " + op());
}
}
}
10 changes: 10 additions & 0 deletions api/src/main/java/org/apache/iceberg/expressions/Binder.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ public <T> Expression predicate(BoundPredicate<T> pred) {
public <T> Expression predicate(UnboundPredicate<T> pred) {
return pred.bind(struct, caseSensitive);
}

@Override
public <T> Expression aggregate(UnboundAggregate<T> agg) {
return agg.bind(struct, caseSensitive);
}

@Override
public <T, C> Expression aggregate(BoundAggregate<T, C> agg) {
throw new IllegalStateException("Found already bound aggregate: " + agg);
}
}

private static class ReferenceVisitor extends ExpressionVisitor<Set<Integer>> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.expressions;

import org.apache.iceberg.StructLike;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;

public class BoundAggregate<T, C> extends Aggregate<BoundTerm<T>> implements Bound<C> {
protected BoundAggregate(Operation op, BoundTerm<T> term) {
super(op, term);
}

@Override
public C eval(StructLike struct) {
throw new UnsupportedOperationException(this.getClass().getName() + " does not implement eval");
}

@Override
public BoundReference<?> ref() {
return term().ref();
}

public Type type() {
if (op() == Operation.COUNT || op() == Operation.COUNT_STAR) {
return Types.LongType.get();
} else {
return term().type();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ enum Operation {
AND,
OR,
STARTS_WITH,
NOT_STARTS_WITH;
NOT_STARTS_WITH,
COUNT,
COUNT_STAR,
MAX,
MIN;

public static Operation fromString(String operationType) {
Preconditions.checkArgument(null != operationType, "Invalid operation type: null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ public <T> R predicate(BoundPredicate<T> pred) {
public <T> R predicate(UnboundPredicate<T> pred) {
return null;
}

public <T, C> R aggregate(BoundAggregate<T, C> agg) {
throw new UnsupportedOperationException("Cannot visit aggregate expression");
}

public <T> R aggregate(UnboundAggregate<T> agg) {
throw new UnsupportedOperationException("Cannot visit aggregate expression");
}
}

public abstract static class BoundExpressionVisitor<R> extends ExpressionVisitor<R> {
Expand Down Expand Up @@ -338,6 +346,12 @@ public static <R> R visit(Expression expr, ExpressionVisitor<R> visitor) {
} else {
return visitor.predicate((UnboundPredicate<?>) expr);
}
} else if (expr instanceof Aggregate) {
if (expr instanceof BoundAggregate) {
return visitor.aggregate((BoundAggregate<?, ?>) expr);
} else {
return visitor.aggregate((UnboundAggregate<?>) expr);
}
} else {
switch (expr.op()) {
case TRUE:
Expand Down
16 changes: 16 additions & 0 deletions api/src/main/java/org/apache/iceberg/expressions/Expressions.java
Original file line number Diff line number Diff line change
Expand Up @@ -308,4 +308,20 @@ public static <T> NamedReference<T> ref(String name) {
public static <T> UnboundTerm<T> transform(String name, Transform<?, T> transform) {
return new UnboundTransform<>(ref(name), transform);
}

public static UnboundAggregate<String> count(String name) {
return new UnboundAggregate<>(Operation.COUNT, ref(name));
}

public static UnboundAggregate<String> countStar() {
return new UnboundAggregate<>(Operation.COUNT_STAR, null);
}

public static UnboundAggregate<String> max(String name) {
return new UnboundAggregate<>(Operation.MAX, ref(name));
}

public static UnboundAggregate<String> min(String name) {
return new UnboundAggregate<>(Operation.MIN, ref(name));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.expressions;

import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.types.Types;

public class UnboundAggregate<T> extends Aggregate<UnboundTerm<T>>
implements Unbound<T, Expression> {

UnboundAggregate(Operation op, UnboundTerm<T> term) {
super(op, term);
}

@Override
public NamedReference<?> ref() {
return term().ref();
}

/**
* Bind this UnboundAggregate.
*
* @param struct The {@link Types.StructType struct type} to resolve references by name.
* @param caseSensitive A boolean flag to control whether the bind should enforce case
* sensitivity.
* @return an {@link Expression}
* @throws ValidationException if literals do not match bound references, or if comparison on
* expression is invalid
*/
@Override
public Expression bind(Types.StructType struct, boolean caseSensitive) {
if (op() == Operation.COUNT_STAR) {
return new BoundAggregate<>(op(), null);
} else {
Preconditions.checkArgument(term() != null, "Invalid aggregate term: null");
BoundTerm<T> bound = term().bind(struct, caseSensitive);
return new BoundAggregate<>(op(), bound);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.expressions;

import java.util.Arrays;
import java.util.List;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.types.Types.StructType;
import org.assertj.core.api.Assertions;
import org.junit.Assert;
import org.junit.Test;

public class TestAggregateBinding {
private static final List<Expression.Operation> AGGREGATES =
Arrays.asList(Expression.Operation.COUNT, Expression.Operation.MAX, Expression.Operation.MIN);
private static final StructType struct =
StructType.of(Types.NestedField.required(10, "x", Types.IntegerType.get()));

@Test
public void testAggregateBinding() {
for (Expression.Operation op : AGGREGATES) {
UnboundAggregate unbound = null;
switch (op) {
case COUNT:
unbound = Expressions.count("x");
break;
case MAX:
unbound = Expressions.max("x");
break;
case MIN:
unbound = Expressions.min("x");
break;
default:
throw new UnsupportedOperationException("Invalid aggregate: " + op);
}

Expression expr = unbound.bind(struct, true);
BoundAggregate bound = assertAndUnwrapAggregate(expr);

Assert.assertEquals("Should reference correct field ID", 10, bound.ref().fieldId());
Assert.assertEquals("Should not change the comparison operation", op, bound.op());
}
}

@Test
public void testCountStarBinding() {
UnboundAggregate unbound = Expressions.countStar();
Expression expr = unbound.bind(null, false);
BoundAggregate bound = assertAndUnwrapAggregate(expr);

Assert.assertEquals(
"Should not change the comparison operation", Expression.Operation.COUNT_STAR, bound.op());
}

@Test
public void testBoundAggregateFails() {
Expression unbound = Expressions.count("x");
Assertions.assertThatThrownBy(() -> Binder.bind(struct, Binder.bind(struct, unbound)))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("Found already bound aggregate");
}

@Test
public void testCaseInsensitiveReference() {
Expression expr = Expressions.max("X");
Expression boundExpr = Binder.bind(struct, expr, false);
BoundAggregate bound = assertAndUnwrapAggregate(boundExpr);
Assert.assertEquals("Should reference correct field ID", 10, bound.ref().fieldId());
Assert.assertEquals(
"Should not change the comparison operation", Expression.Operation.MAX, bound.op());
}

@Test
public void testCaseSensitiveReference() {
Expression expr = Expressions.max("X");
Assertions.assertThatThrownBy(() -> Binder.bind(struct, expr, true))
.isInstanceOf(ValidationException.class)
.hasMessageContaining("Cannot find field 'X' in struct");
}

@Test
public void testMissingField() {
UnboundAggregate unbound = Expressions.count("missing");
try {
unbound.bind(struct, false);
Assert.fail("Binding a missing field should fail");
} catch (ValidationException e) {
Assert.assertTrue(
"Validation should complain about missing field",
e.getMessage().contains("Cannot find field 'missing' in struct:"));
}
}

private static <T, C> BoundAggregate<T, C> assertAndUnwrapAggregate(Expression expr) {
Assert.assertTrue(
"Expression should be a bound aggregate: " + expr, expr instanceof BoundAggregate);
return (BoundAggregate<T, C>) expr;
}
}
Loading

0 comments on commit 8271791

Please sign in to comment.