diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java index 51977ed5135b..cc8559c33300 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java @@ -17,7 +17,11 @@ */ package org.apache.beam.sdk.extensions.sql.impl.transform.agg; +import com.google.auto.value.AutoValue; import java.io.Serializable; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.transforms.Combine; /** @@ -33,21 +37,30 @@ public static CountIfFn combineFn() { public static class CountIfFn extends Combine.CombineFn { - public static class Accum implements Serializable { - boolean isExpressionFalse = true; - long countIfResult = 0L; + @AutoValue + public abstract static class Accum implements Serializable { + abstract boolean isExpressionFalse(); + + abstract long countIfResult(); + + static Accum empty() { + return of(true, 0L); + } + + static Accum of(boolean isExpressionFalse, long countIfResult) { + return new AutoValue_CountIf_CountIfFn_Accum(isExpressionFalse, countIfResult); + } } @Override public Accum createAccumulator() { - return new Accum(); + return Accum.empty(); } @Override public Accum addInput(Accum accum, Boolean input) { - if (input) { - accum.isExpressionFalse = false; - accum.countIfResult += 1; + if (Boolean.TRUE.equals(input)) { + return Accum.of(false, accum.countIfResult() + 1); } return accum; } @@ -56,18 +69,22 @@ public Accum addInput(Accum accum, Boolean input) { public Accum mergeAccumulators(Iterable accums) { CountIfFn.Accum merged = createAccumulator(); for (CountIfFn.Accum accum : accums) { - if (!accum.isExpressionFalse) { - merged.isExpressionFalse = false; - merged.countIfResult += accum.countIfResult; + if (!accum.isExpressionFalse()) { + merged = Accum.of(false, merged.countIfResult() + accum.countIfResult()); } } return merged; } + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return SerializableCoder.of(Accum.class); + } + @Override public Long extractOutput(Accum accum) { - if (!accum.isExpressionFalse) { - return accum.countIfResult; + if (!accum.isExpressionFalse()) { + return accum.countIfResult(); } return 0L; } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java new file mode 100644 index 000000000000..3e5d56ab6014 --- /dev/null +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.impl.transform.agg; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.util.Arrays; +import java.util.List; +import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.junit.Test; + +/** Unit tests for {@link CountIf}. */ +public class CountIfTest { + + @Test + public void testCreatesEmptyAccumulator() { + assertEquals(CountIf.CountIfFn.Accum.empty(), CountIf.combineFn().createAccumulator()); + } + + @Test + public void testReturnsAccumulatorUnchangedForNullInput() { + CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); + CountIf.CountIfFn.Accum accumulator = countIfFn.createAccumulator(); + assertEquals(accumulator, countIfFn.addInput(accumulator, null)); + } + + @Test + public void testAddsInputToAccumulator() { + CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); + CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 1); + + assertEquals( + expectedAccumulator, countIfFn.addInput(countIfFn.createAccumulator(), Boolean.TRUE)); + } + + @Test + public void testCreatesAccumulatorCoder() { + assertNotNull( + CountIf.combineFn().getAccumulatorCoder(CoderRegistry.createDefault(), BooleanCoder.of())); + } + + @Test + public void testMergeAccumulators() { + CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); + List accums = + Arrays.asList(CountIf.CountIfFn.Accum.of(false, 2), CountIf.CountIfFn.Accum.of(false, 2)); + CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 4); + + assertEquals(expectedAccumulator, countIfFn.mergeAccumulators(accums)); + } + + @Test + public void testExtractsOutput() { + CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); + CountIf.CountIfFn.Accum expectedAccumulator = countIfFn.createAccumulator(); + assertEquals(Long.valueOf(0), countIfFn.extractOutput(expectedAccumulator)); + } +}