From add7bbcb0d690b63c3d282de8aef20cb87e23e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 15 Feb 2022 11:20:12 +0100 Subject: [PATCH 1/3] [BEAM-13202] Fix typos on tests names for VarianceFnTest --- .../sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java index 4b72fca933eb..f7a8ad1fa06b 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java @@ -92,12 +92,12 @@ public void testAddsInputToAccumulator() { } @Test - public void testCeatesAccumulatorCoder() { + public void testCreatesAccumulatorCoder() { assertNotNull(varianceFn.getAccumulatorCoder(CoderRegistry.createDefault(), VarIntCoder.of())); } @Test - public void testReturnsOutput() { + public void testExtractsOutput() { assertEquals(expectedExtractedResult, varianceFn.extractOutput(testAccumulatorInput)); } } From 1105c34dea5326d7f366360f43fcbbf103e653f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 15 Feb 2022 12:22:24 +0100 Subject: [PATCH 2/3] [BEAM-13202] Add Coder to CountIfFn.Accum --- .../sql/impl/transform/agg/CountIf.java | 41 +++++++--- .../sql/impl/transform/agg/CountIfTest.java | 75 +++++++++++++++++++ 2 files changed, 104 insertions(+), 12 deletions(-) create mode 100644 sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java index 51977ed5135b..cc8559c33300 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java @@ -17,7 +17,11 @@ */ package org.apache.beam.sdk.extensions.sql.impl.transform.agg; +import com.google.auto.value.AutoValue; import java.io.Serializable; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.transforms.Combine; /** @@ -33,21 +37,30 @@ public static CountIfFn combineFn() { public static class CountIfFn extends Combine.CombineFn { - public static class Accum implements Serializable { - boolean isExpressionFalse = true; - long countIfResult = 0L; + @AutoValue + public abstract static class Accum implements Serializable { + abstract boolean isExpressionFalse(); + + abstract long countIfResult(); + + static Accum empty() { + return of(true, 0L); + } + + static Accum of(boolean isExpressionFalse, long countIfResult) { + return new AutoValue_CountIf_CountIfFn_Accum(isExpressionFalse, countIfResult); + } } @Override public Accum createAccumulator() { - return new Accum(); + return Accum.empty(); } @Override public Accum addInput(Accum accum, Boolean input) { - if (input) { - accum.isExpressionFalse = false; - accum.countIfResult += 1; + if (Boolean.TRUE.equals(input)) { + return Accum.of(false, accum.countIfResult() + 1); } return accum; } @@ -56,18 +69,22 @@ public Accum addInput(Accum accum, Boolean input) { public Accum mergeAccumulators(Iterable accums) { CountIfFn.Accum merged = createAccumulator(); for (CountIfFn.Accum accum : accums) { - if (!accum.isExpressionFalse) { - merged.isExpressionFalse = false; - merged.countIfResult += accum.countIfResult; + if (!accum.isExpressionFalse()) { + merged = Accum.of(false, merged.countIfResult() + accum.countIfResult()); } } return merged; } + @Override + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { + return SerializableCoder.of(Accum.class); + } + @Override public Long extractOutput(Accum accum) { - if (!accum.isExpressionFalse) { - return accum.countIfResult; + if (!accum.isExpressionFalse()) { + return accum.countIfResult(); } return 0L; } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java new file mode 100644 index 000000000000..3e5d56ab6014 --- /dev/null +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.impl.transform.agg; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.util.Arrays; +import java.util.List; +import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.junit.Test; + +/** Unit tests for {@link CountIf}. */ +public class CountIfTest { + + @Test + public void testCreatesEmptyAccumulator() { + assertEquals(CountIf.CountIfFn.Accum.empty(), CountIf.combineFn().createAccumulator()); + } + + @Test + public void testReturnsAccumulatorUnchangedForNullInput() { + CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); + CountIf.CountIfFn.Accum accumulator = countIfFn.createAccumulator(); + assertEquals(accumulator, countIfFn.addInput(accumulator, null)); + } + + @Test + public void testAddsInputToAccumulator() { + CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); + CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 1); + + assertEquals( + expectedAccumulator, countIfFn.addInput(countIfFn.createAccumulator(), Boolean.TRUE)); + } + + @Test + public void testCreatesAccumulatorCoder() { + assertNotNull( + CountIf.combineFn().getAccumulatorCoder(CoderRegistry.createDefault(), BooleanCoder.of())); + } + + @Test + public void testMergeAccumulators() { + CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); + List accums = + Arrays.asList(CountIf.CountIfFn.Accum.of(false, 2), CountIf.CountIfFn.Accum.of(false, 2)); + CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 4); + + assertEquals(expectedAccumulator, countIfFn.mergeAccumulators(accums)); + } + + @Test + public void testExtractsOutput() { + CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); + CountIf.CountIfFn.Accum expectedAccumulator = countIfFn.createAccumulator(); + assertEquals(Long.valueOf(0), countIfFn.extractOutput(expectedAccumulator)); + } +} From cf75357600196289ed0a22c9aea0a0ebfc4b2c0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 15 Feb 2022 18:11:17 +0100 Subject: [PATCH 3/3] [BEAM-13202] Reuse Count transform code since CountIf is a specific case --- .../sql/impl/transform/agg/CountIf.java | 60 ++++++------------- .../sql/impl/transform/agg/CountIfTest.java | 37 ++++++------ 2 files changed, 39 insertions(+), 58 deletions(-) diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java index cc8559c33300..98d9dd99bfed 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIf.java @@ -17,12 +17,11 @@ */ package org.apache.beam.sdk.extensions.sql.impl.transform.agg; -import com.google.auto.value.AutoValue; -import java.io.Serializable; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; -import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Count; /** * Returns the count of TRUE values for expression. Returns 0 if there are zero input rows, or if @@ -31,62 +30,41 @@ public class CountIf { private CountIf() {} - public static CountIfFn combineFn() { - return new CountIf.CountIfFn(); + public static Combine.CombineFn combineFn() { + return new CountIfFn(); } - public static class CountIfFn extends Combine.CombineFn { - - @AutoValue - public abstract static class Accum implements Serializable { - abstract boolean isExpressionFalse(); - - abstract long countIfResult(); - - static Accum empty() { - return of(true, 0L); - } - - static Accum of(boolean isExpressionFalse, long countIfResult) { - return new AutoValue_CountIf_CountIfFn_Accum(isExpressionFalse, countIfResult); - } - } + public static class CountIfFn extends Combine.CombineFn { + private final Combine.CombineFn countFn = + (Combine.CombineFn) Count.combineFn(); @Override - public Accum createAccumulator() { - return Accum.empty(); + public long[] createAccumulator() { + return countFn.createAccumulator(); } @Override - public Accum addInput(Accum accum, Boolean input) { + public long[] addInput(long[] accumulator, Boolean input) { if (Boolean.TRUE.equals(input)) { - return Accum.of(false, accum.countIfResult() + 1); + countFn.addInput(accumulator, input); } - return accum; + return accumulator; } @Override - public Accum mergeAccumulators(Iterable accums) { - CountIfFn.Accum merged = createAccumulator(); - for (CountIfFn.Accum accum : accums) { - if (!accum.isExpressionFalse()) { - merged = Accum.of(false, merged.countIfResult() + accum.countIfResult()); - } - } - return merged; + public long[] mergeAccumulators(Iterable accumulators) { + return countFn.mergeAccumulators(accumulators); } @Override - public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) { - return SerializableCoder.of(Accum.class); + public Long extractOutput(long[] accumulator) { + return countFn.extractOutput(accumulator); } @Override - public Long extractOutput(Accum accum) { - if (!accum.isExpressionFalse()) { - return accum.countIfResult(); - } - return 0L; + public Coder getAccumulatorCoder(CoderRegistry registry, Coder inputCoder) + throws CannotProvideCoderException { + return countFn.getAccumulatorCoder(registry, inputCoder); } } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java index 3e5d56ab6014..4c0724716e16 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CountIfTest.java @@ -23,7 +23,9 @@ import java.util.Arrays; import java.util.List; import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.transforms.Combine; import org.junit.Test; /** Unit tests for {@link CountIf}. */ @@ -31,45 +33,46 @@ public class CountIfTest { @Test public void testCreatesEmptyAccumulator() { - assertEquals(CountIf.CountIfFn.Accum.empty(), CountIf.combineFn().createAccumulator()); + long[] accumulator = (long[]) CountIf.combineFn().createAccumulator(); + + assertEquals(0, accumulator[0]); } @Test public void testReturnsAccumulatorUnchangedForNullInput() { - CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); - CountIf.CountIfFn.Accum accumulator = countIfFn.createAccumulator(); - assertEquals(accumulator, countIfFn.addInput(accumulator, null)); + Combine.CombineFn countIfFn = CountIf.combineFn(); + long[] accumulator = (long[]) countIfFn.addInput(countIfFn.createAccumulator(), null); + + assertEquals(0L, accumulator[0]); } @Test public void testAddsInputToAccumulator() { - CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); - CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 1); + Combine.CombineFn countIfFn = CountIf.combineFn(); + long[] accumulator = (long[]) countIfFn.addInput(countIfFn.createAccumulator(), Boolean.TRUE); - assertEquals( - expectedAccumulator, countIfFn.addInput(countIfFn.createAccumulator(), Boolean.TRUE)); + assertEquals(1L, accumulator[0]); } @Test - public void testCreatesAccumulatorCoder() { + public void testCreatesAccumulatorCoder() throws CannotProvideCoderException { assertNotNull( CountIf.combineFn().getAccumulatorCoder(CoderRegistry.createDefault(), BooleanCoder.of())); } @Test public void testMergeAccumulators() { - CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); - List accums = - Arrays.asList(CountIf.CountIfFn.Accum.of(false, 2), CountIf.CountIfFn.Accum.of(false, 2)); - CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 4); + Combine.CombineFn countIfFn = CountIf.combineFn(); + List accums = Arrays.asList(new long[] {2}, new long[] {2}); + long[] accumulator = (long[]) countIfFn.mergeAccumulators(accums); - assertEquals(expectedAccumulator, countIfFn.mergeAccumulators(accums)); + assertEquals(4L, accumulator[0]); } @Test public void testExtractsOutput() { - CountIf.CountIfFn countIfFn = new CountIf.CountIfFn(); - CountIf.CountIfFn.Accum expectedAccumulator = countIfFn.createAccumulator(); - assertEquals(Long.valueOf(0), countIfFn.extractOutput(expectedAccumulator)); + Combine.CombineFn countIfFn = CountIf.combineFn(); + + assertEquals(0L, countIfFn.extractOutput(countIfFn.createAccumulator())); } }