From 8271791061e72dbf554028e477746e89fca9ec02 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 25 Oct 2022 14:21:57 -0700 Subject: [PATCH] Core, Spark: Add Aggregate expressions (#5961) --- .../apache/iceberg/expressions/Aggregate.java | 58 +++++++++ .../apache/iceberg/expressions/Binder.java | 10 ++ .../iceberg/expressions/BoundAggregate.java | 47 +++++++ .../iceberg/expressions/Expression.java | 6 +- .../expressions/ExpressionVisitors.java | 14 +++ .../iceberg/expressions/Expressions.java | 16 +++ .../iceberg/expressions/UnboundAggregate.java | 57 +++++++++ .../expressions/TestAggregateBinding.java | 116 ++++++++++++++++++ .../apache/iceberg/spark/SparkAggregates.java | 80 ++++++++++++ .../org/apache/iceberg/spark/SparkUtil.java | 8 ++ .../apache/iceberg/spark/SparkV2Filters.java | 39 +++--- .../spark/source/TestSparkAggregates.java | 76 ++++++++++++ 12 files changed, 503 insertions(+), 24 deletions(-) create mode 100644 api/src/main/java/org/apache/iceberg/expressions/Aggregate.java create mode 100644 api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java create mode 100644 api/src/main/java/org/apache/iceberg/expressions/UnboundAggregate.java create mode 100644 api/src/test/java/org/apache/iceberg/expressions/TestAggregateBinding.java create mode 100644 spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java create mode 100644 spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java diff --git a/api/src/main/java/org/apache/iceberg/expressions/Aggregate.java b/api/src/main/java/org/apache/iceberg/expressions/Aggregate.java new file mode 100644 index 000000000000..7db1822e49e4 --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/expressions/Aggregate.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +/** + * The aggregate functions that can be pushed and evaluated in Iceberg. Currently only three + * aggregate functions Max, Min and Count are supported. + */ +public abstract class Aggregate implements Expression { + private final Operation op; + private final C term; + + Aggregate(Operation op, C term) { + this.op = op; + this.term = term; + } + + @Override + public Operation op() { + return op; + } + + public C term() { + return term; + } + + @Override + public String toString() { + switch (op()) { + case COUNT: + return "count(" + term() + ")"; + case COUNT_STAR: + return "count(*)"; + case MAX: + return "max(" + term() + ")"; + case MIN: + return "min(" + term() + ")"; + default: + throw new UnsupportedOperationException("Invalid aggregate: " + op()); + } + } +} diff --git a/api/src/main/java/org/apache/iceberg/expressions/Binder.java b/api/src/main/java/org/apache/iceberg/expressions/Binder.java index d2a7b1d09e0b..3454fa14e0b1 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/Binder.java +++ b/api/src/main/java/org/apache/iceberg/expressions/Binder.java @@ -158,6 +158,16 @@ public Expression predicate(BoundPredicate pred) { public Expression predicate(UnboundPredicate pred) { return pred.bind(struct, caseSensitive); } + + @Override + public Expression aggregate(UnboundAggregate agg) { + return agg.bind(struct, caseSensitive); + } + + @Override + public Expression aggregate(BoundAggregate agg) { + throw new IllegalStateException("Found already bound aggregate: " + agg); + } } private static class ReferenceVisitor extends ExpressionVisitor> { diff --git a/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java b/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java new file mode 100644 index 000000000000..650271b3b78a --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import org.apache.iceberg.StructLike; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; + +public class BoundAggregate extends Aggregate> implements Bound { + protected BoundAggregate(Operation op, BoundTerm term) { + super(op, term); + } + + @Override + public C eval(StructLike struct) { + throw new UnsupportedOperationException(this.getClass().getName() + " does not implement eval"); + } + + @Override + public BoundReference ref() { + return term().ref(); + } + + public Type type() { + if (op() == Operation.COUNT || op() == Operation.COUNT_STAR) { + return Types.LongType.get(); + } else { + return term().type(); + } + } +} diff --git a/api/src/main/java/org/apache/iceberg/expressions/Expression.java b/api/src/main/java/org/apache/iceberg/expressions/Expression.java index cd82aa07ad42..dc88172c590d 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/Expression.java +++ b/api/src/main/java/org/apache/iceberg/expressions/Expression.java @@ -43,7 +43,11 @@ enum Operation { AND, OR, STARTS_WITH, - NOT_STARTS_WITH; + NOT_STARTS_WITH, + COUNT, + COUNT_STAR, + MAX, + MIN; public static Operation fromString(String operationType) { Preconditions.checkArgument(null != operationType, "Invalid operation type: null"); diff --git a/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java b/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java index 4076d7febc3e..79ca6a712887 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java +++ b/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java @@ -55,6 +55,14 @@ public R predicate(BoundPredicate pred) { public R predicate(UnboundPredicate pred) { return null; } + + public R aggregate(BoundAggregate agg) { + throw new UnsupportedOperationException("Cannot visit aggregate expression"); + } + + public R aggregate(UnboundAggregate agg) { + throw new UnsupportedOperationException("Cannot visit aggregate expression"); + } } public abstract static class BoundExpressionVisitor extends ExpressionVisitor { @@ -338,6 +346,12 @@ public static R visit(Expression expr, ExpressionVisitor visitor) { } else { return visitor.predicate((UnboundPredicate) expr); } + } else if (expr instanceof Aggregate) { + if (expr instanceof BoundAggregate) { + return visitor.aggregate((BoundAggregate) expr); + } else { + return visitor.aggregate((UnboundAggregate) expr); + } } else { switch (expr.op()) { case TRUE: diff --git a/api/src/main/java/org/apache/iceberg/expressions/Expressions.java b/api/src/main/java/org/apache/iceberg/expressions/Expressions.java index 7fad8324c4ae..171da823cc8f 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/Expressions.java +++ b/api/src/main/java/org/apache/iceberg/expressions/Expressions.java @@ -308,4 +308,20 @@ public static NamedReference ref(String name) { public static UnboundTerm transform(String name, Transform transform) { return new UnboundTransform<>(ref(name), transform); } + + public static UnboundAggregate count(String name) { + return new UnboundAggregate<>(Operation.COUNT, ref(name)); + } + + public static UnboundAggregate countStar() { + return new UnboundAggregate<>(Operation.COUNT_STAR, null); + } + + public static UnboundAggregate max(String name) { + return new UnboundAggregate<>(Operation.MAX, ref(name)); + } + + public static UnboundAggregate min(String name) { + return new UnboundAggregate<>(Operation.MIN, ref(name)); + } } diff --git a/api/src/main/java/org/apache/iceberg/expressions/UnboundAggregate.java b/api/src/main/java/org/apache/iceberg/expressions/UnboundAggregate.java new file mode 100644 index 000000000000..5e4cce06c7e8 --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/expressions/UnboundAggregate.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.types.Types; + +public class UnboundAggregate extends Aggregate> + implements Unbound { + + UnboundAggregate(Operation op, UnboundTerm term) { + super(op, term); + } + + @Override + public NamedReference ref() { + return term().ref(); + } + + /** + * Bind this UnboundAggregate. + * + * @param struct The {@link Types.StructType struct type} to resolve references by name. + * @param caseSensitive A boolean flag to control whether the bind should enforce case + * sensitivity. + * @return an {@link Expression} + * @throws ValidationException if literals do not match bound references, or if comparison on + * expression is invalid + */ + @Override + public Expression bind(Types.StructType struct, boolean caseSensitive) { + if (op() == Operation.COUNT_STAR) { + return new BoundAggregate<>(op(), null); + } else { + Preconditions.checkArgument(term() != null, "Invalid aggregate term: null"); + BoundTerm bound = term().bind(struct, caseSensitive); + return new BoundAggregate<>(op(), bound); + } + } +} diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestAggregateBinding.java b/api/src/test/java/org/apache/iceberg/expressions/TestAggregateBinding.java new file mode 100644 index 000000000000..7bbe8ad7dad3 --- /dev/null +++ b/api/src/test/java/org/apache/iceberg/expressions/TestAggregateBinding.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.expressions; + +import java.util.Arrays; +import java.util.List; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.StructType; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Test; + +public class TestAggregateBinding { + private static final List AGGREGATES = + Arrays.asList(Expression.Operation.COUNT, Expression.Operation.MAX, Expression.Operation.MIN); + private static final StructType struct = + StructType.of(Types.NestedField.required(10, "x", Types.IntegerType.get())); + + @Test + public void testAggregateBinding() { + for (Expression.Operation op : AGGREGATES) { + UnboundAggregate unbound = null; + switch (op) { + case COUNT: + unbound = Expressions.count("x"); + break; + case MAX: + unbound = Expressions.max("x"); + break; + case MIN: + unbound = Expressions.min("x"); + break; + default: + throw new UnsupportedOperationException("Invalid aggregate: " + op); + } + + Expression expr = unbound.bind(struct, true); + BoundAggregate bound = assertAndUnwrapAggregate(expr); + + Assert.assertEquals("Should reference correct field ID", 10, bound.ref().fieldId()); + Assert.assertEquals("Should not change the comparison operation", op, bound.op()); + } + } + + @Test + public void testCountStarBinding() { + UnboundAggregate unbound = Expressions.countStar(); + Expression expr = unbound.bind(null, false); + BoundAggregate bound = assertAndUnwrapAggregate(expr); + + Assert.assertEquals( + "Should not change the comparison operation", Expression.Operation.COUNT_STAR, bound.op()); + } + + @Test + public void testBoundAggregateFails() { + Expression unbound = Expressions.count("x"); + Assertions.assertThatThrownBy(() -> Binder.bind(struct, Binder.bind(struct, unbound))) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Found already bound aggregate"); + } + + @Test + public void testCaseInsensitiveReference() { + Expression expr = Expressions.max("X"); + Expression boundExpr = Binder.bind(struct, expr, false); + BoundAggregate bound = assertAndUnwrapAggregate(boundExpr); + Assert.assertEquals("Should reference correct field ID", 10, bound.ref().fieldId()); + Assert.assertEquals( + "Should not change the comparison operation", Expression.Operation.MAX, bound.op()); + } + + @Test + public void testCaseSensitiveReference() { + Expression expr = Expressions.max("X"); + Assertions.assertThatThrownBy(() -> Binder.bind(struct, expr, true)) + .isInstanceOf(ValidationException.class) + .hasMessageContaining("Cannot find field 'X' in struct"); + } + + @Test + public void testMissingField() { + UnboundAggregate unbound = Expressions.count("missing"); + try { + unbound.bind(struct, false); + Assert.fail("Binding a missing field should fail"); + } catch (ValidationException e) { + Assert.assertTrue( + "Validation should complain about missing field", + e.getMessage().contains("Cannot find field 'missing' in struct:")); + } + } + + private static BoundAggregate assertAndUnwrapAggregate(Expression expr) { + Assert.assertTrue( + "Expression should be a bound aggregate: " + expr, expr instanceof BoundAggregate); + return (BoundAggregate) expr; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java new file mode 100644 index 000000000000..6741e33fa114 --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark; + +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expression.Operation; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc; +import org.apache.spark.sql.connector.expressions.aggregate.Count; +import org.apache.spark.sql.connector.expressions.aggregate.CountStar; +import org.apache.spark.sql.connector.expressions.aggregate.Max; +import org.apache.spark.sql.connector.expressions.aggregate.Min; + +public class SparkAggregates { + private SparkAggregates() {} + + private static final Map, Operation> AGGREGATES = + ImmutableMap., Operation>builder() + .put(Count.class, Operation.COUNT) + .put(CountStar.class, Operation.COUNT_STAR) + .put(Max.class, Operation.MAX) + .put(Min.class, Operation.MIN) + .build(); + + public static Expression convert(AggregateFunc aggregate) { + Operation op = AGGREGATES.get(aggregate.getClass()); + if (op != null) { + switch (op) { + case COUNT: + Count countAgg = (Count) aggregate; + if (countAgg.isDistinct()) { + // manifest file doesn't have count distinct so this can't be converted to push down + return null; + } + + if (countAgg.column() instanceof NamedReference) { + return Expressions.count(SparkUtil.toColumnName((NamedReference) countAgg.column())); + } else { + return null; + } + case COUNT_STAR: + return Expressions.countStar(); + case MAX: + Max maxAgg = (Max) aggregate; + if (maxAgg.column() instanceof NamedReference) { + return Expressions.max(SparkUtil.toColumnName((NamedReference) maxAgg.column())); + } else { + return null; + } + case MIN: + Min minAgg = (Min) aggregate; + if (minAgg.column() instanceof NamedReference) { + return Expressions.min(SparkUtil.toColumnName((NamedReference) minAgg.column())); + } else { + return null; + } + } + } + return null; + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java index 950ed7bc87b8..2e8312fd9724 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java @@ -32,6 +32,7 @@ import org.apache.iceberg.Table; import org.apache.iceberg.hadoop.HadoopConfigurable; import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.transforms.Transform; @@ -45,6 +46,7 @@ import org.apache.spark.sql.catalyst.expressions.EqualTo; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; @@ -73,6 +75,8 @@ public class SparkUtil { private static final String SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR = SPARK_CATALOG_CONF_PREFIX + ".%s.hadoop."; + private static final Joiner DOT = Joiner.on("."); + private SparkUtil() {} public static FileIO serializableFileIO(Table table) { @@ -287,4 +291,8 @@ public static List partitionMapToExpression( return filterExpressions; } + + public static String toColumnName(NamedReference ref) { + return DOT.join(ref.fieldNames()); + } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java index 2f09e6e9c9c2..072c14c08bb3 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java @@ -40,7 +40,6 @@ import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.expressions.Expression.Operation; import org.apache.iceberg.expressions.Expressions; -import org.apache.iceberg.relocated.com.google.common.base.Joiner; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.util.NaNUtil; @@ -54,8 +53,6 @@ public class SparkV2Filters { - private static final Joiner DOT = Joiner.on("."); - private static final String TRUE = "ALWAYS_TRUE"; private static final String FALSE = "ALWAYS_FALSE"; private static final String EQ = "="; @@ -105,17 +102,17 @@ public static Expression convert(Predicate predicate) { return Expressions.alwaysFalse(); case IS_NULL: - return isRef(child(predicate)) ? isNull(toColumnName(child(predicate))) : null; + return isRef(child(predicate)) ? isNull(SparkUtil.toColumnName(child(predicate))) : null; case NOT_NULL: - return isRef(child(predicate)) ? notNull(toColumnName(child(predicate))) : null; + return isRef(child(predicate)) ? notNull(SparkUtil.toColumnName(child(predicate))) : null; case LT: if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) { - String columnName = toColumnName(leftChild(predicate)); + String columnName = SparkUtil.toColumnName(leftChild(predicate)); return lessThan(columnName, convertLiteral(rightChild(predicate))); } else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) { - String columnName = toColumnName(rightChild(predicate)); + String columnName = SparkUtil.toColumnName(rightChild(predicate)); return greaterThan(columnName, convertLiteral(leftChild(predicate))); } else { return null; @@ -123,10 +120,10 @@ public static Expression convert(Predicate predicate) { case LT_EQ: if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) { - String columnName = toColumnName(leftChild(predicate)); + String columnName = SparkUtil.toColumnName(leftChild(predicate)); return lessThanOrEqual(columnName, convertLiteral(rightChild(predicate))); } else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) { - String columnName = toColumnName(rightChild(predicate)); + String columnName = SparkUtil.toColumnName(rightChild(predicate)); return greaterThanOrEqual(columnName, convertLiteral(leftChild(predicate))); } else { return null; @@ -134,10 +131,10 @@ public static Expression convert(Predicate predicate) { case GT: if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) { - String columnName = toColumnName(leftChild(predicate)); + String columnName = SparkUtil.toColumnName(leftChild(predicate)); return greaterThan(columnName, convertLiteral(rightChild(predicate))); } else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) { - String columnName = toColumnName(rightChild(predicate)); + String columnName = SparkUtil.toColumnName(rightChild(predicate)); return lessThan(columnName, convertLiteral(leftChild(predicate))); } else { return null; @@ -145,10 +142,10 @@ public static Expression convert(Predicate predicate) { case GT_EQ: if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) { - String columnName = toColumnName(leftChild(predicate)); + String columnName = SparkUtil.toColumnName(leftChild(predicate)); return greaterThanOrEqual(columnName, convertLiteral(rightChild(predicate))); } else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) { - String columnName = toColumnName(rightChild(predicate)); + String columnName = SparkUtil.toColumnName(rightChild(predicate)); return lessThanOrEqual(columnName, convertLiteral(leftChild(predicate))); } else { return null; @@ -158,10 +155,10 @@ public static Expression convert(Predicate predicate) { Object value; String columnName; if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) { - columnName = toColumnName(leftChild(predicate)); + columnName = SparkUtil.toColumnName(leftChild(predicate)); value = convertLiteral(rightChild(predicate)); } else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) { - columnName = toColumnName(rightChild(predicate)); + columnName = SparkUtil.toColumnName(rightChild(predicate)); value = convertLiteral(leftChild(predicate)); } else { return null; @@ -183,7 +180,7 @@ public static Expression convert(Predicate predicate) { case IN: if (isSupportedInPredicate(predicate)) { return in( - toColumnName(childAtIndex(predicate, 0)), + SparkUtil.toColumnName(childAtIndex(predicate, 0)), Arrays.stream(predicate.children()) .skip(1) .map(val -> convertLiteral(((Literal) val))) @@ -202,13 +199,13 @@ public static Expression convert(Predicate predicate) { // col NOT IN (1, 2) in Spark is equal to notNull(col) && notIn(col, 1, 2) in Iceberg Expression notIn = notIn( - toColumnName(childAtIndex(childPredicate, 0)), + SparkUtil.toColumnName(childAtIndex(childPredicate, 0)), Arrays.stream(childPredicate.children()) .skip(1) .map(val -> convertLiteral(((Literal) val))) .filter(Objects::nonNull) .collect(Collectors.toList())); - return and(notNull(toColumnName(childAtIndex(childPredicate, 0))), notIn); + return and(notNull(SparkUtil.toColumnName(childAtIndex(childPredicate, 0))), notIn); } else if (hasNoInFilter(childPredicate)) { Expression child = convert(childPredicate); if (child != null) { @@ -240,7 +237,7 @@ public static Expression convert(Predicate predicate) { } case STARTS_WITH: - String colName = toColumnName(leftChild(predicate)); + String colName = SparkUtil.toColumnName(leftChild(predicate)); return startsWith(colName, convertLiteral(rightChild(predicate)).toString()); } } @@ -248,10 +245,6 @@ public static Expression convert(Predicate predicate) { return null; } - private static String toColumnName(NamedReference ref) { - return DOT.join(ref.fieldNames()); - } - @SuppressWarnings("unchecked") private static T child(Predicate predicate) { org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java new file mode 100644 index 000000000000..e2d6f744f5a5 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Map; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.spark.SparkAggregates; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.aggregate.Count; +import org.apache.spark.sql.connector.expressions.aggregate.CountStar; +import org.apache.spark.sql.connector.expressions.aggregate.Max; +import org.apache.spark.sql.connector.expressions.aggregate.Min; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkAggregates { + + @Test + public void testAggregates() { + Map attrMap = Maps.newHashMap(); + attrMap.put("id", "id"); + attrMap.put("`i.d`", "i.d"); + attrMap.put("`i``d`", "i`d"); + attrMap.put("`d`.b.`dd```", "d.b.dd`"); + attrMap.put("a.`aa```.c", "a.aa`.c"); + + attrMap.forEach( + (quoted, unquoted) -> { + NamedReference namedReference = FieldReference.apply(quoted); + + Max max = new Max(namedReference); + Expression expectedMax = Expressions.max(unquoted); + Expression actualMax = SparkAggregates.convert(max); + Assert.assertEquals("Max must match", expectedMax.toString(), actualMax.toString()); + + Min min = new Min(namedReference); + Expression expectedMin = Expressions.min(unquoted); + Expression actualMin = SparkAggregates.convert(min); + Assert.assertEquals("Min must match", expectedMin.toString(), actualMin.toString()); + + Count count = new Count(namedReference, false); + Expression expectedCount = Expressions.count(unquoted); + Expression actualCount = SparkAggregates.convert(count); + Assert.assertEquals("Count must match", expectedCount.toString(), actualCount.toString()); + + Count countDistinct = new Count(namedReference, true); + Expression convertedCountDistinct = SparkAggregates.convert(countDistinct); + Assert.assertNull("Count Distinct is converted to null", convertedCountDistinct); + + CountStar countStar = new CountStar(); + Expression expectedCountStar = Expressions.countStar(); + Expression actualCountStar = SparkAggregates.convert(countStar); + Assert.assertEquals( + "CountStar must match", expectedCountStar.toString(), actualCountStar.toString()); + }); + } +}