diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java index cc8559c33300..98d9dd99bfed 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java @@ -17,12 +17,11 @@ */ package org.apache.beam.sdk.extensions.sql.impl.transform.agg; -import com.google.auto.value.AutoValue; -import java.io.Serializable; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; -import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Count; /** * Returns the count of TRUE values for expression. Returns 0 if there are zero input rows, or if @@ -31,62 +30,41 @@ public class CountIf { private CountIf() {} - public static CountIfFn combineFn() { - return new CountIf.CountIfFn(); + public static Combine.CombineFn combineFn() { + return new CountIfFn(); } - public static class CountIfFn extends Combine.CombineFn { - - @AutoValue - public abstract static class Accum implements Serializable { - abstract boolean isExpressionFalse(); - - abstract long countIfResult(); - - static Accum empty() { - return of(true, 0L); - } - - static Accum of(boolean isExpressionFalse, long countIfResult) { - return new AutoValue_CountIf_CountIfFn_Accum(isExpressionFalse, countIfResult); - } - } + public static class CountIfFn extends Combine.CombineFn { + private final Combine.CombineFn countFn = + (Combine.CombineFn) Count.combineFn(); @Override - public Accum createAccumulator() { - return Accum.empty(); + public long[] createAccumulator() { + return countFn.createAccumulator(); } @Override - public Accum addInput(Accum accum, Boolean input) { + public long[] addInput(long[] accumulator, Boolean input) { if (Boolean.TRUE.equals(input)) { - return Accum.of(false, accum.countIfResult() + 1); + countFn.addInput(accumulator, input); } - return accum; + return accumulator; } @Override - public Accum mergeAccumulators(Iterable accums) { - CountIfFn.Accum merged = createAccumulator(); - for (CountIfFn.Accum accum : accums) { - if (!accum.isExpressionFalse()) { - merged = Accum.of(false, merged.countIfResult() + accum.countIfResult()); - } - } - return merged; + public long[] mergeAccumulators(Iterable accumulators) { + return countFn.mergeAccumulators(accumulators); } @Override - public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { - return SerializableCoder.of(Accum.class); + public Long extractOutput(long[] accumulator) { + return countFn.extractOutput(accumulator); } @Override - public Long extractOutput(Accum accum) { - if (!accum.isExpressionFalse()) { - return accum.countIfResult(); - } - return 0L; + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return countFn.getAccumulatorCoder(registry, inputCoder); } } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java index 3e5d56ab6014..4c0724716e16 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java @@ -23,7 +23,9 @@ import java.util.Arrays; import java.util.List; import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.transforms.Combine; import org.junit.Test; /** Unit tests for {@link CountIf}. */ @@ -31,45 +33,46 @@ public class CountIfTest { @Test public void testCreatesEmptyAccumulator() { - assertEquals(CountIf.CountIfFn.Accum.empty(), CountIf.combineFn().createAccumulator()); + long[] accumulator = (long[]) CountIf.combineFn().createAccumulator(); + + assertEquals(0, accumulator[0]); } @Test public void testReturnsAccumulatorUnchangedForNullInput() { - CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); - CountIf.CountIfFn.Accum accumulator = countIfFn.createAccumulator(); - assertEquals(accumulator, countIfFn.addInput(accumulator, null)); + Combine.CombineFn countIfFn = CountIf.combineFn(); + long[] accumulator = (long[]) countIfFn.addInput(countIfFn.createAccumulator(), null); + + assertEquals(0L, accumulator[0]); } @Test public void testAddsInputToAccumulator() { - CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); - CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 1); + Combine.CombineFn countIfFn = CountIf.combineFn(); + long[] accumulator = (long[]) countIfFn.addInput(countIfFn.createAccumulator(), Boolean.TRUE); - assertEquals( - expectedAccumulator, countIfFn.addInput(countIfFn.createAccumulator(), Boolean.TRUE)); + assertEquals(1L, accumulator[0]); } @Test - public void testCreatesAccumulatorCoder() { + public void testCreatesAccumulatorCoder() throws CannotProvideCoderException { assertNotNull( CountIf.combineFn().getAccumulatorCoder(CoderRegistry.createDefault(), BooleanCoder.of())); } @Test public void testMergeAccumulators() { - CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); - List accums = - Arrays.asList(CountIf.CountIfFn.Accum.of(false, 2), CountIf.CountIfFn.Accum.of(false, 2)); - CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 4); + Combine.CombineFn countIfFn = CountIf.combineFn(); + List accums = Arrays.asList(new long[] {2}, new long[] {2}); + long[] accumulator = (long[]) countIfFn.mergeAccumulators(accums); - assertEquals(expectedAccumulator, countIfFn.mergeAccumulators(accums)); + assertEquals(4L, accumulator[0]); } @Test public void testExtractsOutput() { - CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); - CountIf.CountIfFn.Accum expectedAccumulator = countIfFn.createAccumulator(); - assertEquals(Long.valueOf(0), countIfFn.extractOutput(expectedAccumulator)); + Combine.CombineFn countIfFn = CountIf.combineFn(); + + assertEquals(0L, countIfFn.extractOutput(countIfFn.createAccumulator())); } }