Skip to content

Commit

Permalink
[BEAM-13202] Reuse Count transform code since CountIf is a specific case
Browse files Browse the repository at this point in the history
  • Loading branch information
iemejia committed Feb 15, 2022
1 parent 1105c34 commit cf75357
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.transform.agg;

import com.google.auto.value.AutoValue;
import java.io.Serializable;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;

/**
* Returns the count of TRUE values for expression. Returns 0 if there are zero input rows, or if
Expand All @@ -31,62 +30,41 @@
public class CountIf {
private CountIf() {}

public static CountIfFn combineFn() {
return new CountIf.CountIfFn();
public static Combine.CombineFn<Boolean, ?, Long> combineFn() {
return new CountIfFn();
}

public static class CountIfFn extends Combine.CombineFn<Boolean, CountIfFn.Accum, Long> {

@AutoValue
public abstract static class Accum implements Serializable {
abstract boolean isExpressionFalse();

abstract long countIfResult();

static Accum empty() {
return of(true, 0L);
}

static Accum of(boolean isExpressionFalse, long countIfResult) {
return new AutoValue_CountIf_CountIfFn_Accum(isExpressionFalse, countIfResult);
}
}
public static class CountIfFn extends Combine.CombineFn<Boolean, long[], Long> {
private final Combine.CombineFn<Boolean, long[], Long> countFn =
(Combine.CombineFn<Boolean, long[], Long>) Count.<Boolean>combineFn();

@Override
public Accum createAccumulator() {
return Accum.empty();
public long[] createAccumulator() {
return countFn.createAccumulator();
}

@Override
public Accum addInput(Accum accum, Boolean input) {
public long[] addInput(long[] accumulator, Boolean input) {
if (Boolean.TRUE.equals(input)) {
return Accum.of(false, accum.countIfResult() + 1);
countFn.addInput(accumulator, input);
}
return accum;
return accumulator;
}

@Override
public Accum mergeAccumulators(Iterable<Accum> accums) {
CountIfFn.Accum merged = createAccumulator();
for (CountIfFn.Accum accum : accums) {
if (!accum.isExpressionFalse()) {
merged = Accum.of(false, merged.countIfResult() + accum.countIfResult());
}
}
return merged;
public long[] mergeAccumulators(Iterable<long[]> accumulators) {
return countFn.mergeAccumulators(accumulators);
}

@Override
public Coder<Accum> getAccumulatorCoder(CoderRegistry registry, Coder<Boolean> inputCoder) {
return SerializableCoder.of(Accum.class);
public Long extractOutput(long[] accumulator) {
return countFn.extractOutput(accumulator);
}

@Override
public Long extractOutput(Accum accum) {
if (!accum.isExpressionFalse()) {
return accum.countIfResult();
}
return 0L;
public Coder<long[]> getAccumulatorCoder(CoderRegistry registry, Coder<Boolean> inputCoder)
throws CannotProvideCoderException {
return countFn.getAccumulatorCoder(registry, inputCoder);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,53 +23,56 @@
import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.coders.BooleanCoder;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.transforms.Combine;
import org.junit.Test;

/** Unit tests for {@link CountIf}. */
public class CountIfTest {

@Test
public void testCreatesEmptyAccumulator() {
assertEquals(CountIf.CountIfFn.Accum.empty(), CountIf.combineFn().createAccumulator());
long[] accumulator = (long[]) CountIf.combineFn().createAccumulator();

assertEquals(0, accumulator[0]);
}

@Test
public void testReturnsAccumulatorUnchangedForNullInput() {
CountIf.CountIfFn countIfFn = new CountIf.CountIfFn();
CountIf.CountIfFn.Accum accumulator = countIfFn.createAccumulator();
assertEquals(accumulator, countIfFn.addInput(accumulator, null));
Combine.CombineFn countIfFn = CountIf.combineFn();
long[] accumulator = (long[]) countIfFn.addInput(countIfFn.createAccumulator(), null);

assertEquals(0L, accumulator[0]);
}

@Test
public void testAddsInputToAccumulator() {
CountIf.CountIfFn countIfFn = new CountIf.CountIfFn();
CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 1);
Combine.CombineFn countIfFn = CountIf.combineFn();
long[] accumulator = (long[]) countIfFn.addInput(countIfFn.createAccumulator(), Boolean.TRUE);

assertEquals(
expectedAccumulator, countIfFn.addInput(countIfFn.createAccumulator(), Boolean.TRUE));
assertEquals(1L, accumulator[0]);
}

@Test
public void testCreatesAccumulatorCoder() {
public void testCreatesAccumulatorCoder() throws CannotProvideCoderException {
assertNotNull(
CountIf.combineFn().getAccumulatorCoder(CoderRegistry.createDefault(), BooleanCoder.of()));
}

@Test
public void testMergeAccumulators() {
CountIf.CountIfFn countIfFn = new CountIf.CountIfFn();
List<CountIf.CountIfFn.Accum> accums =
Arrays.asList(CountIf.CountIfFn.Accum.of(false, 2), CountIf.CountIfFn.Accum.of(false, 2));
CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 4);
Combine.CombineFn countIfFn = CountIf.combineFn();
List<long[]> accums = Arrays.asList(new long[] {2}, new long[] {2});
long[] accumulator = (long[]) countIfFn.mergeAccumulators(accums);

assertEquals(expectedAccumulator, countIfFn.mergeAccumulators(accums));
assertEquals(4L, accumulator[0]);
}

@Test
public void testExtractsOutput() {
CountIf.CountIfFn countIfFn = new CountIf.CountIfFn();
CountIf.CountIfFn.Accum expectedAccumulator = countIfFn.createAccumulator();
assertEquals(Long.valueOf(0), countIfFn.extractOutput(expectedAccumulator));
Combine.CombineFn countIfFn = CountIf.combineFn();

assertEquals(0L, countIfFn.extractOutput(countIfFn.createAccumulator()));
}
}

0 comments on commit cf75357

Please sign in to comment.