Skip to content

Commit

Permalink
[BEAM-13202] Add Coder to CountIfFn.Accum
Browse files Browse the repository at this point in the history
  • Loading branch information
iemejia committed Feb 15, 2022
1 parent add7bbc commit 1105c34
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.transform.agg;

import com.google.auto.value.AutoValue;
import java.io.Serializable;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.transforms.Combine;

/**
Expand All @@ -33,21 +37,30 @@ public static CountIfFn combineFn() {

public static class CountIfFn extends Combine.CombineFn<Boolean, CountIfFn.Accum, Long> {

public static class Accum implements Serializable {
boolean isExpressionFalse = true;
long countIfResult = 0L;
@AutoValue
public abstract static class Accum implements Serializable {
abstract boolean isExpressionFalse();

abstract long countIfResult();

static Accum empty() {
return of(true, 0L);
}

static Accum of(boolean isExpressionFalse, long countIfResult) {
return new AutoValue_CountIf_CountIfFn_Accum(isExpressionFalse, countIfResult);
}
}

@Override
public Accum createAccumulator() {
return new Accum();
return Accum.empty();
}

@Override
public Accum addInput(Accum accum, Boolean input) {
if (input) {
accum.isExpressionFalse = false;
accum.countIfResult += 1;
if (Boolean.TRUE.equals(input)) {
return Accum.of(false, accum.countIfResult() + 1);
}
return accum;
}
Expand All @@ -56,18 +69,22 @@ public Accum addInput(Accum accum, Boolean input) {
public Accum mergeAccumulators(Iterable<Accum> accums) {
CountIfFn.Accum merged = createAccumulator();
for (CountIfFn.Accum accum : accums) {
if (!accum.isExpressionFalse) {
merged.isExpressionFalse = false;
merged.countIfResult += accum.countIfResult;
if (!accum.isExpressionFalse()) {
merged = Accum.of(false, merged.countIfResult() + accum.countIfResult());
}
}
return merged;
}

@Override
public Coder<Accum> getAccumulatorCoder(CoderRegistry registry, Coder<Boolean> inputCoder) {
return SerializableCoder.of(Accum.class);
}

@Override
public Long extractOutput(Accum accum) {
if (!accum.isExpressionFalse) {
return accum.countIfResult;
if (!accum.isExpressionFalse()) {
return accum.countIfResult();
}
return 0L;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.transform.agg;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.coders.BooleanCoder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.junit.Test;

/** Unit tests for {@link CountIf}. */
public class CountIfTest {

@Test
public void testCreatesEmptyAccumulator() {
assertEquals(CountIf.CountIfFn.Accum.empty(), CountIf.combineFn().createAccumulator());
}

@Test
public void testReturnsAccumulatorUnchangedForNullInput() {
CountIf.CountIfFn countIfFn = new CountIf.CountIfFn();
CountIf.CountIfFn.Accum accumulator = countIfFn.createAccumulator();
assertEquals(accumulator, countIfFn.addInput(accumulator, null));
}

@Test
public void testAddsInputToAccumulator() {
CountIf.CountIfFn countIfFn = new CountIf.CountIfFn();
CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 1);

assertEquals(
expectedAccumulator, countIfFn.addInput(countIfFn.createAccumulator(), Boolean.TRUE));
}

@Test
public void testCreatesAccumulatorCoder() {
assertNotNull(
CountIf.combineFn().getAccumulatorCoder(CoderRegistry.createDefault(), BooleanCoder.of()));
}

@Test
public void testMergeAccumulators() {
CountIf.CountIfFn countIfFn = new CountIf.CountIfFn();
List<CountIf.CountIfFn.Accum> accums =
Arrays.asList(CountIf.CountIfFn.Accum.of(false, 2), CountIf.CountIfFn.Accum.of(false, 2));
CountIf.CountIfFn.Accum expectedAccumulator = CountIf.CountIfFn.Accum.of(false, 4);

assertEquals(expectedAccumulator, countIfFn.mergeAccumulators(accums));
}

@Test
public void testExtractsOutput() {
CountIf.CountIfFn countIfFn = new CountIf.CountIfFn();
CountIf.CountIfFn.Accum expectedAccumulator = countIfFn.createAccumulator();
assertEquals(Long.valueOf(0), countIfFn.extractOutput(expectedAccumulator));
}
}

0 comments on commit 1105c34

Please sign in to comment.