diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java index 51977ed5135b..98d9dd99bfed 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java @@ -17,8 +17,11 @@ */ package org.apache.beam.sdk.extensions.sql.impl.transform.agg; -import java.io.Serializable; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Count; /** * Returns the count of TRUE values for expression. Returns 0 if there are zero input rows, or if @@ -27,49 +30,41 @@ public class CountIf { private CountIf() {} - public static CountIfFn combineFn() { - return new CountIf.CountIfFn(); + public static Combine.CombineFn combineFn() { + return new CountIfFn(); } - public static class CountIfFn extends Combine.CombineFn { + public static class CountIfFn extends Combine.CombineFn { + private final Combine.CombineFn countFn = + (Combine.CombineFn) Count.combineFn(); - public static class Accum implements Serializable { - boolean isExpressionFalse = true; - long countIfResult = 0L; + @Override + public long[] createAccumulator() { + return countFn.createAccumulator(); } @Override - public Accum createAccumulator() { - return new Accum(); + public long[] addInput(long[] accumulator, Boolean input) { + if (Boolean.TRUE.equals(input)) { + countFn.addInput(accumulator, input); + } + return accumulator; } @Override - public Accum addInput(Accum accum, Boolean input) { - if (input) { - accum.isExpressionFalse = false; - accum.countIfResult += 1; - } - return accum; + public long[] mergeAccumulators(Iterable accumulators) { + return countFn.mergeAccumulators(accumulators); } @Override - public Accum mergeAccumulators(Iterable accums) { - CountIfFn.Accum merged = createAccumulator(); - for (CountIfFn.Accum accum : accums) { - if (!accum.isExpressionFalse) { - merged.isExpressionFalse = false; - merged.countIfResult += accum.countIfResult; - } - } - return merged; + public Long extractOutput(long[] accumulator) { + return countFn.extractOutput(accumulator); } @Override - public Long extractOutput(Accum accum) { - if (!accum.isExpressionFalse) { - return accum.countIfResult; - } - return 0L; + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return countFn.getAccumulatorCoder(registry, inputCoder); } } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java new file mode 100644 index 000000000000..4c0724716e16 --- /dev/null +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.impl.transform.agg; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.util.Arrays; +import java.util.List; +import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.transforms.Combine; +import org.junit.Test; + +/** Unit tests for {@link CountIf}. */ +public class CountIfTest { + + @Test + public void testCreatesEmptyAccumulator() { + long[] accumulator = (long[]) CountIf.combineFn().createAccumulator(); + + assertEquals(0, accumulator[0]); + } + + @Test + public void testReturnsAccumulatorUnchangedForNullInput() { + Combine.CombineFn countIfFn = CountIf.combineFn(); + long[] accumulator = (long[]) countIfFn.addInput(countIfFn.createAccumulator(), null); + + assertEquals(0L, accumulator[0]); + } + + @Test + public void testAddsInputToAccumulator() { + Combine.CombineFn countIfFn = CountIf.combineFn(); + long[] accumulator = (long[]) countIfFn.addInput(countIfFn.createAccumulator(), Boolean.TRUE); + + assertEquals(1L, accumulator[0]); + } + + @Test + public void testCreatesAccumulatorCoder() throws CannotProvideCoderException { + assertNotNull( + CountIf.combineFn().getAccumulatorCoder(CoderRegistry.createDefault(), BooleanCoder.of())); + } + + @Test + public void testMergeAccumulators() { + Combine.CombineFn countIfFn = CountIf.combineFn(); + List accums = Arrays.asList(new long[] {2}, new long[] {2}); + long[] accumulator = (long[]) countIfFn.mergeAccumulators(accums); + + assertEquals(4L, accumulator[0]); + } + + @Test + public void testExtractsOutput() { + Combine.CombineFn countIfFn = CountIf.combineFn(); + + assertEquals(0L, countIfFn.extractOutput(countIfFn.createAccumulator())); + } +} diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java index 4b72fca933eb..f7a8ad1fa06b 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java @@ -92,12 +92,12 @@ public void testAddsInputToAccumulator() { } @Test - public void testCeatesAccumulatorCoder() { + public void testCreatesAccumulatorCoder() { assertNotNull(varianceFn.getAccumulatorCoder(CoderRegistry.createDefault(), VarIntCoder.of())); } @Test - public void testReturnsOutput() { + public void testExtractsOutput() { assertEquals(expectedExtractedResult, varianceFn.extractOutput(testAccumulatorInput)); } }