Skip to content

Commit

Permalink
Merge pull request #16856: [BEAM-13202] Add Coder to CountIfFn.Accum
Browse files Browse the repository at this point in the history
  • Loading branch information
iemejia authored Feb 16, 2022
2 parents b2f2128 + cf75357 commit 6e98dd4
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.transform.agg;

import java.io.Serializable;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;

/**
* Returns the count of TRUE values for expression. Returns 0 if there are zero input rows, or if
Expand All @@ -27,49 +30,41 @@
public class CountIf {
private CountIf() {}

public static CountIfFn combineFn() {
return new CountIf.CountIfFn();
public static Combine.CombineFn<Boolean, ?, Long> combineFn() {
return new CountIfFn();
}

public static class CountIfFn extends Combine.CombineFn<Boolean, CountIfFn.Accum, Long> {
public static class CountIfFn extends Combine.CombineFn<Boolean, long[], Long> {
private final Combine.CombineFn<Boolean, long[], Long> countFn =
(Combine.CombineFn<Boolean, long[], Long>) Count.<Boolean>combineFn();

public static class Accum implements Serializable {
boolean isExpressionFalse = true;
long countIfResult = 0L;
@Override
public long[] createAccumulator() {
return countFn.createAccumulator();
}

@Override
public Accum createAccumulator() {
return new Accum();
public long[] addInput(long[] accumulator, Boolean input) {
if (Boolean.TRUE.equals(input)) {
countFn.addInput(accumulator, input);
}
return accumulator;
}

@Override
public Accum addInput(Accum accum, Boolean input) {
if (input) {
accum.isExpressionFalse = false;
accum.countIfResult += 1;
}
return accum;
public long[] mergeAccumulators(Iterable<long[]> accumulators) {
return countFn.mergeAccumulators(accumulators);
}

@Override
public Accum mergeAccumulators(Iterable<Accum> accums) {
CountIfFn.Accum merged = createAccumulator();
for (CountIfFn.Accum accum : accums) {
if (!accum.isExpressionFalse) {
merged.isExpressionFalse = false;
merged.countIfResult += accum.countIfResult;
}
}
return merged;
public Long extractOutput(long[] accumulator) {
return countFn.extractOutput(accumulator);
}

@Override
public Long extractOutput(Accum accum) {
if (!accum.isExpressionFalse) {
return accum.countIfResult;
}
return 0L;
public Coder<long[]> getAccumulatorCoder(CoderRegistry registry, Coder<Boolean> inputCoder)
throws CannotProvideCoderException {
return countFn.getAccumulatorCoder(registry, inputCoder);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.transform.agg;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.coders.BooleanCoder;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.transforms.Combine;
import org.junit.Test;

/** Unit tests for {@link CountIf}. */
public class CountIfTest {

@Test
public void testCreatesEmptyAccumulator() {
long[] accumulator = (long[]) CountIf.combineFn().createAccumulator();

assertEquals(0, accumulator[0]);
}

@Test
public void testReturnsAccumulatorUnchangedForNullInput() {
Combine.CombineFn countIfFn = CountIf.combineFn();
long[] accumulator = (long[]) countIfFn.addInput(countIfFn.createAccumulator(), null);

assertEquals(0L, accumulator[0]);
}

@Test
public void testAddsInputToAccumulator() {
Combine.CombineFn countIfFn = CountIf.combineFn();
long[] accumulator = (long[]) countIfFn.addInput(countIfFn.createAccumulator(), Boolean.TRUE);

assertEquals(1L, accumulator[0]);
}

@Test
public void testCreatesAccumulatorCoder() throws CannotProvideCoderException {
assertNotNull(
CountIf.combineFn().getAccumulatorCoder(CoderRegistry.createDefault(), BooleanCoder.of()));
}

@Test
public void testMergeAccumulators() {
Combine.CombineFn countIfFn = CountIf.combineFn();
List<long[]> accums = Arrays.asList(new long[] {2}, new long[] {2});
long[] accumulator = (long[]) countIfFn.mergeAccumulators(accums);

assertEquals(4L, accumulator[0]);
}

@Test
public void testExtractsOutput() {
Combine.CombineFn countIfFn = CountIf.combineFn();

assertEquals(0L, countIfFn.extractOutput(countIfFn.createAccumulator()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ public void testAddsInputToAccumulator() {
}

@Test
public void testCeatesAccumulatorCoder() {
public void testCreatesAccumulatorCoder() {
assertNotNull(varianceFn.getAccumulatorCoder(CoderRegistry.createDefault(), VarIntCoder.of()));
}

@Test
public void testReturnsOutput() {
public void testExtractsOutput() {
assertEquals(expectedExtractedResult, varianceFn.extractOutput(testAccumulatorInput));
}
}

0 comments on commit 6e98dd4

Please sign in to comment.